diff --git a/3rdparty/cutlass b/3rdparty/cutlass index e51efbfe18..a49a78ffef 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit e51efbfe18fe4f4cbb66ab814c55bf4aa0185491 +Subproject commit a49a78ffefc86a87160dfe0ccc3a3a2d1622c918 diff --git a/benchmarks/README.md b/benchmarks/README.md index 10b8d3ee37..4a16f37599 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -152,6 +152,7 @@ The output CSV will contain detailed metrics including: | `--ep_size` | Expert-parallel world size | | `--ep_rank` | Expert-parallel rank | | `--gated_act` | Gated activation function: `swiglu` (default) or `geglu` | +| `--autotune` | Enable autotune for supported operation | ### MOE Routing Method Compatibility diff --git a/benchmarks/bench_append_paged_kv_cache.py b/benchmarks/bench_append_paged_kv_cache.py index a7ba8b6b6f..52d59486c5 100644 --- a/benchmarks/bench_append_paged_kv_cache.py +++ b/benchmarks/bench_append_paged_kv_cache.py @@ -1,9 +1,13 @@ +import sys + + import argparse import dataclasses from typing import Tuple import numpy as np -import torch +import paddle +from flashinfer.paddle_utils import * import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -17,34 +21,18 @@ class ModelConfig: def _make_70b(tp: int) -> ModelConfig: - return ModelConfig( - num_kv_heads=8 // tp, - num_layers=80, - head_dim=128, - ) + return ModelConfig(num_kv_heads=8 // tp, num_layers=80, head_dim=128) MODELS = { - "l1b": ModelConfig( - num_kv_heads=8, - num_layers=16, - head_dim=64, - ), - "l3b": ModelConfig( - num_kv_heads=8, - num_layers=28, - head_dim=128, - ), - "l8b": ModelConfig( - num_kv_heads=8, - num_layers=32, - head_dim=128, - ), + "l1b": ModelConfig(num_kv_heads=8, num_layers=16, head_dim=64), + "l3b": ModelConfig(num_kv_heads=8, num_layers=28, head_dim=128), + "l8b": ModelConfig(num_kv_heads=8, num_layers=32, head_dim=128), "l70b-tp8": _make_70b(8), } -@torch.inference_mode() +@paddle.no_grad() def main(): parser = argparse.ArgumentParser() parser.add_argument("--seqlen", type=int, default=5000) @@ -52,7 +40,6 @@ def main(): parser.add_argument("--page-len", type=int, default=16) parser.add_argument("--dtype", type=str, default="float16") args = parser.parse_args() - seqlens_ = [ [1] * args.batch_size, [args.seqlen - args.batch_size + 1] + [1] * (args.batch_size - 1), @@ -62,28 +49,22 @@ def main(): seqlen_strlen = max(len(str(seqlens)) for seqlens in seqlens_) page_len = int(args.page_len) dtype = getattr(torch, args.dtype) - assert isinstance(dtype, torch.dtype) - device = torch.device("cuda:0") + assert isinstance(dtype, paddle.dtype) + device = device2str("cuda:0") total_pages = int(256000 / page_len) - - torch.cuda.profiler.start() - +>>>>>> torch.cuda.profiler.start() for model_name, model in MODELS.items(): - page_shape = (2, page_len, model.num_kv_heads, model.head_dim) - layer_buf = torch.empty((total_pages,) + page_shape, dtype=dtype, device=device) + page_shape = 2, page_len, model.num_kv_heads, model.head_dim + layer_buf = paddle.empty(shape=(total_pages,) + page_shape, dtype=dtype) for seqlens in seqlens_: - k = torch.rand( - (sum(seqlens), model.num_kv_heads, model.head_dim), - dtype=dtype, - device=device, + k = paddle.rand( + shape=(sum(seqlens), model.num_kv_heads, model.head_dim), dtype=dtype ) - v = torch.rand( - (sum(seqlens), model.num_kv_heads, model.head_dim), - dtype=dtype, - device=device, + v = paddle.rand( + shape=(sum(seqlens), model.num_kv_heads, model.head_dim), dtype=dtype ) - x_indptr = torch.tensor([0] + seqlens, device=device, dtype=torch.int32) - x_indptr = torch.cumsum(x_indptr, 0, dtype=torch.int32) + x_indptr = paddle.to_tensor(data=[0] + seqlens, dtype="int32", place=device) + x_indptr = paddle.cumsum(x=x_indptr, axis=0, dtype="int32") kv_indices_host = [] kv_indptr_host = [0] next_page_id = 0 @@ -92,27 +73,31 @@ def main(): kv_indices_host.extend(range(next_page_id, next_page_id + npages)) next_page_id += npages kv_indptr_host.append(len(kv_indices_host)) - kv_indices = torch.tensor(kv_indices_host, device=device, dtype=torch.int32) - kv_indptr = torch.tensor(kv_indptr_host, device=device, dtype=torch.int32) - kv_last_page_len = torch.tensor( - [(seqlen - 1) % page_len + 1 for seqlen in seqlens], - device=device, - dtype=torch.int32, + kv_indices = paddle.to_tensor( + data=kv_indices_host, dtype="int32", place=device + ) + kv_indptr = paddle.to_tensor( + data=kv_indptr_host, dtype="int32", place=device + ) + kv_last_page_len = paddle.to_tensor( + data=[((seqlen - 1) % page_len + 1) for seqlen in seqlens], + dtype="int32", + place=device, ) - @torch.cuda.nvtx.range(f"convert model={model_name}, seqlens={seqlens}") - def fn_convert() -> Tuple[torch.Tensor, torch.Tensor]: +>>>>>> @torch.cuda.nvtx.range(f"convert model={model_name}, seqlens={seqlens}") + def fn_convert() -> Tuple[paddle.Tensor, paddle.Tensor]: return flashinfer.get_batch_indices_positions( x_indptr, flashinfer.get_seq_lens(kv_indptr, kv_last_page_len, page_len), - k.shape[0], + tuple(k.shape)[0], ) batch_indices, positions = fn_convert() convert_latencies = bench_gpu_time(fn_convert) convert_latency_ms = np.median(convert_latencies) - @torch.cuda.nvtx.range(f"append model={model_name}, seqlens={seqlens}") +>>>>>> @torch.cuda.nvtx.range(f"append model={model_name}, seqlens={seqlens}") def fn() -> None: flashinfer.append_paged_kv_cache( k, @@ -130,23 +115,22 @@ def fn() -> None: latency_ms = np.median(latencies) all_layers_latency_ms = convert_latency_ms + latency_ms * model.num_layers throughput = ( - k.numel() + k.size * k.element_size() * sum(1 for _ in ["k", "v"]) * sum(1 for _ in ["read", "write"]) - / (latency_ms * 1e-3) + / (latency_ms * 0.001) ) print( f"model: {model_name:8}", f"seqlens: {seqlens!r:{seqlen_strlen}}", - f"convert: {convert_latency_ms * 1e3:2.0f}us", - f"1layer: {latency_ms * 1e3:2.0f}us", - f"{model.num_layers}layers: {all_layers_latency_ms * 1e3:3.0f}us", - f"throughput: {throughput * 1e-9:8.3f}GB/s", + f"convert: {convert_latency_ms * 1000.0:2.0f}us", + f"1layer: {latency_ms * 1000.0:2.0f}us", + f"{model.num_layers}layers: {all_layers_latency_ms * 1000.0:3.0f}us", + f"throughput: {throughput * 1e-09:8.3f}GB/s", ) print("---") - - torch.cuda.profiler.stop() +>>>>>> torch.cuda.profiler.stop() if __name__ == "__main__": diff --git a/benchmarks/bench_append_paged_mla_kv_cache.py b/benchmarks/bench_append_paged_mla_kv_cache.py index f1355213d7..7c817a156c 100644 --- a/benchmarks/bench_append_paged_mla_kv_cache.py +++ b/benchmarks/bench_append_paged_mla_kv_cache.py @@ -1,9 +1,13 @@ +import sys + + import argparse import dataclasses from typing import Tuple import numpy as np -import torch +import paddle +from flashinfer.paddle_utils import * import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -22,7 +26,7 @@ class ModelConfig: } -@torch.inference_mode() +@paddle.no_grad() def main(): parser = argparse.ArgumentParser() parser.add_argument("--seqlen", type=int, default=5000) @@ -30,7 +34,6 @@ def main(): parser.add_argument("--page-len", type=int, default=16) parser.add_argument("--dtype", type=str, default="float16") args = parser.parse_args() - seqlens_ = [ [1] * args.batch_size, [args.seqlen - args.batch_size + 1] + [1] * (args.batch_size - 1), @@ -40,34 +43,20 @@ def main(): seqlen_strlen = max(len(str(seqlens)) for seqlens in seqlens_) page_len = int(args.page_len) dtype = getattr(torch, args.dtype) - assert isinstance(dtype, torch.dtype) - device = torch.device("cuda:0") + assert isinstance(dtype, paddle.dtype) + device = device2str("cuda:0") total_pages = int(25600 / page_len) - - torch.cuda.profiler.start() - +>>>>>> torch.cuda.profiler.start() for model_name, model in MODELS.items(): - ckv_page_shape = (page_len, model.ckv_dim) - kpe_page_shape = (page_len, model.kpe_dim) - ckv_layer_buf = torch.empty( - (total_pages,) + ckv_page_shape, dtype=dtype, device=device - ) - kpe_layer_buf = torch.empty( - (total_pages,) + kpe_page_shape, dtype=dtype, device=device - ) + ckv_page_shape = page_len, model.ckv_dim + kpe_page_shape = page_len, model.kpe_dim + ckv_layer_buf = paddle.empty(shape=(total_pages,) + ckv_page_shape, dtype=dtype) + kpe_layer_buf = paddle.empty(shape=(total_pages,) + kpe_page_shape, dtype=dtype) for seqlens in seqlens_: - ckv = torch.rand( - (sum(seqlens), model.ckv_dim), - dtype=dtype, - device=device, - ) - kpe = torch.rand( - (sum(seqlens), model.kpe_dim), - dtype=dtype, - device=device, - ) - x_indptr = torch.tensor([0] + seqlens, device=device, dtype=torch.int32) - x_indptr = torch.cumsum(x_indptr, 0, dtype=torch.int32) + ckv = paddle.rand(shape=(sum(seqlens), model.ckv_dim), dtype=dtype) + kpe = paddle.rand(shape=(sum(seqlens), model.kpe_dim), dtype=dtype) + x_indptr = paddle.to_tensor(data=[0] + seqlens, dtype="int32", place=device) + x_indptr = paddle.cumsum(x=x_indptr, axis=0, dtype="int32") kv_indices_host = [] kv_indptr_host = [0] next_page_id = 0 @@ -76,27 +65,31 @@ def main(): kv_indices_host.extend(range(next_page_id, next_page_id + npages)) next_page_id += npages kv_indptr_host.append(len(kv_indices_host)) - kv_indices = torch.tensor(kv_indices_host, device=device, dtype=torch.int32) - kv_indptr = torch.tensor(kv_indptr_host, device=device, dtype=torch.int32) - kv_last_page_len = torch.tensor( - [(seqlen - 1) % page_len + 1 for seqlen in seqlens], - device=device, - dtype=torch.int32, + kv_indices = paddle.to_tensor( + data=kv_indices_host, dtype="int32", place=device + ) + kv_indptr = paddle.to_tensor( + data=kv_indptr_host, dtype="int32", place=device + ) + kv_last_page_len = paddle.to_tensor( + data=[((seqlen - 1) % page_len + 1) for seqlen in seqlens], + dtype="int32", + place=device, ) - @torch.cuda.nvtx.range(f"convert model={model_name}, seqlens={seqlens}") - def fn_convert() -> Tuple[torch.Tensor, torch.Tensor]: +>>>>>> @torch.cuda.nvtx.range(f"convert model={model_name}, seqlens={seqlens}") + def fn_convert() -> Tuple[paddle.Tensor, paddle.Tensor]: return flashinfer.get_batch_indices_positions( x_indptr, flashinfer.get_seq_lens(kv_indptr, kv_last_page_len, page_len), - ckv.shape[0], + tuple(ckv.shape)[0], ) batch_indices, positions = fn_convert() convert_latencies = bench_gpu_time(fn_convert) convert_latency_ms = np.median(convert_latencies) - @torch.cuda.nvtx.range(f"append model={model_name}, seqlens={seqlens}") +>>>>>> @torch.cuda.nvtx.range(f"append model={model_name}, seqlens={seqlens}") def fn() -> None: flashinfer.append_paged_mla_kv_cache( ckv, @@ -114,22 +107,21 @@ def fn() -> None: latency_ms = np.median(latencies) all_layers_latency_ms = convert_latency_ms + latency_ms * model.num_layers throughput = ( - (ckv.numel() + kpe.numel()) + (ckv.size + kpe.size) * ckv.element_size() * sum(1 for _ in ["read", "write"]) - / (latency_ms * 1e-3) + / (latency_ms * 0.001) ) print( f"model: {model_name:8}", f"seqlens: {seqlens!r:{seqlen_strlen}}", - f"convert: {convert_latency_ms * 1e3:2.0f}us", - f"1layer: {latency_ms * 1e3:2.0f}us", - f"{model.num_layers}layers: {all_layers_latency_ms * 1e3:3.0f}us", - f"throughput: {throughput * 1e-9:8.3f}GB/s", + f"convert: {convert_latency_ms * 1000.0:2.0f}us", + f"1layer: {latency_ms * 1000.0:2.0f}us", + f"{model.num_layers}layers: {all_layers_latency_ms * 1000.0:3.0f}us", + f"throughput: {throughput * 1e-09:8.3f}GB/s", ) print("---") - - torch.cuda.profiler.stop() +>>>>>> torch.cuda.profiler.stop() if __name__ == "__main__": diff --git a/benchmarks/bench_attention_sink_triton_sgl_context.py b/benchmarks/bench_attention_sink_triton_sgl_context.py index e7d7457852..4f18188558 100644 --- a/benchmarks/bench_attention_sink_triton_sgl_context.py +++ b/benchmarks/bench_attention_sink_triton_sgl_context.py @@ -1,27 +1,23 @@ -# bench: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +import paddle """ Memory-efficient attention for prefill. It supports page size = 1 and prefill with KV cache (i.e. extend). """ - -import torch +import numpy as np import triton import triton.language as tl -import numpy as np from flashinfer.testing.utils import bench_gpu_time _is_cuda = True if _is_cuda: - CUDA_CAPABILITY = torch.cuda.get_device_capability() - + CUDA_CAPABILITY = paddle.device.cuda.get_device_capability() _is_hip = False @triton.jit def tanh(x): - # Tanh is just a scaled sigmoid return 2 * tl.sigmoid(2 * x) - 1 @@ -72,34 +68,26 @@ def _fwd_kernel( cur_head = tl.program_id(1) cur_block_m = tl.program_id(2) cur_kv_head = cur_head // kv_group_num - cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq) cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq) cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend - if USE_CUSTOM_MASK: cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq) - offs_d = tl.arange(0, BLOCK_DMODEL) offs_dv = tl.arange(0, BLOCK_DV) offs_m = tl.arange(0, BLOCK_M) - mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend - + mask_m = cur_block_m * BLOCK_M + offs_m < cur_seq_len_extend mask_d = offs_d < Lq mask_dv = offs_dv < Lv - offs_q = ( (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] ) - q = tl.load( - Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0 - ) - + q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None] & mask_d[None, :], other=0.0) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) offs_qpe = ( @@ -109,18 +97,13 @@ def _fwd_kernel( + offs_dpe[None, :] ) qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0) - - # stage 1: compute scores with prefix offs_n = tl.arange(0, BLOCK_N) - acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) deno = tl.zeros([BLOCK_M], dtype=tl.float32) e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - for start_n in range(0, cur_seq_len_prefix, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - mask_n = (start_n + offs_n) < cur_seq_len_prefix - + mask_n = start_n + offs_n < cur_seq_len_prefix final_mask = mask_m[:, None] & mask_n[None, :] if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK: custom_mask = tl.load( @@ -129,41 +112,33 @@ def _fwd_kernel( + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len + start_n + offs_n[None, :], - mask=(mask_m[:, None] & mask_n[None, :]), + mask=mask_m[:, None] & mask_n[None, :], other=0, ) final_mask &= custom_mask if SLIDING_WINDOW_SIZE > 0: - # Add mask where q_id <= kv_id + sliding_window_size - # q_id = prefix_len + cur_m, kv_id = cur_n window_mask = ( cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None] - ) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE) + <= start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE + ) final_mask &= window_mask - SKIP_TILE = False - if (USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK) or SLIDING_WINDOW_SIZE > 0: + if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0: SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0 - if not SKIP_TILE: offs_kv_loc = tl.load( kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0, ) - - # load k in transposed way offs_buf_k = ( offs_kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_d[:, None] ) k = tl.load( - K_Buffer + offs_buf_k, - mask=(mask_n[None, :]) & (mask_d[:, None]), - other=0.0, + K_Buffer + offs_buf_k, mask=mask_n[None, :] & mask_d[:, None], other=0.0 ) - qk = tl.dot(q.to(k.dtype), k) if BLOCK_DPE > 0: offs_kpe = ( @@ -171,27 +146,18 @@ def _fwd_kernel( + cur_kv_head * stride_buf_kh + offs_dpe[:, None] ) - kpe = tl.load( - K_Buffer + offs_kpe, - mask=mask_n[None, :], - other=0.0, - ) + kpe = tl.load(K_Buffer + offs_kpe, mask=mask_n[None, :], other=0.0) qk += tl.dot(qpe.to(kpe.dtype), kpe) qk *= sm_scale - if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) - qk = tl.where(final_mask, qk, float("-inf")) - row_max = tl.max(qk, 1) row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max) n_e_max = tl.maximum(row_max_fixed, e_max) - re_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max[:, None]) deno = deno * re_scale + tl.sum(p, 1) - offs_buf_v = ( offs_kv_loc[:, None] * stride_buf_vbs + cur_kv_head * stride_buf_vh @@ -204,11 +170,7 @@ def _fwd_kernel( ) p = p.to(v.dtype) acc = acc * re_scale[:, None] + tl.dot(p, v) - e_max = n_e_max - - # stage 2: compute the triangle part - cur_block_m_end = ( cur_seq_len_extend if not IS_CAUSAL @@ -216,8 +178,7 @@ def _fwd_kernel( ) for start_n in range(0, cur_block_m_end, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - mask_n = (start_n + offs_n) < cur_block_m_end - + mask_n = start_n + offs_n < cur_block_m_end final_mask = mask_m[:, None] & mask_n[None, :] if USE_CUSTOM_MASK: custom_mask = tl.load( @@ -227,43 +188,38 @@ def _fwd_kernel( + cur_seq_len_prefix + start_n + offs_n[None, :], - mask=(mask_m[:, None] & mask_n[None, :]), + mask=mask_m[:, None] & mask_n[None, :], other=0, ) custom_mask &= mask_m[:, None] & mask_n[None, :] final_mask &= custom_mask elif IS_CAUSAL: - mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( - start_n + offs_n[None, :] + mask_causual = ( + cur_block_m * BLOCK_M + offs_m[:, None] >= start_n + offs_n[None, :] ) mask_causual &= mask_m[:, None] & mask_n[None, :] final_mask &= mask_causual else: mask_non_causal = mask_m[:, None] & mask_n[None, :] final_mask &= mask_non_causal - if SLIDING_WINDOW_SIZE > 0: - # Add mask where q_id <= kv_id + sliding_window_size - window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= ( - start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE + window_mask = ( + cur_block_m * BLOCK_M + offs_m[:, None] + <= start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE ) final_mask &= window_mask - SKIP_TILE = False if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0: SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0 - if not SKIP_TILE: - # load k in transposed way offs_k = ( (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] ) k = tl.load( - K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + K_Extend + offs_k, mask=mask_n[None, :] & mask_d[:, None], other=0.0 ) - qk = tl.dot(q, k, out_dtype=tl.float32) if BLOCK_DPE > 0: offs_kpe = ( @@ -271,28 +227,18 @@ def _fwd_kernel( + cur_kv_head * stride_kh + offs_dpe[:, None] ) - kpe = tl.load( - K_Extend + offs_kpe, - mask=mask_n[None, :], - other=0.0, - ) + kpe = tl.load(K_Extend + offs_kpe, mask=mask_n[None, :], other=0.0) qk += tl.dot(qpe, kpe) - qk *= sm_scale - if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) - qk = tl.where(final_mask, qk, float("-inf")) - row_max = tl.max(qk, 1) row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max) n_e_max = tl.maximum(row_max_fixed, e_max) - re_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max[:, None]) deno = deno * re_scale + tl.sum(p, 1) - offs_v = ( (cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs + cur_kv_head * stride_vh @@ -303,13 +249,10 @@ def _fwd_kernel( ) p = p.to(v.dtype) acc = acc * re_scale[:, None] + tl.dot(p, v) - e_max = n_e_max - if HAS_SINK: cur_sink = tl.load(sink_ptr + cur_head) deno += tl.exp(cur_sink - e_max) - offs_o = ( (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs @@ -356,11 +299,10 @@ def extend_attention_fwd( k_buffer, v_buffer: (prefix + extend) tensors in mem_manager """ Lq, Lk, Lv = ( - q_extend.shape[-1], - k_extend.shape[-1], - v_extend.shape[-1], + tuple(q_extend.shape)[-1], + tuple(k_extend.shape)[-1], + tuple(v_extend.shape)[-1], ) - if Lq == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 @@ -374,55 +316,43 @@ def extend_attention_fwd( BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) - if _is_hip: - BLOCK_M, BLOCK_N = (64, 64) + BLOCK_M, BLOCK_N = 64, 64 num_warps = 4 - else: if _is_cuda and CUDA_CAPABILITY[0] >= 9: if Lq <= 256: - BLOCK_M, BLOCK_N = (128, 64) + BLOCK_M, BLOCK_N = 128, 64 else: - BLOCK_M, BLOCK_N = (32, 64) + BLOCK_M, BLOCK_N = 32, 64 elif _is_cuda and CUDA_CAPABILITY[0] >= 8: - # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K) if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6: if Lq <= 128: - BLOCK_M, BLOCK_N = (64, 128) + BLOCK_M, BLOCK_N = 64, 128 elif Lq <= 256: - BLOCK_M, BLOCK_N = (64, 64) + BLOCK_M, BLOCK_N = 64, 64 else: - BLOCK_M, BLOCK_N = (32, 32) + BLOCK_M, BLOCK_N = 32, 32 + elif Lq <= 128: + BLOCK_M, BLOCK_N = 128, 128 + elif Lq <= 256: + BLOCK_M, BLOCK_N = 64, 64 else: - if Lq <= 128: - BLOCK_M, BLOCK_N = (128, 128) - elif Lq <= 256: - BLOCK_M, BLOCK_N = (64, 64) - else: - BLOCK_M, BLOCK_N = (32, 64) + BLOCK_M, BLOCK_N = 32, 64 else: BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) - num_warps = 4 if Lk <= 64 else 8 - - sm_scale = sm_scale or 1.0 / (Lq**0.5) - batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1] - kv_group_num = q_extend.shape[1] // k_extend.shape[1] - + sm_scale = sm_scale or 1.0 / Lq**0.5 + batch_size, head_num = tuple(qo_indptr.shape)[0] - 1, tuple(q_extend.shape)[1] + kv_group_num = tuple(q_extend.shape)[1] // tuple(k_extend.shape)[1] USE_CUSTOM_MASK = custom_mask is not None - # Skip custom mask for prefix part SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask - HAS_SINK = sinks is not None - - grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) + grid = batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M) num_stages = 1 - extra_kargs = {} if _is_hip: extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} - _fwd_kernel[grid]( q_extend, k_extend, @@ -438,18 +368,18 @@ def extend_attention_fwd( sinks, sm_scale, kv_group_num, - q_extend.stride(0), - q_extend.stride(1), - k_extend.stride(0), - k_extend.stride(1), - v_extend.stride(0), - v_extend.stride(1), - o_extend.stride(0), - o_extend.stride(1), - k_buffer.stride(0), - k_buffer.stride(1), - v_buffer.stride(0), - v_buffer.stride(1), + q_extend.get_strides()[0], + q_extend.get_strides()[1], + k_extend.get_strides()[0], + k_extend.get_strides()[1], + v_extend.get_strides()[0], + v_extend.get_strides()[1], + o_extend.get_strides()[0], + o_extend.get_strides()[1], + k_buffer.get_strides()[0], + k_buffer.get_strides()[1], + v_buffer.get_strides()[0], + v_buffer.get_strides()[1], SLIDING_WINDOW_SIZE=sliding_window_size, logit_cap=logit_cap, BLOCK_DMODEL=BLOCK_DMODEL, @@ -473,58 +403,39 @@ def extend_attention_fwd( def bench_extend_attention_sink_triton_sgl( batch_size, seq_len, head_qo_num, head_kv_num, head_dim, bench_with_sink ): - torch.manual_seed(42) - dtype = torch.bfloat16 + paddle.seed(seed=42) + dtype = "bfloat16" device = "cuda:0" - - # Split S into prefix and extend lengths - prefill_len = seq_len // 2 # Similar to test's N_CTX // 2 - extend_len = seq_len // 4 # Make extend length smaller than prefix - - # Calculate total tokens and extend tokens + prefill_len = seq_len // 2 + extend_len = seq_len // 4 total_extend_tokens = batch_size * extend_len total_prefix_tokens = batch_size * prefill_len - - # Create query, key, value tensors for extension - q_extend = torch.randn( - total_extend_tokens, head_qo_num, head_dim, dtype=dtype, device=device + q_extend = paddle.randn( + shape=[total_extend_tokens, head_qo_num, head_dim], dtype=dtype ) - k_extend = torch.randn( - total_extend_tokens, head_kv_num, head_dim, dtype=dtype, device=device + k_extend = paddle.randn( + shape=[total_extend_tokens, head_kv_num, head_dim], dtype=dtype ) - v_extend = torch.randn( - total_extend_tokens, head_kv_num, head_dim, dtype=dtype, device=device + v_extend = paddle.randn( + shape=[total_extend_tokens, head_kv_num, head_dim], dtype=dtype ) - o_extend = torch.empty_like(q_extend) - - # Create key-value buffers for prefix - k_buffer = torch.randn( - total_prefix_tokens, head_kv_num, head_dim, dtype=dtype, device=device + o_extend = paddle.empty_like(x=q_extend) + k_buffer = paddle.randn( + shape=[total_prefix_tokens, head_kv_num, head_dim], dtype=dtype ) - v_buffer = torch.randn( - total_prefix_tokens, head_kv_num, head_dim, dtype=dtype, device=device + v_buffer = paddle.randn( + shape=[total_prefix_tokens, head_kv_num, head_dim], dtype=dtype ) - - # Create index pointers - qo_indptr = torch.arange( - 0, (batch_size + 1) * extend_len, extend_len, device=device - ).to(torch.int32) - kv_indptr = torch.arange( - 0, (batch_size + 1) * prefill_len, prefill_len, device=device - ).to(torch.int32) - kv_indices = torch.arange(0, total_prefix_tokens, device=device).to(torch.int32) - - sm_scale = 1.0 / (head_dim**0.5) - # sliding_window = 128 # From GPT-OSS config, skip for now + qo_indptr = paddle.arange( + start=0, end=(batch_size + 1) * extend_len, step=extend_len + ).to("int32") + kv_indptr = paddle.arange( + start=0, end=(batch_size + 1) * prefill_len, step=prefill_len + ).to("int32") + kv_indices = paddle.arange(start=0, end=total_prefix_tokens).to("int32") + sm_scale = 1.0 / head_dim**0.5 sliding_window = -1 - - sink = ( - torch.randn(head_qo_num, device=device, dtype=torch.float32) - if bench_with_sink - else None - ) - - # warmup + sink = paddle.randn(shape=head_qo_num, dtype="float32") if bench_with_sink else None for _ in range(5): extend_attention_fwd( q_extend, @@ -544,9 +455,7 @@ def bench_extend_attention_sink_triton_sgl( sliding_window_size=sliding_window, sinks=sink, ) - - # benchmark - torch.cuda.synchronize() + paddle.device.synchronize() measurements = bench_gpu_time( lambda: extend_attention_fwd( q_extend, @@ -570,9 +479,9 @@ def bench_extend_attention_sink_triton_sgl( repeat_time_ms=1000, ) ms = np.median(measurements) - kv_cache_numel = k_buffer.numel() + v_buffer.numel() + kv_cache_numel = k_buffer.size + v_buffer.size io = ( - q_extend.numel() * q_extend.element_size() + q_extend.size * q_extend.element_size() + kv_cache_numel * k_buffer.element_size() ) print( @@ -582,10 +491,6 @@ def bench_extend_attention_sink_triton_sgl( print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s") -# gpt oss -# head_num = 64 -# head_dim = 64 -# head_kv_num = 8 if __name__ == "__main__": import argparse @@ -599,10 +504,7 @@ def bench_extend_attention_sink_triton_sgl( "--head_kv_num", type=int, default=8, help="Number of key/value heads" ) parser.add_argument( - "--head_qo_num", - type=int, - default=64, - help="Number of query heads", + "--head_qo_num", type=int, default=64, help="Number of query heads" ) parser.add_argument("--sink", action="store_true", help="Whether to test with sink") parser.add_argument( @@ -619,9 +521,7 @@ def bench_extend_attention_sink_triton_sgl( default=[1024, 4096, 8192], help="List of sequence lengths to test", ) - args = parser.parse_args() - for batch_size in args.batch_sizes: for seq_len in args.seq_lens: bench_extend_attention_sink_triton_sgl( diff --git a/benchmarks/bench_attention_sink_triton_sgl_decode.py b/benchmarks/bench_attention_sink_triton_sgl_decode.py index 08ada939b2..bbbef1914a 100644 --- a/benchmarks/bench_attention_sink_triton_sgl_decode.py +++ b/benchmarks/bench_attention_sink_triton_sgl_decode.py @@ -1,26 +1,25 @@ -# bench: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/triton_ops/decode_attention.py -# mypy: disable-error-code="no-redef" +import sys + + +import paddle +from flashinfer.paddle_utils import * """ Memory-efficient attention for decoding. It supports page size = 1. """ - -import torch +import numpy as np import triton import triton.language as tl -import numpy as np from flashinfer.testing.utils import bench_gpu_time _is_hip = False - _MIN_BLOCK_KV = 32 @triton.jit def tanh(x): - # Tanh is just a scaled sigmoid return 2 * tl.sigmoid(2 * x) - 1 @@ -56,30 +55,23 @@ def _fwd_kernel_stage1( cur_batch = tl.program_id(0) cur_head = tl.program_id(1) split_kv_id = tl.program_id(2) - cur_kv_head = cur_head // kv_group_num - offs_d = tl.arange(0, BLOCK_DMODEL) offs_dv = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lk mask_dv = offs_dv < Lv - cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx kv_splits = tl.load(num_kv_splits + cur_batch) - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - kv_len_per_split = ( tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV ) split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) - e_max = -float("inf") e_sum = 0.0 acc = tl.zeros([BLOCK_DV], dtype=tl.float32) - if split_kv_end > split_kv_start: q = tl.load(Q + off_q, mask=mask_d, other=0.0) for start_n in range(split_kv_start, split_kv_end, BLOCK_N): @@ -96,17 +88,14 @@ def _fwd_kernel_stage1( ) k = tl.load( K_Buffer + offs_buf_k, - mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + mask=(offs_n[:, None] < split_kv_end) & mask_d[None, :], other=0.0, ) qk = tl.sum(q[None, :] * k, 1) qk *= sm_scale - if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) - qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) - offs_buf_v = ( kv_loc[:, None] * stride_buf_vbs + cur_kv_head * stride_buf_vh @@ -114,42 +103,29 @@ def _fwd_kernel_stage1( ) v = tl.load( V_Buffer + offs_buf_v, - mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + mask=(offs_n[:, None] < split_kv_end) & mask_dv[None, :], other=0.0, ) - n_e_max = tl.maximum(tl.max(qk, 0), e_max) re_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max) acc *= re_scale acc += tl.sum(p[:, None] * v, 0) - e_sum = e_sum * re_scale + tl.sum(p, 0) e_max = n_e_max - offs_mid_o = ( cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + offs_dv ) - - tl.store( - Att_Out + offs_mid_o, - acc / e_sum, - mask=(mask_dv), - ) - + tl.store(Att_Out + offs_mid_o, acc / e_sum, mask=mask_dv) offs_mid_o_1 = ( cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os ) // Lv - - tl.store( - Att_Lse + offs_mid_o_1, - e_max + tl.log(e_sum), - ) + tl.store(Att_Lse + offs_mid_o_1, e_max + tl.log(e_sum)) def _decode_att_m_fwd( @@ -166,28 +142,22 @@ def _decode_att_m_fwd( logit_cap, ): BLOCK = 64 - # [TODO] work around SGPR limit on MI3xx if _is_hip: BLOCK = 8 MAX_KV_SPLITS = max_kv_splits - Lk = k_buffer.shape[-1] - Lv = v_buffer.shape[-1] - - batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] - - grid = (batch, head_num, MAX_KV_SPLITS) - kv_group_num = q.shape[1] // k_buffer.shape[1] - + Lk = tuple(k_buffer.shape)[-1] + Lv = tuple(v_buffer.shape)[-1] + batch, head_num = tuple(kv_indptr.shape)[0] - 1, tuple(q.shape)[1] + grid = batch, head_num, MAX_KV_SPLITS + kv_group_num = tuple(q.shape)[1] // tuple(k_buffer.shape)[1] if kv_group_num == 1: num_warps = 4 else: num_warps = 2 if _is_hip: num_warps = 1 - BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DV = triton.next_power_of_2(Lv) - _fwd_kernel_stage1[grid]( q, k_buffer, @@ -198,15 +168,15 @@ def _decode_att_m_fwd( att_out, att_lse, num_kv_splits, - q.stride(0), - q.stride(1), - k_buffer.stride(0), - k_buffer.stride(1), - v_buffer.stride(0), - v_buffer.stride(1), - att_out.stride(0), - att_out.stride(1), - att_out.stride(2), + q.get_strides()[0], + q.get_strides()[1], + k_buffer.get_strides()[0], + k_buffer.get_strides()[1], + v_buffer.get_strides()[0], + v_buffer.get_strides()[1], + att_out.get_strides()[0], + att_out.get_strides()[1], + att_out.get_strides()[2], kv_group_num=kv_group_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DV=BLOCK_DV, @@ -256,8 +226,6 @@ def _fwd_grouped_kernel_stage1( cur_head_id = tl.program_id(1) cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) split_kv_id = tl.program_id(2) - - # ruff: noqa: SIM300 if BLOCK_H < kv_group_num: VALID_BLOCK_H: tl.constexpr = BLOCK_H else: @@ -265,40 +233,33 @@ def _fwd_grouped_kernel_stage1( cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H mask_h = mask_h & (cur_head < q_head_num) - offs_d = tl.arange(0, BLOCK_DMODEL) offs_dv = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lk mask_dv = offs_dv < Lv - cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx kv_splits = tl.load(num_kv_splits + cur_batch) - offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] - if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) mask_dpe = offs_dpe < Lk off_qpe = ( cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] ) - kv_len_per_split = ( tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV ) split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) - e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) - if split_kv_end > split_kv_start: - q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) + q = tl.load(Q + offs_q, mask=mask_h[:, None] & mask_d[None, :], other=0.0) if BLOCK_DPE > 0: qpe = tl.load( - Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 + Q + off_qpe, mask=mask_h[:, None] & mask_dpe[None, :], other=0.0 ) for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) @@ -314,7 +275,7 @@ def _fwd_grouped_kernel_stage1( ) k = tl.load( K_Buffer + offs_buf_k, - mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), + mask=(offs_n[None, :] < split_kv_end) & mask_d[:, None], other=0.0, ) qk = tl.dot(q, k.to(q.dtype)) @@ -326,19 +287,16 @@ def _fwd_grouped_kernel_stage1( ) kpe = tl.load( K_Buffer + offs_buf_kpe, - mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), + mask=(offs_n[None, :] < split_kv_end) & mask_dpe[:, None], other=0.0, ) qk += tl.dot(qpe, kpe.to(qpe.dtype)) qk *= sm_scale - if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) - qk = tl.where( mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") ) - offs_buf_v = ( kv_loc[:, None] * stride_buf_vbs + cur_kv_head * stride_buf_vh @@ -346,43 +304,33 @@ def _fwd_grouped_kernel_stage1( ) v = tl.load( V_Buffer + offs_buf_v, - mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + mask=(offs_n[:, None] < split_kv_end) & mask_dv[None, :], other=0.0, ) - n_e_max = tl.maximum(tl.max(qk, 1), e_max) re_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max[:, None]) acc *= re_scale[:, None] acc += tl.dot(p.to(v.dtype), v) - e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_mid_o = ( cur_batch * stride_mid_ob + cur_head[:, None] * stride_mid_oh + split_kv_id * stride_mid_os + offs_dv[None, :] ) - tl.store( Att_Out + offs_mid_o, acc / e_sum[:, None], - mask=(mask_h[:, None]) & (mask_dv[None, :]), + mask=mask_h[:, None] & mask_dv[None, :], ) - offs_mid_o_1 = ( cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os ) // Lv - - tl.store( - Att_Lse + offs_mid_o_1, - e_max + tl.log(e_sum), - mask=mask_h, - ) + tl.store(Att_Lse + offs_mid_o_1, e_max + tl.log(e_sum), mask=mask_h) def _decode_grouped_att_m_fwd( @@ -399,13 +347,10 @@ def _decode_grouped_att_m_fwd( logit_cap, ): BLOCK = 32 - Lk = k_buffer.shape[-1] - Lv = v_buffer.shape[-1] - - # [TODO] work around shmem limit on MI3xx + Lk = tuple(k_buffer.shape)[-1] + Lv = tuple(v_buffer.shape)[-1] if _is_hip and Lk >= 576: BLOCK = 16 - if Lk == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 @@ -416,26 +361,16 @@ def _decode_grouped_att_m_fwd( BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) - - batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] - kv_group_num = q.shape[1] // k_buffer.shape[1] - + batch, head_num = tuple(kv_indptr.shape)[0] - 1, tuple(q.shape)[1] + kv_group_num = tuple(q.shape)[1] // tuple(k_buffer.shape)[1] BLOCK_H = 16 MAX_KV_SPLITS = max_kv_splits - grid = ( - batch, - triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), - MAX_KV_SPLITS, - ) - + grid = batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), MAX_KV_SPLITS extra_kargs = {} num_stages = 2 if _is_hip: - # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html - # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} num_stages = 1 - _fwd_grouped_kernel_stage1[grid]( q, k_buffer, @@ -446,15 +381,15 @@ def _decode_grouped_att_m_fwd( att_out, att_lse, num_kv_splits, - q.stride(0), - q.stride(1), - k_buffer.stride(0), - k_buffer.stride(1), - v_buffer.stride(0), - v_buffer.stride(1), - att_out.stride(0), - att_out.stride(1), - att_out.stride(2), + q.get_strides()[0], + q.get_strides()[1], + k_buffer.get_strides()[0], + k_buffer.get_strides()[1], + v_buffer.get_strides()[0], + v_buffer.get_strides()[1], + att_out.get_strides()[0], + att_out.get_strides()[1], + att_out.get_strides()[2], kv_group_num=kv_group_num, q_head_num=head_num, BLOCK_DMODEL=BLOCK_DMODEL, @@ -493,48 +428,38 @@ def _fwd_kernel_stage2( ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) - cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load( kv_indptr + cur_batch ) kv_splits = tl.load(num_kv_splits + cur_batch) - offs_d = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lv - e_sum = 0.0 e_max = -float("inf") acc = tl.zeros([BLOCK_DV], dtype=tl.float32) - offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv kv_len_per_split = ( tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV ) - for split_kv_id in range(0, MAX_KV_SPLITS): split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) - if split_kv_end > split_kv_start: tv = tl.load( Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 ) tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) n_e_max = tl.maximum(tlogic, e_max) - old_scale = tl.exp(e_max - n_e_max) acc *= old_scale exp_logic = tl.exp(tlogic - n_e_max) acc += exp_logic * tv - e_sum = e_sum * old_scale + exp_logic e_max = n_e_max - if HAS_SINK: cur_sink = tl.load(sink_ptr + cur_head) e_sum += tl.exp(cur_sink - e_max) - tl.store( O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / e_sum, @@ -543,30 +468,17 @@ def _fwd_kernel_stage2( def _decode_softmax_reducev_fwd( - logits, - lse, - q, - o, - v_buffer, - kv_indptr, - num_kv_splits, - max_kv_splits, - sinks=None, + logits, lse, q, o, v_buffer, kv_indptr, num_kv_splits, max_kv_splits, sinks=None ): - batch, head_num = q.shape[0], q.shape[1] - Lv = v_buffer.shape[-1] + batch, head_num = tuple(q.shape)[0], tuple(q.shape)[1] + Lv = tuple(v_buffer.shape)[-1] BLOCK_DV = triton.next_power_of_2(Lv) - MAX_KV_SPLITS = max_kv_splits HAS_SINK = sinks is not None - extra_kargs = {} if _is_hip: - # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html - # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} - - grid = (batch, head_num) + grid = batch, head_num _fwd_kernel_stage2[grid]( logits, lse, @@ -574,11 +486,11 @@ def _decode_softmax_reducev_fwd( kv_indptr, num_kv_splits, sinks, - logits.stride(0), - logits.stride(1), - logits.stride(2), - o.stride(0), - o.stride(1), + logits.get_strides()[0], + logits.get_strides()[1], + logits.get_strides()[2], + o.get_strides()[0], + o.get_strides()[1], MAX_KV_SPLITS=MAX_KV_SPLITS, MIN_BLOCK_KV=_MIN_BLOCK_KV, BLOCK_DV=BLOCK_DV, @@ -687,14 +599,11 @@ def decode_attention_fwd( logit_cap=0.0, sinks=None, ): - assert max_kv_splits == attn_logits.shape[2] - assert q.shape[0] <= kv_indptr.shape[0] - 1 - assert q.shape[0] <= attn_logits.shape[0] - - kv_group_num = q.shape[1] // v_buffer.shape[1] - + assert max_kv_splits == tuple(attn_logits.shape)[2] + assert tuple(q.shape)[0] <= tuple(kv_indptr.shape)[0] - 1 + assert tuple(q.shape)[0] <= tuple(attn_logits.shape)[0] + kv_group_num = tuple(q.shape)[1] // tuple(v_buffer.shape)[1] if kv_group_num == 1: - # MHA decode_attention_fwd_normal( q, k_buffer, @@ -711,7 +620,6 @@ def decode_attention_fwd( sinks=sinks, ) else: - # GQA/MQA/MLA decode_attention_fwd_grouped( q, k_buffer, @@ -732,52 +640,29 @@ def decode_attention_fwd( def bench_decode_attention_sink_triton_sgl( batch_size, seq_len, head_qo_num, head_kv_num, head_dim, bench_with_sink ): - torch.manual_seed(42) + paddle.seed(seed=42) device = "cuda:0" - - dtype = torch.bfloat16 + dtype = "bfloat16" total_tokens = batch_size * seq_len - device = torch.device("cuda") - sm_scale = 1.0 / (head_dim**0.5) + device = device2str("cuda") + sm_scale = 1.0 / head_dim**0.5 max_kv_splits = 8 - num_kv_splits = torch.full((batch_size,), 4, dtype=torch.int32, device="cuda") - - # q represents the new token being generated, one per batch - q = torch.randn(batch_size, head_qo_num, head_dim, dtype=dtype, device="cuda") - - # k_buffer and v_buffer represent all previous tokens - k_buffer = torch.randn( - total_tokens, head_kv_num, head_dim, dtype=dtype, device="cuda" + num_kv_splits = paddle.full(shape=(batch_size,), fill_value=4, dtype="int32") + q = paddle.randn(shape=[batch_size, head_qo_num, head_dim], dtype=dtype) + k_buffer = paddle.randn(shape=[total_tokens, head_kv_num, head_dim], dtype=dtype) + v_buffer = paddle.randn(shape=[total_tokens, head_kv_num, head_dim], dtype=dtype) + o = paddle.zeros(shape=[batch_size, head_qo_num, head_dim], dtype=dtype) + b_seq_len = paddle.full(shape=(batch_size,), fill_value=seq_len) + kv_indptr = paddle.zeros(shape=(batch_size + 1,), dtype="int32") + kv_indptr[1 : batch_size + 1] = paddle.cumsum(x=b_seq_len, axis=0) + kv_indices = paddle.arange(end=total_tokens) + attn_logits1 = paddle.empty( + shape=(batch_size, head_qo_num, max_kv_splits, head_dim), dtype="float32" ) - v_buffer = torch.randn( - total_tokens, head_kv_num, head_dim, dtype=dtype, device="cuda" + attn_lse1 = paddle.empty( + shape=(batch_size, head_qo_num, max_kv_splits, head_dim), dtype="float32" ) - - o = torch.zeros(batch_size, head_qo_num, head_dim, dtype=dtype, device="cuda") - - b_seq_len = torch.full((batch_size,), seq_len, device="cuda") - - kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") - kv_indptr[1 : batch_size + 1] = torch.cumsum(b_seq_len, dim=0) - kv_indices = torch.arange(total_tokens, device="cuda") - - attn_logits1 = torch.empty( - (batch_size, head_qo_num, max_kv_splits, head_dim), - dtype=torch.float32, - device="cuda", - ) - attn_lse1 = torch.empty( - (batch_size, head_qo_num, max_kv_splits, head_dim), - dtype=torch.float32, - device="cuda", - ) - sink = ( - torch.randn(head_qo_num, device=device, dtype=torch.float32) - if bench_with_sink - else None - ) - - # warmup + sink = paddle.randn(shape=head_qo_num, dtype="float32") if bench_with_sink else None for _ in range(5): decode_attention_fwd_grouped( q, @@ -794,8 +679,6 @@ def bench_decode_attention_sink_triton_sgl( logit_cap=0.0, sinks=sink, ) - - # benchmark measurements = bench_gpu_time( lambda: decode_attention_fwd_grouped( q, @@ -816,8 +699,8 @@ def bench_decode_attention_sink_triton_sgl( repeat_time_ms=1000, ) ms = np.median(measurements) - kv_cache_numel = k_buffer.numel() + v_buffer.numel() - io = q.numel() * q.element_size() + kv_cache_numel * k_buffer.element_size() + kv_cache_numel = k_buffer.size + v_buffer.size + io = q.size * q.element_size() + kv_cache_numel * k_buffer.element_size() print( f"batch_size={batch_size}, seq_len={seq_len}, num_qo_heads={head_qo_num}, num_kv_heads={head_kv_num}, head_dim={head_dim}" ) @@ -825,10 +708,6 @@ def bench_decode_attention_sink_triton_sgl( print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s") -# gpt oss -# head_num = 64 -# head_dim = 64 -# head_kv_num = 8 if __name__ == "__main__": import argparse @@ -842,10 +721,7 @@ def bench_decode_attention_sink_triton_sgl( "--head_kv_num", type=int, default=8, help="Number of key/value heads" ) parser.add_argument( - "--head_qo_num", - type=int, - default=64, - help="Number of query heads", + "--head_qo_num", type=int, default=64, help="Number of query heads" ) parser.add_argument("--sink", action="store_true", help="Whether to test with sink") parser.add_argument( @@ -862,9 +738,7 @@ def bench_decode_attention_sink_triton_sgl( default=[1024, 4096, 8192, 16384], help="List of sequence lengths to test", ) - args = parser.parse_args() - for batch_size in args.batch_sizes: for seq_len in args.seq_lens: bench_decode_attention_sink_triton_sgl( diff --git a/benchmarks/bench_batch_attention.py b/benchmarks/bench_batch_attention.py index e9ef9f79ae..72bb8baede 100644 --- a/benchmarks/bench_batch_attention.py +++ b/benchmarks/bench_batch_attention.py @@ -4,8 +4,8 @@ from typing import List, Sequence, Tuple import numpy as np +import paddle import pandas as pd -import torch import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -20,34 +20,26 @@ def run_bench( num_qo_heads: int, head_dim: int, device: int = 0, - causal: bool = True, + causal: bool = True ) -> Tuple[float, float, float, float, float]: - seq_lens = torch.tensor(kv_lens, dtype=torch.int32) - q_lens = torch.tensor(qo_lens, dtype=torch.int32) - seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int() - - q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int() - kv_indptr = torch.cat( - [torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0 - ).int() + seq_lens = paddle.to_tensor(data=kv_lens, dtype="int32") + q_lens = paddle.to_tensor(data=qo_lens, dtype="int32") + seq_lens_blocks = paddle.ceil(x=seq_lens / page_block_size).astype(dtype="int32") + q_indptr = paddle.concat( + x=[paddle.to_tensor(data=[0]), paddle.cumsum(x=q_lens, axis=0)], axis=0 + ).astype(dtype="int32") + kv_indptr = paddle.concat( + x=[paddle.to_tensor(data=[0]), paddle.cumsum(x=seq_lens_blocks, axis=0)], axis=0 + ).astype(dtype="int32") num_blocks = kv_indptr[-1].item() - - q = torch.rand( - q_indptr[-1].item(), num_qo_heads, head_dim, dtype=torch.bfloat16, device=device + q = paddle.rand( + shape=[q_indptr[-1].item(), num_qo_heads, head_dim], dtype="bfloat16" ) - kv_data = torch.randn( - num_blocks, - 2, - page_block_size, - num_kv_heads, - head_dim, - dtype=torch.bfloat16, - device=device, + kv_data = paddle.randn( + shape=[num_blocks, 2, page_block_size, num_kv_heads, head_dim], dtype="bfloat16" ) - - # old wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device), + paddle.empty(shape=128 * 1024 * 1024, dtype="uint8"), kv_layout="NHD", backend="fa2", ) @@ -55,25 +47,23 @@ def run_bench( wrapper_old.plan( q_indptr.to(device), kv_indptr.to(device), - torch.arange(num_blocks, dtype=torch.int32, device=device), + paddle.arange(dtype="int32", end=num_blocks), last_page_len, num_qo_heads, num_kv_heads, head_dim, page_block_size, causal=causal, - q_data_type=torch.bfloat16, - kv_data_type=torch.bfloat16, + q_data_type="bfloat16", + kv_data_type="bfloat16", ) measurements_old = bench_gpu_time(lambda: wrapper_old.run(q, kv_data)) ms_old = np.mean(measurements_old) - - # new wrapper = flashinfer.BatchAttention(kv_layout="NHD") wrapper.plan( q_indptr.to(device), kv_indptr.to(device), - torch.arange(num_blocks, dtype=torch.int32, device=device), + paddle.arange(dtype="int32", end=num_blocks), seq_lens.to(device), num_qo_heads, num_kv_heads, @@ -81,28 +71,24 @@ def run_bench( head_dim, page_block_size, causal=causal, - q_data_type=torch.bfloat16, - kv_data_type=torch.bfloat16, + q_data_type="bfloat16", + kv_data_type="bfloat16", ) measurements_new = bench_gpu_time(lambda: wrapper.run(q, kv_data)) ms_new = np.mean(measurements_new) - - total_bytes = ( - q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size() - ) + total_bytes = q.size * q.element_size() + kv_data.size * kv_data.element_size() mem_MB = total_bytes / 1024**2 - bw_old = total_bytes / (ms_old * 1e-3) / 1024**3 - bw_new = total_bytes / (ms_new * 1e-3) / 1024**3 - + bw_old = total_bytes / (ms_old * 0.001) / 1024**3 + bw_new = total_bytes / (ms_new * 0.001) / 1024**3 return ms_old, ms_new, mem_MB, bw_old, bw_new def synthesize_seq_len_configs() -> List[List[Tuple[int, int]]]: cfgs: List[List[Tuple[int, int]]] = [ - [(8192, 1)] * 128, # decode-only - [(4096, 128)] * 4, # prefill-only - [(600, 1)] * 122 + [(10_000, 17)] * 8, # hybird - [(8192, 1)] * 127 * 2 + [(8192, 4096)] * 1, # hybrid (chunked-prefill) + [(8192, 1)] * 128, + [(4096, 128)] * 4, + [(600, 1)] * 122 + [(10000, 17)] * 8, + [(8192, 1)] * 127 * 2 + [(8192, 4096)] * 1, ] def _rand_case(bsz: int, lo: int, hi: int) -> List[Tuple[int, int]]: @@ -117,25 +103,21 @@ def _rand_case(bsz: int, lo: int, hi: int) -> List[Tuple[int, int]]: return out cfgs.append(_rand_case(256, 1000, 8192)) - cfgs.append(_rand_case(128, 2000, 16_000)) + cfgs.append(_rand_case(128, 2000, 16000)) return cfgs def main() -> None: np.random.seed(42) - torch.random.manual_seed(42) - + paddle.seed(seed=42) seq_len_cfgs = synthesize_seq_len_configs() - sweep = { "page_block_size": (1, 8, 16), "head_dim": (64, 128), "num_kv_heads": (4,), "num_qo_heads": (28,), } - records = [] - for cfg_id, pairs in enumerate(seq_len_cfgs, start=1): kv_lens = [p[0] for p in pairs] qo_lens = [p[1] for p in pairs] @@ -181,7 +163,6 @@ def main() -> None: }, ] ) - df = pd.DataFrame( records, columns=[ diff --git a/benchmarks/bench_batch_decode.py b/benchmarks/bench_batch_decode.py index 81413e1a26..edd23ffc09 100644 --- a/benchmarks/bench_batch_decode.py +++ b/benchmarks/bench_batch_decode.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,9 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import numpy as np -import torch import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -37,27 +37,26 @@ def bench_batch_decode( kv_dtype, ): np.random.seed(42) - seq_lens = torch.full((batch_size,), seq_len) - seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int() - kv_indptr = torch.cat([torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0) - kv_indptr = kv_indptr.int() + seq_lens = paddle.full(shape=(batch_size,), fill_value=seq_len) + seq_lens_blocks = paddle.ceil(x=seq_lens / page_block_size).astype(dtype="int32") + kv_indptr = paddle.concat( + x=[paddle.to_tensor(data=[0]), paddle.cumsum(x=seq_lens_blocks, axis=0)], axis=0 + ) + kv_indptr = kv_indptr.astype(dtype="int32") last_page_len = seq_lens - (seq_lens_blocks - 1) * page_block_size - last_page_len = last_page_len.int() + last_page_len = last_page_len.astype(dtype="int32") num_blocks = kv_indptr[-1].item() - - q = torch.rand(batch_size, num_qo_heads, head_dim, dtype=q_dtype, device="cuda:0") - kv_data = torch.randn( - num_blocks, 2, page_block_size, num_kv_heads, head_dim, device="cuda:0" + q = paddle.rand(shape=[batch_size, num_qo_heads, head_dim], dtype=q_dtype) + kv_data = paddle.randn( + shape=[num_blocks, 2, page_block_size, num_kv_heads, head_dim] ).to(kv_dtype) - workspace_buffer = torch.empty( - 128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0" - ) + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="uint8") wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout="NHD", use_tensor_cores=True ) wrapper.plan( kv_indptr.to(0), - torch.arange(num_blocks).int().to(0), + paddle.arange(end=num_blocks).astype(dtype="int32").to(0), last_page_len.to(0), num_qo_heads, num_kv_heads, @@ -66,11 +65,9 @@ def bench_batch_decode( data_type=kv_dtype, q_data_type=q_dtype, ) - measurements = bench_gpu_time(lambda: wrapper.run(q, kv_data)) ms = np.median(measurements) - - io = q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size() + io = q.size * q.element_size() + kv_data.size * kv_data.element_size() print( f"batch_size={batch_size}, seq_len={seq_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, page_block_size={page_block_size}, q_dtype={q_dtype}, kv_dtype={kv_dtype}" ) @@ -79,8 +76,8 @@ def bench_batch_decode( if __name__ == "__main__": - for q_dtype in [torch.bfloat16]: - for kv_dtype in [torch.bfloat16, torch.float8_e4m3fn]: + for q_dtype in ["bfloat16"]: + for kv_dtype in ["bfloat16", paddle.float8_e4m3fn]: for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]: for seq_len in [512, 1024, 2048, 4096, 8192, 16384]: bench_batch_decode( diff --git a/benchmarks/bench_blackwell_attention.py b/benchmarks/bench_blackwell_attention.py index e263862e73..8f82b8c1dd 100644 --- a/benchmarks/bench_blackwell_attention.py +++ b/benchmarks/bench_blackwell_attention.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,40 +15,24 @@ See the License for the specific language governing permissions and limitations under the License. """ - import numpy as np -import torch import flashinfer from flashinfer.testing.utils import bench_gpu_time -def bench_fmha_blackwell( - batch_size, - qkv_len, - num_heads, - head_dim, - causal, - dtype, -): - q = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" - ) - k = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" - ) - v = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" - ) - +def bench_fmha_blackwell(batch_size, qkv_len, num_heads, head_dim, causal, dtype): + q = paddle.randn(shape=[batch_size * qkv_len, num_heads, head_dim], dtype=dtype) + k = paddle.randn(shape=[batch_size * qkv_len, num_heads, head_dim], dtype=dtype) + v = paddle.randn(shape=[batch_size * qkv_len, num_heads, head_dim], dtype=dtype) qo_segment_offsets = ( - torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len + paddle.arange(start=0, end=batch_size + 1, dtype="int32") * qkv_len ) kv_segment_offsets = ( - torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len + paddle.arange(start=0, end=batch_size + 1, dtype="int32") * qkv_len ) wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( - torch.empty(128 * 1024 * 1024, dtype=dtype, device="cuda"), + paddle.empty(shape=128 * 1024 * 1024, dtype=dtype), kv_layout="NHD", backend="cutlass", ) @@ -63,17 +49,33 @@ def bench_fmha_blackwell( ) _o = wrapper.run(q, k, v) measurements = bench_gpu_time( - lambda: wrapper.run(q, k, v), - dry_run_time_ms=100, - repeat_time_ms=1000, + lambda: wrapper.run(q, k, v), dry_run_time_ms=100, repeat_time_ms=1000 ) ms = np.median(measurements) def flops(ms): if causal: - return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9 + return ( + batch_size + * qkv_len + * qkv_len + * num_heads + * head_dim + * 2 + / ms + / 1000000000.0 + ) else: - return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9 + return ( + batch_size + * qkv_len + * qkv_len + * num_heads + * head_dim + * 4 + / ms + / 1000000000.0 + ) print( f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s" @@ -81,20 +83,19 @@ def flops(ms): if __name__ == "__main__": - bench_fmha_blackwell(128, 512, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 128, False, torch.bfloat16) - - bench_fmha_blackwell(128, 512, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 128, True, torch.bfloat16) + bench_fmha_blackwell(128, 512, 32, 128, False, "bfloat16") + bench_fmha_blackwell(64, 1024, 32, 128, False, "bfloat16") + bench_fmha_blackwell(32, 2048, 32, 128, False, "bfloat16") + bench_fmha_blackwell(16, 4096, 32, 128, False, "bfloat16") + bench_fmha_blackwell(8, 8192, 32, 128, False, "bfloat16") + bench_fmha_blackwell(4, 16384, 32, 128, False, "bfloat16") + bench_fmha_blackwell(2, 32768, 32, 128, False, "bfloat16") + bench_fmha_blackwell(1, 65536, 32, 128, False, "bfloat16") + bench_fmha_blackwell(128, 512, 32, 128, True, "bfloat16") + bench_fmha_blackwell(64, 1024, 32, 128, True, "bfloat16") + bench_fmha_blackwell(32, 2048, 32, 128, True, "bfloat16") + bench_fmha_blackwell(16, 4096, 32, 128, True, "bfloat16") + bench_fmha_blackwell(8, 8192, 32, 128, True, "bfloat16") + bench_fmha_blackwell(4, 16384, 32, 128, True, "bfloat16") + bench_fmha_blackwell(2, 32768, 32, 128, True, "bfloat16") + bench_fmha_blackwell(1, 65536, 32, 128, True, "bfloat16") diff --git a/benchmarks/bench_block_sparse_attention.py b/benchmarks/bench_block_sparse_attention.py index e2a51012f5..a4a261dfbe 100644 --- a/benchmarks/bench_block_sparse_attention.py +++ b/benchmarks/bench_block_sparse_attention.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,9 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import numpy as np -import torch import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -36,38 +36,30 @@ def bench_variable_block_sparse_attention( return if seq_len // num_blocks_col < 1: return - - # synthesize uniform block sz - block_row_sz = torch.ones(num_blocks_row, dtype=torch.int32) * ( + block_row_sz = paddle.ones(shape=num_blocks_row, dtype="int32") * ( seq_len // num_blocks_row ) - block_row_sz[-1] = seq_len - (seq_len // num_blocks_row) * (num_blocks_row - 1) - block_row_sz = block_row_sz.unsqueeze(0).repeat(num_kv_heads, 1) - - block_col_sz = torch.ones(num_blocks_col, dtype=torch.int32) * ( + block_row_sz[-1] = seq_len - seq_len // num_blocks_row * (num_blocks_row - 1) + block_row_sz = block_row_sz.unsqueeze(axis=0).tile(repeat_times=[num_kv_heads, 1]) + block_col_sz = paddle.ones(shape=num_blocks_col, dtype="int32") * ( seq_len // num_blocks_col ) - block_col_sz[-1] = seq_len - (seq_len // num_blocks_col) * (num_blocks_col - 1) - block_col_sz = block_col_sz.unsqueeze(0).repeat(num_kv_heads, 1) - + block_col_sz[-1] = seq_len - seq_len // num_blocks_col * (num_blocks_col - 1) + block_col_sz = block_col_sz.unsqueeze(axis=0).tile(repeat_times=[num_kv_heads, 1]) block_mask_map = ( - torch.rand(num_kv_heads, num_blocks_row, num_blocks_col) < block_density - ) - - q = torch.randn(num_qo_heads, seq_len, head_dim, dtype=torch.half, device="cuda") - k = torch.randn(num_kv_heads, seq_len, head_dim, dtype=torch.half, device="cuda") - v = torch.randn(num_kv_heads, seq_len, head_dim, dtype=torch.half, device="cuda") - - float_workspace_buffer = torch.empty( - 128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0" + paddle.rand(shape=[num_kv_heads, num_blocks_row, num_blocks_col]) + < block_density ) + q = paddle.randn(shape=[num_qo_heads, seq_len, head_dim], dtype="float16") + k = paddle.randn(shape=[num_kv_heads, seq_len, head_dim], dtype="float16") + v = paddle.randn(shape=[num_kv_heads, seq_len, head_dim], dtype="float16") + float_workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="uint8") sparse_wrapper_fa2 = flashinfer.sparse.VariableBlockSparseAttentionWrapper( float_workspace_buffer, backend="fa2" ) sparse_wrapper_fa3 = flashinfer.sparse.VariableBlockSparseAttentionWrapper( float_workspace_buffer, backend="fa3" ) - sparse_wrapper_fa2.plan( block_mask_map=block_mask_map, block_row_sz=block_row_sz, @@ -75,7 +67,7 @@ def bench_variable_block_sparse_attention( num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, - q_data_type=torch.half, + q_data_type="float16", ) sparse_wrapper_fa3.plan( block_mask_map=block_mask_map, @@ -84,28 +76,23 @@ def bench_variable_block_sparse_attention( num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, - q_data_type=torch.half, + q_data_type="float16", ) - - # Benchmark sparse attention with FA2 measurements_fa2 = bench_gpu_time( lambda: sparse_wrapper_fa2.run(q, k, v), dry_run_time_ms=100, repeat_time_ms=1000, ) sparse_ms_fa2 = np.median(measurements_fa2) - - # Benchmark sparse attention with FA3 measurements_fa3 = bench_gpu_time( lambda: sparse_wrapper_fa3.run(q, k, v), dry_run_time_ms=100, repeat_time_ms=1000, ) sparse_ms_fa3 = np.median(measurements_fa3) - - q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") - k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") - v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") + q = paddle.randn(shape=[seq_len, num_qo_heads, head_dim], dtype="float16") + k = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") + v = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") dense_sm80_ms, dense_sm90_ms = ( np.median( bench_gpu_time( @@ -120,7 +107,7 @@ def bench_variable_block_sparse_attention( ) def flops(ms): - return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1000000000.0 print( f"bench_variable_block_sparse_attention (num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, seq_len={seq_len}, num_blocks_row={num_blocks_row}, num_blocks_col={num_blocks_col}, block_density={block_density}), sparse fa2-template: {flops(sparse_ms_fa2):.3f} TFLOPs/s, sparse fa3-template: {flops(sparse_ms_fa3):.3f} TFLOPs/s, dense fa2-template: {flops(dense_sm80_ms):.3f} TFLOPs/s, dense fa3-template: {flops(dense_sm90_ms):.3f} TFLOPs/s" diff --git a/benchmarks/bench_cute_dsl_blockscaled_gemm.py b/benchmarks/bench_cute_dsl_blockscaled_gemm.py index fb444b019d..d33ba39d80 100644 --- a/benchmarks/bench_cute_dsl_blockscaled_gemm.py +++ b/benchmarks/bench_cute_dsl_blockscaled_gemm.py @@ -1,22 +1,23 @@ +import sys + + import json import random + import cutlass -from flashinfer.cute_dsl.blockscaled_gemm import ( - create_scale_factor_tensor, - grouped_gemm_nt_masked, # deepgemm-like python interface for DLFW integration -) -import torch import cutlass.torch as cutlass_torch +import paddle +from flashinfer.paddle_utils import * + +from flashinfer.cute_dsl.blockscaled_gemm import (create_scale_factor_tensor, + grouped_gemm_nt_masked) from flashinfer.cute_dsl.utils import get_cutlass_dtype from flashinfer.testing.utils import bench_kineto, count_bytes - ab_dtype = "float4_e2m1fn" sf_dtype = "float8_e4m3fn" c_dtype = "bfloat16" sf_vec_size = 16 - -# DeepGEMM case a_major = "k" b_major = "k" c_major = "n" @@ -49,53 +50,37 @@ def test_func(): "Sm100BlockScaledPersistentDenseGemmKernel", suppress_kineto_output=True, ) - valid_m = data["masked_m"].sum().item() t_calibrated = t / valid_m * (expected_m_per_group * num_groups) - - tflops = 2 * valid_m * n * k / t / 1e12 + tflops = 2 * valid_m * n * k / t / 1000000000000.0 gb_per_s = ( ( count_bytes(data["a"], data["c"]) * valid_m / (max_m * num_groups) + count_bytes(data["b"]) ) - / 1e9 + / 1000000000.0 / t ) - print( - f" > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): " - f"{t * 1e6:4.0f} us | {tflops:4.0f} TFLOPS | {gb_per_s:4.0f} GB/s" + f" > Perf (num_groups={num_groups!r}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1000000.0:4.0f} us | {tflops:4.0f} TFLOPS | {gb_per_s:4.0f} GB/s" ) - metrics = dict( num_groups=num_groups, m_per_group=expected_m_per_group, valid_m=valid_m, n=n, k=k, - t_us_raw=t * 1e6, - t_us_calibrated=t_calibrated * 1e6, + t_us_raw=t * 1000000.0, + t_us_calibrated=t_calibrated * 1000000.0, tflops=tflops, gb_per_s=gb_per_s, ) print(f"MAIN_OUTPUT={json.dumps(metrics)}") -# ref: DeepGEMM def enumerate_m_grouped_masked(): max_m = 4096 - - cases = [ - # GB200 cases - (6, 1024), - (6, 512), - # DeepGEMM default cases - (1, 1024), - (2, 512), - (4, 256), - ] - # more GB200 cases + cases = [(6, 1024), (6, 512), (1, 1024), (2, 512), (4, 256)] num_experts = 288 num_experts_per_token = 8 for num_ranks in [4, 8, 16, 32, 36, 48, 72]: @@ -103,12 +88,8 @@ def enumerate_m_grouped_masked(): num_groups = num_experts // num_ranks expected_m_per_group = num_tokens * num_experts_per_token // num_groups cases.append((num_groups, expected_m_per_group)) - for num_groups, expected_m_per_group in cases: - for n, k in ( - (4096, 7168), - (7168, 2048), - ): + for n, k in ((4096, 7168), (7168, 2048)): yield dict( num_groups=num_groups, max_m=max_m, @@ -118,66 +99,48 @@ def enumerate_m_grouped_masked(): ) -# Copy and modified from test_cute_dsl_blockscaled_gemm.py, may extract common logic later if needed def create_data(num_groups, max_m, expected_m_per_group, n, k, device="cuda:0"): - device = torch.device(device) + device = device2str(device) l = num_groups m = max_m - a_ref = cutlass_torch.matrix(l, m, k, a_major == "m", cutlass.Float32) b_ref = cutlass_torch.matrix(l, n, k, b_major == "n", cutlass.Float32) c_ref = cutlass_torch.matrix(l, m, n, c_major == "m", cutlass.Float32) - a_tensor, a_torch = cutlass_torch.cute_tensor_like( - a_ref, - get_cutlass_dtype(ab_dtype), - is_dynamic_layout=True, - assumed_align=16, + a_ref, get_cutlass_dtype(ab_dtype), is_dynamic_layout=True, assumed_align=16 ) b_tensor, b_torch = cutlass_torch.cute_tensor_like( - b_ref, - get_cutlass_dtype(ab_dtype), - is_dynamic_layout=True, - assumed_align=16, + b_ref, get_cutlass_dtype(ab_dtype), is_dynamic_layout=True, assumed_align=16 ) c_tensor, c_torch = cutlass_torch.cute_tensor_like( - c_ref, - get_cutlass_dtype(c_dtype), - is_dynamic_layout=True, - assumed_align=16, + c_ref, get_cutlass_dtype(c_dtype), is_dynamic_layout=True, assumed_align=16 ) - - # for deepgemm-like python interface if ab_dtype == "float4_e2m1fn": - m, k, l = a_torch.shape - n, k, l = b_torch.shape - # slice into half after flatten - half_len_a = a_torch.numel() // 2 - half_len_b = b_torch.numel() // 2 + m, k, l = tuple(a_torch.shape) + n, k, l = tuple(b_torch.shape) + half_len_a = a_torch.size // 2 + half_len_b = b_torch.size // 2 a_torch = ( - a_torch.permute(2, 0, 1) + a_torch.transpose(perm=[2, 0, 1]) .flatten()[:half_len_a] .reshape(l, m, k // 2) - .permute(1, 2, 0) + .transpose(perm=[1, 2, 0]) ) b_torch = ( - b_torch.permute(2, 0, 1) + b_torch.transpose(perm=[2, 0, 1]) .flatten()[:half_len_b] .reshape(l, n, k // 2) - .permute(1, 2, 0) + .transpose(perm=[1, 2, 0]) ) - sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor( l, m, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device ) sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor( l, n, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device ) - masked_m_tensor = create_masked_m( num_groups=num_groups, expected_m_per_group=expected_m_per_group, max_m=max_m ) - return dict( a=(a_torch, sfa_torch), b=(b_torch, sfb_torch), @@ -188,7 +151,7 @@ def create_data(num_groups, max_m, expected_m_per_group, n, k, device="cuda:0"): def create_masked_m(num_groups, expected_m_per_group, max_m): """Align with DeepGEMM :: generate_m_grouped_masked""" - masked_m = torch.empty((num_groups,), device="cuda", dtype=torch.int) + masked_m = paddle.empty(shape=(num_groups,), dtype="int32") for j in range(num_groups): masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) assert masked_m.amax().item() <= max_m @@ -196,7 +159,7 @@ def create_masked_m(num_groups, expected_m_per_group, max_m): if __name__ == "__main__": - torch.manual_seed(42) + paddle.seed(seed=42) random.seed(42) for config in enumerate_m_grouped_masked(): bench_one(**config) diff --git a/benchmarks/bench_cutlass_fused_moe.py b/benchmarks/bench_cutlass_fused_moe.py index e0dff8e215..dc5f599bf5 100644 --- a/benchmarks/bench_cutlass_fused_moe.py +++ b/benchmarks/bench_cutlass_fused_moe.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2024 by FlashInfer team. @@ -13,13 +19,10 @@ See the License for the specific language governing permissions and limitations under the License. """ - import argparse import pprint import numpy as np -import torch -from torch.nn import functional as F import flashinfer.fused_moe as fused_moe from flashinfer import fp4_quantize @@ -27,28 +30,16 @@ from flashinfer.testing.utils import bench_gpu_time FLOAT4_E2M1_MAX = 6.0 -FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max - - +FLOAT8_E4M3_MAX = paddle.finfo(dtype=paddle.float8_e4m3fn).max test_configs = [ - { - "hidden_size": 7168, - "num_experts": 256, - "top_k": 8, - "intermediate_size": 256, - }, - { - "hidden_size": 7168, - "num_experts": 32, - "top_k": 8, - "intermediate_size": 2048, - }, + {"hidden_size": 7168, "num_experts": 256, "top_k": 8, "intermediate_size": 256}, + {"hidden_size": 7168, "num_experts": 32, "top_k": 8, "intermediate_size": 2048}, ] def compute_routing( - router_logits: torch.Tensor, top_k: int -) -> tuple[torch.Tensor, torch.Tensor]: + router_logits: paddle.Tensor, top_k: int +) -> tuple[paddle.Tensor, paddle.Tensor]: """ Compute routing weights and selected experts from router logits. @@ -61,112 +52,95 @@ def compute_routing( - routing_weights: Expert weights of shape [batch_size, top_k] - selected_experts: Expert indices of shape [batch_size, top_k] """ - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.float() + routing_weights = paddle.nn.functional.softmax( + x=router_logits, axis=1, dtype="float32" + ) + routing_weights, selected_experts = paddle.topk(x=routing_weights, k=top_k, axis=-1) + routing_weights /= routing_weights.sum(axis=-1, keepdim=True) + routing_weights = routing_weights.astype(dtype="float32") return routing_weights, selected_experts def bench_cutlass_fused_moe( - batch_size, - hidden_size, - num_experts, - top_k, - intermediate_size, - skip_autotune, + batch_size, hidden_size, num_experts, top_k, intermediate_size, skip_autotune ): - torch.manual_seed(42) + paddle.seed(seed=42) quant_blocksize = 16 round_up = lambda x, y: (x + y - 1) // y * y e = num_experts m = batch_size n = intermediate_size k = hidden_size - otype = torch.bfloat16 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10 - w1_cutlass = torch.cat((w1[:, n:, :], w1[:, :n, :]), dim=1).contiguous() - + otype = "bfloat16" + w1 = paddle.randn(shape=(e, 2 * n, k), dtype=otype) / 10 + w1_cutlass = paddle.concat(x=(w1[:, n:, :], w1[:, :n, :]), axis=1).contiguous() sf_w1_2n = round_up(2 * n, 128) sf_w1_k = round_up(k // quant_blocksize, 4) - w1_blockscale = torch.empty( - (e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + w1_blockscale = paddle.empty( + shape=(e, sf_w1_2n, sf_w1_k), dtype=paddle.float8_e4m3fn ) - w1_blockscale_cutlass = torch.empty( - (e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + w1_blockscale_cutlass = paddle.empty( + shape=(e, sf_w1_2n, sf_w1_k), dtype=paddle.float8_e4m3fn ) - - w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10 + w2 = paddle.randn(shape=(e, k, n), dtype=otype) / 10 sf_w2_k = round_up(k, 128) sf_w2_n = round_up(n // quant_blocksize, 4) - w2_blockscale = torch.empty( - (e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn - ) - w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) - w1_q_cutlass = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) - w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) - w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32) - w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32) - + w2_blockscale = paddle.empty(shape=(e, sf_w2_k, sf_w2_n), dtype=paddle.float8_e4m3fn) + w1_q = paddle.empty(shape=(e, 2 * n, k // 2), dtype="uint8") + w1_q_cutlass = paddle.empty(shape=(e, 2 * n, k // 2), dtype="uint8") + w2_q = paddle.empty(shape=(e, k, n // 2), dtype="uint8") + w1_gs = paddle.empty(shape=(e,), dtype="float32") + w2_gs = paddle.empty(shape=(e,), dtype="float32") for expert in range(e): - w1_amax = torch.abs(w1).max().to(torch.float32) - w2_amax = torch.abs(w2).max().to(torch.float32) + w1_amax = paddle.abs(x=w1)._max().to("float32") + w2_amax = paddle.abs(x=w2)._max().to("float32") w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax - w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert]) - w1_q_cutlass[expert], w1_blockscale_cutlass[expert] = fp4_quantize( w1_cutlass[expert], w1_gs[expert] ) - w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert]) - - x = torch.randn(m, k, dtype=otype).cuda() - a1_gs = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(x).max().to( - torch.float32 - ).cuda() - a1_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) - a2_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) - router_logits = torch.randn(m, e, dtype=otype).cuda() + x = paddle.randn(shape=[m, k], dtype=otype).cuda() + a1_gs = ( + FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / paddle.abs(x=x)._max().to("float32").cuda() + ) + a1_gs = paddle.to_tensor(data=1.0, dtype="float32", place="gpu") + a2_gs = paddle.to_tensor(data=1.0, dtype="float32", place="gpu") + router_logits = paddle.randn(shape=[m, e], dtype=otype).cuda() routing_weights, selected_experts = compute_routing(router_logits, top_k) - - flash_output = torch.zeros_like(x) - + flash_output = paddle.zeros_like(x=x) quant_scales = [ a1_gs, - w1_blockscale.view(torch.int32), + w1_blockscale.view("int32"), 1.0 / (a1_gs * w1_gs), a2_gs, - w2_blockscale.view(torch.int32), + w2_blockscale.view("int32"), 1.0 / (a2_gs * w2_gs), ] hidden_states = x hidden_states, input_sf = fp4_quantize(x, a1_gs) - - # Warmup for _ in range(3): _ = fused_moe.cutlass_fused_moe( hidden_states, - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, - w1_q.contiguous().view(torch.long), - w2_q.contiguous().view(torch.long), + w1_q.contiguous().view("int64"), + w2_q.contiguous().view("int64"), otype, quant_scales=quant_scales, input_sf=input_sf, output=flash_output, tune_max_num_tokens=16384, ) - if not skip_autotune: - with torch.inference_mode(), autotune(True): + with paddle.no_grad(), autotune(True): _ = fused_moe.cutlass_fused_moe( hidden_states, - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, - w1_q.contiguous().view(torch.long), - w2_q.contiguous().view(torch.long), + w1_q.contiguous().view("int64"), + w2_q.contiguous().view("int64"), otype, quant_scales=quant_scales, input_sf=input_sf, @@ -176,20 +150,20 @@ def bench_cutlass_fused_moe( ms_list = bench_gpu_time( lambda: fused_moe.cutlass_fused_moe( hidden_states, - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, - w1_q.contiguous().view(torch.long), - w2_q.contiguous().view(torch.long), + w1_q.contiguous().view("int64"), + w2_q.contiguous().view("int64"), otype, quant_scales=quant_scales, input_sf=input_sf, output=flash_output, - ), + ) ) median_ms = np.median(ms_list) print(f"{'input':<15} {'weight1':<20} {'weight2':<20} {'time(ms)'}") print( - f"{str(tuple(hidden_states.shape)):<15} {str(tuple(w1.shape)):<20} {str(tuple(w2.shape)):<20} {median_ms:.3f}" + f"{str(tuple(tuple(hidden_states.shape))):<15} {str(tuple(tuple(w1.shape))):<20} {str(tuple(tuple(w2.shape))):<20} {median_ms:.3f}" ) @@ -206,7 +180,6 @@ def bench_cutlass_fused_moe( parser.add_argument("--skip-autotune", action="store_true", help="Skip autotuning") args = parser.parse_args() AutoTuner.get().clear_cache() - for config in test_configs: bench_cutlass_fused_moe( args.num_tokens, @@ -216,11 +189,8 @@ def bench_cutlass_fused_moe( config["intermediate_size"], args.skip_autotune, ) - configs = AutoTuner.get().profiling_cache if args.update_config and configs: - # The original key contains a runner's hash in k[2] which might be different across machines. - # So, we remove it for now. v[0] and v[1] are the runner id and the tactic. converted = {str((k[0], k[1], k[3])): (v[0], v[1]) for k, v in configs.items()} config_path = get_config_path(is_module=False) with open(config_path, "w") as f: diff --git a/benchmarks/bench_deepgemm_blackwell.py b/benchmarks/bench_deepgemm_blackwell.py index ec66f22341..01d0206526 100644 --- a/benchmarks/bench_deepgemm_blackwell.py +++ b/benchmarks/bench_deepgemm_blackwell.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,41 +15,25 @@ See the License for the specific language governing permissions and limitations under the License. """ - import numpy as np -import torch -from flashinfer.gemm import ( - batch_deepgemm_fp8_nt_groupwise, - group_deepgemm_fp8_nt_groupwise, -) +from flashinfer.gemm import (batch_deepgemm_fp8_nt_groupwise, + group_deepgemm_fp8_nt_groupwise) from flashinfer.testing.utils import bench_gpu_time, quantize_fp8 def bench_deepgemm_grouped_fp8_blackwell(batch_size, m, n, k, in_dtype, out_dtype): """Benchmark DeepGEMM-based grouped GEMM with FP8 quantization.""" - - # Create float32 input tensors - a_f32 = torch.randn(batch_size * m, k, device="cuda", dtype=torch.float32) - b_f32 = torch.randn(batch_size, n, k, device="cuda", dtype=torch.float32) - - # Quantize tensor A using per-token quantization + a_f32 = paddle.randn(shape=[batch_size * m, k], dtype="float32") + b_f32 = paddle.randn(shape=[batch_size, n, k], dtype="float32") a_fp8, a_scale = quantize_fp8(a_f32, (batch_size * m, k // 128), (1, 128), "K") - - # Quantize tensor B using per-block quantization b_fp8, b_scale = quantize_fp8( b_f32, (batch_size, n // 128, k // 128), (1, 128, 128), "K" ) - - # Create group assignment indices - m_indices = torch.arange( - batch_size, device="cuda", dtype=torch.int32 - ).repeat_interleave(m) - - # Pre-allocate output tensor - out = torch.empty(batch_size * m, n, device="cuda", dtype=out_dtype) - - # Benchmark the DeepGEMM function + m_indices = paddle.arange(dtype="int32", end=batch_size).repeat_interleave( + repeats=m + ) + out = paddle.empty(shape=[batch_size * m, n], dtype=out_dtype) measurements = bench_gpu_time( lambda: group_deepgemm_fp8_nt_groupwise( a_fp8, b_fp8, a_scale, b_scale, m_indices, out=out, out_dtype=out_dtype @@ -56,41 +42,34 @@ def bench_deepgemm_grouped_fp8_blackwell(batch_size, m, n, k, in_dtype, out_dtyp repeat_time_ms=1000, ) ms = np.median(measurements) - tflops_per_second = 2 * batch_size * m * n * k * 1e-9 / ms + tflops_per_second = 2 * batch_size * m * n * k * 1e-09 / ms memory_bandwidth_per_second = ( sum( [ - _.numel() * _.element_size() + (_.size * _.element_size()) for _ in [a_fp8, b_fp8, a_scale, b_scale, m_indices, out] ] ) - * 1e-9 + * 1e-09 / ms ) print( - f"group_deepgemm_fp8_nt_groupwise batch_size={batch_size} m={m} n={n} k={k} " - f"in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s" - f"memory_bandwidth: {memory_bandwidth_per_second:.2f} TB/s" + f"group_deepgemm_fp8_nt_groupwise batch_size={batch_size} m={m} n={n} k={k} in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/smemory_bandwidth: {memory_bandwidth_per_second:.2f} TB/s" ) - return tflops_per_second def bench_deepgemm_batch_fp8_blackwell(batch_size, m, n, k, in_dtype, out_dtype): """Benchmark DeepGEMM-based batch GEMM with FP8 quantization.""" - - a = torch.randn((batch_size, m, k), device="cuda", dtype=torch.float32) - b = torch.randn((batch_size, n, k), device="cuda", dtype=torch.float32) - masked_m = torch.randint(0, m, (batch_size,), device="cuda", dtype=torch.int32) + a = paddle.randn(shape=(batch_size, m, k), dtype="float32") + b = paddle.randn(shape=(batch_size, n, k), dtype="float32") + masked_m = paddle.randint(low=0, high=m, shape=(batch_size,), dtype="int32") a_fp8, a_scale = quantize_fp8(a, (batch_size, m, k // 128), (1, 1, 128), "K") b_fp8, b_scale = quantize_fp8( b, (batch_size, n // 128, k // 128), (1, 128, 128), "K" ) - expected_m = min(int(masked_m.float().mean()) + 1, m) - - out = torch.empty((batch_size, m, n), device="cuda", dtype=out_dtype) - - # Benchmark the DeepGEMM function + expected_m = min(int(masked_m.astype(dtype="float32").mean()) + 1, m) + out = paddle.empty(shape=(batch_size, m, n), dtype=out_dtype) measurements = bench_gpu_time( lambda: batch_deepgemm_fp8_nt_groupwise( a_fp8, @@ -106,44 +85,38 @@ def bench_deepgemm_batch_fp8_blackwell(batch_size, m, n, k, in_dtype, out_dtype) repeat_time_ms=1000, ) ms = np.median(measurements) - - tflops_per_second = 2 * batch_size * m * n * k * 1e-9 / ms + tflops_per_second = 2 * batch_size * m * n * k * 1e-09 / ms memory_bandwidth_per_second = ( sum( [ - _.numel() * _.element_size() + (_.size * _.element_size()) for _ in [a_fp8, b_fp8, a_scale, b_scale, masked_m, out] ] ) - * 1e-9 + * 1e-09 / ms ) print( - f"group_deepgemm_fp8_nt_groupwise batch_size={batch_size} m={m} n={n} k={k} " - f"in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s" - f"memory_bandwidth: {memory_bandwidth_per_second:.2f} TB/s" + f"group_deepgemm_fp8_nt_groupwise batch_size={batch_size} m={m} n={n} k={k} in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/smemory_bandwidth: {memory_bandwidth_per_second:.2f} TB/s" ) - return tflops_per_second if __name__ == "__main__": print("=== DeepGEMM Grouped FP8 GEMM Benchmark ===\n") - for batch_size in [1, 4, 8, 64, 128, 256]: for m in [128, 256, 1024, 8192, 16384]: for n, k in [(128, 512), (512, 128), (4096, 7168), (7168, 2048)]: if m // batch_size < 128: continue - if m * batch_size <= 16384: # Limit total problem size + if m * batch_size <= 16384: bench_deepgemm_grouped_fp8_blackwell( - batch_size, m, n, k, torch.float8_e4m3fn, torch.bfloat16 + batch_size, m, n, k, paddle.float8_e4m3fn, "bfloat16" ) - for batch_size in [1, 4, 8, 64, 128, 256]: for m in [128, 256, 1024, 8192, 16384]: for n, k in [(128, 512), (512, 128), (4096, 7168), (7168, 2048)]: - if m * batch_size <= 16384: # Limit total problem size + if m * batch_size <= 16384: bench_deepgemm_batch_fp8_blackwell( - batch_size, m, n, k, torch.float8_e4m3fn, torch.bfloat16 + batch_size, m, n, k, paddle.float8_e4m3fn, "bfloat16" ) diff --git a/benchmarks/bench_deepseek_mla.py b/benchmarks/bench_deepseek_mla.py index b13fc6c2fd..0c98e5db6d 100644 --- a/benchmarks/bench_deepseek_mla.py +++ b/benchmarks/bench_deepseek_mla.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,9 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import numpy as np -import torch import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -25,27 +25,27 @@ def bench_deepseek_mla_decode(batch_size, seq_len, num_heads, backend): head_dim_ckv = 512 head_dim_kpe = 64 page_size = 1 - q_nope = torch.randn( - batch_size * 1, num_heads, head_dim_ckv, dtype=torch.half, device="cuda" - ) - q_pe = torch.zeros( - batch_size * 1, num_heads, head_dim_kpe, dtype=torch.half, device="cuda" + q_nope = paddle.randn( + shape=[batch_size * 1, num_heads, head_dim_ckv], dtype="float16" ) - ckv = torch.randn( - batch_size * seq_len, 1, head_dim_ckv, dtype=torch.half, device="cuda" + q_pe = paddle.zeros( + shape=[batch_size * 1, num_heads, head_dim_kpe], dtype="float16" ) - kpe = torch.zeros( - batch_size * seq_len, 1, head_dim_kpe, dtype=torch.half, device="cuda" - ) - sm_scale = 1.0 / ((head_dim_ckv + head_dim_kpe) ** 0.5) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + ckv = paddle.randn(shape=[batch_size * seq_len, 1, head_dim_ckv], dtype="float16") + kpe = paddle.zeros(shape=[batch_size * seq_len, 1, head_dim_kpe], dtype="float16") + sm_scale = 1.0 / (head_dim_ckv + head_dim_kpe) ** 0.5 + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8").to(0) wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( workspace_buffer, backend=backend ) - q_indptr = torch.arange(0, batch_size + 1).to(0).int() - kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * seq_len - kv_indices = torch.arange(0, batch_size * seq_len).to(0).int() - kv_lens = torch.full((batch_size,), seq_len, dtype=torch.int32).to(0) + q_indptr = paddle.arange(start=0, end=batch_size + 1).to(0).astype(dtype="int32") + kv_indptr = ( + paddle.arange(start=0, end=batch_size + 1).to(0).astype(dtype="int32") * seq_len + ) + kv_indices = ( + paddle.arange(start=0, end=batch_size * seq_len).to(0).astype(dtype="int32") + ) + kv_lens = paddle.full(shape=(batch_size,), fill_value=seq_len, dtype="int32").to(0) wrapper.plan( q_indptr, kv_indptr, @@ -55,26 +55,23 @@ def bench_deepseek_mla_decode(batch_size, seq_len, num_heads, backend): head_dim_ckv, head_dim_kpe, page_size, - False, # causal + False, sm_scale, q_nope.dtype, ckv.dtype, ) o = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) - measurements = bench_gpu_time( lambda: wrapper.run(q_nope, q_pe, ckv, kpe), dry_run_time_ms=100, repeat_time_ms=1000, ) ms = np.median(measurements) - - io = sum([_.numel() * _.element_size() for _ in [q_nope, q_pe, ckv, kpe, o]]) + io = sum([(_.size * _.element_size()) for _ in [q_nope, q_pe, ckv, kpe, o]]) flops = 2 * batch_size * num_heads * (2 * head_dim_ckv + head_dim_kpe) * seq_len - print(f"Config: batch_size={batch_size}, seq_len={seq_len}, num_heads={num_heads}") - print(f"Memory bandwidth: {io * 1e-6 / ms:.2f} GB/s") - print(f"FLOPs: {flops * 1e-9 / ms:.2f} TFLOPs") + print(f"Memory bandwidth: {io * 1e-06 / ms:.2f} GB/s") + print(f"FLOPs: {flops * 1e-09 / ms:.2f} TFLOPs") if __name__ == "__main__": diff --git a/benchmarks/bench_fused_add_rmsnorm.py b/benchmarks/bench_fused_add_rmsnorm.py index d03c7605ac..dec5ff8c53 100644 --- a/benchmarks/bench_fused_add_rmsnorm.py +++ b/benchmarks/bench_fused_add_rmsnorm.py @@ -1,13 +1,13 @@ import argparse import numpy as np -import torch +import paddle import flashinfer from flashinfer.testing.utils import bench_gpu_time -@torch.inference_mode() +@paddle.no_grad() def main(): parser = argparse.ArgumentParser() parser.add_argument("--batch-sizes", nargs="+", type=int, default=[1, 19, 99, 989]) @@ -21,45 +21,37 @@ def main(): "--dtypes", nargs="+", choices=["float16", "bfloat16"], default=["float16"] ) args = parser.parse_args() - - eps = 1e-6 - - # Loop over each combination of batch_size, hidden_size, and dtype + eps = 1e-06 for batch_size in args.batch_sizes: for hidden_size in args.hidden_sizes: for dtype_str in args.dtypes: dtype = getattr(torch, dtype_str) + x = paddle.randn(shape=(batch_size, hidden_size), dtype=dtype) + residual = paddle.randn(shape=x.shape, dtype=x.dtype) + weight = paddle.randn(shape=hidden_size, dtype=dtype) - # Define tensors with the correct dtype - x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda") - residual = torch.randn_like(x) - weight = torch.randn(hidden_size, dtype=dtype, device="cuda") - - @torch.cuda.nvtx.range( +>>>>>> @torch.cuda.nvtx.range( f"fused_add_rmsnorm batch_size={batch_size}, hidden_size={hidden_size}, dtype={dtype_str}" ) def fn() -> None: flashinfer.fused_add_rmsnorm(x, residual, weight, eps) - # Run benchmarking measurements = bench_gpu_time(fn) latency_ms = np.median(measurements) throughput = ( - x.numel() * x.element_size() * 2 - + residual.numel() * residual.element_size() * 2 - + weight.numel() * weight.element_size() - ) / (latency_ms * 1e-3) + x.size * x.element_size() * 2 + + residual.size * residual.element_size() * 2 + + weight.size * weight.element_size() + ) / (latency_ms * 0.001) print( f"batch_size: {batch_size:3},", f"hidden_size: {hidden_size:5},", f"dtype: {dtype_str:8},", - f"latency: {latency_ms * 1e3:2.0f}us,", - f"throughput: {throughput * 1e-9:7.3f}GB/s", + f"latency: {latency_ms * 1000.0:2.0f}us,", + f"throughput: {throughput * 1e-09:7.3f}GB/s", ) - print("---") - - torch.cuda.profiler.stop() +>>>>>> torch.cuda.profiler.stop() if __name__ == "__main__": diff --git a/benchmarks/bench_groupwise_gemm_fp8_blackwell.py b/benchmarks/bench_groupwise_gemm_fp8_blackwell.py index 3175d78a11..3eef2c2c6f 100644 --- a/benchmarks/bench_groupwise_gemm_fp8_blackwell.py +++ b/benchmarks/bench_groupwise_gemm_fp8_blackwell.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,9 +19,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import numpy as np -import torch import triton import triton.language as tl @@ -25,20 +29,16 @@ @triton.jit def _w8a8_block_fp8_matmul( - # Pointers to inputs and output A, B, C, As, Bs, - # Shape for matmul M, N, K, - # Block size for block-wise quantization group_n, group_k, - # Stride for inputs and output stride_am, stride_ak, stride_bk, @@ -49,7 +49,6 @@ def _w8a8_block_fp8_matmul( stride_As_k, stride_Bs_k, stride_Bs_n, - # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, @@ -59,7 +58,6 @@ def _w8a8_block_fp8_matmul( product) on input tensors `A` and `B` with block-wise quantization, and store the result in output tensor `C`. """ - pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -67,40 +65,33 @@ def _w8a8_block_fp8_matmul( group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - + pid_m = first_pid_m + pid % group_size_m + pid_n = pid % num_pid_in_group // group_size_m offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - As_ptrs = As + offs_am * stride_As_m offs_bsn = offs_bn // group_n Bs_ptrs = Bs + offs_bsn * stride_Bs_n - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k a_s = tl.load(As_ptrs + offs_ks * stride_As_k) b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) - accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - if C.dtype.element_ty == tl.bfloat16: c = accumulator.to(tl.bfloat16) elif C.dtype.element_ty == tl.float16: c = accumulator.to(tl.float16) else: c = accumulator.to(tl.float32) - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] @@ -109,16 +100,15 @@ def _w8a8_block_fp8_matmul( def triton_w8a8_block_fp8_matmul( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - out: torch.Tensor, -) -> torch.Tensor: - M = A.shape[0] - N, K = B.shape + A: paddle.Tensor, + B: paddle.Tensor, + As: paddle.Tensor, + Bs: paddle.Tensor, + out: paddle.Tensor, +) -> paddle.Tensor: + M = tuple(A.shape)[0] + N, K = tuple(B.shape) block_n, block_k = 128, 128 - config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": block_n, @@ -144,48 +134,44 @@ def grid(META): K, block_n, block_k, - A.stride(-2), - A.stride(-1), - B.stride(1), - B.stride(0), - out.stride(-2), - out.stride(-1), - As.stride(-2), - As.stride(-1), - Bs.stride(1), - Bs.stride(0), + A.get_strides()[-2], + A.get_strides()[-1], + B.get_strides()[1], + B.get_strides()[0], + out.get_strides()[-2], + out.get_strides()[-1], + As.get_strides()[-2], + As.get_strides()[-1], + Bs.get_strides()[1], + Bs.get_strides()[0], **config, ) - return out def bench_groupwise_gemm_fp8_blackwell(m, n, k, in_dtype, out_dtype): - a = torch.randn((m, k), device="cuda").to(in_dtype) - b = torch.randn((n, k), device="cuda").to(in_dtype) - a_scale = torch.rand((k // 128, m), dtype=torch.float32, device="cuda") - b_scale = torch.rand((k // 128, n // 128), dtype=torch.float32, device="cuda") - - out = torch.empty((m, n), dtype=out_dtype, device="cuda") + a = paddle.randn(shape=(m, k)).to(in_dtype) + b = paddle.randn(shape=(n, k)).to(in_dtype) + a_scale = paddle.rand(shape=(k // 128, m), dtype="float32") + b_scale = paddle.rand(shape=(k // 128, n // 128), dtype="float32") + out = paddle.empty(shape=(m, n), dtype=out_dtype) gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out) - measurements = bench_gpu_time( lambda: gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out) ) ms = np.median(measurements) - tflops_per_second = 2 * m * n * k * 1e-9 / ms + tflops_per_second = 2 * m * n * k * 1e-09 / ms print( f"gemm_fp8_nt_groupwise {m} {n} {k} {in_dtype} {out_dtype}: {tflops_per_second:.2f} TFLOPs/s" ) - - tl_out = torch.empty((m, n), dtype=out_dtype, device="cuda") - a_scale = a_scale.transpose(0, 1).contiguous() - b_scale = b_scale.transpose(0, 1).contiguous() + tl_out = paddle.empty(shape=(m, n), dtype=out_dtype) + a_scale = a_scale.transpose(perm=dim2perm(a_scale.ndim, 0, 1)).contiguous() + b_scale = b_scale.transpose(perm=dim2perm(b_scale.ndim, 0, 1)).contiguous() measurements = bench_gpu_time( lambda: triton_w8a8_block_fp8_matmul(a, b, a_scale, b_scale, tl_out) ) ms = np.median(measurements) - tflops_per_second = 2 * m * n * k * 1e-9 / ms + tflops_per_second = 2 * m * n * k * 1e-09 / ms print( f"triton_gemm_fp8_nt_groupwise {m} {n} {k} {in_dtype} {out_dtype}: {tflops_per_second:.2f} TFLOPs/s" ) @@ -196,5 +182,5 @@ def bench_groupwise_gemm_fp8_blackwell(m, n, k, in_dtype, out_dtype): for n in [1024, 2048, 4096, 8192]: for k in [1024, 2048, 4096, 8192]: bench_groupwise_gemm_fp8_blackwell( - m, n, k, torch.float8_e5m2, torch.bfloat16 +>>>>>> m, n, k, paddle.float8_e5m2, "bfloat16" ) diff --git a/benchmarks/bench_groupwise_grouped_gemm_fp8_blackwell.py b/benchmarks/bench_groupwise_grouped_gemm_fp8_blackwell.py index 340c41a220..c1f15c0f11 100644 --- a/benchmarks/bench_groupwise_grouped_gemm_fp8_blackwell.py +++ b/benchmarks/bench_groupwise_grouped_gemm_fp8_blackwell.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,9 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import numpy as np -import torch import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -24,22 +24,15 @@ def bench_groupwise_grouped_gemm_fp8_blackwell( batch_size, m, n, k, in_dtype, out_dtype ): - torch.random.manual_seed(0) - a = torch.randn(batch_size * m, k, device="cuda:0").to(in_dtype) - b = torch.randn(batch_size, n, k, device="cuda:0").to(in_dtype) - out = torch.empty(batch_size * m, n, device="cuda:0", dtype=out_dtype) - - a_scale = torch.randn( - (k // 128, batch_size * m), dtype=torch.float32, device="cuda:0" - ) - b_scale = torch.randn( - (batch_size, k // 128, n // 128), dtype=torch.float32, device="cuda:0" + paddle.seed(seed=0) + a = paddle.randn(shape=[batch_size * m, k]).to(in_dtype) + b = paddle.randn(shape=[batch_size, n, k]).to(in_dtype) + out = paddle.empty(shape=[batch_size * m, n], dtype=out_dtype) + a_scale = paddle.randn(shape=(k // 128, batch_size * m), dtype="float32") + b_scale = paddle.randn(shape=(batch_size, k // 128, n // 128), dtype="float32") + segment_offsets = paddle.arange( + start=0, end=(batch_size + 1) * m, step=m, dtype="int32" ) - - segment_offsets = torch.arange( - 0, (batch_size + 1) * m, m, device="cuda:0", dtype=torch.int32 - ) - measurements = bench_gpu_time( lambda: flashinfer.gemm.group_gemm_fp8_nt_groupwise( a, b, a_scale, b_scale, segment_offsets, out=out, mma_sm=2 @@ -48,7 +41,7 @@ def bench_groupwise_grouped_gemm_fp8_blackwell( repeat_time_ms=1000, ) ms = np.median(measurements) - tflops_per_second = 2 * batch_size * m * n * k * 1e-9 / ms + tflops_per_second = 2 * batch_size * m * n * k * 1e-09 / ms print( f"group_gemm_fp8_nt_groupwise batch_size={batch_size} m={m} n={n} k={k} in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s" ) @@ -60,5 +53,5 @@ def bench_groupwise_grouped_gemm_fp8_blackwell( for n in [1024, 2048, 4096, 8192]: for k in [1024, 2048, 4096, 8192]: bench_groupwise_grouped_gemm_fp8_blackwell( - batch_size, m, n, k, torch.float8_e5m2, torch.bfloat16 +>>>>>> batch_size, m, n, k, paddle.float8_e5m2, "bfloat16" ) diff --git a/benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.py b/benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.py index c274e65592..95f3e38466 100644 --- a/benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.py +++ b/benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,11 +15,9 @@ See the License for the specific language governing permissions and limitations under the License. """ - from itertools import product import numpy as np -import torch import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -26,50 +26,43 @@ def bench_groupwise_grouped_gemm_mxfp4_blackwell( group_size, m, n, k, in_dtype, out_dtype ): - torch.random.manual_seed(0) + paddle.seed(seed=0) assert n % 8 == 0 assert k % 128 == 0 tile_size = 32 alignment_sf = 128 - fp8_info = torch.finfo(in_dtype) + fp8_info = paddle.finfo(dtype=in_dtype) a = ( - torch.empty(group_size * m, k, dtype=torch.float32, device="cuda:0") - .uniform_(-fp8_info.max, fp8_info.max) + paddle.empty(shape=[group_size * m, k], dtype="float32") + .uniform_(min=-fp8_info.max, max=fp8_info.max) .to(in_dtype) ) - b = torch.randint( - 0, 256, (group_size, n, k // 2), dtype=torch.uint8, device="cuda:0" - ) - out = torch.empty(group_size * m, n, dtype=out_dtype, device="cuda:0") - - a_scale = torch.randint( - 0, - 256, - ( + b = paddle.randint(low=0, high=256, shape=(group_size, n, k // 2), dtype="uint8") + out = paddle.empty(shape=[group_size * m, n], dtype=out_dtype) + a_scale = paddle.randint( + low=0, + high=256, + shape=( (group_size * m + (alignment_sf - 1) * group_size) // alignment_sf * alignment_sf, k // tile_size, ), - dtype=torch.uint8, - device="cuda:0", + dtype="uint8", ) - b_scale = torch.randint( - 0, - 256, - ( + b_scale = paddle.randint( + low=0, + high=256, + shape=( group_size, (n + alignment_sf - 1) // alignment_sf * alignment_sf, k // tile_size, ), - dtype=torch.uint8, - device="cuda:0", + dtype="uint8", ) - - segment_offsets = torch.arange( - 0, (group_size + 1) * m, m, device="cuda:0", dtype=torch.int32 + segment_offsets = paddle.arange( + start=0, end=(group_size + 1) * m, step=m, dtype="int32" ) - ms_best = float("inf") config_best = None mma_sm_list = [1, 2] @@ -107,7 +100,7 @@ def bench_groupwise_grouped_gemm_mxfp4_blackwell( "tile_k": tile_k, "swap_ab": swap_ab, } - tflops_per_second = 2 * group_size * m * n * k * 1e-9 / ms_best + tflops_per_second = 2 * group_size * m * n * k * 1e-09 / ms_best print( f"group_gemm_mxfp4_nt_groupwise group_size={group_size} m={m} n={n} k={k} in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s" ) @@ -121,5 +114,5 @@ def bench_groupwise_grouped_gemm_mxfp4_blackwell( for n in [1024, 2048, 4096, 8192]: for k in [1024, 2048, 4096, 8192]: bench_groupwise_grouped_gemm_mxfp4_blackwell( - group_size, m, n, k, torch.float8_e4m3fn, torch.bfloat16 + group_size, m, n, k, paddle.float8_e4m3fn, "bfloat16" ) diff --git a/benchmarks/bench_hopper_attention.py b/benchmarks/bench_hopper_attention.py index 6ad2fdaa1b..3e1b3841bc 100644 --- a/benchmarks/bench_hopper_attention.py +++ b/benchmarks/bench_hopper_attention.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,9 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import numpy as np -import torch import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -23,10 +23,9 @@ def bench_single_prefill(seq_len, num_heads, causal, head_dim): num_qo_heads = num_kv_heads = num_heads - q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") - k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") - v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") - + q = paddle.randn(shape=[seq_len, num_qo_heads, head_dim], dtype="float16") + k = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") + v = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") sm80_ms, sm90_ms = ( np.median( bench_gpu_time( @@ -42,9 +41,9 @@ def bench_single_prefill(seq_len, num_heads, causal, head_dim): def flops(ms): if causal: - return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 + return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1000000000.0 else: - return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1000000000.0 print( f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s" @@ -53,44 +52,37 @@ def flops(ms): def bench_batch_ragged_prefill(batch_size, num_heads, seq_len, causal, head_dim): num_qo_heads = num_kv_heads = num_heads - q = torch.randn( - batch_size * seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + q = paddle.randn( + shape=[batch_size * seq_len, num_qo_heads, head_dim], dtype="float16" ) - k = torch.randn( - batch_size * seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + k = paddle.randn( + shape=[batch_size * seq_len, num_kv_heads, head_dim], dtype="float16" ) - v = torch.randn( - batch_size * seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + v = paddle.randn( + shape=[batch_size * seq_len, num_kv_heads, head_dim], dtype="float16" ) - sm80_wrapper, sm90_wrapper = ( flashinfer.BatchPrefillWithRaggedKVCacheWrapper( - torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0"), + paddle.empty(shape=256 * 1024 * 1024, dtype="uint8"), kv_layout="NHD", backend=backend, ) for backend in ["fa2", "fa3"] ) - - qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() - kv_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() - + qo_indptr = paddle.arange( + start=0, end=batch_size * seq_len + 1, step=seq_len + ).astype(dtype="int32") + kv_indptr = paddle.arange( + start=0, end=batch_size * seq_len + 1, step=seq_len + ).astype(dtype="int32") for wrapper in [sm80_wrapper, sm90_wrapper]: wrapper.plan( - qo_indptr, - kv_indptr, - num_qo_heads, - num_kv_heads, - head_dim, - causal=causal, + qo_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, causal=causal ) - sm80_ms, sm90_ms = ( np.median( bench_gpu_time( - lambda: wrapper.run(q, k, v), - dry_run_time_ms=100, - repeat_time_ms=1000, + lambda: wrapper.run(q, k, v), dry_run_time_ms=100, repeat_time_ms=1000 ) ) for wrapper in [sm80_wrapper, sm90_wrapper] @@ -99,11 +91,25 @@ def bench_batch_ragged_prefill(batch_size, num_heads, seq_len, causal, head_dim) def flops(ms): if causal: return ( - batch_size * seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 + batch_size + * seq_len + * seq_len + * num_qo_heads + * head_dim + * 2 + / ms + / 1000000000.0 ) else: return ( - batch_size * seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + batch_size + * seq_len + * seq_len + * num_qo_heads + * head_dim + * 4 + / ms + / 1000000000.0 ) print( @@ -115,42 +121,35 @@ def bench_batch_paged_prefill( page_size, batch_size, num_heads, seq_len, causal, head_dim ): num_qo_heads = num_kv_heads = num_heads - q = torch.randn( - batch_size * seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + q = paddle.randn( + shape=[batch_size * seq_len, num_qo_heads, head_dim], dtype="float16" ) - k = torch.randn( - batch_size * seq_len // page_size, - page_size, - num_kv_heads, - head_dim, - dtype=torch.half, - device="cuda", + k = paddle.randn( + shape=[batch_size * seq_len // page_size, page_size, num_kv_heads, head_dim], + dtype="float16", ) - v = torch.randn( - batch_size * seq_len // page_size, - page_size, - num_kv_heads, - head_dim, - dtype=torch.half, - device="cuda", + v = paddle.randn( + shape=[batch_size * seq_len // page_size, page_size, num_kv_heads, head_dim], + dtype="float16", ) - sm80_wrapper, sm90_wrapper = ( flashinfer.BatchPrefillWithPagedKVCacheWrapper( - torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0"), + paddle.empty(shape=256 * 1024 * 1024, dtype="uint8"), kv_layout="NHD", backend=backend, ) for backend in ["fa2", "fa3"] ) - - qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() - kv_indptr = torch.arange( - 0, batch_size * (seq_len // page_size) + 1, (seq_len // page_size) - ).int() - kv_indices = torch.arange(0, batch_size * (seq_len // page_size)).int() - last_page_len = torch.ones(batch_size, dtype=torch.int32) * page_size - + qo_indptr = paddle.arange( + start=0, end=batch_size * seq_len + 1, step=seq_len + ).astype(dtype="int32") + kv_indptr = paddle.arange( + start=0, end=batch_size * (seq_len // page_size) + 1, step=seq_len // page_size + ).astype(dtype="int32") + kv_indices = paddle.arange(start=0, end=batch_size * (seq_len // page_size)).astype( + dtype="int32" + ) + last_page_len = paddle.ones(shape=batch_size, dtype="int32") * page_size for wrapper in [sm80_wrapper, sm90_wrapper]: wrapper.plan( qo_indptr, @@ -160,16 +159,13 @@ def bench_batch_paged_prefill( num_qo_heads, num_kv_heads, head_dim, - page_size, # page_size + page_size, causal=causal, ) - sm80_ms, sm90_ms = ( np.median( bench_gpu_time( - lambda: wrapper.run(q, (k, v)), - dry_run_time_ms=100, - repeat_time_ms=1000, + lambda: wrapper.run(q, (k, v)), dry_run_time_ms=100, repeat_time_ms=1000 ) ) for wrapper in [sm80_wrapper, sm90_wrapper] @@ -178,11 +174,25 @@ def bench_batch_paged_prefill( def flops(ms): if causal: return ( - batch_size * seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 + batch_size + * seq_len + * seq_len + * num_qo_heads + * head_dim + * 2 + / ms + / 1000000000.0 ) else: return ( - batch_size * seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + batch_size + * seq_len + * seq_len + * num_qo_heads + * head_dim + * 4 + / ms + / 1000000000.0 ) print( @@ -191,12 +201,11 @@ def flops(ms): if __name__ == "__main__": - device_capability = torch.cuda.get_device_capability() + device_capability = paddle.device.cuda.get_device_capability() if device_capability[0] != 9: print(f"Current device capability: {device_capability}.") print("Current benchmark targets capability (9, 0). Returning...") exit() - bench_batch_paged_prefill(1, 128, 32, 1024, True, 128) bench_batch_paged_prefill(1, 64, 32, 2048, True, 128) bench_batch_paged_prefill(1, 32, 32, 4096, True, 128) diff --git a/benchmarks/bench_hopper_fp8_attention.py b/benchmarks/bench_hopper_fp8_attention.py index 34d71d7f9e..1f328285c7 100644 --- a/benchmarks/bench_hopper_fp8_attention.py +++ b/benchmarks/bench_hopper_fp8_attention.py @@ -1,5 +1,5 @@ import numpy as np -import torch +import paddle import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -7,10 +7,9 @@ def bench_single_prefill(seq_len, num_heads, causal, head_dim): num_qo_heads = num_kv_heads = num_heads - q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") - k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") - v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") - + q = paddle.randn(shape=[seq_len, num_qo_heads, head_dim], dtype="float16") + k = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") + v = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") sm80_ms, sm90_ms = ( np.median( bench_gpu_time( @@ -23,21 +22,19 @@ def bench_single_prefill(seq_len, num_heads, causal, head_dim): ) for backend in ["fa2", "fa3"] ) - - q = torch.randn( - seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" - ).to(dtype=torch.float8_e4m3fn) - k = torch.randn( - seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" - ).to(dtype=torch.float8_e4m3fn) - v = torch.randn( - seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" - ).to(dtype=torch.float8_e4m3fn) - + q = paddle.randn(shape=[seq_len, num_qo_heads, head_dim], dtype="float16").to( + dtype=paddle.float8_e4m3fn + ) + k = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16").to( + dtype=paddle.float8_e4m3fn + ) + v = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16").to( + dtype=paddle.float8_e4m3fn + ) fp8_sm90_ms = np.median( bench_gpu_time( lambda: flashinfer.single_prefill_with_kv_cache_return_lse( - q, k, v, causal=causal, backend="fa3", o_dtype=torch.half + q, k, v, causal=causal, backend="fa3", o_dtype="float16" ), dry_run_time_ms=100, repeat_time_ms=1000, @@ -46,9 +43,9 @@ def bench_single_prefill(seq_len, num_heads, causal, head_dim): def flops(ms): if causal: - return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9 + return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1000000000.0 else: - return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9 + return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1000000000.0 print( f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s, fa3-fp8: {flops(fp8_sm90_ms):.3f} TFLOPs/s" @@ -56,12 +53,11 @@ def flops(ms): if __name__ == "__main__": - device_capability = torch.cuda.get_device_capability() + device_capability = paddle.device.cuda.get_device_capability() if device_capability[0] != 9: print(f"Current device capability: {device_capability}.") print("Current benchmark targets capability (9, 0). Returning...") exit() - for seq_len in [4096, 8192, 16384]: for num_heads in [24, 32]: for causal in [True, False]: diff --git a/benchmarks/bench_hopper_grouped_gemm.py b/benchmarks/bench_hopper_grouped_gemm.py index d4f314cfd3..91db1f82e8 100644 --- a/benchmarks/bench_hopper_grouped_gemm.py +++ b/benchmarks/bench_hopper_grouped_gemm.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,9 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import numpy as np -import torch import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -25,43 +25,38 @@ def bench_grouped_gemm( batch_size, num_tokens_per_group, d_in, d_out, dtype, output_dtype ): np.random.seed(42) - W = torch.randn(batch_size, d_out, d_in, device="cuda:0").to(dtype) - X = torch.randn(batch_size * num_tokens_per_group, d_in, device="cuda:0").to(dtype) - Y = torch.empty( - batch_size * num_tokens_per_group, d_out, dtype=output_dtype, device="cuda:0" + W = paddle.randn(shape=[batch_size, d_out, d_in]).to(dtype) + X = paddle.randn(shape=[batch_size * num_tokens_per_group, d_in]).to(dtype) + Y = paddle.empty( + shape=[batch_size * num_tokens_per_group, d_out], dtype=output_dtype ) - - workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + workspace_buffer = paddle.empty(shape=32 * 1024 * 1024, dtype="int8") segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer, backend="auto") - seg_indptr = torch.arange( - 0, - (batch_size + 1) * num_tokens_per_group, - num_tokens_per_group, - dtype=torch.int64, - device="cuda:0", + seg_indptr = paddle.arange( + start=0, + end=(batch_size + 1) * num_tokens_per_group, + step=num_tokens_per_group, + dtype="int64", ) - measurements = bench_gpu_time( lambda: segment_gemm.run(X, W, batch_size, True, out=Y, seg_indptr=seg_indptr) ) ms = np.median(measurements) flops = 2 * batch_size * num_tokens_per_group * d_in * d_out - print( f"Config: batch_size={batch_size}, num_tokens_per_group={num_tokens_per_group}, d_in={d_in}, d_out={d_out}, dtype={dtype}, output_dtype={output_dtype}" ) - print(f"FLOPs: {flops / ms * 1e-9:.2f} TFLOPs/s") + print(f"FLOPs: {flops / ms * 1e-09:.2f} TFLOPs/s") if __name__ == "__main__": - device_capability = torch.cuda.get_device_capability() + device_capability = paddle.device.cuda.get_device_capability() if device_capability[0] != 9: print(f"Current device capability: {device_capability}.") print("Current benchmark targets capability (9, 0). Returning...") exit() - - for dtype_in in [torch.float8_e4m3fn, torch.bfloat16]: - for dtype_out in [torch.bfloat16]: + for dtype_in in [paddle.float8_e4m3fn, "bfloat16"]: + for dtype_out in ["bfloat16"]: for batch_size in [1, 3, 8, 16]: for num_tokens_per_group in [32, 64, 128, 256, 512]: for d_in in [4096, 8192]: diff --git a/benchmarks/bench_mixed_attention.py b/benchmarks/bench_mixed_attention.py index 85753a71f9..3797533203 100644 --- a/benchmarks/bench_mixed_attention.py +++ b/benchmarks/bench_mixed_attention.py @@ -1,5 +1,5 @@ import numpy as np -import torch +import paddle import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -10,82 +10,75 @@ def run_bench( p_kv_lens, d_qo_lens, d_kv_lens, - # page_block_size=1, num_kv_heads=4, num_qo_heads=28, head_dim=128, device=0, causal=True, ): - # POD Attention only supports page size = 1 due to use of single prefill kernel page_block_size = 1 - seq_lens = torch.tensor(d_kv_lens + p_kv_lens, dtype=torch.int32) - q_lens = torch.tensor(d_qo_lens + p_qo_lens, dtype=torch.int32) - - seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int() + seq_lens = paddle.to_tensor(data=d_kv_lens + p_kv_lens, dtype="int32") + q_lens = paddle.to_tensor(data=d_qo_lens + p_qo_lens, dtype="int32") + seq_lens_blocks = paddle.ceil(x=seq_lens / page_block_size).astype(dtype="int32") d_seq_lens_blocks = ( - torch.tensor(d_kv_lens, dtype=torch.int32) / page_block_size - ).int() - - q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int() - kv_indptr = torch.cat( - [torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0 - ).int() - d_q_indptr = torch.cat( - [torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0 - ).int() - d_kv_indptr = torch.cat( - [torch.tensor([0]), torch.cumsum(d_seq_lens_blocks, 0)], dim=0 - ).int() + paddle.to_tensor(data=d_kv_lens, dtype="int32") / page_block_size + ).astype(dtype="int32") + q_indptr = paddle.concat( + x=[paddle.to_tensor(data=[0]), paddle.cumsum(x=q_lens, axis=0)], axis=0 + ).astype(dtype="int32") + kv_indptr = paddle.concat( + x=[paddle.to_tensor(data=[0]), paddle.cumsum(x=seq_lens_blocks, axis=0)], axis=0 + ).astype(dtype="int32") + d_q_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0]), + paddle.cumsum(x=paddle.to_tensor(data=d_qo_lens), axis=0), + ], + axis=0, + ).astype(dtype="int32") + d_kv_indptr = paddle.concat( + x=[paddle.to_tensor(data=[0]), paddle.cumsum(x=d_seq_lens_blocks, axis=0)], + axis=0, + ).astype(dtype="int32") num_blocks = kv_indptr[-1].item() - - q = torch.rand(q_indptr[-1].item(), num_qo_heads, head_dim).to( - device, dtype=torch.bfloat16 + q = paddle.rand(shape=[q_indptr[-1].item(), num_qo_heads, head_dim]).to( + device, dtype="bfloat16" ) - kv_data = torch.randn(num_blocks, 2, page_block_size, num_kv_heads, head_dim).to( - device, dtype=torch.bfloat16 - ) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + kv_data = paddle.randn( + shape=[num_blocks, 2, page_block_size, num_kv_heads, head_dim] + ).to(device, dtype="bfloat16") + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="uint8") kv_layout = "NHD" - wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout=kv_layout, - backend="fa2", + workspace_buffer, kv_layout=kv_layout, backend="fa2" ) last_page_len = (seq_lens - 1) % page_block_size + 1 wrapper_old.plan( q_indptr.to(device), kv_indptr.to(device), - torch.arange(num_blocks).int().to(device), + paddle.arange(end=num_blocks).astype(dtype="int32").to(device), last_page_len, num_qo_heads, num_kv_heads, head_dim, page_block_size, causal=causal, - q_data_type=torch.bfloat16, - kv_data_type=torch.bfloat16, + q_data_type="bfloat16", + kv_data_type="bfloat16", ) o = wrapper_old.run(q, kv_data) measurements = bench_gpu_time(lambda: wrapper_old.run(q, kv_data)) ms_old = np.median(measurements) - if len(p_kv_lens) == 1: q_d = q[: d_q_indptr[-1]] - kv_d = kv_data[: d_kv_indptr[-1]].unbind(1) + kv_d = kv_data[: d_kv_indptr[-1]].unbind(axis=1) q_p = q[d_q_indptr[-1] :] - k_p, v_p = kv_data[d_kv_indptr[-1] :].unbind(1) - k_p, v_p = k_p.squeeze(1), v_p.squeeze(1) - kv_indices_d = torch.arange( - 0, d_kv_indptr[-1], device=device, dtype=torch.int32 - ) - + k_p, v_p = kv_data[d_kv_indptr[-1] :].unbind(axis=1) + k_p, v_p = k_p.squeeze(axis=1), v_p.squeeze(axis=1) + kv_indices_d = paddle.arange(start=0, end=d_kv_indptr[-1], dtype="int32") last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1 wrapper_pod = flashinfer.PODWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout=kv_layout, + workspace_buffer, kv_layout=kv_layout ) wrapper_pod.plan( d_kv_indptr.to(device), @@ -95,66 +88,43 @@ def run_bench( num_kv_heads=num_kv_heads, head_dim=head_dim, page_size=page_block_size, - q_data_type=torch.bfloat16, - kv_data_type=torch.bfloat16, - ) - o_p, o_d = wrapper_pod.run( - q_p, - k_p, - v_p, - q_d, - kv_data, - causal_p=causal, - ) - o_pod = torch.cat([o_d, o_p], dim=0) - # Verify output matches - torch.testing.assert_close( - o, o_pod, rtol=1e-3, atol=1e-3, msg="POD-Attention output mismatch!" + q_data_type="bfloat16", + kv_data_type="bfloat16", ) + o_p, o_d = wrapper_pod.run(q_p, k_p, v_p, q_d, kv_data, causal_p=causal) + o_pod = paddle.concat(x=[o_d, o_p], axis=0) + assert paddle.allclose( + x=o, y=o_pod, rtol=0.001, atol=0.001 + ).item(), "POD-Attention output mismatch!" measurements = bench_gpu_time( lambda: wrapper_pod.run( - q_p, - k_p, - v_p, - q_d, - kv_d, - causal_p=causal, - causal_d=causal, + q_p, k_p, v_p, q_d, kv_d, causal_p=causal, causal_d=causal ) ) ms_pod = np.median(measurements) print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms") if len(p_kv_lens) == 1: print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms") - total_bytes = ( - q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size() - ) - print(f"Loading memory size (MB): {total_bytes / (1024**2):.2f} MB") - - bandwidth_old_gb_s = total_bytes / (ms_old * 1e-3) / (1024**3) - + total_bytes = q.size * q.element_size() + kv_data.size * kv_data.element_size() + print(f"Loading memory size (MB): {total_bytes / 1024 ** 2:.2f} MB") + bandwidth_old_gb_s = total_bytes / (ms_old * 0.001) / 1024**3 print(f"Memory bandwidth (Batched Prefill): {bandwidth_old_gb_s:.2f} GB/s") if len(p_kv_lens) == 1: - bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3) + bandwidth_pod_gb_s = total_bytes / (ms_pod * 0.001) / 1024**3 print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s") if __name__ == "__main__": np.random.seed(42) - torch.random.manual_seed(42) - - # Irregular sequence lengths for prefill and decode + paddle.seed(seed=42) d_q_len_configs = [[1] * 122, [1] * 128, [1] * 242, [1] * 256] d_kv_len_configs = [[600] * 122, [10000] * 128, [400] * 242, [8192] * 256] p_q_configs = [[17] * 1, [10000], [17] * 1, []] p_kv_configs = [[10000] * 1, [10000], [8192] * 1, []] - - # construct random length testcases for _ in range(1): bsz = 256 stride = 16 sparsity = 0.05 - full_kv_len = np.random.randint(1000, 8192, size=bsz) p_q_lens = [] p_kv_lens = [] @@ -171,23 +141,19 @@ def run_bench( qo_len = 1 d_q_lens.append(qo_len) d_kv_lens.append(kv_len) - p_q_configs.append(p_q_lens) p_kv_configs.append(p_kv_lens) d_q_len_configs.append(d_q_lens) d_kv_len_configs.append(d_kv_lens) - for _ in range(1): bsz = 128 stride = 16 sparsity = 0.05 - full_kv_len = np.random.randint(2000, 16000, size=bsz) p_q_lens = [] p_kv_lens = [] d_q_lens = [] d_kv_lens = [] - for i in range(bsz): if i % stride == 0: kv_len = full_kv_len[i] @@ -199,17 +165,14 @@ def run_bench( qo_len = 1 d_q_lens.append(qo_len) d_kv_lens.append(kv_len) - p_q_configs.append(p_q_lens) p_kv_configs.append(p_kv_lens) d_q_len_configs.append(d_q_lens) d_kv_len_configs.append(d_kv_lens) - page_block_size = 1 num_kv_heads = 4 num_qo_heads = 28 head_dim = 128 - for idx, (p_q_lens, p_kv_lens, d_q_len, d_kv_len) in enumerate( zip(p_q_configs, p_kv_configs, d_q_len_configs, d_kv_len_configs) ): @@ -219,7 +182,6 @@ def run_bench( p_kv_lens, d_q_len, d_kv_len, - # page_block_size=page_block_size, num_kv_heads=num_kv_heads, num_qo_heads=num_qo_heads, head_dim=head_dim, diff --git a/benchmarks/bench_persistent_gemm.py b/benchmarks/bench_persistent_gemm.py index 98a5fb8ccf..f92d8e287a 100644 --- a/benchmarks/bench_persistent_gemm.py +++ b/benchmarks/bench_persistent_gemm.py @@ -1,5 +1,5 @@ import numpy as np -import torch +import paddle import triton import flashinfer @@ -12,14 +12,14 @@ def is_cuda(): def supports_tma(): - return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + return is_cuda() and paddle.device.cuda.get_device_capability()[0] >= 9 def bench_gemm_persistent(num_sms, dtype, M, N, K, reps=1000, warmup_reps=10000): measurements = bench_gpu_time( lambda: flashinfer.triton.sm_constraint_gemm.gemm_persistent( - a=torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype), - b=torch.randn((N, K), device="cuda", dtype=torch.float16).to(dtype), + a=paddle.randn(shape=(M, K), dtype="float16").to(dtype), + b=paddle.randn(shape=(N, K), dtype="float16").to(dtype), alpha=1.0, beta=0.0, num_sms=num_sms, @@ -28,10 +28,7 @@ def bench_gemm_persistent(num_sms, dtype, M, N, K, reps=1000, warmup_reps=10000) repeat_time_ms=reps, ) ms = np.median(measurements) - - # matmul: 2 * M * N * K - # scale and add: 3 * M * N - flops = (2 * M * N * K + 3 * M * N) / ms / 1e9 + flops = (2 * M * N * K + 3 * M * N) / ms / 1000000000.0 print( f"GEMM Persistent | num_sms: {num_sms}, M: {M}, N: {N}, K: {K}, {dtype}: {flops:.3f} TFLOPs/s" ) @@ -40,12 +37,12 @@ def bench_gemm_persistent(num_sms, dtype, M, N, K, reps=1000, warmup_reps=10000) def bench_gemm_descriptor_persistent( num_sms, dtype, M, N, K, reps=1000, warmup_reps=10000 ): - if dtype == torch.float32: + if dtype == "float32": return measurements = bench_gpu_time( lambda: flashinfer.triton.sm_constraint_gemm.gemm_descriptor_persistent( - a=torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype), - b=torch.randn((N, K), device="cuda", dtype=torch.float16).to(dtype), + a=paddle.randn(shape=(M, K), dtype="float16").to(dtype), + b=paddle.randn(shape=(N, K), dtype="float16").to(dtype), alpha=1.0, beta=0.0, num_sms=num_sms, @@ -54,10 +51,7 @@ def bench_gemm_descriptor_persistent( repeat_time_ms=reps, ) ms = np.median(measurements) - - # matmul: 2 * M * N * K - # scale and add: 3 * M * N - flops = (2 * M * N * K + 3 * M * N) / ms / 1e9 + flops = (2 * M * N * K + 3 * M * N) / ms / 1000000000.0 print( f"GEMM Descriptor | num_sms: {num_sms}, M: {M}, N: {N}, K: {K}, {dtype}: {flops:.3f} TFLOPs/s" ) @@ -65,14 +59,8 @@ def bench_gemm_descriptor_persistent( if __name__ == "__main__": assert supports_tma() - for M, N, K in [(4096, 4096, 4096), (8192, 8192, 8192)]: - for dtype in [ - torch.float8_e4m3fn, - torch.float16, - torch.bfloat16, - torch.float32, - ]: + for dtype in [paddle.float8_e4m3fn, "float16", "bfloat16", "float32"]: for num_sms in [1, 16, 32, 64, 128, 132, 133, 256]: bench_gemm_persistent(num_sms, dtype, M, N, K) bench_gemm_descriptor_persistent(num_sms, dtype, M, N, K) diff --git a/benchmarks/bench_renorm.py b/benchmarks/bench_renorm.py index aca54c3318..a760374032 100644 --- a/benchmarks/bench_renorm.py +++ b/benchmarks/bench_renorm.py @@ -1,5 +1,5 @@ import numpy as np -import torch +import paddle import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -7,7 +7,7 @@ def normal_distribution(std): def normal_noise(shape, device): - return torch.randn(shape, device=device) * std + return paddle.randn(shape=shape) * std normal_noise.__name__ = f"normal_distribution(std={std})" return normal_noise @@ -15,17 +15,17 @@ def normal_noise(shape, device): def gumbel_distribution(beta): def gumbel_noise(shape, device): - U = torch.rand(shape, device=device) + U = paddle.rand(shape=shape) eps = 1e-20 - return torch.log(-torch.log(U + eps) + eps) / beta + return paddle.log(x=-paddle.log(x=U + eps) + eps) / beta gumbel_noise.__name__ = f"gumbel_distribution(beta={beta})" return gumbel_noise -@torch.inference_mode() +@paddle.no_grad() def main(): - torch.manual_seed(42) + paddle.seed(seed=42) print("---") print("top-p renorm") for vocab_size in [128512]: @@ -38,20 +38,18 @@ def main(): ]: for p in [0.1, 0.5, 0.9, 1.0]: logits = distrib((batch_size, vocab_size), device="cuda") - probs = torch.softmax(logits, dim=-1) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) measurements = bench_gpu_time( lambda: flashinfer.sampling.top_p_renorm_probs(probs, p), dry_run_time_ms=100, repeat_time_ms=1000, ) ms = np.median(measurements) - - io = (probs.numel() * probs.element_size()) * 2 - bandwidth = io * 1e-6 / ms + io = probs.size * probs.element_size() * 2 + bandwidth = io * 1e-06 / ms print( - f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1000.0:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" ) - print("---") print("top-k renorm") for vocab_size in [128512]: @@ -64,20 +62,18 @@ def main(): ]: for k in [10, 100, 1000, 5000]: logits = distrib((batch_size, vocab_size), device="cuda") - probs = torch.softmax(logits, dim=-1) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) measurements = bench_gpu_time( lambda: flashinfer.sampling.top_k_renorm_probs(probs, k), dry_run_time_ms=100, repeat_time_ms=1000, ) ms = np.median(measurements) - - io = (probs.numel() * probs.element_size()) * 2 - bandwidth = io * 1e-6 / ms + io = probs.size * probs.element_size() * 2 + bandwidth = io * 1e-06 / ms print( - f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1000.0:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" ) - print("---") print("top-k mask logits") for vocab_size in [128512]: @@ -96,11 +92,10 @@ def main(): repeat_time_ms=1000, ) ms = np.median(measurements) - - io = (logits.numel() * logits.element_size()) * 2 - bandwidth = io * 1e-6 / ms + io = logits.size * logits.element_size() * 2 + bandwidth = io * 1e-06 / ms print( - f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1000.0:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" ) diff --git a/benchmarks/bench_rope.py b/benchmarks/bench_rope.py index 235788423a..b4ea69e277 100644 --- a/benchmarks/bench_rope.py +++ b/benchmarks/bench_rope.py @@ -1,3 +1,5 @@ +import paddle + """ Benchmark RoPE for flashinfer and vLLM. vLLM installation is required to run this benchmark. @@ -5,22 +7,18 @@ $ pip install vllm $ python bench_rope.py """ - from typing import Optional, Tuple, Union import numpy as np -import torch -import torch.nn as nn import triton -from vllm.model_executor.layers.rotary_embedding import ( - RotaryEmbedding as vLLMRotaryEmbedding, -) +from vllm.model_executor.layers.rotary_embedding import \ + RotaryEmbedding as vLLMRotaryEmbedding from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace from flashinfer.testing.utils import bench_gpu_time -class FlashInferRotaryEmbedding(nn.Module): +class FlashInferRotaryEmbedding(paddle.nn.Layer): def __init__( self, head_size: int, @@ -28,7 +26,7 @@ def __init__( max_position_embeddings: int, base: int, is_neox_style: bool, - dtype: torch.dtype, + dtype: paddle.dtype, ) -> None: super().__init__() self.head_size = head_size @@ -37,38 +35,34 @@ def __init__( self.base = base self.is_neox_style = is_neox_style self.dtype = dtype - cache = self._compute_cos_sin_cache() - self.cos_sin_cache: torch.Tensor - self.register_buffer("cos_sin_cache", cache, persistent=False) + self.cos_sin_cache: paddle.Tensor + self.register_buffer(name="cos_sin_cache", tensor=cache, persistable=False) - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: - inv_freq = 1.0 / ( - base - ** ( - torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim - ) + def _compute_inv_freq(self, base: Union[int, float]) -> paddle.Tensor: + inv_freq = 1.0 / base ** ( + paddle.arange(start=0, end=self.rotary_dim, step=2, dtype="float32") + / self.rotary_dim ) return inv_freq - def _compute_cos_sin_cache(self) -> torch.Tensor: + def _compute_cos_sin_cache(self) -> paddle.Tensor: """Compute the cos and sin cache.""" inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) - - freqs = torch.einsum("i,j -> ij", t, inv_freq) + t = paddle.arange(dtype="float32", end=self.max_position_embeddings) + freqs = paddle.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) + cache = paddle.concat(x=(cos, sin), axis=-1) return cache def _apply_rotary_emb( self, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, + x: paddle.Tensor, + cos: paddle.Tensor, + sin: paddle.Tensor, is_neox_style: bool, - ) -> torch.Tensor: + ) -> paddle.Tensor: """ Args: x: [num_tokens, num_heads, head_size] @@ -77,27 +71,27 @@ def _apply_rotary_emb( is_neox_style: Whether to use the Neox-style or GPT-J-style rotary positional embeddings. """ - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) + cos = cos.unsqueeze(axis=-2).to(x.dtype) + sin = sin.unsqueeze(axis=-2).to(x.dtype) if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) + x1, x2 = paddle.chunk(x=x, chunks=2, axis=-1) else: x1 = x[..., ::2] x2 = x[..., 1::2] o1 = x1 * cos - x2 * sin o2 = x2 * cos + x1 * sin if is_neox_style: - return torch.cat((o1, o2), dim=-1) + return paddle.concat(x=(o1, o2), axis=-1) else: - return torch.stack((o1, o2), dim=-1).flatten(-2) + return paddle.stack(x=(o1, o2), axis=-1).flatten(start_axis=-2) def forward_cuda( self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + positions: paddle.Tensor, + query: paddle.Tensor, + key: paddle.Tensor, + offsets: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: apply_rope_with_cos_sin_cache_inplace( positions=positions, query=query, @@ -142,7 +136,7 @@ def forward_cuda( "max_position_embeddings": 65536, "base": 500000, "is_neox_style": True, - "dtype": torch.bfloat16, + "dtype": "bfloat16", "device": "cuda", "batch_size": 2, "num_q_heads": 32, @@ -167,9 +161,7 @@ def benchmark( print( f"provider: {provider}, head_size: {head_size}, rotary_dim: {rotary_dim}, max_position_embeddings: {max_position_embeddings}, base: {base}, is_neox_style: {is_neox_style}, dtype: {dtype}, device: {device}, batch_size: {batch_size}, seq_len: {seq_len}, num_q_heads: {num_q_heads}, num_kv_heads: {num_kv_heads}" ) - rope_forward = None - if provider == "vllm": rope = vLLMRotaryEmbedding( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype @@ -185,22 +177,17 @@ def benchmark( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ).to(device) rope_forward = rope.forward_native - - pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) - query = torch.randn( - batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device + pos_ids = paddle.arange(end=seq_len).tile(repeat_times=batch_size) + query = paddle.randn( + shape=[batch_size * seq_len, num_q_heads * head_size], dtype=dtype ) - key = torch.randn( - batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device + key = paddle.randn( + shape=[batch_size * seq_len, num_kv_heads * head_size], dtype=dtype ) - - # Get raw measurements measurements = bench_gpu_time(lambda: rope_forward(pos_ids, query, key)) - # Calculate statistics to match original return values ms = np.median(measurements) min_ms = np.percentile(measurements, 20) max_ms = np.percentile(measurements, 80) - return ms, min_ms, max_ms diff --git a/benchmarks/bench_sampling.py b/benchmarks/bench_sampling.py index 2eb2de3875..77a4b021a6 100644 --- a/benchmarks/bench_sampling.py +++ b/benchmarks/bench_sampling.py @@ -1,5 +1,5 @@ import numpy as np -import torch +import paddle import flashinfer from flashinfer.testing.utils import bench_gpu_time @@ -7,7 +7,7 @@ def normal_distribution(std): def normal_noise(shape, device): - return torch.randn(shape, device=device) * std + return paddle.randn(shape=shape) * std normal_noise.__name__ = f"normal_distribution(std={std})" return normal_noise @@ -15,42 +15,42 @@ def normal_noise(shape, device): def gumbel_distribution(beta): def gumbel_noise(shape, device): - U = torch.rand(shape, device=device) + U = paddle.rand(shape=shape) eps = 1e-20 - return torch.log(-torch.log(U + eps) + eps) / beta + return paddle.log(x=-paddle.log(x=U + eps) + eps) / beta gumbel_noise.__name__ = f"gumbel_distribution(beta={beta})" return gumbel_noise def init_seed_sampling(*args, **kwargs): - torch.manual_seed(42) + paddle.seed(seed=42) return flashinfer.sampling.sampling_from_probs(*args, **kwargs) def init_seed_sampling_from_logits(*args, **kwargs): - torch.manual_seed(42) + paddle.seed(seed=42) return flashinfer.sampling.sampling_from_logits(*args, **kwargs) def init_seed_sampling_from_softmax_logits(logits, *args, **kwargs): - torch.manual_seed(42) + paddle.seed(seed=42) return flashinfer.sampling.sampling_from_probs( - torch.softmax(logits, dim=-1), *args, **kwargs + paddle.nn.functional.softmax(x=logits, axis=-1), *args, **kwargs ) def init_seed_top_k_sampling(*args, **kwargs): - torch.manual_seed(42) + paddle.seed(seed=42) return flashinfer.sampling.top_k_sampling_from_probs(*args, **kwargs) def init_seed_top_p_sampling(*args, **kwargs): - torch.manual_seed(42) + paddle.seed(seed=42) return flashinfer.sampling.top_p_sampling_from_probs(*args, **kwargs) -@torch.inference_mode() +@paddle.no_grad() def main(): print("---") print("naive sampling") @@ -64,26 +64,22 @@ def main(): ]: for deterministic in [True, False]: logits = distrib((batch_size, vocab_size), device="cuda") - probs = torch.softmax(logits, dim=-1) - samples = torch.zeros( - batch_size, dtype=torch.int32, device=probs.device - ) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) + samples = paddle.zeros(shape=batch_size, dtype="int32") measurements = bench_gpu_time( lambda: init_seed_sampling(probs, deterministic=deterministic), dry_run_time_ms=100, repeat_time_ms=1000, ) ms = np.median(measurements) - io = ( - probs.numel() * probs.element_size() - + samples.numel() * samples.element_size() + probs.size * probs.element_size() + + samples.size * samples.element_size() ) - bandwidth = io * 1e-6 / ms + bandwidth = io * 1e-06 / ms print( - f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1000.0:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" ) - print("---") print("top-k sampling") for vocab_size in [128512]: @@ -97,10 +93,8 @@ def main(): for deterministic in [True, False]: for k in [10, 100, 1000, 5000]: logits = distrib((batch_size, vocab_size), device="cuda") - probs = torch.softmax(logits, dim=-1) - samples = torch.zeros( - batch_size, dtype=torch.int32, device=probs.device - ) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) + samples = paddle.zeros(shape=batch_size, dtype="int32") measurements = bench_gpu_time( lambda: init_seed_top_k_sampling( probs, k, deterministic=deterministic @@ -109,19 +103,16 @@ def main(): repeat_time_ms=1000, ) ms = np.median(measurements) - io = ( - probs.numel() * probs.element_size() - + samples.numel() * samples.element_size() + probs.size * probs.element_size() + + samples.size * samples.element_size() ) - bandwidth = io * 1e-6 / ms + bandwidth = io * 1e-06 / ms print( - f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms * 1000.0:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" ) - print("---") print("top-p sampling") - for vocab_size in [128512]: for batch_size in [1, 16, 32, 64, 128, 256, 512]: for distrib in [ @@ -133,10 +124,8 @@ def main(): for deterministic in [True, False]: for p in [0.1, 0.5, 0.9]: logits = distrib((batch_size, vocab_size), device="cuda") - probs = torch.softmax(logits, dim=-1) - samples = torch.zeros( - batch_size, dtype=torch.int32, device=probs.device - ) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) + samples = paddle.zeros(shape=batch_size, dtype="int32") measurements = bench_gpu_time( lambda: init_seed_top_p_sampling( probs, p, deterministic=deterministic @@ -145,16 +134,14 @@ def main(): repeat_time_ms=1000, ) ms = np.median(measurements) - io = ( - probs.numel() * probs.element_size() - + samples.numel() * samples.element_size() + probs.size * probs.element_size() + + samples.size * samples.element_size() ) - bandwidth = io * 1e-6 / ms + bandwidth = io * 1e-06 / ms print( - f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, p: {p}, duration: {ms * 1000.0:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" ) - print("---") print("sampling from softmax(logits)") for vocab_size in [128512]: @@ -167,9 +154,7 @@ def main(): ]: for deterministic in [True, False]: logits = distrib((batch_size, vocab_size), device="cuda") - samples = torch.zeros( - batch_size, dtype=torch.int32, device=logits.device - ) + samples = paddle.zeros(shape=batch_size, dtype="int32") measurements = bench_gpu_time( lambda: init_seed_sampling_from_softmax_logits( logits, samples, deterministic=deterministic @@ -179,14 +164,13 @@ def main(): ) ms = np.median(measurements) io = ( - logits.numel() * logits.element_size() - + samples.numel() * samples.element_size() + logits.size * logits.element_size() + + samples.size * samples.element_size() ) - bandwidth = io * 1e-6 / ms + bandwidth = io * 1e-06 / ms print( - f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1000.0:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" ) - print("---") print("sampling from logits") for vocab_size in [128512]: @@ -199,9 +183,7 @@ def main(): ]: for deterministic in [True, False]: logits = distrib((batch_size, vocab_size), device="cuda") - samples = torch.zeros( - batch_size, dtype=torch.int32, device=logits.device - ) + samples = paddle.zeros(shape=batch_size, dtype="int32") measurements = bench_gpu_time( lambda: init_seed_sampling_from_logits( logits, samples, deterministic=deterministic @@ -210,14 +192,13 @@ def main(): repeat_time_ms=1000, ) ms = np.median(measurements) - io = ( - logits.numel() * logits.element_size() - + samples.numel() * samples.element_size() + logits.size * logits.element_size() + + samples.size * samples.element_size() ) - bandwidth = io * 1e-6 / ms + bandwidth = io * 1e-06 / ms print( - f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1000.0:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" ) diff --git a/benchmarks/bench_trtllm_fmha.py b/benchmarks/bench_trtllm_fmha.py index 1615503b6e..4634b0c62d 100644 --- a/benchmarks/bench_trtllm_fmha.py +++ b/benchmarks/bench_trtllm_fmha.py @@ -1,35 +1,31 @@ import numpy as np -import torch +import paddle import flashinfer -from flashinfer.testing.utils import bench_gpu_time, bench_gpu_time_with_cudagraph +from flashinfer.testing.utils import (bench_gpu_time, + bench_gpu_time_with_cudagraph) page_size = 16 num_kv_heads = 4 num_qo_heads = 32 head_dim = 128 - -workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") +workspace_buffer = paddle.empty(shape=1024 * 1024 * 1024, dtype="uint8") def bench_trtllm_fmha(batch_size, seq_len, kv_cache_dtype): - torch.manual_seed(42) - seq_lens = torch.full((batch_size,), seq_len, device="cuda:0", dtype=torch.int32) - seq_lens_blocks = torch.ceil(seq_lens / page_size).int() - kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int, device="cuda:0") - kv_indptr[1:] = torch.cumsum(seq_lens_blocks, dim=0) - last_page_len = (seq_lens - (seq_lens_blocks - 1) * page_size).int() + paddle.seed(seed=42) + seq_lens = paddle.full(shape=(batch_size,), fill_value=seq_len, dtype="int32") + seq_lens_blocks = paddle.ceil(x=seq_lens / page_size).astype(dtype="int32") + kv_indptr = paddle.zeros(shape=batch_size + 1, dtype="int32") + kv_indptr[1:] = paddle.cumsum(x=seq_lens_blocks, axis=0) + last_page_len = (seq_lens - (seq_lens_blocks - 1) * page_size).astype(dtype="int32") last_page_len[last_page_len == 0] = page_size num_blocks = kv_indptr[-1].item() - kv_indices = torch.arange(num_blocks, dtype=torch.int32, device="cuda:0") - - q = torch.rand(batch_size, num_qo_heads, head_dim, device="cuda:0").to( - torch.bfloat16 + kv_indices = paddle.arange(dtype="int32", end=num_blocks) + q = paddle.rand(shape=[batch_size, num_qo_heads, head_dim]).to("bfloat16") + kv_data = paddle.randn(shape=[num_blocks, 2, num_kv_heads, page_size, head_dim]).to( + paddle.float8_e4m3fn if kv_cache_dtype == "fp8" else "float16" ) - kv_data = torch.randn( - num_blocks, 2, num_kv_heads, page_size, head_dim, device="cuda:0" - ).to(torch.float8_e4m3fn if kv_cache_dtype == "fp8" else torch.float16) - wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, "HND", backend="trtllm-gen" ) @@ -45,13 +41,11 @@ def bench_trtllm_fmha(batch_size, seq_len, kv_cache_dtype): q_data_type=q.dtype, kv_data_type=kv_data.dtype, ) - # add one warmup here wrapper.run(q, kv_data) - torch.cuda.synchronize() - + paddle.device.synchronize() measurements = bench_gpu_time(lambda: wrapper.run(q, kv_data)) ms = np.median(measurements) - io = q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size() + io = q.size * q.element_size() + kv_data.size * kv_data.element_size() print( f"batch_size={batch_size}, seq_len={seq_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, page_size={page_size}" ) @@ -59,13 +53,18 @@ def bench_trtllm_fmha(batch_size, seq_len, kv_cache_dtype): print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s") -def to_float8(x, dtype=torch.float8_e4m3fn): - finfo = torch.finfo(dtype) - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) +def to_float8(x, dtype=paddle.float8_e4m3fn): + finfo = paddle.finfo(dtype=dtype) + min_val, max_val = tuple( + [ + paddle.amin(x, axis=None, keepdim=False), + paddle.max(x, axis=None, keepdim=False), + ] + ) + amax = paddle.maximum(x=min_val.abs(), y=max_val.abs()).clip(min=1e-12) scale = finfo.max / amax * 0.1 - x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) - return x_scl_sat.to(dtype), scale.float().reciprocal() + x_scl_sat = (x * scale).clip(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.astype(dtype="float32").reciprocal() def bench_trtllm_fmha_wrapper( @@ -81,68 +80,41 @@ def bench_trtllm_fmha_wrapper( window_left, bench_with_sink, ): - torch.manual_seed(42) + paddle.seed(seed=42) device = "cuda:0" num_qo_heads = num_kv_heads * head_grp_size batch_size = batch_size - - # Initialize tensors num_tokens = max_seq_len * batch_size num_blocks = (num_tokens + page_size - 1) // page_size - - dtype_map = { - "half": torch.float16, - "bf16": torch.bfloat16, - "fp8": torch.float8_e4m3fn, - } - - q = torch.randn(batch_size, num_qo_heads, head_dim, device=device).to( - dtype_map[q_dtype] - ) - - # Sequence lengths and block tables - seq_lens = torch.full((batch_size,), max_seq_len) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) - blocks_per_seq = [(seq_len + page_size - 1) // page_size for seq_len in seq_lens] - - # Generate random but unique block IDs for all sequences + dtype_map = {"half": "float16", "bf16": "bfloat16", "fp8": paddle.float8_e4m3fn} + q = paddle.randn(shape=[batch_size, num_qo_heads, head_dim]).to(dtype_map[q_dtype]) + seq_lens = paddle.full(shape=(batch_size,), fill_value=max_seq_len) + seq_lens_tensor = paddle.to_tensor(data=seq_lens, dtype="int32", place=device) + blocks_per_seq = [((seq_len + page_size - 1) // page_size) for seq_len in seq_lens] total_blocks_needed = sum(blocks_per_seq) - all_block_ids = torch.randperm( - total_blocks_needed, device=device - ) # Random permutation - - kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape).to(q.dtype) - + all_block_ids = paddle.randperm(n=total_blocks_needed) + kv_cache_shape = num_blocks, 2, num_kv_heads, page_size, head_dim + kv_cache = paddle.randn(shape=kv_cache_shape).to(q.dtype) if kv_cache_dtype.startswith("fp8") and q_dtype != "fp8": kv_cache, _ = to_float8(kv_cache) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size - sinks = ( - torch.randn(num_qo_heads, device=device, dtype=torch.float32) - if bench_with_sink - else None + paddle.randn(shape=num_qo_heads, dtype="float32") if bench_with_sink else None ) - - # Compute kv_indptr as cumulative sum of blocks per sequence kv_indptr = ( - torch.cat( - [torch.tensor([0], device=device), torch.cumsum(blocks_per_seq, dim=0)] + paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=blocks_per_seq, axis=0), + ] ) - .int() + .astype(dtype="int32") .to(device) ) - - kv_indices = all_block_ids.int() - - # Calculate last page lengths + kv_indices = all_block_ids.astype(dtype="int32") kv_last_page_len = seq_lens_tensor % page_size kv_last_page_len[kv_last_page_len == 0] = page_size - - # trtllm-gen wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, "HND", backend="trtllm-gen" ) @@ -159,18 +131,15 @@ def bench_trtllm_fmha_wrapper( q_data_type=q.dtype, window_left=window_left, ) - - # add one warmup here wrapper.run(q, kv_cache, sinks=sinks) - torch.cuda.synchronize() - + paddle.device.synchronize() measurements = bench_gpu_time_with_cudagraph( lambda: wrapper.run(q, kv_cache, sinks=sinks), dry_run_time_ms=100, repeat_time_ms=1000, ) ms = np.median(measurements) - io = q.numel() * q.element_size() + kv_cache.numel() * kv_cache.element_size() + io = q.size * q.element_size() + kv_cache.size * kv_cache.element_size() print( f"batch_size={batch_size}, seq_len={max_seq_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, page_size={page_size}" ) @@ -212,9 +181,7 @@ def bench_trtllm_fmha_wrapper( default=[1024, 4096, 8192, 16384], help="List of sequence lengths to test", ) - args = parser.parse_args() - for batch_size in args.batch_sizes: for seq_len in args.seq_lens: bench_trtllm_fmha_wrapper( diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 01ba8e1b91..c3f4d19c90 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -1,38 +1,26 @@ +import sys + + import argparse -from typing import Optional, Literal -import torch +from typing import Literal, Optional + import numpy as np -from flashinfer import ( - RoutingMethodType, - GatedActType, - fp4_quantize, - mxfp8_quantize, - next_positive_power_of_2, -) -from flashinfer.fused_moe import trtllm_fp4_block_scale_moe +import paddle +from flashinfer.paddle_utils import * + +from flashinfer import (GatedActType, RoutingMethodType, fp4_quantize, + mxfp8_quantize, next_positive_power_of_2) from flashinfer.autotuner import autotune +from flashinfer.fused_moe import trtllm_fp4_block_scale_moe from flashinfer.testing.utils import bench_gpu_time from flashinfer.utils import device_support_pdl def get_tile_tokens_dim(num_tokens, num_experts, top_k): - # Factor to account for the imbalance of the experts. - # factor equals to the - # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert - # - 1.0 means perfect expert distribution. - # - > 1.0 means some experts have more - # tokens than the perfect distribution. - # - < 1.0 does not make sense. imbalance_factor = 1.3 - # Calculate the number of tokens per expert - # assuming perfect distribution. - num_tokens_per_expert = (num_tokens * top_k) // num_experts - # Apply the imbalance factor. + num_tokens_per_expert = num_tokens * top_k // num_experts num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) - # And pad the number to the next power of 2. tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile - # as it's the range supported by the kernel. tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) return tile_tokens_dim @@ -48,101 +36,101 @@ def bench_trtllm_gen_fused_moe_autotuner( warmups: int, iterations: int, ): - device = torch.device("cuda:0") + device = device2str("cuda:0") enable_pdl = device_support_pdl(device) - routing_logits = torch.rand(num_tokens, num_experts, device=device).to( - torch.bfloat16 - ) - hidden_states = torch.randn(num_tokens, hidden_size, device=device).to( - torch.bfloat16 - ) + routing_logits = paddle.rand(shape=[num_tokens, num_experts]).to("bfloat16") + hidden_states = paddle.randn(shape=[num_tokens, hidden_size]).to("bfloat16") if quant_mode == "NvFP4xNvFP4": hidden_states, hidden_states_scale = fp4_quantize( hidden_states, - torch.tensor([448.0 * 6.0], device=device), + paddle.to_tensor(data=[448.0 * 6.0], place=device), sf_vec_size=16, sf_use_ue8m0=False, ) - hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( + hidden_states_scale = hidden_states_scale.view(paddle.float8_e4m3fn).reshape( num_tokens, -1 ) hidden_states_global_scale = 1.0 / 448.0 / 6.0 elif quant_mode == "MxFP4xMxFP8": hidden_states, hidden_states_scale = mxfp8_quantize(hidden_states, False) - hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( + hidden_states_scale = hidden_states_scale.view(paddle.float8_e4m3fn).reshape( num_tokens, -1 ) hidden_states_global_scale = 1.0 - else: # MxFP4xBf16 + else: hidden_states_scale = None hidden_states_global_scale = 1.0 - - w13 = torch.randn( - num_experts, intermediate_size * 2, hidden_size, device=device - ).to(torch.bfloat16) - w2 = torch.randn(num_experts, hidden_size, intermediate_size, device=device).to( - torch.bfloat16 + w13 = paddle.randn(shape=[num_experts, intermediate_size * 2, hidden_size]).to( + "bfloat16" + ) + w2 = paddle.randn(shape=[num_experts, hidden_size, intermediate_size]).to( + "bfloat16" ) if quant_mode == "NvFP4xNvFP4": w13, w13_scale = fp4_quantize( w13, - torch.tensor([448.0 * 6.0], device=device), + paddle.to_tensor(data=[448.0 * 6.0], place=device), sf_vec_size=16, sf_use_ue8m0=False, ) - w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( + w13_scale = w13_scale.view(paddle.float8_e4m3fn).reshape( num_experts, intermediate_size * 2, -1 ) w2, w2_scale = fp4_quantize( w2, - torch.tensor([448.0 * 6.0], device=device), + paddle.to_tensor(data=[448.0 * 6.0], place=device), sf_vec_size=16, sf_use_ue8m0=False, ) - w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( + w2_scale = w2_scale.view(paddle.float8_e4m3fn).reshape( num_experts, hidden_size, -1 ) w13_global_scale = 1.0 / 448.0 / 6.0 w2_global_scale = 1.0 / 448.0 / 6.0 else: w13, w13_scale = fp4_quantize( - w13, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True + w13, + paddle.to_tensor(data=[1.0], place=device), + sf_vec_size=32, + sf_use_ue8m0=True, ) - w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( + w13_scale = w13_scale.view(paddle.float8_e4m3fn).reshape( num_experts, intermediate_size * 2, -1 ) w2, w2_scale = fp4_quantize( - w2, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True + w2, + paddle.to_tensor(data=[1.0], place=device), + sf_vec_size=32, + sf_use_ue8m0=True, ) - w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( + w2_scale = w2_scale.view(paddle.float8_e4m3fn).reshape( num_experts, hidden_size, -1 ) w13_global_scale = 1.0 w2_global_scale = 1.0 - bias13 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 - bias2 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 - + bias13 = paddle.randn(shape=[num_experts, intermediate_size * 2]) * 10 + bias2 = paddle.randn(shape=[num_experts, intermediate_size * 2]) * 10 tile_tokens_dim = get_tile_tokens_dim(num_tokens, num_experts, top_k) - output1_scale_scalar = torch.tensor( - [hidden_states_global_scale * w13_global_scale] * num_experts, device=device + output1_scale_scalar = paddle.to_tensor( + data=[hidden_states_global_scale * w13_global_scale] * num_experts, place=device ) - output1_scale_gate_scalar = torch.tensor( - [hidden_states_global_scale * w13_global_scale] * num_experts, device=device + output1_scale_gate_scalar = paddle.to_tensor( + data=[hidden_states_global_scale * w13_global_scale] * num_experts, place=device ) - output2_scale_scalar = torch.tensor( - [hidden_states_global_scale * w2_global_scale] * num_experts, device=device + output2_scale_scalar = paddle.to_tensor( + data=[hidden_states_global_scale * w2_global_scale] * num_experts, place=device ) fn = lambda: trtllm_fp4_block_scale_moe( routing_logits, - None, # routing_bias + None, hidden_states, hidden_states_scale, w13, w13_scale, bias13, - None, # gemm1_alpha - None, # gemm1_beta - None, # gemm1_clamp_limit + None, + None, + None, w2, w2_scale, bias2, @@ -151,30 +139,26 @@ def bench_trtllm_gen_fused_moe_autotuner( output2_scale_scalar, num_experts, top_k, - None, # n_group - None, # topk_group + None, + None, intermediate_size, - 0, # local_expert_offset + 0, num_experts, - None, # routed_scaling_factor + None, tile_tokens_dim, RoutingMethodType.Renormalize.value[0], True, enable_pdl, - GatedActType.SwiGlu.value, # gated_act_type + GatedActType.SwiGlu.value, None, num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, ) def bench(do_autotune): - # warmup with autotune(do_autotune): for _ in range(warmups): fn() - ms_list = bench_gpu_time( - fn, - repeat_iters=iterations, - ) + ms_list = bench_gpu_time(fn, repeat_iters=iterations) median_ms = np.median(ms_list) return median_ms diff --git a/benchmarks/bench_trtllm_gen_mla.py b/benchmarks/bench_trtllm_gen_mla.py index 3051608322..96367e4d0c 100644 --- a/benchmarks/bench_trtllm_gen_mla.py +++ b/benchmarks/bench_trtllm_gen_mla.py @@ -1,5 +1,9 @@ +import sys + + import numpy as np -import torch +import paddle +from flashinfer.paddle_utils import * import flashinfer from flashinfer.testing.utils import bench_gpu_time_with_cudagraph @@ -12,43 +16,33 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): - torch.manual_seed(42) + paddle.seed(seed=42) device = "cuda:0" - - # Initialize tensors - query = torch.randn( - batch_size, - q_len_per_request, - num_q_heads, - kv_lora_rank + qk_rope_head_dim, - device=device, + query = paddle.randn( + shape=[ + batch_size, + q_len_per_request, + num_q_heads, + kv_lora_rank + qk_rope_head_dim, + ] ).to(dtype) - num_tokens = seq_len * batch_size num_blocks = (num_tokens + page_size - 1) // page_size - - # Sequence lengths and block tables - seq_lens = [torch.randint(1, seq_len, (1,)).item() for _ in range(batch_size)] + seq_lens = [ + paddle.randint(low=1, high=seq_len, shape=(1,)).item() + for _ in range(batch_size) + ] seq_lens[-1] = seq_len max_seq_len = max(seq_lens) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) - + seq_lens_tensor = paddle.to_tensor(data=seq_lens, dtype="int32", place=device) blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size - max_num_blocks_per_seq = blocks_per_seq.max().item() - - # Generate random but unique block IDs for all sequences + max_num_blocks_per_seq = blocks_per_seq._max().item() total_blocks_needed = sum(blocks_per_seq) - all_block_ids = torch.randperm( - total_blocks_needed, device=device - ) # Random permutation - - # Generate unique block IDs for all sequences + all_block_ids = paddle.randperm(n=total_blocks_needed) block_id = 0 - block_tables = torch.zeros( - (batch_size, max_num_blocks_per_seq), dtype=torch.int, device=device + block_tables = paddle.zeros( + shape=(batch_size, max_num_blocks_per_seq), dtype="int32" ) - - # Populate block tables and track block assignments block_id = 0 for i in range(batch_size): num_blocks_needed = blocks_per_seq[i] @@ -56,23 +50,13 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): block_id : block_id + num_blocks_needed ] block_id += num_blocks_needed - - # Create interleaved KV cache - # Allocate more than needed blocks, block_id is just enough, to mimick real-world cases - kv_cache = torch.randn( - size=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim), device=device + kv_cache = paddle.randn( + shape=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim) ).to(dtype) - # (num_blocks, 1, page_size, kv_lora_rank + qk_rope_head_dim) - - # Allocate workspace buffer - # todo(Yingyi): calculate the actual size of workspace buffer - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - - # Run decode-MLA - # warmup + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, - kv_cache=kv_cache.unsqueeze(1), + kv_cache=kv_cache.unsqueeze(axis=1), workspace_buffer=workspace_buffer, qk_nope_head_dim=qk_nope_head_dim, kv_lora_rank=kv_lora_rank, @@ -80,14 +64,13 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): block_tables=block_tables, seq_lens=seq_lens_tensor, max_seq_len=max_seq_len, - bmm1_scale=1.0 / ((128 + 64) ** 0.5), + bmm1_scale=1.0 / (128 + 64) ** 0.5, bmm2_scale=1.0, ) - # benchmark measurements = bench_gpu_time_with_cudagraph( lambda: flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, - kv_cache=kv_cache.unsqueeze(1), + kv_cache=kv_cache.unsqueeze(axis=1), workspace_buffer=workspace_buffer, qk_nope_head_dim=qk_nope_head_dim, kv_lora_rank=kv_lora_rank, @@ -95,16 +78,13 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): block_tables=block_tables, seq_lens=seq_lens_tensor, max_seq_len=max_seq_len, - bmm1_scale=1.0 / ((128 + 64) ** 0.5), + bmm1_scale=1.0 / (128 + 64) ** 0.5, bmm2_scale=1.0, ), dry_run_time_ms=100, repeat_time_ms=1000, ) - io = ( - query.numel() * query.element_size() - + kv_cache.numel() * kv_cache.element_size() - ) + io = query.size * query.element_size() + kv_cache.size * kv_cache.element_size() ms = np.median(measurements) flops = ( 2 @@ -119,11 +99,11 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): ) print(f"execution time: {ms} ms") print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s") - print(f"FLOPs: {flops * 1e-9 / ms:.2f} TFLOPs/s") + print(f"FLOPs: {flops * 1e-09 / ms:.2f} TFLOPs/s") if __name__ == "__main__": - for dtype in [torch.bfloat16, torch.float8_e4m3fn]: + for dtype in ["bfloat16", paddle.float8_e4m3fn]: for page_size in [32, 64]: for batch_size in [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024]: for seq_len in [1024, 4096, 8192]: diff --git a/benchmarks/flashinfer_benchmark.py b/benchmarks/flashinfer_benchmark.py index 86fcf39193..d8ff432018 100644 --- a/benchmarks/flashinfer_benchmark.py +++ b/benchmarks/flashinfer_benchmark.py @@ -2,11 +2,9 @@ import sys from routines.attention import parse_attention_args, run_attention_test -from routines.flashinfer_benchmark_utils import ( - benchmark_apis, - full_output_columns, - output_column_dict, -) +from routines.flashinfer_benchmark_utils import (benchmark_apis, + full_output_columns, + output_column_dict) from routines.gemm import parse_gemm_args, run_gemm_test from routines.moe import parse_moe_args, run_moe_test @@ -18,8 +16,6 @@ def run_test(args): Args: args: Parsed command line arguments containing test configuration """ - - ## Depending on routine type, route to corresponding test routine if args.routine in benchmark_apis["attention"]: res = run_attention_test(args) elif args.routine in benchmark_apis["gemm"]: @@ -28,14 +24,11 @@ def run_test(args): res = run_moe_test(args) else: raise ValueError(f"Unsupported routine: {args.routine}") - - # Write results to output file if specified if args.output_path is not None: with open(args.output_path, "a") as fout: for cur_res in res: for key in output_column_dict["general"]: cur_res[key] = getattr(args, key) - output_line = ",".join( [str(cur_res[col]) for col in full_output_columns] ) @@ -55,8 +48,6 @@ def parse_args(line=sys.argv[1:]): Returns: Parsed argument namespace """ - - ## Shared arguments parser = argparse.ArgumentParser() parser.add_argument( "--routine", @@ -68,7 +59,6 @@ def parse_args(line=sys.argv[1:]): + list(benchmark_apis["moe"]), ) args, _ = parser.parse_known_args(line[:]) - parser.add_argument( "--no_cuda_graph", action="store_true", @@ -136,8 +126,6 @@ def parse_args(line=sys.argv[1:]): default="", help="Placeholder for generated reproducer command for the test case. Not to be used directly.", ) - - ## Check routine and pass on to routine-specific argument parser if args.routine in benchmark_apis["attention"]: args = parse_attention_args(line, parser) elif args.routine in benchmark_apis["gemm"]: @@ -146,14 +134,12 @@ def parse_args(line=sys.argv[1:]): args = parse_moe_args(line, parser) else: raise ValueError(f"Unsupported routine: {args.routine}") - if args.generate_repro_command: args.repro_command = "python3 flashinfer_benchmark.py " + " ".join(line) return args if __name__ == "__main__": - # Parse testlist argument first testlist_parser = argparse.ArgumentParser(add_help=False) testlist_parser.add_argument( "--testlist", @@ -170,15 +156,10 @@ def parse_args(line=sys.argv[1:]): help="Output path for results csv.", ) testlist_args, _ = testlist_parser.parse_known_args() - - # Setup output file if specified if testlist_args.output_path is not None: with open(testlist_args.output_path, "w") as fout: fout.write(",".join(full_output_columns) + "\n") - - # Process tests either from testlist file or command line arguments if testlist_args.testlist is not None: - # If testlist, run each test in the testlist with open(testlist_args.testlist, "r") as f: import shlex @@ -195,7 +176,6 @@ def parse_args(line=sys.argv[1:]): print(f"[ERROR] Error: {e}") continue else: - # If no testlist, just run the command args = parse_args() args.output_path = testlist_args.output_path run_test(args) diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index a5caf3ad1f..4e41cc04c8 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -1,21 +1,20 @@ +import sys + + from collections import defaultdict import numpy as np -import torch +import paddle +from flashinfer.paddle_utils import * import flashinfer from flashinfer.testing.utils import ( attention_tb_per_sec_with_actual_seq_lens, - attention_tflops_per_sec_with_actual_seq_lens, - bench_gpu_time, - bench_gpu_time_with_cudagraph, -) + attention_tflops_per_sec_with_actual_seq_lens, bench_gpu_time, + bench_gpu_time_with_cudagraph) -from .flashinfer_benchmark_utils import ( - dtype_str_to_torch_dtype, - get_device, - print_perf_metrics, -) +from .flashinfer_benchmark_utils import (dtype_str_to_torch_dtype, get_device, + print_perf_metrics) def run_attention_test(args): @@ -147,10 +146,9 @@ def parse_attention_args(line, parser): default=False, help="Use random actual sequence lengths for the query and key and value. Random values are generated between 1 and maximum sequence length. If False, use maximum sequence length.", ) - args = parser.parse_args(line) if args.verbose >= 1: - print(f"[INFO] {args = }") + print(f"[INFO] args = {args!r}") return args @@ -170,12 +168,12 @@ def sample_actual_seq_lens(max_seqlen, batch_size, device, random_actual_seq_len actual_seq_lens: Actual sequence lengths for each batch. """ if random_actual_seq_len: - actual_seq_lens = torch.randint( - 1, max_seqlen + 1, (batch_size, 1, 1, 1), device=device, dtype=torch.int32 + actual_seq_lens = paddle.randint( + low=1, high=max_seqlen + 1, shape=(batch_size, 1, 1, 1), dtype="int32" ) else: - actual_seq_lens = torch.full( - (batch_size, 1, 1, 1), max_seqlen, device=device, dtype=torch.int32 + actual_seq_lens = paddle.full( + shape=(batch_size, 1, 1, 1), fill_value=max_seqlen, dtype="int32" ) return actual_seq_lens @@ -200,30 +198,21 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): if args.verbose >= 1: print("[INFO] Running testBatchDecodeWithPagedKVCacheWrapper") print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - - # Basic setup device = get_device(args) if args.generate_repro_command: print( f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" ) - - q_init_dtype = torch.bfloat16 - kv_init_dtype = torch.bfloat16 - rtol = 2e-1 - atol = 1e-2 - - # Handle different query data types. + q_init_dtype = "bfloat16" + kv_init_dtype = "bfloat16" + rtol = 0.2 + atol = 0.01 q_dtype = dtype_str_to_torch_dtype(args.q_dtype) - if q_dtype not in [torch.bfloat16, torch.float8_e4m3fn]: + if q_dtype not in ["bfloat16", paddle.float8_e4m3fn]: raise ValueError(f"Unsupported q_dtype: {args.q_dtype}") - - # Handle different KV cache data types. kv_dtype = dtype_str_to_torch_dtype(args.kv_dtype) - if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn]: + if kv_dtype not in ["bfloat16", paddle.float8_e4m3fn]: raise ValueError(f"Unsupported kv_dtype: {args.kv_dtype}") - - # Parse and validate backend configurations backends = args.backends page_size = args.page_size batch_size = args.batch_size @@ -234,15 +223,10 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): head_dim_qk = args.head_dim_qk head_dim_vo = args.head_dim_vo is_cuda_graph_compatible = not args.no_cuda_graph - # return_lse = not args.no_lse # TO-DO: Add support for this run_refcheck = args.refcheck - - # Derived parameters if "fa2" in backends: remove_fa2 = False - head_grp_size = ( - num_qo_heads // num_kv_heads - ) # If 5, FA2 backend is not supported. + head_grp_size = num_qo_heads // num_kv_heads if head_grp_size == 5: print( "[INFO] FA2 backend is not supported for this configuration. Skipping." @@ -250,81 +234,57 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): remove_fa2 = True if remove_fa2: backends.remove("fa2") - if "fa2_tc" in backends: remove_fa2_tc = False - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, + if q_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2] or kv_dtype in [ + paddle.float8_e4m3fn, + paddle.float8_e5m2, ]: print("[INFO] FA2_TC backend does not support FP8. Skipping.") remove_fa2_tc = True if remove_fa2_tc: backends.remove("fa2_tc") - if "cudnn" in backends: remove_cudnn = False - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, + if q_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2] or kv_dtype in [ + paddle.float8_e4m3fn, + paddle.float8_e5m2, ]: print("[INFO] cuDNN backend does not support FP8. Skipping.") remove_cudnn = True if remove_cudnn: backends.remove("cudnn") - if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") return - - # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} outputs = {} - - # Sample sequence lengths and create tensors actual_seq_lens_kv = sample_actual_seq_lens( s_kv, batch_size, device, args.random_actual_seq_len ) - sum_seq_kv = torch.sum(actual_seq_lens_kv).item() + sum_seq_kv = paddle.sum(x=actual_seq_lens_kv).item() avg_seq_len_kv = sum_seq_kv // batch_size - if args.verbose >= 1: print(f"[VERBOSE] Average actual seq len: {avg_seq_len_kv}") if args.verbose >= 2: - print(f"[VVERBOSE] {actual_seq_lens_kv.flatten() = }") - - # Create query tensor - q = torch.rand( - batch_size, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype - ) + print( + f"[VVERBOSE] actual_seq_lens_kv.flatten() = {actual_seq_lens_kv.flatten()!r}" + ) + q = paddle.rand(shape=[batch_size, num_qo_heads, head_dim_qk], dtype=q_init_dtype) if args.verbose >= 2: - print(f"[VVERBOSE] {q.shape = }") - - # Create KV cache + print(f"[VVERBOSE] q.shape = {tuple(q.shape)!r}") num_pages_per_seq = (s_kv + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size - if args.verbose >= 2: - print(f"[VVERBOSE] {num_pages_per_seq = }") - print(f"[VVERBOSE] {total_num_pages = }") - - # Initialize KV cache with appropriate shape and stride - kv_cache_shape = ( - total_num_pages, - 2, # 2 for key and value - num_kv_heads, - page_size, - head_dim_qk, - ) - kv_cache = torch.randn(size=kv_cache_shape, dtype=kv_init_dtype).to(device) - - # Keep a copy for TRT-LLM which uses different strides + print(f"[VVERBOSE] num_pages_per_seq = {num_pages_per_seq!r}") + print(f"[VVERBOSE] total_num_pages = {total_num_pages!r}") + kv_cache_shape = total_num_pages, 2, num_kv_heads, page_size, head_dim_qk + kv_cache = paddle.randn(shape=kv_cache_shape, dtype=kv_init_dtype).to(device) if "trtllm-gen" in backends: kv_cache_for_trt = kv_cache.detach().clone() - kv_cache = kv_cache.as_strided( - kv_cache.shape, - ( + shape=tuple(kv_cache.shape), + stride=( 2 * page_size * num_kv_heads * head_dim_qk, page_size * num_kv_heads * head_dim_qk, head_dim_qk, @@ -333,15 +293,11 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): ), ) k_cache_view, v_cache_view = kv_cache[:, 0, :, :, :], kv_cache[:, 1, :, :, :] - if "trtllm-gen" in backends: - # kv_cache now has different tensor stride and logical values. Copy over values to kv_cache_for_trt. - # Result is kv_cache and kv_cache_for_trt have the same logical values but different tensor strides. - kv_cache_for_trt.copy_(kv_cache) - + paddle.assign(kv_cache, output=kv_cache_for_trt) v_cache = v_cache_view.as_strided( - v_cache_view.shape, - ( + shape=tuple(v_cache_view.shape), + stride=( 2 * page_size * num_kv_heads * head_dim_qk, head_dim_qk, num_kv_heads * head_dim_qk, @@ -349,76 +305,65 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): ), ) k_cache = k_cache_view.as_strided( - k_cache_view.shape, - ( + shape=tuple(k_cache_view.shape), + stride=( 2 * page_size * num_kv_heads * head_dim_qk, head_dim_qk, num_kv_heads * head_dim_qk, 1, ), ) - - # Now initialize the page tables - block_tables = torch.tensor( - [ - [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + block_tables = paddle.to_tensor( + data=[ + [(k + i * num_pages_per_seq) for k in range(num_pages_per_seq)] for i in range(batch_size) ], - dtype=torch.int, - device=device, + dtype="int32", + place=device, ) - kv_indptr = ( - torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum( - (actual_seq_lens_kv.flatten() + page_size - 1) // page_size, dim=0 + paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum( + x=(actual_seq_lens_kv.flatten() + page_size - 1) // page_size, + axis=0, ), ] ) - .int() + .astype(dtype="int32") .to(device) ) - - # kv_indices[-1] is the total number of actual pages - kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32) + kv_indices = paddle.zeros(shape=kv_indptr[-1], dtype="int32") for i in range(len(kv_indptr) - 1): start_idx = kv_indptr[i] end_idx = kv_indptr[i + 1] - kv_indices[start_idx:end_idx] = torch.arange( - i * num_pages_per_seq, - i * num_pages_per_seq + (end_idx - start_idx), - device=device, + kv_indices[start_idx:end_idx] = paddle.arange( + start=i * num_pages_per_seq, + end=i * num_pages_per_seq + (end_idx - start_idx), ) - kv_last_page_len = ( - torch.where( - actual_seq_lens_kv.flatten() % page_size == 0, - torch.full((batch_size,), page_size, device=device), - actual_seq_lens_kv.flatten() % page_size, + paddle.where( + condition=actual_seq_lens_kv.flatten() % page_size == 0, + x=paddle.full(shape=(batch_size,), fill_value=page_size), + y=actual_seq_lens_kv.flatten() % page_size, ) - .int() + .astype(dtype="int32") .to(device) ) - ragged_q = ( - torch.arange(0, batch_size + 1, device=device) * (num_qo_heads * head_dim_qk) - ).long() # For cuDNN - - scale = float(1.0 / (head_dim_qk**0.5)) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - + paddle.arange(start=0, end=batch_size + 1) * (num_qo_heads * head_dim_qk) + ).astype(dtype="int64") + scale = float(1.0 / head_dim_qk**0.5) + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") if args.verbose >= 2: - print(f"[VVERBOSE] {kv_cache.shape = }") - print(f"[VVERBOSE] {kv_cache.stride() = }") - print(f"[VVERBOSE] {block_tables.shape = }") - print(f"[VVERBOSE] {kv_indptr.shape = }") - print(f"[VVERBOSE] {kv_indices.shape = }") - print(f"[VVERBOSE] {kv_last_page_len.shape = }") - print(f"[VVERBOSE] {scale = }") - - # Prepare wrappers + print(f"[VVERBOSE] kv_cache.shape = {tuple(kv_cache.shape)!r}") + print(f"[VVERBOSE] kv_cache.stride() = {kv_cache.get_strides()!r}") + print(f"[VVERBOSE] block_tables.shape = {tuple(block_tables.shape)!r}") + print(f"[VVERBOSE] kv_indptr.shape = {tuple(kv_indptr.shape)!r}") + print(f"[VVERBOSE] kv_indices.shape = {tuple(kv_indices.shape)!r}") + print(f"[VVERBOSE] kv_last_page_len.shape = {tuple(kv_last_page_len.shape)!r}") + print(f"[VVERBOSE] scale = {scale!r}") backend_wrappers = {} for backend in backends: if backend in ["fa2", "fa2_tc", "trtllm-gen"]: @@ -429,7 +374,7 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): workspace_buffer, "HND", use_cuda_graph=is_cuda_graph_compatible, - use_tensor_cores=(backend != "fa2"), + use_tensor_cores=backend != "fa2", paged_kv_indptr_buffer=plan_kv_indptr, paged_kv_indices_buffer=kv_indices, paged_kv_last_page_len_buffer=kv_last_page_len, @@ -446,23 +391,21 @@ def testBatchDecodeWithPagedKVCacheWrapper(args): q_data_type=q_dtype, data_type=kv_dtype, ) - - ## If FP8, prepare k_scale, v_scale = None, None - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if q_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: q = q.to(q_dtype) - if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - k_data, v_data = torch.chunk(kv_cache, 2, dim=1) + if kv_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: + k_data, v_data = paddle.chunk(x=kv_cache, chunks=2, axis=1) k_scale = k_data.amax().item() / 256 v_scale = v_data.amax().item() / 256 k_fp8 = (k_data / k_scale).to(kv_dtype) v_fp8 = (v_data / v_scale).to(kv_dtype) - kv_cache = torch.cat([k_fp8, v_fp8], dim=1) + kv_cache = paddle.concat(x=[k_fp8, v_fp8], axis=1) if "trtllm-gen" in backends: - k_data, v_data = torch.chunk(kv_cache_for_trt, 2, dim=1) + k_data, v_data = paddle.chunk(x=kv_cache_for_trt, chunks=2, axis=1) k_fp8 = (k_data / k_scale).to(kv_dtype) v_fp8 = (v_data / v_scale).to(kv_dtype) - kv_cache_for_trt = torch.cat([k_fp8, v_fp8], dim=1) + kv_cache_for_trt = paddle.concat(x=[k_fp8, v_fp8], axis=1) def run_backend_wrapper(backend): if backend in ["fa2", "fa2_tc", "trtllm-gen"]: @@ -505,8 +448,6 @@ def run_backend_wrapper(backend): .detach() ) has_reference_output = True - - # Iterate over each backend: for cur_backend in backends: if run_refcheck: outputs[cur_backend] = run_backend_wrapper(cur_backend).detach() @@ -531,24 +472,22 @@ def run_backend_wrapper(backend): l2_flush_device=device, sleep_after_run=False, ) - - # Perform reference check tested_backends = list(outputs.keys()) tested_outputs = list(outputs.values()) if len(tested_backends) > 1: if run_refcheck and has_reference_output: - if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if reference_output.dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: if args.verbose >= 2: print( "[VVERBOSE] Reference output is FP8. Converting to float32 for reference check." ) - reference_output = reference_output.to(torch.float32) - tested_outputs = [output.to(torch.float32) for output in tested_outputs] + reference_output = reference_output.to("float32") + tested_outputs = [output.to("float32") for output in tested_outputs] for i in range(len(tested_outputs)): try: - torch.testing.assert_close( - reference_output, tested_outputs[i], rtol=rtol, atol=atol - ) + assert paddle.allclose( + x=reference_output, y=tested_outputs[i], rtol=rtol, atol=atol + ).item(), "" except AssertionError as e: print( f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}" @@ -556,14 +495,13 @@ def run_backend_wrapper(backend): if not args.allow_output_mismatch: print(e) raise - # Compute perf metrics res = [] for backend in backends: if len(backend_times[backend]) > 0: median_time = np.median(backend_times[backend]) std_time = np.std(backend_times[backend]) actual_seq_lens_kv_flat = actual_seq_lens_kv.flatten().to("cpu") - actual_seq_lens_q_flat = torch.ones_like(actual_seq_lens_kv_flat) + actual_seq_lens_q_flat = paddle.ones_like(x=actual_seq_lens_kv_flat) tflops = attention_tflops_per_sec_with_actual_seq_lens( actual_seq_lens_q_flat, actual_seq_lens_kv_flat, @@ -586,7 +524,6 @@ def run_backend_wrapper(backend): o_dtype=q_dtype, ) print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: cur_res = defaultdict(str) cur_res["routine"] = args.routine @@ -633,28 +570,21 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): if args.verbose >= 1: print("[INFO] Running testBatchPrefillWithPagedKVCacheWrapper") print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - - # Basic setup device = get_device(args) if args.generate_repro_command: print( f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" ) - - q_init_dtype = torch.bfloat16 - kv_init_dtype = torch.bfloat16 - rtol = 2e-1 - atol = 1e-2 - + q_init_dtype = "bfloat16" + kv_init_dtype = "bfloat16" + rtol = 0.2 + atol = 0.01 q_dtype = dtype_str_to_torch_dtype(args.q_dtype) - if q_dtype not in [torch.bfloat16, torch.float8_e4m3fn]: + if q_dtype not in ["bfloat16", paddle.float8_e4m3fn]: raise ValueError(f"Unsupported q_dtype: {args.q_dtype}") - kv_dtype = dtype_str_to_torch_dtype(args.kv_dtype) - if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn]: + if kv_dtype not in ["bfloat16", paddle.float8_e4m3fn]: raise ValueError(f"Unsupported kv_dtype: {args.kv_dtype}") - - # Parse and validate backend configurations backends = args.backends page_size = args.page_size batch_size = args.batch_size @@ -666,20 +596,17 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): head_dim_vo = args.head_dim_vo causal = args.causal is_cuda_graph_compatible = not args.no_cuda_graph - # return_lse = not args.no_lse # TO-DO: Add support for this run_refcheck = args.refcheck - - # Check for backend-specific constraints if "fa2" in backends: remove_fa2 = False - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if q_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: print("[INFO] FA2 backend does not support FP8. Skipping.") remove_fa2 = True if remove_fa2: backends.remove("fa2") if "fa3" in backends: remove_fa3 = False - device_capability = torch.cuda.get_device_capability() + device_capability = paddle.device.cuda.get_device_capability() if device_capability[0] != 9: print( f"[INFO] FA3 backend does not support capability {device_capability}. Skipping." @@ -689,81 +616,66 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): backends.remove("fa3") if "cudnn" in backends: remove_cudnn = False - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, + if q_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2] or kv_dtype in [ + paddle.float8_e4m3fn, +>>>>>> paddle.float8_e5m2, ]: print("[INFO] cuDNN backend does not support FP8. Skipping.") remove_cudnn = True if remove_cudnn: backends.remove("cudnn") - if "trtllm-gen" in backends: remove_trtllm = False - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, + if q_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2] or kv_dtype in [ + paddle.float8_e4m3fn, +>>>>>> paddle.float8_e5m2, ]: print("[INFO] trtllm-gen backend does not support FP8. Skipping.") remove_trtllm = True if remove_trtllm: backends.remove("trtllm-gen") - if "cutlass" in backends: print("[INFO] CUTLASS backend does not support prefill. Skipping.") remove_cutlass = True if remove_cutlass: backends.remove("cutlass") - if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") return - - # Check for layer-specific constraints layer_not_supported = False - if not ((head_dim_qk == 128 and head_dim_qk == head_dim_vo) or head_dim_qk == 192): + if not (head_dim_qk == 128 and head_dim_qk == head_dim_vo or head_dim_qk == 192): print("[ERROR] Head dimension must be 128 or 192") layer_not_supported = True if layer_not_supported: print("[ERROR] Layer not supported. Exiting.") return - - # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} outputs = {} - - # Randomly sample actual_seq_lens_q. Assume actual_seq_lens_kv is the same as actual_seq_lens_q. actual_seq_lens_q = sample_actual_seq_lens( s_qo, batch_size, None, args.random_actual_seq_len ) actual_seq_lens_kv = actual_seq_lens_q.clone() - avg_seq_len_q = actual_seq_lens_q.sum().item() // batch_size if args.verbose >= 1: print(f"[VERBOSE] Average actual seq len: {avg_seq_len_q}") if args.verbose >= 2: - print(f"[VVERBOSE] {actual_seq_lens_q.flatten() = }") - - cumsum_s_qo = torch.sum(actual_seq_lens_q) - q = torch.randn( - cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype - ) + print( + f"[VVERBOSE] actual_seq_lens_q.flatten() = {actual_seq_lens_q.flatten()!r}" + ) + cumsum_s_qo = paddle.sum(x=actual_seq_lens_q) + q = paddle.randn(shape=[cumsum_s_qo, num_qo_heads, head_dim_qk], dtype=q_init_dtype) if args.verbose >= 2: - print(f"[VVERBOSE] {q.shape = }") - - # Create KV cache + print(f"[VVERBOSE] q.shape = {tuple(q.shape)!r}") num_pages_per_seq = (s_kv + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size - if args.verbose >= 2: - print(f"[VVERBOSE] {num_pages_per_seq = }") - print(f"[VVERBOSE] {total_num_pages = }") - - kv_cache_shape = (total_num_pages, 2, num_kv_heads, page_size, head_dim_qk) - kv_cache = torch.randn(size=kv_cache_shape, dtype=kv_init_dtype).to(device) + print(f"[VVERBOSE] num_pages_per_seq = {num_pages_per_seq!r}") + print(f"[VVERBOSE] total_num_pages = {total_num_pages!r}") + kv_cache_shape = total_num_pages, 2, num_kv_heads, page_size, head_dim_qk + kv_cache = paddle.randn(shape=kv_cache_shape, dtype=kv_init_dtype).to(device) kv_cache = kv_cache.as_strided( - kv_cache.shape, - ( + shape=tuple(kv_cache.shape), + stride=( 2 * page_size * num_kv_heads * head_dim_qk, page_size * num_kv_heads * head_dim_qk, head_dim_qk, @@ -772,10 +684,9 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): ), ) k_cache_view, v_cache_view = kv_cache[:, 0, :, :, :], kv_cache[:, 1, :, :, :] - v_cache = v_cache_view.as_strided( - v_cache_view.shape, - ( + shape=tuple(v_cache_view.shape), + stride=( 2 * page_size * num_kv_heads * head_dim_qk, head_dim_qk, num_kv_heads * head_dim_qk, @@ -783,112 +694,103 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): ), ) k_cache = k_cache_view.as_strided( - k_cache_view.shape, - ( + shape=tuple(k_cache_view.shape), + stride=( 2 * page_size * num_kv_heads * head_dim_qk, head_dim_qk, num_kv_heads * head_dim_qk, 1, ), ) - - # Now initialize the page tables - block_tables = torch.tensor( - [ - [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + block_tables = paddle.to_tensor( + data=[ + [(k + i * num_pages_per_seq) for k in range(num_pages_per_seq)] for i in range(batch_size) ], - dtype=torch.int, - device=device, + dtype="int32", + place=device, ) - actual_seq_lens_q_device = actual_seq_lens_q.to(device) actual_seq_lens_kv_device = actual_seq_lens_kv.to(device) q_indptr = ( - torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0) + paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_q_device.view(-1), axis=0) * head_dim_qk * num_qo_heads, ] ) - .long() + .astype(dtype="int64") .to(device) - ) # For cuDNN + ) qo_indptr = ( - torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0), + paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_q_device.view(-1), axis=0), ] ) - .int() + .astype(dtype="int32") .to(device) ) - - # Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr kv_indptr = ( - torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum( - (actual_seq_lens_kv_device.flatten() + page_size - 1) // page_size, - dim=0, + paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum( + x=(actual_seq_lens_kv_device.flatten() + page_size - 1) + // page_size, + axis=0, ), ] ) - .int() + .astype(dtype="int32") .to(device) ) - kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32) + kv_indices = paddle.zeros(shape=kv_indptr[-1], dtype="int32") for i in range(len(kv_indptr) - 1): start_idx = kv_indptr[i] end_idx = kv_indptr[i + 1] - kv_indices[start_idx:end_idx] = torch.arange( - i * num_pages_per_seq, - i * num_pages_per_seq + (end_idx - start_idx), - device=device, + kv_indices[start_idx:end_idx] = paddle.arange( + start=i * num_pages_per_seq, + end=i * num_pages_per_seq + (end_idx - start_idx), ) kv_last_page_len = ( - torch.where( - actual_seq_lens_kv_device.flatten() % page_size == 0, - torch.full((batch_size,), page_size, device=device), - actual_seq_lens_kv_device.flatten() % page_size, + paddle.where( + condition=actual_seq_lens_kv_device.flatten() % page_size == 0, + x=paddle.full(shape=(batch_size,), fill_value=page_size), + y=actual_seq_lens_kv_device.flatten() % page_size, ) - .int() + .astype(dtype="int32") .to(device) ) - - scale = float(1.0 / (head_dim_qk**0.5)) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - + scale = float(1.0 / head_dim_qk**0.5) + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") if args.verbose >= 2: - print(f"[VVERBOSE] {kv_cache.shape = }") - print(f"[VVERBOSE] {kv_cache.stride() = }") - print(f"[VVERBOSE] {block_tables.shape = }") - print(f"[VVERBOSE] {qo_indptr.shape = }") - print(f"[VVERBOSE] {qo_indptr.dtype = }") - print(f"[VVERBOSE] {kv_indptr.shape = }") - print(f"[VVERBOSE] {kv_indices.shape = }") - print(f"[VVERBOSE] {kv_last_page_len.shape = }") - print(f"[VVERBOSE] {scale = }") - - # Prepare wrappers + print(f"[VVERBOSE] kv_cache.shape = {tuple(kv_cache.shape)!r}") + print(f"[VVERBOSE] kv_cache.stride() = {kv_cache.get_strides()!r}") + print(f"[VVERBOSE] block_tables.shape = {tuple(block_tables.shape)!r}") + print(f"[VVERBOSE] qo_indptr.shape = {tuple(qo_indptr.shape)!r}") + print(f"[VVERBOSE] qo_indptr.dtype = {qo_indptr.dtype!r}") + print(f"[VVERBOSE] kv_indptr.shape = {tuple(kv_indptr.shape)!r}") + print(f"[VVERBOSE] kv_indices.shape = {tuple(kv_indices.shape)!r}") + print(f"[VVERBOSE] kv_last_page_len.shape = {tuple(kv_last_page_len.shape)!r}") + print(f"[VVERBOSE] scale = {scale!r}") backend_wrappers = {} for backend in backends: if backend in ["fa2", "fa3", "trtllm-gen"]: - backend_wrappers[backend] = ( - flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, - "HND", - use_cuda_graph=is_cuda_graph_compatible, - qo_indptr_buf=qo_indptr, - paged_kv_indptr_buf=kv_indptr, - paged_kv_indices_buf=kv_indices, - paged_kv_last_page_len_buf=kv_last_page_len, - backend=backend, - ) + backend_wrappers[ + backend + ] = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, + "HND", + use_cuda_graph=is_cuda_graph_compatible, + qo_indptr_buf=qo_indptr, + paged_kv_indptr_buf=kv_indptr, + paged_kv_indices_buf=kv_indices, + paged_kv_last_page_len_buf=kv_last_page_len, + backend=backend, ) backend_wrappers[backend].plan( qo_indptr, @@ -904,17 +806,16 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): q_data_type=q_dtype, kv_data_type=kv_dtype, ) - k_scale, v_scale = None, None - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if q_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: q = q.to(q_dtype) - if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - k_data, v_data = torch.chunk(kv_cache, 2, dim=1) + if kv_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: + k_data, v_data = paddle.chunk(x=kv_cache, chunks=2, axis=1) k_scale = k_data.amax().item() / 256 v_scale = v_data.amax().item() / 256 k_fp8 = (k_data / k_scale).to(kv_dtype) v_fp8 = (v_data / v_scale).to(kv_dtype) - kv_cache = torch.cat([k_fp8, v_fp8], dim=1) + kv_cache = paddle.concat(x=[k_fp8, v_fp8], axis=1) def run_backend_wrapper(backend): if backend in ["fa2", "fa3", "trtllm-gen"]: @@ -963,8 +864,6 @@ def run_backend_wrapper(backend): q, kv_cache, k_scale=k_scale, v_scale=v_scale ) has_reference_output = True - - # Iterate over each backend: for cur_backend in backends: if run_refcheck: outputs[cur_backend] = run_backend_wrapper(cur_backend) @@ -989,24 +888,22 @@ def run_backend_wrapper(backend): l2_flush_device=device, sleep_after_run=False, ) - - # Perform reference check tested_backends = list(outputs.keys()) tested_outputs = list(outputs.values()) if len(tested_backends) > 1: if run_refcheck and has_reference_output: - if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if reference_output.dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: if args.verbose >= 2: print( "[VVERBOSE] Reference output is FP8. Converting to float32 for reference check." ) - reference_output = reference_output.to(torch.float32) - tested_outputs = [output.to(torch.float32) for output in tested_outputs] + reference_output = reference_output.to("float32") + tested_outputs = [output.to("float32") for output in tested_outputs] for i in range(len(tested_backends)): try: - torch.testing.assert_close( - reference_output, tested_outputs[i], rtol=rtol, atol=atol - ) + assert paddle.allclose( + x=reference_output, y=tested_outputs[i], rtol=rtol, atol=atol + ).item(), "" except AssertionError as e: print( f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}" @@ -1014,8 +911,6 @@ def run_backend_wrapper(backend): if not args.allow_output_mismatch: print(e) raise - - # Compute perf metrics res = [] for backend in backends: if len(backend_times[backend]) > 0: @@ -1045,7 +940,6 @@ def run_backend_wrapper(backend): o_dtype=q_dtype, ) print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: cur_res = defaultdict(str) cur_res["routine"] = args.routine @@ -1092,27 +986,21 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): if args.verbose >= 1: print("[INFO] Running testBatchPrefillWithRaggedKVCacheWrapper") print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - - # Basic setup device = get_device(args) if args.generate_repro_command: print( f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" ) - - q_init_dtype = torch.bfloat16 - kv_init_dtype = torch.bfloat16 - rtol = 2e-1 - atol = 1e-2 - + q_init_dtype = "bfloat16" + kv_init_dtype = "bfloat16" + rtol = 0.2 + atol = 0.01 q_dtype = dtype_str_to_torch_dtype(args.q_dtype) - if q_dtype not in [torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2]: + if q_dtype not in ["bfloat16", paddle.float8_e4m3fn, paddle.float8_e5m2]: raise ValueError(f"Unsupported q_dtype: {args.q_dtype}") kv_dtype = dtype_str_to_torch_dtype(args.kv_dtype) - if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2]: + if kv_dtype not in ["bfloat16", paddle.float8_e4m3fn, paddle.float8_e5m2]: raise ValueError(f"Unsupported kv_dtype: {args.kv_dtype}") - - # Parse and validate backend configurations backends = args.backends batch_size = args.batch_size s_qo = args.s_qo @@ -1123,183 +1011,151 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): head_dim_vo = args.head_dim_vo causal = args.causal is_cuda_graph_compatible = not args.no_cuda_graph - # return_lse = not args.no_lse # TO-DO: Add support for this run_refcheck = args.refcheck - - # Check for backend-specific constraints if "cudnn" in backends: remove_cudnn = False - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, + if q_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2] or kv_dtype in [ + paddle.float8_e4m3fn, +>>>>>> paddle.float8_e5m2, ]: print("[INFO] CUDNN backend does not support FP8. Skipping.") remove_cudnn = True if remove_cudnn: backends.remove("cudnn") - if "cutlass" in backends: remove_cutlass = False - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, + if q_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2] or kv_dtype in [ + paddle.float8_e4m3fn, +>>>>>> paddle.float8_e5m2, ]: print("[INFO] CUTLASS backend does not support FP8. Skipping.") remove_cutlass = True if remove_cutlass: backends.remove("cutlass") - if "trtllm-gen" in backends: print("[INFO] trtllm-gen backend does not support ragged prefill. Skipping.") remove_trtllm = True if remove_trtllm: backends.remove("trtllm-gen") - if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") return - - # Check for layer-specific constraints layer_not_supported = False - if not ((head_dim_qk == 128 and head_dim_qk == head_dim_vo) or head_dim_qk == 192): + if not (head_dim_qk == 128 and head_dim_qk == head_dim_vo or head_dim_qk == 192): print("[ERROR] Head dimension must be 128 or 192") layer_not_supported = True if layer_not_supported: print("[ERROR] Layer not supported. Exiting.") return - backend_times = {backend: [] for backend in backends} outputs = {} - - # Randomly sample actual_seq_lens_q. Assume actual_seq_lens_kv is the same as actual_seq_lens_q. actual_seq_lens_q = sample_actual_seq_lens( s_qo, batch_size, None, args.random_actual_seq_len ) actual_seq_lens_kv = actual_seq_lens_q.clone() - avg_seq_len_q = actual_seq_lens_q.sum().item() // batch_size if args.verbose >= 1: print(f"[VERBOSE] Average actual seq len: {avg_seq_len_q}") if args.verbose >= 2: - print(f"[VVERBOSE] {actual_seq_lens_q.flatten() = }") - - cumsum_s_qo = torch.sum(actual_seq_lens_q) - cumsum_s_kv = torch.sum(actual_seq_lens_kv) - q = torch.randn( - cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype - ) + print( + f"[VVERBOSE] actual_seq_lens_q.flatten() = {actual_seq_lens_q.flatten()!r}" + ) + cumsum_s_qo = paddle.sum(x=actual_seq_lens_q) + cumsum_s_kv = paddle.sum(x=actual_seq_lens_kv) + q = paddle.randn(shape=[cumsum_s_qo, num_qo_heads, head_dim_qk], dtype=q_init_dtype) if args.verbose >= 2: - print(f"[VVERBOSE] {q.shape = }") - - k = torch.randn( - cumsum_s_kv, num_kv_heads, head_dim_qk, device=device, dtype=kv_init_dtype + print(f"[VVERBOSE] q.shape = {tuple(q.shape)!r}") + k = paddle.randn( + shape=[cumsum_s_kv, num_kv_heads, head_dim_qk], dtype=kv_init_dtype ) - v = torch.randn( - cumsum_s_kv, num_kv_heads, head_dim_vo, device=device, dtype=kv_init_dtype + v = paddle.randn( + shape=[cumsum_s_kv, num_kv_heads, head_dim_vo], dtype=kv_init_dtype ) - block_tables = None - - ## The following are for BatchPrefillWithRaggedKVCacheWrapper actual_seq_lens_q_device = actual_seq_lens_q.to(device) actual_seq_lens_kv_device = actual_seq_lens_kv.to(device) - q_indptr = ( - torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0) + paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_q_device.view(-1), axis=0) * head_dim_qk * num_qo_heads, ] ) - .long() + .astype(dtype="int64") .to(device) - ) # For cuDNN - - k_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_kv_device.view(-1), dim=0) + ) + k_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_kv_device.view(-1), axis=0) * head_dim_qk * num_kv_heads, ] - ).long() - - v_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_kv_device.view(-1), dim=0) + ).astype(dtype="int64") + v_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_kv_device.view(-1), axis=0) * head_dim_vo * num_kv_heads, ] - ).long() - - o_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0) + ).astype(dtype="int64") + o_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_q_device.view(-1), axis=0) * head_dim_vo * num_qo_heads, ] - ).long() - - batch_offsets_stats = torch.cat( - [ - torch.zeros( - 1, - device=actual_seq_lens_q_device.device, - dtype=actual_seq_lens_q_device.dtype, - ), - torch.cumsum(actual_seq_lens_q_device.flatten(), dim=0) * num_qo_heads, + ).astype(dtype="int64") + batch_offsets_stats = paddle.concat( + x=[ + paddle.zeros(shape=[1], dtype=actual_seq_lens_q_device.dtype), + paddle.cumsum(x=actual_seq_lens_q_device.flatten(), axis=0) * num_qo_heads, ] ).cuda() - qo_indptr = ( - torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0), + paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_q_device.view(-1), axis=0), ] ) - .int() + .astype(dtype="int32") .to(device) ) - # Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr kv_indptr = ( - torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_kv_device.view(-1), dim=0), + paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_kv_device.view(-1), axis=0), ] ) - .int() + .astype(dtype="int32") .to(device) ) - - scale = float(1.0 / (head_dim_qk**0.5)) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - + scale = float(1.0 / head_dim_qk**0.5) + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") if args.verbose >= 2: - print(f"[VVERBOSE] {k.shape = }") - print(f"[VVERBOSE] {v.shape = }") - print(f"[VVERBOSE] {qo_indptr.shape = }") - print(f"[VVERBOSE] {kv_indptr.shape = }") - print(f"[VVERBOSE] {scale = }") - - # Prepare wrappers + print(f"[VVERBOSE] k.shape = {tuple(k.shape)!r}") + print(f"[VVERBOSE] v.shape = {tuple(v.shape)!r}") + print(f"[VVERBOSE] qo_indptr.shape = {tuple(qo_indptr.shape)!r}") + print(f"[VVERBOSE] kv_indptr.shape = {tuple(kv_indptr.shape)!r}") + print(f"[VVERBOSE] scale = {scale!r}") backend_wrappers = {} for backend in backends: if backend in ["cutlass", "fa2", "fa3", "trtllm-gen"]: - backend_wrappers[backend] = ( - flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - workspace_buffer, - "NHD", - use_cuda_graph=is_cuda_graph_compatible, - qo_indptr_buf=qo_indptr, - kv_indptr_buf=kv_indptr, - backend=backend, - ) + backend_wrappers[ + backend + ] = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, + "NHD", + use_cuda_graph=is_cuda_graph_compatible, + qo_indptr_buf=qo_indptr, + kv_indptr_buf=kv_indptr, + backend=backend, ) backend_wrappers[backend].plan( qo_indptr, @@ -1312,11 +1168,10 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): q_data_type=q_dtype, kv_data_type=kv_dtype, ) - k_scale, v_scale = None, None - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if q_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: q = q.to(q_dtype) - if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if kv_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: k_scale = k.amax().item() / 256 v_scale = v.amax().item() / 256 k = (k / k_scale).to(kv_dtype) @@ -1353,8 +1208,6 @@ def run_backend_wrapper(backend): if run_refcheck and "fa2" in backends: reference_output = backend_wrappers["fa2"].run_return_lse(q, k, v)[0] has_reference_output = True - - # Iterate over each backend: for cur_backend in backends: if run_refcheck: outputs[cur_backend] = run_backend_wrapper(cur_backend) @@ -1379,24 +1232,22 @@ def run_backend_wrapper(backend): l2_flush_device=device, sleep_after_run=True, ) - - # Perform reference check tested_backends = list(outputs.keys()) tested_outputs = list(outputs.values()) if len(tested_backends) > 1: if run_refcheck and has_reference_output: - if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if reference_output.dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: if args.verbose >= 2: print( "[VVERBOSE] Reference output is FP8. Converting to float32 for reference check." ) - reference_output = reference_output.to(torch.float32) - tested_outputs = [output.to(torch.float32) for output in tested_outputs] + reference_output = reference_output.to("float32") + tested_outputs = [output.to("float32") for output in tested_outputs] for i in range(len(tested_backends)): try: - torch.testing.assert_close( - reference_output, tested_outputs[i], rtol=rtol, atol=atol - ) + assert paddle.allclose( + x=reference_output, y=tested_outputs[i], rtol=rtol, atol=atol + ).item(), "" except AssertionError as e: print( f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}" @@ -1404,8 +1255,6 @@ def run_backend_wrapper(backend): if not args.allow_output_mismatch: print(e) raise - - # Compute perf metrics res = [] for backend in backends: if len(backend_times[backend]) > 0: @@ -1434,9 +1283,7 @@ def run_backend_wrapper(backend): kv_dtype=kv_dtype, o_dtype=q_dtype, ) - print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: cur_res = defaultdict(str) cur_res["routine"] = args.routine @@ -1445,7 +1292,7 @@ def run_backend_wrapper(backend): cur_res["tflops"] = tflops cur_res["tb_per_sec"] = tb_per_sec cur_res["backend"] = backend - cur_res["page_size"] = 0 # No page size for ragged + cur_res["page_size"] = 0 cur_res["batch_size"] = batch_size cur_res["s_qo"] = s_qo cur_res["s_kv"] = s_kv @@ -1483,160 +1330,123 @@ def testBatchMLAPagedAttentionWrapper(args): if args.verbose >= 1: print("[INFO] Running testBatchMLAPagedAttentionWrapper") print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - - # Basic setup device = get_device(args) if args.generate_repro_command: print( f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" ) - - q_init_dtype = torch.bfloat16 - kv_init_dtype = torch.bfloat16 - rtol = 2e-1 - atol = 1e-2 - - # Handle different query data types. + q_init_dtype = "bfloat16" + kv_init_dtype = "bfloat16" + rtol = 0.2 + atol = 0.01 q_dtype = dtype_str_to_torch_dtype(args.q_dtype) - if q_dtype not in [torch.bfloat16, torch.float8_e4m3fn]: + if q_dtype not in ["bfloat16", paddle.float8_e4m3fn]: raise ValueError(f"Unsupported q_dtype: {args.q_dtype}") - - # Handle different KV cache data types. kv_dtype = dtype_str_to_torch_dtype(args.kv_dtype) - if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn]: + if kv_dtype not in ["bfloat16", paddle.float8_e4m3fn]: raise ValueError(f"Unsupported kv_dtype: {args.kv_dtype}") - backends = args.backends page_size = args.page_size batch_size = args.batch_size s_qo = args.s_qo s_kv = args.s_kv num_qo_heads = args.num_qo_heads - # num_kv_heads not used in MLA - # head_dim_qk = args.head_dim_qk assert args.head_dim_ckv is not None, "head_dim_ckv must be provided for MLA" assert args.head_dim_kpe is not None, "head_dim_kpe must be provided for MLA" head_dim_ckv = args.head_dim_ckv head_dim_kpe = args.head_dim_kpe is_cuda_graph_compatible = not args.no_cuda_graph - causal = False # False for MLA + causal = False run_refcheck = args.refcheck - - # Check for backend-specific constraints if "fa2" in backends: remove_fa2 = False - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, + if q_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2] or kv_dtype in [ + paddle.float8_e4m3fn, +>>>>>> paddle.float8_e5m2, ]: print("[INFO] FA2 backend does not support FP8. Skipping.") remove_fa2 = True if remove_fa2: backends.remove("fa2") - - # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} outputs = {} - actual_seq_lens_kv = sample_actual_seq_lens( s_kv, batch_size, device, args.random_actual_seq_len ) - sum_seq_kv = torch.sum(actual_seq_lens_kv).item() + sum_seq_kv = paddle.sum(x=actual_seq_lens_kv).item() avg_seq_len_kv = sum_seq_kv // batch_size - if args.verbose >= 1: print(f"[VERBOSE] Average actual seq len: {avg_seq_len_kv}") if args.verbose >= 2: - print(f"[VVERBOSE] {actual_seq_lens_kv.flatten() = }") - - q_nope = torch.rand( - batch_size, num_qo_heads, head_dim_ckv, dtype=q_init_dtype, device="cuda" + print( + f"[VVERBOSE] actual_seq_lens_kv.flatten() = {actual_seq_lens_kv.flatten()!r}" + ) + q_nope = paddle.rand( + shape=[batch_size, num_qo_heads, head_dim_ckv], dtype=q_init_dtype ) - q_pe = torch.zeros( - batch_size, num_qo_heads, head_dim_kpe, dtype=q_init_dtype, device="cuda" + q_pe = paddle.zeros( + shape=[batch_size, num_qo_heads, head_dim_kpe], dtype=q_init_dtype ) - q = torch.cat([q_nope, q_pe], dim=2) - + q = paddle.concat(x=[q_nope, q_pe], axis=2) if args.verbose >= 2: - print(f"[VVERBOSE] {q_nope.shape = }") - print(f"[VVERBOSE] {q_pe.shape = }") - print(f"[VVERBOSE] {q.shape = }") - - # Create KV cache + print(f"[VVERBOSE] q_nope.shape = {tuple(q_nope.shape)!r}") + print(f"[VVERBOSE] q_pe.shape = {tuple(q_pe.shape)!r}") + print(f"[VVERBOSE] q.shape = {tuple(q.shape)!r}") num_pages_per_seq = (s_kv + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size - - # Now initialize the page tables - block_tables = torch.tensor( - [ - [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + block_tables = paddle.to_tensor( + data=[ + [(k + i * num_pages_per_seq) for k in range(num_pages_per_seq)] for i in range(batch_size) ], - dtype=torch.int, - device=device, + dtype="int32", + place=device, ) - if args.verbose >= 2: - print(f"[VVERBOSE] {num_pages_per_seq = }") - print(f"[VVERBOSE] {total_num_pages = }") - print(f"[VVERBOSE] {block_tables.shape = }") - - # Initialize KV cache with appropriate shape and stride - ckv_cache_shape = ( - total_num_pages, - page_size, - head_dim_ckv, - ) - ckv_cache = torch.randn(size=ckv_cache_shape, dtype=kv_init_dtype, device=device) - - kpe_cache_shape = ( - total_num_pages, - page_size, - head_dim_kpe, - ) - kpe_cache = torch.randn(size=kpe_cache_shape, dtype=q_init_dtype, device=device) - kv_cache = torch.cat([ckv_cache, kpe_cache], dim=2) - - qo_indptr = torch.arange(0, batch_size + 1, device=device).int() + print(f"[VVERBOSE] num_pages_per_seq = {num_pages_per_seq!r}") + print(f"[VVERBOSE] total_num_pages = {total_num_pages!r}") + print(f"[VVERBOSE] block_tables.shape = {tuple(block_tables.shape)!r}") + ckv_cache_shape = total_num_pages, page_size, head_dim_ckv + ckv_cache = paddle.randn(shape=ckv_cache_shape, dtype=kv_init_dtype) + kpe_cache_shape = total_num_pages, page_size, head_dim_kpe + kpe_cache = paddle.randn(shape=kpe_cache_shape, dtype=q_init_dtype) + kv_cache = paddle.concat(x=[ckv_cache, kpe_cache], axis=2) + qo_indptr = paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") kv_indptr = ( - torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum( - (actual_seq_lens_kv.flatten() + page_size - 1) // page_size, dim=0 + paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum( + x=(actual_seq_lens_kv.flatten() + page_size - 1) // page_size, + axis=0, ), ] ) - .int() + .astype(dtype="int32") .to(device) ) - - # kv_indices[-1] is the total number of actual pages - kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32) + kv_indices = paddle.zeros(shape=kv_indptr[-1], dtype="int32") for i in range(len(kv_indptr) - 1): start_idx = kv_indptr[i] end_idx = kv_indptr[i + 1] - kv_indices[start_idx:end_idx] = torch.arange( - i * num_pages_per_seq, - i * num_pages_per_seq + (end_idx - start_idx), - device=device, + kv_indices[start_idx:end_idx] = paddle.arange( + start=i * num_pages_per_seq, + end=i * num_pages_per_seq + (end_idx - start_idx), ) - - sm_scale = 1.0 / ((head_dim_ckv + head_dim_kpe) ** 0.5) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - + sm_scale = 1.0 / (head_dim_ckv + head_dim_kpe) ** 0.5 + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") if args.verbose >= 2: - print(f"[VVERBOSE] {ckv_cache.shape = }") - print(f"[VVERBOSE] {kpe_cache.shape = }") - print(f"[VVERBOSE] {kv_cache.shape = }") - print(f"[VVERBOSE] {qo_indptr.shape = }") - print(f"[VVERBOSE] {kv_indptr.shape = }") - print(f"[VVERBOSE] {kv_indices.shape = }") - print(f"[VVERBOSE] {actual_seq_lens_kv.shape = }") - print(f"[VVERBOSE] {sm_scale = }") - print(f"[VVERBOSE] {workspace_buffer.shape = }") - - # Create wrapper + print(f"[VVERBOSE] ckv_cache.shape = {tuple(ckv_cache.shape)!r}") + print(f"[VVERBOSE] kpe_cache.shape = {tuple(kpe_cache.shape)!r}") + print(f"[VVERBOSE] kv_cache.shape = {tuple(kv_cache.shape)!r}") + print(f"[VVERBOSE] qo_indptr.shape = {tuple(qo_indptr.shape)!r}") + print(f"[VVERBOSE] kv_indptr.shape = {tuple(kv_indptr.shape)!r}") + print(f"[VVERBOSE] kv_indices.shape = {tuple(kv_indices.shape)!r}") + print( + f"[VVERBOSE] actual_seq_lens_kv.shape = {tuple(actual_seq_lens_kv.shape)!r}" + ) + print(f"[VVERBOSE] sm_scale = {sm_scale!r}") + print(f"[VVERBOSE] workspace_buffer.shape = {tuple(workspace_buffer.shape)!r}") if "fa2" in backends: fi_fa2_mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( float_workspace_buffer=workspace_buffer, @@ -1661,12 +1471,11 @@ def testBatchMLAPagedAttentionWrapper(args): q_data_type=q_dtype, kv_data_type=kv_dtype, ) - - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if q_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: q = q.to(q_dtype) q_pe = q_pe.to(q_dtype) q_nope = q_nope.to(q_dtype) - if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if kv_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: ckv_cache = ckv_cache.to(kv_dtype) kpe_cache = kpe_cache.to(kv_dtype) kv_cache = kv_cache.to(kv_dtype) @@ -1678,10 +1487,10 @@ def run_backend_wrapper(backend): ) if backend == "trtllm-gen-native": return flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( - query=q.unsqueeze(1), - kv_cache=kv_cache.unsqueeze(1), + query=q.unsqueeze(axis=1), + kv_cache=kv_cache.unsqueeze(axis=1), workspace_buffer=workspace_buffer, - qk_nope_head_dim=128, # To-do: Why?? + qk_nope_head_dim=128, kv_lora_rank=head_dim_ckv, qk_rope_head_dim=head_dim_kpe, block_tables=block_tables, @@ -1700,8 +1509,6 @@ def run_backend_wrapper(backend): has_reference_output = True else: has_reference_output = False - - # Iterate over each backend: for cur_backend in backends: if run_refcheck: outputs[cur_backend] = run_backend_wrapper(cur_backend).detach() @@ -1726,20 +1533,18 @@ def run_backend_wrapper(backend): l2_flush_device=device, sleep_after_run=False, ) - - # Perform reference check tested_backends = list(outputs.keys()) tested_outputs = list(outputs.values()) if len(tested_backends) > 1: if run_refcheck and has_reference_output: - if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - reference_output = reference_output.to(torch.float32) - tested_outputs = [output.to(torch.float32) for output in tested_outputs] + if reference_output.dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: + reference_output = reference_output.to("float32") + tested_outputs = [output.to("float32") for output in tested_outputs] for i in range(len(tested_outputs)): try: - torch.testing.assert_close( - reference_output, tested_outputs[i], rtol=rtol, atol=atol - ) + assert paddle.allclose( + x=reference_output, y=tested_outputs[i], rtol=rtol, atol=atol + ).item(), "" except AssertionError as e: print( f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}" @@ -1747,45 +1552,40 @@ def run_backend_wrapper(backend): if not args.allow_output_mismatch: print(e) raise - - # Compute perf metrics res = [] for backend in backends: if len(backend_times[backend]) > 0: median_time = np.median(backend_times[backend]) std_time = np.std(backend_times[backend]) actual_seq_lens_kv_flat = actual_seq_lens_kv.flatten().to("cpu") - actual_seq_lens_q_flat = torch.ones_like( - actual_seq_lens_kv.flatten().to("cpu") + actual_seq_lens_q_flat = paddle.ones_like( + x=actual_seq_lens_kv.flatten().to("cpu") ) o_mem_bytes = ( - actual_seq_lens_q_flat.numel() + actual_seq_lens_q_flat.size * num_qo_heads * head_dim_ckv - * q_dtype.itemsize + * q_dtype.element_size() ) qkv_mem_bytes = sum( [ - _.numel() * _.element_size() + (_.size * _.element_size()) for _ in [q_nope, q_pe, ckv_cache, kpe_cache] ] ) total_mem_bytes = o_mem_bytes + qkv_mem_bytes - tb_per_sec = (total_mem_bytes / (median_time * 1e9)).item() + tb_per_sec = (total_mem_bytes / (median_time * 1000000000.0)).item() tflops_total = ( 2 - * torch.dot( - actual_seq_lens_q_flat.to(torch.float32), - actual_seq_lens_kv_flat.to(torch.float32), + * paddle.dot( + x=actual_seq_lens_q_flat.to("float32"), + y=actual_seq_lens_kv_flat.to("float32"), ) * num_qo_heads * (2 * head_dim_ckv + head_dim_kpe) ) - tflops = (tflops_total / (median_time * 1e9)).item() - + tflops = (tflops_total / (median_time * 1000000000.0)).item() print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - - # TO-Do: if args.output_path is not None: cur_res = defaultdict(str) cur_res["routine"] = args.routine diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index d6bc21c630..2629b22b65 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -1,17 +1,13 @@ -import torch +import sys + + +import paddle +from flashinfer.paddle_utils import * from flashinfer.testing.utils import set_seed -# Output columns for the test results. output_column_dict = { - "perf": [ - "routine", - "median_time", - "std_time", - "tflops", - "tb_per_sec", - "backend", - ], + "perf": ["routine", "median_time", "std_time", "tflops", "tb_per_sec", "backend"], "attention": [ "page_size", "batch_size", @@ -60,7 +56,6 @@ "input_dtype", "weight_dtype", "gated_act", - # CUTLASS fused MoE specific "cutlass_variant", "quantized_input", "tp_size", @@ -78,7 +73,6 @@ "repro_command", ], } - full_output_columns = ( output_column_dict["perf"] + output_column_dict["attention"] @@ -86,7 +80,6 @@ + output_column_dict["moe"] + output_column_dict["general"] ) - benchmark_apis = { "attention": [ "BatchDecodeWithPagedKVCacheWrapper", @@ -118,25 +111,27 @@ def print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec): def get_device(args): set_seed(args.random_seed) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()).replace(" ", "_") + device = device2str("cuda" if paddle.device.cuda.device_count() >= 1 else "cpu") + gpu_name = paddle.device.cuda.get_device_name( + device=paddle.device.get_device() + ).replace(" ", "_") if args.verbose >= 2: - print(f"[VVERBOSE] {gpu_name = }") + print(f"[VVERBOSE] gpu_name = {gpu_name!r}") return device def dtype_str_to_torch_dtype(dtype_str): if dtype_str == "bfloat16": - return torch.bfloat16 + return "bfloat16" elif dtype_str == "float16": - return torch.float16 + return "float16" elif dtype_str == "float32": - return torch.float32 + return "float32" elif dtype_str == "float64": - return torch.float64 + return "float64" elif dtype_str == "fp8_e4m3": - return torch.float8_e4m3fn + return paddle.float8_e4m3fn elif dtype_str == "fp8_e5m2": - return torch.float8_e5m2 +>>>>>> return paddle.float8_e5m2 else: raise ValueError(f"Unsupported dtype: {dtype_str}") diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 91315926dc..9a6d3618e7 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -1,23 +1,20 @@ +import sys + + from collections import defaultdict +import einops import numpy as np -import torch -import torch.nn.functional as F -from einops import einsum +import paddle +from flashinfer.paddle_utils import * import flashinfer -from flashinfer.testing.utils import ( - bench_gpu_time, - bench_gpu_time_with_cudagraph, - dequantize_fp8, - quantize_fp8, -) +from flashinfer.testing.utils import (bench_gpu_time, + bench_gpu_time_with_cudagraph, + dequantize_fp8, quantize_fp8) -from .flashinfer_benchmark_utils import ( - dtype_str_to_torch_dtype, - get_device, - print_perf_metrics, -) +from .flashinfer_benchmark_utils import (dtype_str_to_torch_dtype, get_device, + print_perf_metrics) def run_gemm_test(args): @@ -137,20 +134,24 @@ def parse_gemm_args(line, parser): action="store_true", help="Use 128x4 SF layout for the input and mat2.", ) - args = parser.parse_args(line) if args.verbose >= 1: - print(f"[INFO] {args = }") + print(f"[INFO] args = {args!r}") return args -def to_float8(x, dtype=torch.float8_e4m3fn): - finfo = torch.finfo(dtype) - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) +def to_float8(x, dtype=paddle.float8_e4m3fn): + finfo = paddle.finfo(dtype=dtype) + min_val, max_val = tuple( + [ + paddle.amin(x, axis=None, keepdim=False), + paddle.max(x, axis=None, keepdim=False), + ] + ) + amax = paddle.maximum(x=min_val.abs(), y=max_val.abs()).clip(min=1e-12) scale = finfo.max / amax - x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) - return x_scl_sat.to(dtype), scale.float().reciprocal() + x_scl_sat = (x * scale).clip(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.astype(dtype="float32").reciprocal() def testGemmFp8NtGroupwise(args): @@ -173,14 +174,11 @@ def testGemmFp8NtGroupwise(args): if args.verbose >= 1: print("[INFO] Running testGemmFp8NtGroupwise") print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - device = get_device(args) if args.generate_repro_command: print( f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" ) - - ## Parse input arguments backends = args.backends m = args.m n = args.n @@ -190,56 +188,43 @@ def testGemmFp8NtGroupwise(args): mma_sm = args.mma_sm is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck - out_dtype = dtype_str_to_torch_dtype(args.out_dtype) - if out_dtype not in [torch.bfloat16, torch.float16]: + if out_dtype not in ["bfloat16", "float16"]: raise ValueError(f"Unsupported output dtype: {args.out_dtype}") - ## Done parsing input arguments - if "trtllm" in backends: remove_trtllm = True print("[INFO] trtllm backend testing not supported yet") if remove_trtllm: backends.remove("trtllm") - - ## Prepare input tensors - a_val = torch.randn((m, k), dtype=torch.float, device=device) - b_val = torch.randn((n, k), dtype=torch.float, device=device) / np.sqrt(k) - + a_val = paddle.randn(shape=(m, k), dtype="float32") + b_val = paddle.randn(shape=(n, k), dtype="float32") / np.sqrt(k) if args.verbose >= 2: - print(f"[VVERBOSE] {a_val.shape = }") - print(f"[VVERBOSE] {b_val.shape = }") - + print(f"[VVERBOSE] a_val.shape = {tuple(a_val.shape)!r}") + print(f"[VVERBOSE] b_val.shape = {tuple(b_val.shape)!r}") if scale_major_mode == "K": - a_scale_shape = (m, k // tile_size) - b_scale_shape = (n // tile_size, k // tile_size) + a_scale_shape = m, k // tile_size + b_scale_shape = n // tile_size, k // tile_size else: - a_scale_shape = (k // tile_size, m) - b_scale_shape = (k // tile_size, n // tile_size) - - a_tile_shape = (1, tile_size) - b_tile_shape = (tile_size, tile_size) - + a_scale_shape = k // tile_size, m + b_scale_shape = k // tile_size, n // tile_size + a_tile_shape = 1, tile_size + b_tile_shape = tile_size, tile_size a_fp8, a_scale = quantize_fp8(a_val, a_scale_shape, a_tile_shape, scale_major_mode) b_fp8, b_scale = quantize_fp8(b_val, b_scale_shape, b_tile_shape, scale_major_mode) - if "trtllm" in backends: - a_scale_shape_trtllm = (m, k // tile_size) - b_scale_shape_trtllm = (k // tile_size, n // tile_size) - + a_scale_shape_trtllm = m, k // tile_size + b_scale_shape_trtllm = k // tile_size, n // tile_size a_fp8_trtllm, a_scale_trtllm = quantize_fp8( a_val, a_scale_shape_trtllm, a_tile_shape, "K" ) b_fp8_trtllm, b_scale_trtllm = quantize_fp8( b_val, b_scale_shape_trtllm, b_tile_shape, "MN" ) - if args.verbose >= 2: - print(f"[VVERBOSE] {a_fp8.shape = }") - print(f"[VVERBOSE] {b_fp8.shape = }") - print(f"[VVERBOSE] {a_scale.shape = }") - print(f"[VVERBOSE] {b_scale.shape = }") - + print(f"[VVERBOSE] a_fp8.shape = {tuple(a_fp8.shape)!r}") + print(f"[VVERBOSE] b_fp8.shape = {tuple(b_fp8.shape)!r}") + print(f"[VVERBOSE] a_scale.shape = {tuple(a_scale.shape)!r}") + print(f"[VVERBOSE] b_scale.shape = {tuple(b_scale.shape)!r}") a_dequant = dequantize_fp8(a_fp8, a_scale, scale_major_mode) b_dequant = dequantize_fp8(b_fp8, b_scale, scale_major_mode) @@ -271,10 +256,10 @@ def run_backend(backend): has_reference_output = False if run_refcheck: - reference_output = einsum(a_dequant, b_dequant, "m k, n k -> m n").to(out_dtype) + reference_output = einops.einsum(a_dequant, b_dequant, "m k, n k -> m n").to( + out_dtype + ) has_reference_output = True - - # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} outputs = {} for cur_backend in backends: @@ -288,7 +273,7 @@ def run_backend(backend): l2_flush=True, l2_flush_size_mb=256, l2_flush_device=device, - sleep_after_run=True, # GEMMs are very MMA-heavy, so prefer sleep to reduce throttling. + sleep_after_run=True, ) else: backend_times[cur_backend] = bench_gpu_time( @@ -298,18 +283,17 @@ def run_backend(backend): l2_flush=True, l2_flush_size_mb=256, l2_flush_device=device, - sleep_after_run=True, # GEMMs are very MMA-heavy, so prefer sleep to reduce throttling. + sleep_after_run=True, ) - tested_backends = list(outputs.keys()) tested_outputs = list(outputs.values()) if len(tested_backends) > 0: if run_refcheck and has_reference_output: for i in range(len(tested_backends)): try: - torch.testing.assert_close( - reference_output, tested_outputs[i], rtol=1e-2, atol=1e-2 - ) + assert paddle.allclose( + x=reference_output, y=tested_outputs[i], rtol=0.01, atol=0.01 + ).item(), "" except AssertionError as e: print( f"[ERROR] Output tensor mismatch from backend {tested_backends[i]}" @@ -317,21 +301,18 @@ def run_backend(backend): if not args.allow_output_mismatch: print(e) raise - res = [] for backend in backends: if len(backend_times[backend]) > 0: median_time = np.median(backend_times[backend]) std_time = np.std(backend_times[backend]) - problem_flops = 2 * m * n * k - problem_bytes = (m * k + n * k) * torch.float8_e4m3fn.itemsize + ( - m * n - ) * out_dtype.itemsize - tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec - tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + problem_bytes = ( + m * k + n * k + ) * paddle.float8_e4m3fn.itemsize + m * n * out_dtype.element_size() + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: cur_res = defaultdict(str) cur_res["routine"] = args.routine @@ -372,15 +353,12 @@ def testGroupGemmFp8NtGroupwise(args): if args.verbose >= 1: print("[INFO] Running testGroupGemmFp8NtGroupwise") print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - device = get_device(args) if args.generate_repro_command: print( f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" ) - - ## Parse input arguments - backends = ["cutlass"] # Cutlass is currently the only supported backend + backends = ["cutlass"] m = args.m n = args.n k = args.k @@ -390,45 +368,33 @@ def testGroupGemmFp8NtGroupwise(args): mma_sm = args.mma_sm is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck - out_dtype = dtype_str_to_torch_dtype(args.out_dtype) - if out_dtype not in [torch.bfloat16, torch.float16]: + if out_dtype not in ["bfloat16", "float16"]: raise ValueError(f"Unsupported output dtype: {args.out_dtype}") - ## Done parsing input arguments - ## Prepare input tensors - a_val = torch.randn((group_size * m, k), dtype=torch.float, device="cuda") - b_val = torch.randn((group_size, n, k), dtype=torch.float, device="cuda") / np.sqrt( - k - ) - + a_val = paddle.randn(shape=(group_size * m, k), dtype="float32") + b_val = paddle.randn(shape=(group_size, n, k), dtype="float32") / np.sqrt(k) if args.verbose >= 2: - print(f"[VVERBOSE] {a_val.shape = }") - print(f"[VVERBOSE] {b_val.shape = }") - + print(f"[VVERBOSE] a_val.shape = {tuple(a_val.shape)!r}") + print(f"[VVERBOSE] b_val.shape = {tuple(b_val.shape)!r}") if scale_major_mode == "K": - a_scale_shape = (group_size * m, k // tile_size) - b_scale_shape = (group_size, n // tile_size, k // tile_size) + a_scale_shape = group_size * m, k // tile_size + b_scale_shape = group_size, n // tile_size, k // tile_size else: - a_scale_shape = (k // tile_size, m * group_size) - b_scale_shape = (group_size, k // tile_size, n // tile_size) - - a_tile_shape = (1, tile_size) - b_tile_shape = (1, tile_size, tile_size) - + a_scale_shape = k // tile_size, m * group_size + b_scale_shape = group_size, k // tile_size, n // tile_size + a_tile_shape = 1, tile_size + b_tile_shape = 1, tile_size, tile_size a_fp8, a_scale = quantize_fp8(a_val, a_scale_shape, a_tile_shape, scale_major_mode) b_fp8, b_scale = quantize_fp8(b_val, b_scale_shape, b_tile_shape, scale_major_mode) - a_dequant = dequantize_fp8(a_fp8, a_scale, scale_major_mode) b_dequant = dequantize_fp8(b_fp8, b_scale, scale_major_mode) - - m_indptr = torch.arange(0, group_size + 1, dtype=torch.int32, device="cuda") * m - + m_indptr = paddle.arange(start=0, end=group_size + 1, dtype="int32") * m if args.verbose >= 2: - print(f"[VVERBOSE] {a_fp8.shape = }") - print(f"[VVERBOSE] {b_fp8.shape = }") - print(f"[VVERBOSE] {a_scale.shape = }") - print(f"[VVERBOSE] {b_scale.shape = }") - print(f"[VVERBOSE] {m_indptr.shape = }") + print(f"[VVERBOSE] a_fp8.shape = {tuple(a_fp8.shape)!r}") + print(f"[VVERBOSE] b_fp8.shape = {tuple(b_fp8.shape)!r}") + print(f"[VVERBOSE] a_scale.shape = {tuple(a_scale.shape)!r}") + print(f"[VVERBOSE] b_scale.shape = {tuple(b_scale.shape)!r}") + print(f"[VVERBOSE] m_indptr.shape = {tuple(m_indptr.shape)!r}") def run_backend(backend): if backend == "cutlass": @@ -448,15 +414,13 @@ def run_backend(backend): has_reference_output = False if run_refcheck: reference_output = ( - einsum( + einops.einsum( a_dequant.view((group_size, m, k)), b_dequant, "b m k, b n k -> b m n" ) .view((group_size * m, n)) .to(out_dtype) ) has_reference_output = True - - # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} outputs = {} for cur_backend in backends: @@ -470,7 +434,7 @@ def run_backend(backend): l2_flush=True, l2_flush_size_mb=256, l2_flush_device=device, - sleep_after_run=True, # GEMMs are very MMA-heavy, so prefer sleep to reduce throttling. + sleep_after_run=True, ) else: backend_times[cur_backend] = bench_gpu_time( @@ -480,18 +444,17 @@ def run_backend(backend): l2_flush=True, l2_flush_size_mb=256, l2_flush_device=device, - sleep_after_run=True, # GEMMs are very MMA-heavy, so prefer sleep to reduce throttling. + sleep_after_run=True, ) - tested_backends = list(outputs.keys()) tested_outputs = list(outputs.values()) if len(tested_backends) > 0: if run_refcheck and has_reference_output: for i in range(len(tested_backends)): try: - torch.testing.assert_close( - reference_output, tested_outputs[i], rtol=1e-2, atol=1e-2 - ) + assert paddle.allclose( + x=reference_output, y=tested_outputs[i], rtol=0.01, atol=0.01 + ).item(), "" except AssertionError as e: print( f"[ERROR] Output tensor mismatch from backend {tested_backends[i]}" @@ -499,7 +462,6 @@ def run_backend(backend): if not args.allow_output_mismatch: print(e) raise - res = [] for backend in backends: if len(backend_times[backend]) > 0: @@ -507,13 +469,12 @@ def run_backend(backend): std_time = np.std(backend_times[backend]) problem_flops = 2 * m * n * k * group_size problem_bytes = ( - group_size * m * k + group_size * n * k - ) * torch.float8_e4m3fn.itemsize + (group_size * m * n) * out_dtype.itemsize - tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec - tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec - + (group_size * m * k + group_size * n * k) * paddle.float8_e4m3fn.itemsize + + group_size * m * n * out_dtype.element_size() + ) + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: cur_res = defaultdict(str) cur_res["routine"] = args.routine @@ -555,14 +516,11 @@ def testBmmFp8(args): if args.verbose >= 1: print("[INFO] Running testBmmFp8") print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - device = get_device(args) if args.generate_repro_command: print( f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" ) - - ## Parse input arguments backends = args.backends batch_size = args.batch_size m = args.m @@ -574,44 +532,38 @@ def testBmmFp8(args): backends = args.backends is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck - input_dtype = dtype_str_to_torch_dtype(args.input_dtype) - if input_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + if input_dtype not in [paddle.float8_e4m3fn, paddle.float8_e5m2]: raise ValueError( f"Unsupported input dtype: {input_dtype}. Supported dtypes are fp8_e4m3 and fp8_e5m2." ) - mat2_dtype = dtype_str_to_torch_dtype(args.mat2_dtype) - if mat2_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + if mat2_dtype not in [paddle.float8_e4m3fn, paddle.float8_e5m2]: raise ValueError( f"Unsupported mat2 dtype: {mat2_dtype}. Supported dtypes are fp8_e4m3 and fp8_e5m2." ) - res_dtype = dtype_str_to_torch_dtype(args.out_dtype) - if res_dtype not in [torch.bfloat16, torch.float16]: + if res_dtype not in ["bfloat16", "float16"]: raise ValueError( f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16." ) - ## Done parsing input arguments - - ## Prepare input tensors - input = torch.randn([batch_size, m, k], device=device, dtype=torch.bfloat16) + input = paddle.randn(shape=[batch_size, m, k], dtype="bfloat16") input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) - - mat2 = torch.randn( - [batch_size, n, k], device=device, dtype=torch.bfloat16 - ).transpose(-2, -1) + mat2 = paddle.randn(shape=[batch_size, n, k], dtype="bfloat16").transpose( + perm=dim2perm( + paddle.randn(shape=[batch_size, n, k], dtype="bfloat16").ndim, -2, -1 + ) + ) mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype) - if args.verbose >= 2: - print(f"[VVERBOSE] {input_fp8.shape = }") - print(f"[VVERBOSE] {input_fp8.dtype = }") - print(f"[VVERBOSE] {mat2_fp8.shape = }") - print(f"[VVERBOSE] {mat2_fp8.dtype = }") - print(f"[VVERBOSE] {input_inv_s = }") - print(f"[VVERBOSE] {input_inv_s.dtype = }") - print(f"[VVERBOSE] {mat2_inv_s = }") - print(f"[VVERBOSE] {mat2_inv_s.dtype = }") + print(f"[VVERBOSE] input_fp8.shape = {tuple(input_fp8.shape)!r}") + print(f"[VVERBOSE] input_fp8.dtype = {input_fp8.dtype!r}") + print(f"[VVERBOSE] mat2_fp8.shape = {tuple(mat2_fp8.shape)!r}") + print(f"[VVERBOSE] mat2_fp8.dtype = {mat2_fp8.dtype!r}") + print(f"[VVERBOSE] input_inv_s = {input_inv_s!r}") + print(f"[VVERBOSE] input_inv_s.dtype = {input_inv_s.dtype!r}") + print(f"[VVERBOSE] mat2_inv_s = {mat2_inv_s!r}") + print(f"[VVERBOSE] mat2_inv_s.dtype = {mat2_inv_s.dtype!r}") def run_backend(backend): if backend in ["cudnn", "cublas", "cutlass"]: @@ -628,10 +580,8 @@ def run_backend(backend): has_reference_output = False if run_refcheck: - reference_output = torch.bmm(input, mat2) + reference_output = paddle.bmm(x=input, y=mat2) has_reference_output = True - - # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} outputs = {} for cur_backend in backends: @@ -658,23 +608,22 @@ def run_backend(backend): l2_flush_device=device, sleep_after_run=True, ) - tested_backends = list(outputs.keys()) tested_outputs = list(outputs.values()) if len(tested_backends) > 0: if run_refcheck and has_reference_output: - if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if reference_output.dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: print( "[INFO] Reference output is FP8. Converting to float32 for reference check." ) - reference_output = reference_output.to(torch.float32) - tested_outputs = [output.to(torch.float32) for output in tested_outputs] + reference_output = reference_output.to("float32") + tested_outputs = [output.to("float32") for output in tested_outputs] for i in range(len(tested_backends)): try: - cos_sim = F.cosine_similarity( - reference_output.reshape(-1), - tested_outputs[i].reshape(-1), - dim=0, + cos_sim = paddle.nn.functional.cosine_similarity( + x1=reference_output.reshape(-1), + x2=tested_outputs[i].reshape(-1), + axis=0, ) assert cos_sim > 0.99 except AssertionError as e: @@ -684,7 +633,6 @@ def run_backend(backend): if not args.allow_output_mismatch: print(e) raise - res = [] for backend in backends: if len(backend_times[backend]) > 0: @@ -692,14 +640,13 @@ def run_backend(backend): std_time = np.std(backend_times[backend]) problem_flops = 2 * m * n * k * batch_size problem_bytes = ( - m * k * input_dtype.itemsize - + n * k * mat2_dtype.itemsize - + m * n * res_dtype.itemsize + m * k * input_dtype.element_size() + + n * k * mat2_dtype.element_size() + + m * n * res_dtype.element_size() ) - tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec - tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: cur_res = defaultdict(str) cur_res["batch_size"] = batch_size @@ -740,14 +687,11 @@ def testMmFp4(args): if args.verbose >= 1: print("[INFO] Running testMmFp4") print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - device = get_device(args) if args.generate_repro_command: print( f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" ) - - ## Parse input arguments backends = args.backends m = args.m n = args.n @@ -757,17 +701,14 @@ def testMmFp4(args): is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck use_128x4_sf_layout = args.use_128x4_sf_layout - res_dtype = dtype_str_to_torch_dtype(args.out_dtype) - if res_dtype not in [torch.bfloat16, torch.float16]: + if res_dtype not in ["bfloat16", "float16"]: raise ValueError( f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16." ) - ## Done parsing input arguments - if "trtllm" in backends: remove_trtllm = False - if res_dtype == torch.float16: + if res_dtype == "float16": print("[INFO] trtllm backend does not suppot float16 output") remove_trtllm = True if remove_trtllm: @@ -786,22 +727,18 @@ def testMmFp4(args): remove_cudnn = True if remove_cudnn: backends.remove("cudnn") - if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") return - - input = torch.randn([m, k], device=device, dtype=torch.bfloat16) - mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16) + input = paddle.randn(shape=[m, k], dtype="bfloat16") + mat2 = paddle.randn(shape=[n, k], dtype="bfloat16") a_sf_layout = ( flashinfer.SfLayout.layout_128x4 if use_128x4_sf_layout else flashinfer.SfLayout.layout_8x4 ) - - global_sf_input = (448 * 6) / input.float().abs().nan_to_num().max() - global_sf_mat2 = (448 * 6) / mat2.float().abs().nan_to_num().max() - + global_sf_input = 448 * 6 / input.astype(dtype="float32").abs().nan_to_num()._max() + global_sf_mat2 = 448 * 6 / mat2.astype(dtype="float32").abs().nan_to_num()._max() input_fp4, input_inv_s = flashinfer.nvfp4_quantize( input, global_sf_input, sfLayout=a_sf_layout, do_shuffle=False ) @@ -818,15 +755,12 @@ def testMmFp4(args): sfLayout=flashinfer.SfLayout.layout_128x4, do_shuffle=True, ) - if args.verbose >= 2: - print(f"[VVERBOSE] {input_fp4.shape = }") - print(f"[VVERBOSE] {input_fp4.dtype = }") - print(f"[VVERBOSE] {mat2_fp4.shape = }") - print(f"[VVERBOSE] {mat2_fp4.dtype = }") - + print(f"[VVERBOSE] input_fp4.shape = {tuple(input_fp4.shape)!r}") + print(f"[VVERBOSE] input_fp4.dtype = {input_fp4.dtype!r}") + print(f"[VVERBOSE] mat2_fp4.shape = {tuple(mat2_fp4.shape)!r}") + print(f"[VVERBOSE] mat2_fp4.dtype = {mat2_fp4.dtype!r}") alpha = 1.0 / (global_sf_input * global_sf_mat2) - # res = torch.empty([m, n], device="cuda", dtype=res_dtype) def run_backend(backend): if backend in ["cudnn", "trtllm", "cutlass"]: @@ -837,7 +771,7 @@ def run_backend(backend): b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T, alpha=alpha, out_dtype=res_dtype, - block_size=16, # Only supports 16 + block_size=16, use_8x4_sf_layout=not use_128x4_sf_layout, backend=backend, ) @@ -846,10 +780,8 @@ def run_backend(backend): has_reference_output = False if run_refcheck: - reference_output = torch.mm(input, mat2.T) + reference_output = paddle.mm(input=input, mat2=mat2.T) has_reference_output = True - - # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} outputs = {} for cur_backend in backends: @@ -876,17 +808,16 @@ def run_backend(backend): l2_flush_device=device, sleep_after_run=True, ) - tested_backends = list(outputs.keys()) tested_outputs = list(outputs.values()) if len(tested_backends) > 0: if run_refcheck and has_reference_output: for i in range(len(tested_backends)): try: - cos_sim = F.cosine_similarity( - reference_output.reshape(-1), - tested_outputs[i].reshape(-1), - dim=0, + cos_sim = paddle.nn.functional.cosine_similarity( + x1=reference_output.reshape(-1), + x2=tested_outputs[i].reshape(-1), + axis=0, ) assert cos_sim > 0.97 except AssertionError as e: @@ -896,20 +827,16 @@ def run_backend(backend): if not args.allow_output_mismatch: print(e) raise - res = [] for backend in backends: if len(backend_times[backend]) > 0: median_time = np.median(backend_times[backend]) std_time = np.std(backend_times[backend]) problem_flops = 2 * m * n * k - problem_bytes = ( - m * k * 0.5 + n * k * 0.5 + m * n * res_dtype.itemsize - ) # 0.5 for fp4 - tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec - tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + problem_bytes = m * k * 0.5 + n * k * 0.5 + m * n * res_dtype.element_size() + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: cur_res = defaultdict(str) cur_res["routine"] = args.routine diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index de1c49c98e..98f91e6fa3 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -1,29 +1,26 @@ +import sys + + from collections import defaultdict from typing import Optional import numpy as np -import torch +import paddle +from flashinfer.paddle_utils import * import flashinfer -from flashinfer.fused_moe import ( - WeightLayout, - trtllm_fp4_block_scale_moe, - trtllm_fp8_block_scale_moe, - trtllm_fp8_per_tensor_scale_moe, - cutlass_fused_moe, - convert_to_block_layout, -) from flashinfer import fp4_quantize, shuffle_matrix_a -from flashinfer.testing.utils import ( - bench_gpu_time, - bench_gpu_time_with_cudagraph, -) +from flashinfer.autotuner import autotune +from flashinfer.fused_moe import (WeightLayout, convert_to_block_layout, + cutlass_fused_moe, + trtllm_fp4_block_scale_moe, + trtllm_fp8_block_scale_moe, + trtllm_fp8_per_tensor_scale_moe) +from flashinfer.testing.utils import (bench_gpu_time, + bench_gpu_time_with_cudagraph) -from .flashinfer_benchmark_utils import ( - dtype_str_to_torch_dtype, - get_device, - print_perf_metrics, -) +from .flashinfer_benchmark_utils import (dtype_str_to_torch_dtype, get_device, + print_perf_metrics) def run_moe_test(args): @@ -127,16 +124,8 @@ def parse_moe_args(line, parser): type=str, required=False, default="deepseek_v3", - choices=[ - "renormalize", - "deepseek_v3", - "llama4", - "renormalize_naive", - "topk", - ], - help=( - "Routing method: renormalize | deepseek_v3 | llama4 | renormalize_naive | topk." - ), + choices=["renormalize", "deepseek_v3", "llama4", "renormalize_naive", "topk"], + help="Routing method: renormalize | deepseek_v3 | llama4 | renormalize_naive | topk.", ) parser.add_argument( "--use_shuffled_weight", @@ -186,8 +175,12 @@ def parse_moe_args(line, parser): choices=["swiglu", "geglu"], help="Type of gated activation function: swiglu | geglu.", ) - - # CUTLASS fused MoE specific + parser.add_argument( + "--autotune", + action="store_true", + default=False, + help="Enable autotuner warmup for supported routines (trtllm_fp4_block_scale_moe and cutlass_fused_moe).", + ) parser.add_argument( "--cutlass_variant", type=str, @@ -230,10 +223,7 @@ def parse_moe_args(line, parser): default=0, help="Expert parallel rank for cutlass_fused_moe.", ) - args = parser.parse_args(line) - - # Normalize routing method (map string to internal int expected by kernels) routing_method_name_to_type = { "renormalize": 1, "deepseek_v3": 2, @@ -242,16 +232,10 @@ def parse_moe_args(line, parser): "topk": 5, } args.routing_method_type = routing_method_name_to_type[args.routing_method] - - # Normalize gated act type (map string to internal int expected by kernels) - gated_act_name_to_type = { - "swiglu": 0, - "geglu": 1, - } + gated_act_name_to_type = {"swiglu": 0, "geglu": 1} args.gated_act_type = gated_act_name_to_type[args.gated_act] - if args.verbose >= 1: - print(f"[INFO] {args = }") + print(f"[INFO] args = {args!r}") return args @@ -262,9 +246,9 @@ def create_trtllm_moe_test_data( num_experts: int, routing_method_type: int, use_routing_bias: bool, - input_dtype: torch.dtype, - weight_dtype: torch.dtype, - device: torch.device, + input_dtype: paddle.dtype, + weight_dtype: paddle.dtype, + device: str, moe_kernel_type: str = "fp8_per_tensor", ): """ @@ -278,71 +262,45 @@ def create_trtllm_moe_test_data( Returns: Tuple of tensors needed for trtllm fused MoE computation """ - # Create routing logits - dtype depends on both routing method AND MOE kernel type - # Different MOE kernels have different routing_logits dtype requirements: - if moe_kernel_type == "fp8_block_scale": - # FP8 block scale MOE always expects float32 routing logits (line 333 in kernel_launcher.cu) - routing_logits = torch.randn( - (num_tokens, num_experts), device=device, dtype=torch.float32 - ) + routing_logits = paddle.randn(shape=(num_tokens, num_experts), dtype="float32") elif moe_kernel_type == "fp8_per_tensor": - # FP8 per-tensor MOE dtype depends on use_routing_scales_on_input parameter - # For Llama4: use_routing_scales_on_input=True -> bfloat16 - # For others: use_routing_scales_on_input=False -> float32 - if routing_method_type == 3: # Llama4 uses routing scales on input - routing_logits = torch.randn( - (num_tokens, num_experts), device=device, dtype=torch.bfloat16 + if routing_method_type == 3: + routing_logits = paddle.randn( + shape=(num_tokens, num_experts), dtype="bfloat16" ) else: - routing_logits = torch.randn( - (num_tokens, num_experts), device=device, dtype=torch.float32 + routing_logits = paddle.randn( + shape=(num_tokens, num_experts), dtype="float32" ) elif moe_kernel_type == "fp4_block_scale": - # FP4 block scale MOE follows the test pattern: float32 for DeepSeekV3, bfloat16 for others - if routing_method_type == 2: # DeepSeekV3 - uses float32 - routing_logits = torch.randn( - (num_tokens, num_experts), device=device, dtype=torch.float32 + if routing_method_type == 2: + routing_logits = paddle.randn( + shape=(num_tokens, num_experts), dtype="float32" ) - else: # All other routing methods (Renormalize, RenormalizeNaive, Llama4) - use bfloat16 - routing_logits = torch.randn( - (num_tokens, num_experts), device=device, dtype=torch.bfloat16 + else: + routing_logits = paddle.randn( + shape=(num_tokens, num_experts), dtype="bfloat16" ) else: raise ValueError(f"Unknown MOE kernel type: {moe_kernel_type}") - - # Create routing bias if needed - always bfloat16 routing_bias = None if use_routing_bias: - routing_bias = torch.randn(num_experts, device=device, dtype=torch.bfloat16) - - # Create hidden states - always start with bfloat16 for proper quantization - hidden_states = 2 * torch.randn( - (num_tokens, hidden_size), device=device, dtype=torch.bfloat16 - ) - - # Create weights - always start with bfloat16 for proper quantization - gemm1_weights = torch.randn( - (num_experts, 2 * intermediate_size, hidden_size), - device=device, - dtype=torch.bfloat16, + routing_bias = paddle.randn(shape=num_experts, dtype="bfloat16") + hidden_states = 2 * paddle.randn(shape=(num_tokens, hidden_size), dtype="bfloat16") + gemm1_weights = paddle.randn( + shape=(num_experts, 2 * intermediate_size, hidden_size), dtype="bfloat16" ) - gemm2_weights = torch.randn( - (num_experts, hidden_size, intermediate_size), - device=device, - dtype=torch.bfloat16, + gemm2_weights = paddle.randn( + shape=(num_experts, hidden_size, intermediate_size), dtype="bfloat16" ) - - return routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights + return (routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights) def calculate_fp4_global_scale_factor(tensor): """Calculate global scale factor for FP4 quantization.""" - # Calculate as a tensor on the same device - # Using the same formula as in test files: FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax - tensor_amax = tensor.abs().max().to(torch.float32) - # FLOAT8_E4M3_MAX = 448, FLOAT4_E2M1_MAX = 6 - global_scale = (448.0 * 6.0) / tensor_amax + tensor_amax = tensor.abs()._max().to("float32") + global_scale = 448.0 * 6.0 / tensor_amax return global_scale @@ -352,12 +310,9 @@ def quant_fp4_simple(a, a_global_sf, use_ue8m0=False, is_sf_swizzled_layout=True In production, use the actual fp4_quantize function. """ sf_vec_size = 16 - - # Use the actual fp4_quantize function from flashinfer a_fp4, a_sf = fp4_quantize( a, a_global_sf, sf_vec_size, use_ue8m0, is_sf_swizzled_layout ) - return a_fp4, a_sf, a_global_sf @@ -369,7 +324,6 @@ def quant_fp4_batches_simple( sfs = [] global_sfs = [] for i in range(num_experts): - # Calculate global scale factor (returns tensor) a_global_sf = calculate_fp4_global_scale_factor(a[i]) a_fp4, a_sf, _ = quant_fp4_simple( a[i], a_global_sf, use_ue8m0, is_sf_swizzled_layout @@ -377,11 +331,9 @@ def quant_fp4_batches_simple( quant_a.append(a_fp4) sfs.append(a_sf) global_sfs.append(a_global_sf) - - result_quant_a = torch.stack(quant_a) - result_sfs = torch.stack(sfs) - result_global_sfs = torch.stack(global_sfs) - + result_quant_a = paddle.stack(x=quant_a) + result_sfs = paddle.stack(x=sfs) + result_global_sfs = paddle.stack(x=global_sfs) return result_quant_a, result_sfs, result_global_sfs @@ -404,14 +356,11 @@ def calculate_moe_tflops( For each token, we only compute for top_k experts. """ - # FLOPS per token per expert (base calculation) flops_per_token_per_expert = ( - 2 * hidden_size * 2 * intermediate_size # First GEMM - + 2 * intermediate_size * hidden_size # Second GEMM + 2 * hidden_size * 2 * intermediate_size + 2 * intermediate_size * hidden_size ) - total_flops = num_tokens * top_k * flops_per_token_per_expert - tflops = total_flops / (time_ms * 1e-3) / 1e12 # Convert to TFLOPS + tflops = total_flops / (time_ms * 0.001) / 1000000000000.0 return tflops @@ -422,11 +371,11 @@ def calculate_moe_bandwidth( num_experts: int, top_k: int, time_ms: float, - input_dtype: torch.dtype, - weight_dtype: torch.dtype, + input_dtype: paddle.dtype, + weight_dtype: paddle.dtype, input_format: Optional[str] = None, weight_format: Optional[str] = None, - routing_logits_dtype: Optional[torch.dtype] = torch.float32, + routing_logits_dtype: Optional[paddle.dtype] = "float32", active_experts: Optional[int] = None, ) -> float: """ @@ -438,62 +387,57 @@ def calculate_moe_bandwidth( routing_logits_dtype: Dtype for routing logits memory accounting (default float32) """ - # Get effective byte sizes - def get_effective_bytes(dtype: torch.dtype, fmt: Optional[str]) -> float: + def get_effective_bytes(dtype: paddle.dtype, fmt: Optional[str]) -> float: if fmt == "fp4": return 0.5 if fmt == "fp8": return 1.0 - return dtype.itemsize + return dtype.element_size() input_bytes_per_element = get_effective_bytes(input_dtype, input_format) weight_bytes_per_element = get_effective_bytes(weight_dtype, weight_format) - - # Input memory: hidden states + routing logits - # Note: routing logits dtype depends on kernel; pass in when known, default float32; None means excluded routing_logits_bytes = ( - 0 if routing_logits_dtype is None else routing_logits_dtype.itemsize + 0 if routing_logits_dtype is None else routing_logits_dtype.element_size() ) input_bytes = ( - # Count hidden states once; kernels typically reuse inputs for multiple experts num_tokens * hidden_size * input_bytes_per_element + num_tokens * num_experts * routing_logits_bytes ) - - # Weight memory (reuse weights across tokens by grouping tokens per expert) - # Assume each active expert's weights are read once per run. weight_bytes_per_expert = ( - 2 * intermediate_size * hidden_size * weight_bytes_per_element # gemm1 - + hidden_size * intermediate_size * weight_bytes_per_element # gemm2 + 2 * intermediate_size * hidden_size * weight_bytes_per_element + + hidden_size * intermediate_size * weight_bytes_per_element ) if active_experts is not None: num_active_experts = active_experts else: num_active_experts = min(num_experts, top_k * num_tokens) weight_bytes = num_active_experts * weight_bytes_per_expert - - # Output memory (typically full precision) - output_bytes = num_tokens * hidden_size * input_dtype.itemsize - + output_bytes = num_tokens * hidden_size * input_dtype.element_size() total_bytes = input_bytes + weight_bytes + output_bytes - tb_per_sec = total_bytes / (time_ms * 1e-3) / 1e12 # Convert to TB/sec + tb_per_sec = total_bytes / (time_ms * 0.001) / 1000000000000.0 return tb_per_sec -def _compute_routing(router_logits: torch.Tensor, top_k: int): - routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.float() +def _compute_routing(router_logits: paddle.Tensor, top_k: int): + routing_weights = paddle.nn.functional.softmax( + x=router_logits, axis=1, dtype="float32" + ) + routing_weights, selected_experts = paddle.topk(x=routing_weights, k=top_k, axis=-1) + routing_weights /= routing_weights.sum(axis=-1, keepdim=True) + routing_weights = routing_weights.astype(dtype="float32") return routing_weights, selected_experts -def _dynamic_per_tensor_fp8_quant(x: torch.Tensor): - fp8_max = torch.finfo(torch.float8_e4m3fn).max - x_max = x.abs().max().float().clamp(min=1e-6) +def _dynamic_per_tensor_fp8_quant(x: paddle.Tensor): + fp8_max = paddle.finfo(dtype=paddle.float8_e4m3fn).max + x_max = x.abs()._max().astype(dtype="float32").clip(min=1e-06) scale = x_max / fp8_max inv_scale = 1.0 / scale - out = (x.float() * inv_scale).clamp(-fp8_max, fp8_max).to(torch.float8_e4m3fn) + out = ( + (x.astype(dtype="float32") * inv_scale) + .clip(min=-fp8_max, max=fp8_max) + .to(paddle.float8_e4m3fn) + ) return out, scale.view((1,)) @@ -515,17 +459,13 @@ def testTrtllmFp4BlockScaleMoe(args): if args.verbose >= 1: print("[INFO] Running testTrtllmFp4BlockScaleMoe") print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - device = get_device(args) if args.generate_repro_command: print( f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" ) - input_dtype = dtype_str_to_torch_dtype(args.input_dtype) weight_dtype = dtype_str_to_torch_dtype(args.weight_dtype) - - # Parse configuration num_tokens = args.num_tokens hidden_size = args.hidden_size intermediate_size = args.intermediate_size @@ -557,107 +497,86 @@ def testTrtllmFp4BlockScaleMoe(args): weight_layout = args.weight_layout is_cuda_graph_compatible = not args.no_cuda_graph gated_act_type = args.gated_act_type - if args.verbose >= 1: print( - f"[INFO] Configuration: tokens={num_tokens}, hidden={hidden_size}, " - f"intermediate={intermediate_size}, experts={num_experts}, top_k={top_k}" - ) - - # Create test data - routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights = ( - create_trtllm_moe_test_data( - num_tokens, - hidden_size, - intermediate_size, - num_experts, - routing_method_type, - args.use_routing_bias, - input_dtype, - weight_dtype, - device, - moe_kernel_type="fp4_block_scale", + f"[INFO] Configuration: tokens={num_tokens}, hidden={hidden_size}, intermediate={intermediate_size}, experts={num_experts}, top_k={top_k}" ) + ( + routing_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + ) = create_trtllm_moe_test_data( + num_tokens, + hidden_size, + intermediate_size, + num_experts, + routing_method_type, + args.use_routing_bias, + input_dtype, + weight_dtype, + device, + moe_kernel_type="fp4_block_scale", ) - - # For FP4, we need to properly quantize weights and create scales use_ue8m0 = False - - # Calculate global scale factor for hidden states hidden_states_scale_global = calculate_fp4_global_scale_factor(hidden_states) - - # Quantize weights using proper FP4 quantization - gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = ( - quant_fp4_batches_simple(gemm1_weights, num_experts, use_ue8m0, True) - ) - gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = ( - quant_fp4_batches_simple(gemm2_weights, num_experts, use_ue8m0, True) - ) - - # Quantize hidden states + ( + gemm1_weights_fp4_bytes, + gemm1_scales_fp4_bytes, + gemm1_scales_global, + ) = quant_fp4_batches_simple(gemm1_weights, num_experts, use_ue8m0, True) + ( + gemm2_weights_fp4_bytes, + gemm2_scales_fp4_bytes, + gemm2_scales_global, + ) = quant_fp4_batches_simple(gemm2_weights, num_experts, use_ue8m0, True) hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, _ = quant_fp4_simple( hidden_states, hidden_states_scale_global, use_ue8m0, True ) - - # Reshape hidden states for the kernel (pack 2 FP4 values into 1 byte) - # Keep as uint8 format for FP4 packed data - hidden_states_fp4 = hidden_states_fp4_bytes.view(torch.uint8).reshape( - hidden_states.shape[0], hidden_states.shape[1] // 2 + hidden_states_fp4 = hidden_states_fp4_bytes.view("uint8").reshape( + tuple(hidden_states.shape)[0], tuple(hidden_states.shape)[1] // 2 ) hidden_states_scale_linear_fp4 = hidden_states_scale_fp4_bytes.view( - torch.float8_e4m3fn - ).reshape(-1) - # Ensure expected vector size (16 elements per hidden value for NvFP4) - expected_scale_elems = (num_tokens * hidden_size) // 16 - if hidden_states_scale_linear_fp4.numel() != expected_scale_elems: + paddle.float8_e4m3fn + ) + expected_scale_elems = num_tokens * hidden_size // 16 + if hidden_states_scale_linear_fp4.size != expected_scale_elems: if args.verbose >= 1: print( - f"[INFO] Adjusting FP4 hidden_states_scale from {hidden_states_scale_linear_fp4.numel()} to {expected_scale_elems} elements" + f"[INFO] Adjusting FP4 hidden_states_scale from {hidden_states_scale_linear_fp4.size} to {expected_scale_elems} elements" ) - hidden_states_scale_linear_fp4 = torch.ones( - expected_scale_elems, device=device, dtype=torch.float8_e4m3fn + hidden_states_scale_linear_fp4 = paddle.ones( + shape=expected_scale_elems, dtype=paddle.float8_e4m3fn ) - - # Prepare weights for kernel - # For FP4 weights, keep them as uint8 (packed format) - don't convert to float8_e4m3fn - gemm1_weights_fp4 = gemm1_weights_fp4_bytes.view(torch.uint8).reshape( + hidden_states_scale_linear_fp4 = hidden_states_scale_linear_fp4.reshape( + num_tokens, hidden_size // 16 + ) + gemm1_weights_fp4 = gemm1_weights_fp4_bytes.view("uint8").reshape( num_experts, 2 * intermediate_size, hidden_size // 2 ) - # Scale factors should be viewed as float8_e4m3fn - gemm1_weights_scale = gemm1_scales_fp4_bytes.view(torch.float8_e4m3fn).reshape( + gemm1_weights_scale = gemm1_scales_fp4_bytes.view(paddle.float8_e4m3fn).reshape( num_experts, 2 * intermediate_size, hidden_size // 16 ) - - gemm2_weights_fp4 = gemm2_weights_fp4_bytes.view(torch.uint8).reshape( + gemm2_weights_fp4 = gemm2_weights_fp4_bytes.view("uint8").reshape( num_experts, hidden_size, intermediate_size // 2 ) - gemm2_weights_scale = gemm2_scales_fp4_bytes.view(torch.float8_e4m3fn).reshape( + gemm2_weights_scale = gemm2_scales_fp4_bytes.view(paddle.float8_e4m3fn).reshape( num_experts, hidden_size, intermediate_size // 16 ) - - # Optional parameters for FP4 (using None for simplicity in benchmarking) gemm1_bias = None gemm1_alpha = None gemm1_beta = None gemm1_clamp_limit = None gemm2_bias = None - - # Create scale scalars (simplified - in practice these would be computed) - output1_scale_scalar = torch.ones( - local_num_experts, device=device, dtype=torch.float32 - ) - output1_scale_gate_scalar = torch.ones( - local_num_experts, device=device, dtype=torch.float32 - ) - output2_scale_scalar = torch.ones( - local_num_experts, device=device, dtype=torch.float32 - ) - + output1_scale_scalar = paddle.ones(shape=local_num_experts, dtype="float32") + output1_scale_gate_scalar = paddle.ones(shape=local_num_experts, dtype="float32") + output2_scale_scalar = paddle.ones(shape=local_num_experts, dtype="float32") if args.verbose >= 2: - print(f"[VVERBOSE] routing_logits.shape = {routing_logits.shape}") - print(f"[VVERBOSE] hidden_states.shape = {hidden_states.shape}") - print(f"[VVERBOSE] gemm1_weights_fp4.shape = {gemm1_weights_fp4.shape}") - print(f"[VVERBOSE] gemm2_weights_fp4.shape = {gemm2_weights_fp4.shape}") + print(f"[VVERBOSE] routing_logits.shape = {tuple(routing_logits.shape)}") + print(f"[VVERBOSE] hidden_states.shape = {tuple(hidden_states.shape)}") + print(f"[VVERBOSE] gemm1_weights_fp4.shape = {tuple(gemm1_weights_fp4.shape)}") + print(f"[VVERBOSE] gemm2_weights_fp4.shape = {tuple(gemm2_weights_fp4.shape)}") def run_fp4_moe(): return trtllm_fp4_block_scale_moe( @@ -691,7 +610,19 @@ def run_fp4_moe(): do_finalize=True, ) - # Benchmark timing + backend = "trtllm" + if getattr(args, "autotune", False): + warmup_iters = ( + args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 + ) + backend = "trtllm_autotune" + if args.verbose >= 1: + print( + f"[INFO] Autotune warmup for FP4 block scale MoE: {warmup_iters} iters" + ) + with autotune(True): + for _ in range(warmup_iters): + run_fp4_moe() if is_cuda_graph_compatible: times = bench_gpu_time_with_cudagraph( fn=run_fp4_moe, @@ -713,8 +644,6 @@ def run_fp4_moe(): l2_flush_device=device, sleep_after_run=False, ) - - # Compute performance metrics median_time = np.median(times) std_time = np.std(times) tflops = calculate_moe_tflops( @@ -733,10 +662,7 @@ def run_fp4_moe(): weight_format="fp4", routing_logits_dtype=routing_logits.dtype, ) - - backend = "trtllm" print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - res = [] if args.output_path is not None: cur_res = defaultdict(str) @@ -766,7 +692,6 @@ def run_fp4_moe(): cur_res["weight_dtype"] = weight_dtype cur_res["gated_act"] = args.gated_act res.append(cur_res) - return res @@ -782,16 +707,12 @@ def testCutlassFusedMoe(args): if args.verbose >= 1: print("[INFO] Running testCutlassFusedMoe") print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - device = get_device(args) if args.generate_repro_command: print( f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" ) - input_dtype = dtype_str_to_torch_dtype(args.input_dtype) - - # Shapes num_tokens = args.num_tokens hidden_size = args.hidden_size intermediate_size = args.intermediate_size @@ -802,75 +723,54 @@ def testCutlassFusedMoe(args): ep_size = getattr(args, "ep_size", 1) ep_rank = getattr(args, "ep_rank", 0) is_cuda_graph_compatible = not args.no_cuda_graph - - # Create base tensors - torch.manual_seed(args.random_seed) - x = torch.randn(num_tokens, hidden_size, dtype=input_dtype, device=device) + paddle.seed(seed=args.random_seed) + x = paddle.randn(shape=[num_tokens, hidden_size], dtype=input_dtype) w31_weight = ( - torch.randn( - num_experts, - 2 * intermediate_size, - hidden_size, - dtype=input_dtype, - device=device, + paddle.randn( + shape=[num_experts, 2 * intermediate_size, hidden_size], dtype=input_dtype ) / 10 ) w2_weight = ( - torch.randn( - num_experts, - hidden_size, - intermediate_size, - dtype=input_dtype, - device=device, + paddle.randn( + shape=[num_experts, hidden_size, intermediate_size], dtype=input_dtype ) / 10 ) - - # Routing - router_logits = torch.randn( - num_tokens, num_experts, dtype=input_dtype, device=device - ) + router_logits = paddle.randn(shape=[num_tokens, num_experts], dtype=input_dtype) routing_weights, selected_experts = _compute_routing(router_logits, top_k) - if args.verbose >= 2: - print(f"[VVERBOSE] x.shape = {x.shape}") - print(f"[VVERBOSE] w31_weight.shape = {w31_weight.shape}") - print(f"[VVERBOSE] w2_weight.shape = {w2_weight.shape}") - - # Build local weights per EP/TP like tests do + print(f"[VVERBOSE] x.shape = {tuple(x.shape)}") + print(f"[VVERBOSE] w31_weight.shape = {tuple(w31_weight.shape)}") + print(f"[VVERBOSE] w2_weight.shape = {tuple(w2_weight.shape)}") experts_per_rank = num_experts // max(ep_size, 1) expert_start = ep_rank * experts_per_rank expert_end = expert_start + experts_per_rank w31_ep = w31_weight[expert_start:expert_end, :] w2_ep = w2_weight[expert_start:expert_end, :] - def build_tp_shards(w31_ep_tensor: torch.Tensor, w2_ep_tensor: torch.Tensor): + def build_tp_shards(w31_ep_tensor: paddle.Tensor, w2_ep_tensor: paddle.Tensor): if tp_size <= 1: return w31_ep_tensor, w2_ep_tensor - # Split w31 into w3 and w1 along intermediate dim - w3_weight, w1_weight = torch.chunk(w31_ep_tensor, 2, dim=1) + w3_weight, w1_weight = paddle.chunk(x=w31_ep_tensor, chunks=2, axis=1) shard = intermediate_size // tp_size start = tp_rank * shard end = start + shard w3_local = w3_weight[:, start:end, :] w1_local = w1_weight[:, start:end, :] - w31_local = torch.cat([w3_local, w1_local], dim=1) + w31_local = paddle.concat(x=[w3_local, w1_local], axis=1) w2_local = w2_ep_tensor[:, :, start:end] return w31_local.contiguous(), w2_local.contiguous() w31_local, w2_local = build_tp_shards(w31_ep, w2_ep) - - # Prepare variant-specific inputs (outside of the timed/captured region) variant = getattr(args, "cutlass_variant", "base") - out = torch.empty_like(x) - + out = paddle.empty_like(x=x) if variant == "base": def run_cutlass(): return cutlass_fused_moe( x, - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, w31_local, w2_local, @@ -884,44 +784,36 @@ def run_cutlass(): ) elif variant == "fp8": - # Per-tensor FP8 for weights and activation scale - w31_weight_fp8 = torch.empty_like(w31_local, dtype=torch.float8_e4m3fn) - w2_weight_fp8 = torch.empty_like(w2_local, dtype=torch.float8_e4m3fn) - local_num_experts = w31_local.shape[0] - w31_scales = torch.empty(local_num_experts, 2, dtype=input_dtype, device=device) - w2_scales = torch.empty(local_num_experts, 1, dtype=input_dtype, device=device) - - # Quantize weights per expert + w31_weight_fp8 = paddle.empty_like(x=w31_local, dtype=paddle.float8_e4m3fn) + w2_weight_fp8 = paddle.empty_like(x=w2_local, dtype=paddle.float8_e4m3fn) + local_num_experts = tuple(w31_local.shape)[0] + w31_scales = paddle.empty(shape=[local_num_experts, 2], dtype=input_dtype) + w2_scales = paddle.empty(shape=[local_num_experts, 1], dtype=input_dtype) for expert_id in range(local_num_experts): w31_expert = w31_local[expert_id] w2_expert = w2_local[expert_id] w31_q, s31 = _dynamic_per_tensor_fp8_quant(w31_expert) w2_q, s2 = _dynamic_per_tensor_fp8_quant(w2_expert) - w31_weight_fp8[expert_id].copy_(w31_q) - w2_weight_fp8[expert_id].copy_(w2_q) - # Store the same scalar twice to mimic test layout (avoid torch.tensor()) + paddle.assign(w31_q, output=w31_weight_fp8[expert_id]) + paddle.assign(w2_q, output=w2_weight_fp8[expert_id]) w31_scales[expert_id, 0] = s31.to(dtype=input_dtype, device=device) w31_scales[expert_id, 1] = s31.to(dtype=input_dtype, device=device) w2_scales[expert_id, 0] = s2.to(dtype=input_dtype, device=device) - x_quant, hidden_states_scale = _dynamic_per_tensor_fp8_quant(x) hidden_states_scale_scalar = hidden_states_scale[0].to(device) - - # Note: follow tests quant_scales format - # [w1_scales * hidden_states_scale, 1.0, 1.0 * w2_scales, hidden_states_scale] w1_scales = w31_scales[:, 1] - one_const = torch.ones((), device=device) + one_const = paddle.ones(shape=()) quant_scales = [ - (w1_scales * hidden_states_scale_scalar).float().squeeze(), + (w1_scales * hidden_states_scale_scalar).astype(dtype="float32").squeeze(), one_const, - w2_scales.squeeze().float(), + w2_scales.squeeze().astype(dtype="float32"), hidden_states_scale_scalar, ] def run_cutlass(): return cutlass_fused_moe( x_quant, - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, w31_weight_fp8, w2_weight_fp8, @@ -935,70 +827,59 @@ def run_cutlass(): ) elif variant == "nvfp4": - # NVFP4: FP4 block-scale weights, optional quantized input FLOAT4_E2M1_MAX = 6.0 - FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + FLOAT8_E4M3_MAX = paddle.finfo(dtype=paddle.float8_e4m3fn).max def round_up(x_val, y): return (x_val + y - 1) // y * y - e = w31_local.shape[0] - n = w2_local.shape[2] # local intermediate size after TP + e = tuple(w31_local.shape)[0] + n = tuple(w2_local.shape)[2] k = hidden_size quant_blocksize = 16 - - # Weight quantization buffers - w1_q = torch.empty((e, 2 * n, k // 2), device=device, dtype=torch.uint8) - w2_q = torch.empty((e, k, n // 2), device=device, dtype=torch.uint8) - w1_blockscale = torch.empty( - (e, round_up(2 * n, 128), round_up(k // quant_blocksize, 4)), - device=device, - dtype=torch.float8_e4m3fn, + w1_q = paddle.empty(shape=(e, 2 * n, k // 2), dtype="uint8") + w2_q = paddle.empty(shape=(e, k, n // 2), dtype="uint8") + w1_blockscale = paddle.empty( + shape=(e, round_up(2 * n, 128), round_up(k // quant_blocksize, 4)), + dtype=paddle.float8_e4m3fn, ) - w2_blockscale = torch.empty( - (e, round_up(k, 128), round_up(n // quant_blocksize, 4)), - device=device, - dtype=torch.float8_e4m3fn, + w2_blockscale = paddle.empty( + shape=(e, round_up(k, 128), round_up(n // quant_blocksize, 4)), + dtype=paddle.float8_e4m3fn, ) - w1_gs = torch.empty((e,), device=device, dtype=torch.float32) - w2_gs = torch.empty((e,), device=device, dtype=torch.float32) - - # Quantize from local shards + w1_gs = paddle.empty(shape=(e,), dtype="float32") + w2_gs = paddle.empty(shape=(e,), dtype="float32") for expert in range(e): w1_src = w31_local[expert] - # w31 layout is [2n, k]; w2 layout is [k, n] - w2_src = w2_local[expert].contiguous() # [hidden_size, n] - w1_amax = torch.abs(w1_src).max().to(torch.float32) - w2_amax = torch.abs(w2_src).max().to(torch.float32) + w2_src = w2_local[expert].contiguous() + w1_amax = paddle.abs(x=w1_src)._max().to("float32") + w2_amax = paddle.abs(x=w2_src)._max().to("float32") w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1_src, w1_gs[expert]) w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2_src, w2_gs[expert]) - - a1_gs = torch.ones((), device=device, dtype=torch.float32) - a2_gs = torch.ones((), device=device, dtype=torch.float32) - + a1_gs = paddle.ones(shape=(), dtype="float32") + a2_gs = paddle.ones(shape=(), dtype="float32") hidden_states = x input_sf = None if getattr(args, "quantized_input", False): hidden_states, input_sf = fp4_quantize(x, a1_gs) - quant_scales = [ a1_gs, - w1_blockscale.view(torch.int32), + w1_blockscale.view("int32"), 1.0 / (a1_gs * w1_gs), a2_gs, - w2_blockscale.view(torch.int32), + w2_blockscale.view("int32"), 1.0 / (a2_gs * w2_gs), ] def run_cutlass(): return cutlass_fused_moe( hidden_states, - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, - w1_q.contiguous().view(torch.long), - w2_q.contiguous().view(torch.long), + w1_q.contiguous().view("int64"), + w2_q.contiguous().view("int64"), input_dtype, tp_size=tp_size, tp_rank=tp_rank, @@ -1008,10 +889,20 @@ def run_cutlass(): input_sf=input_sf, output=out, ) + else: raise ValueError(f"Unknown cutlass_variant: {variant}") - - # Measure + backend = "cutlass" + if getattr(args, "autotune", False): + warmup_iters = ( + args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 + ) + backend = "cutlass_autotune" + if args.verbose >= 1: + print(f"[INFO] Autotune warmup for CUTLASS fused MoE: {warmup_iters} iters") + with autotune(True): + for _ in range(warmup_iters): + run_cutlass() if is_cuda_graph_compatible: times = bench_gpu_time_with_cudagraph( fn=run_cutlass, @@ -1033,7 +924,6 @@ def run_cutlass(): l2_flush_device=device, sleep_after_run=False, ) - median_time = np.median(times) std_time = np.std(times) tflops = calculate_moe_tflops( @@ -1048,25 +938,20 @@ def run_cutlass(): median_time, input_dtype, input_dtype, - input_format=( - "fp8" - if variant == "fp8" - else ( - "fp4" - if (variant == "nvfp4" and getattr(args, "quantized_input", False)) - else None - ) - ), - weight_format=( - "fp8" if variant == "fp8" else ("fp4" if variant == "nvfp4" else None) - ), + input_format="fp8" + if variant == "fp8" + else "fp4" + if variant == "nvfp4" and getattr(args, "quantized_input", False) + else None, + weight_format="fp8" + if variant == "fp8" + else "fp4" + if variant == "nvfp4" + else None, routing_logits_dtype=router_logits.dtype, - active_experts=int(selected_experts.unique().numel()), + active_experts=int(selected_experts.unique().size), ) - - backend = "cutlass" print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - res = [] if args.output_path is not None: cur_res = defaultdict(str) @@ -1081,13 +966,11 @@ def run_cutlass(): cur_res["intermediate_size"] = intermediate_size cur_res["num_experts"] = num_experts cur_res["top_k"] = top_k - # Routing method/weight layout not applicable; leave defaults cur_res["use_shuffled_weight"] = False cur_res["weight_layout"] = 0 cur_res["use_routing_scales_on_input"] = False cur_res["input_dtype"] = input_dtype cur_res["weight_dtype"] = input_dtype - # CUTLASS fused MoE specific cur_res["cutlass_variant"] = variant cur_res["quantized_input"] = args.quantized_input cur_res["tp_size"] = tp_size @@ -1095,7 +978,6 @@ def run_cutlass(): cur_res["ep_size"] = ep_size cur_res["ep_rank"] = ep_rank res.append(cur_res) - return res @@ -1117,17 +999,13 @@ def testTrtllmFp8BlockScaleMoe(args): if args.verbose >= 1: print("[INFO] Running testTrtllmFp8BlockScaleMoe") print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - device = get_device(args) if args.generate_repro_command: print( f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" ) - input_dtype = dtype_str_to_torch_dtype(args.input_dtype) weight_dtype = dtype_str_to_torch_dtype(args.weight_dtype) - - # Parse configuration num_tokens = args.num_tokens hidden_size = args.hidden_size intermediate_size = args.intermediate_size @@ -1158,47 +1036,40 @@ def testTrtllmFp8BlockScaleMoe(args): use_shuffled_weight = args.use_shuffled_weight weight_layout = args.weight_layout is_cuda_graph_compatible = not args.no_cuda_graph - if args.verbose >= 1: print( - f"[INFO] Configuration: tokens={num_tokens}, hidden={hidden_size}, " - f"intermediate={intermediate_size}, experts={num_experts}, top_k={top_k}" - ) - - # Create test data - routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights = ( - create_trtllm_moe_test_data( - num_tokens, - hidden_size, - intermediate_size, - num_experts, - routing_method_type, - args.use_routing_bias, - input_dtype, - weight_dtype, - device, - moe_kernel_type="fp8_block_scale", + f"[INFO] Configuration: tokens={num_tokens}, hidden={hidden_size}, intermediate={intermediate_size}, experts={num_experts}, top_k={top_k}" ) + ( + routing_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + ) = create_trtllm_moe_test_data( + num_tokens, + hidden_size, + intermediate_size, + num_experts, + routing_method_type, + args.use_routing_bias, + input_dtype, + weight_dtype, + device, + moe_kernel_type="fp8_block_scale", ) - - # For FP8 block scale, create quantized weights and block scales - # Quantize to FP8 - gemm1_weights_fp8 = gemm1_weights.to(torch.float8_e4m3fn) - gemm2_weights_fp8 = gemm2_weights.to(torch.float8_e4m3fn) - - # Optionally shuffle weights and convert to BlockMajorK layout to match kernel expectation + gemm1_weights_fp8 = gemm1_weights.to(paddle.float8_e4m3fn) + gemm2_weights_fp8 = gemm2_weights.to(paddle.float8_e4m3fn) if use_shuffled_weight: - # This tile size follows test implementations epilogue_tile_m = 64 - gemm1_weights_fp8_shuffled = [] gemm2_weights_fp8_shuffled = [] for i in range(num_experts): tmp_w1 = shuffle_matrix_a( - gemm1_weights_fp8[i].view(torch.uint8), epilogue_tile_m + gemm1_weights_fp8[i].view("uint8"), epilogue_tile_m ) tmp_w2 = shuffle_matrix_a( - gemm2_weights_fp8[i].view(torch.uint8), epilogue_tile_m + gemm2_weights_fp8[i].view("uint8"), epilogue_tile_m ) if weight_layout == WeightLayout.BlockMajorK: block_k = 128 @@ -1206,40 +1077,31 @@ def testTrtllmFp8BlockScaleMoe(args): tmp_w2 = convert_to_block_layout(tmp_w2, block_k) gemm1_weights_fp8_shuffled.append(tmp_w1) gemm2_weights_fp8_shuffled.append(tmp_w2) - - kernel_gemm1_weights = torch.stack(gemm1_weights_fp8_shuffled).view( - torch.float8_e4m3fn + kernel_gemm1_weights = paddle.stack(x=gemm1_weights_fp8_shuffled).view( + paddle.float8_e4m3fn ) - kernel_gemm2_weights = torch.stack(gemm2_weights_fp8_shuffled).view( - torch.float8_e4m3fn + kernel_gemm2_weights = paddle.stack(x=gemm2_weights_fp8_shuffled).view( + paddle.float8_e4m3fn ) else: kernel_gemm1_weights = gemm1_weights_fp8 kernel_gemm2_weights = gemm2_weights_fp8 - - # Create block scale tensors for hidden states and weights (use float32 for scales) - # TensorRT-LLM FP8 block-scale expects hidden_states_scale shape [hidden_size // 128, num_tokens] - hidden_states_scale = 2.0 * torch.ones( - (hidden_size // 128, num_tokens), device=device, dtype=torch.float32 + hidden_states_scale = 2.0 * paddle.ones( + shape=(hidden_size // 128, num_tokens), dtype="float32" ) - gemm1_weights_scale = 2.0 * torch.ones( - (num_experts, 2 * intermediate_size // 128, hidden_size // 128), - device=device, - dtype=torch.float32, + gemm1_weights_scale = 2.0 * paddle.ones( + shape=(num_experts, 2 * intermediate_size // 128, hidden_size // 128), + dtype="float32", ) - gemm2_weights_scale = 2.0 * torch.ones( - (num_experts, hidden_size // 128, intermediate_size // 128), - device=device, - dtype=torch.float32, + gemm2_weights_scale = 2.0 * paddle.ones( + shape=(num_experts, hidden_size // 128, intermediate_size // 128), + dtype="float32", ) - if args.verbose >= 2: - print(f"[VVERBOSE] routing_logits.shape = {routing_logits.shape}") - print(f"[VVERBOSE] hidden_states.shape = {hidden_states.shape}") - print(f"[VVERBOSE] gemm1_weights_fp8.shape = {gemm1_weights_fp8.shape}") - print(f"[VVERBOSE] gemm2_weights_fp8.shape = {gemm2_weights_fp8.shape}") - - # Match test heuristic for tile_tokens_dim when using BlockMajorK + print(f"[VVERBOSE] routing_logits.shape = {tuple(routing_logits.shape)}") + print(f"[VVERBOSE] hidden_states.shape = {tuple(hidden_states.shape)}") + print(f"[VVERBOSE] gemm1_weights_fp8.shape = {tuple(gemm1_weights_fp8.shape)}") + print(f"[VVERBOSE] gemm2_weights_fp8.shape = {tuple(gemm2_weights_fp8.shape)}") if use_shuffled_weight and weight_layout == WeightLayout.BlockMajorK: def _next_pow2(x: int) -> int: @@ -1252,7 +1114,7 @@ def _next_pow2(x: int) -> int: x |= x >> 16 return x + 1 - tokens_per_expert = max(1, (num_tokens * top_k) // max(local_num_experts, 1)) + tokens_per_expert = max(1, num_tokens * top_k // max(local_num_experts, 1)) suggested_tile = min(max(_next_pow2(tokens_per_expert), 8), 64) if suggested_tile != tile_tokens_dim and args.verbose >= 1: print( @@ -1261,10 +1123,7 @@ def _next_pow2(x: int) -> int: tile_tokens_dim = suggested_tile def run_fp8_block_moe(): - # Quantize hidden states to FP8 for block scale MOE - hidden_states_fp8 = hidden_states.to(torch.float8_e4m3fn) - # Note: FP8 block scale MOE expects int64_t for n_group/topk_group, not Optional[int64_t] - # So we convert None to 0 to indicate "no groups" mode + hidden_states_fp8 = hidden_states.to(paddle.float8_e4m3fn) return trtllm_fp8_block_scale_moe( routing_logits=routing_logits, routing_bias=routing_bias, @@ -1289,7 +1148,6 @@ def run_fp8_block_moe(): enable_pdl=True, ) - # Benchmark timing if is_cuda_graph_compatible: times = bench_gpu_time_with_cudagraph( fn=run_fp8_block_moe, @@ -1311,8 +1169,6 @@ def run_fp8_block_moe(): l2_flush_device=device, sleep_after_run=False, ) - - # Compute performance metrics median_time = np.median(times) std_time = np.std(times) tflops = calculate_moe_tflops( @@ -1331,10 +1187,8 @@ def run_fp8_block_moe(): weight_format="fp8", routing_logits_dtype=routing_logits.dtype, ) - backend = "trtllm" print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - res = [] if args.output_path is not None: cur_res = defaultdict(str) @@ -1363,7 +1217,6 @@ def run_fp8_block_moe(): cur_res["input_dtype"] = input_dtype cur_res["weight_dtype"] = weight_dtype res.append(cur_res) - return res @@ -1385,17 +1238,13 @@ def testTrtllmFp8PerTensorScaleMoe(args): if args.verbose >= 1: print("[INFO] Running testTrtllmFp8PerTensorScaleMoe") print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - device = get_device(args) if args.generate_repro_command: print( f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" ) - input_dtype = dtype_str_to_torch_dtype(args.input_dtype) weight_dtype = dtype_str_to_torch_dtype(args.weight_dtype) - - # Parse configuration num_tokens = args.num_tokens hidden_size = args.hidden_size intermediate_size = args.intermediate_size @@ -1425,57 +1274,41 @@ def testTrtllmFp8PerTensorScaleMoe(args): routing_method_type = args.routing_method_type use_routing_scales_on_input = args.use_routing_scales_on_input is_cuda_graph_compatible = not args.no_cuda_graph - if args.verbose >= 1: print( - f"[INFO] Configuration: tokens={num_tokens}, hidden={hidden_size}, " - f"intermediate={intermediate_size}, experts={num_experts}, top_k={top_k}" - ) - - # Create test data - routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights = ( - create_trtllm_moe_test_data( - num_tokens, - hidden_size, - intermediate_size, - num_experts, - routing_method_type, - args.use_routing_bias, - input_dtype, - weight_dtype, - device, - moe_kernel_type="fp8_per_tensor", + f"[INFO] Configuration: tokens={num_tokens}, hidden={hidden_size}, intermediate={intermediate_size}, experts={num_experts}, top_k={top_k}" ) - ) - - # For FP8 per-tensor scale, create quantized weights and per-tensor scales - # Quantize to FP8 - gemm1_weights_fp8 = gemm1_weights.to(torch.float8_e4m3fn) - gemm2_weights_fp8 = gemm2_weights.to(torch.float8_e4m3fn) - - # Quantize hidden states to FP8 for per-tensor scale - hidden_states_fp8 = hidden_states.to(torch.float8_e4m3fn) - - # Create per-tensor scale scalars - output1_scales_scalar = torch.ones( - local_num_experts, device=device, dtype=torch.float32 - ) - output1_scales_gate_scalar = torch.ones( - local_num_experts, device=device, dtype=torch.float32 - ) - output2_scales_scalar = torch.ones( - local_num_experts, device=device, dtype=torch.float32 - ) - + ( + routing_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + ) = create_trtllm_moe_test_data( + num_tokens, + hidden_size, + intermediate_size, + num_experts, + routing_method_type, + args.use_routing_bias, + input_dtype, + weight_dtype, + device, + moe_kernel_type="fp8_per_tensor", + ) + gemm1_weights_fp8 = gemm1_weights.to(paddle.float8_e4m3fn) + gemm2_weights_fp8 = gemm2_weights.to(paddle.float8_e4m3fn) + hidden_states_fp8 = hidden_states.to(paddle.float8_e4m3fn) + output1_scales_scalar = paddle.ones(shape=local_num_experts, dtype="float32") + output1_scales_gate_scalar = paddle.ones(shape=local_num_experts, dtype="float32") + output2_scales_scalar = paddle.ones(shape=local_num_experts, dtype="float32") if args.verbose >= 2: - print(f"[VVERBOSE] routing_logits.shape = {routing_logits.shape}") - print(f"[VVERBOSE] hidden_states.shape = {hidden_states.shape}") - print(f"[VVERBOSE] gemm1_weights_fp8.shape = {gemm1_weights_fp8.shape}") - print(f"[VVERBOSE] gemm2_weights_fp8.shape = {gemm2_weights_fp8.shape}") + print(f"[VVERBOSE] routing_logits.shape = {tuple(routing_logits.shape)}") + print(f"[VVERBOSE] hidden_states.shape = {tuple(hidden_states.shape)}") + print(f"[VVERBOSE] gemm1_weights_fp8.shape = {tuple(gemm1_weights_fp8.shape)}") + print(f"[VVERBOSE] gemm2_weights_fp8.shape = {tuple(gemm2_weights_fp8.shape)}") def run_fp8_per_tensor_moe(): - # Note: FP8 per-tensor MOE expects int64_t for n_group/topk_group, not Optional[int64_t] - # So we convert None to 0 to indicate "no groups" mode return trtllm_fp8_per_tensor_scale_moe( routing_logits=routing_logits, routing_bias=routing_bias, @@ -1498,7 +1331,6 @@ def run_fp8_per_tensor_moe(): routing_method_type=routing_method_type, ) - # Benchmark timing if is_cuda_graph_compatible: times = bench_gpu_time_with_cudagraph( fn=run_fp8_per_tensor_moe, @@ -1520,8 +1352,6 @@ def run_fp8_per_tensor_moe(): l2_flush_device=device, sleep_after_run=False, ) - - # Compute performance metrics median_time = np.median(times) std_time = np.std(times) tflops = calculate_moe_tflops( @@ -1540,10 +1370,8 @@ def run_fp8_per_tensor_moe(): weight_format="fp8", routing_logits_dtype=routing_logits.dtype, ) - backend = "trtllm" print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - res = [] if args.output_path is not None: cur_res = defaultdict(str) @@ -1570,5 +1398,4 @@ def run_fp8_per_tensor_moe(): cur_res["input_dtype"] = input_dtype cur_res["weight_dtype"] = weight_dtype res.append(cur_res) - return res diff --git a/benchmarks/samples/sample_testlist_output.csv b/benchmarks/samples/sample_testlist_output.csv index cb2e8160ba..11b3a43c5a 100644 --- a/benchmarks/samples/sample_testlist_output.csv +++ b/benchmarks/samples/sample_testlist_output.csv @@ -13,13 +13,13 @@ BatchDecodeWithPagedKVCacheWrapper,0.010036800056695938,4.2086111493617856e-05,2 BatchDecodeWithPagedKVCacheWrapper,0.010139200091361999,3.394793422831141e-05,25.941764007999502,3.2944295111068262,trtllm-gen-native,16,16,1,1024,64,8,128,128,,,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,42,Llama-3.1-70B-1k,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --generate_repro_command --case_tag Llama-3.1-70B-1k BatchMLAPagedAttentionWrapper,0.04095999896526337,0.00018861917453252353,54.58359909057617,0.5696000143893066,fa2,32,16,1,1024,128,,,,512,64,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,True,False,42,DeepSeek-R1-1k,True,python3 flashinfer_benchmark.py --routine BatchMLAPagedAttentionWrapper --backends fa2 trtllm-gen-native --page_size 32 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_ckv 512 --head_dim_kpe 64 --random_actual_seq_len -vv --refcheck --no_cuda_graph --q_dtype bfloat16 --kv_dtype bfloat16 --generate_repro_command --case_tag DeepSeek-R1-1k BatchMLAPagedAttentionWrapper,0.03270399942994118,0.0009912135423119442,68.36302185058594,0.713393358814707,trtllm-gen-native,32,16,1,1024,128,,,,512,64,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,True,False,42,DeepSeek-R1-1k,True,python3 flashinfer_benchmark.py --routine BatchMLAPagedAttentionWrapper --backends fa2 trtllm-gen-native --page_size 32 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_ckv 512 --head_dim_kpe 64 --random_actual_seq_len -vv --refcheck --no_cuda_graph --q_dtype bfloat16 --kv_dtype bfloat16 --generate_repro_command --case_tag DeepSeek-R1-1k -BatchPrefillWithPagedKVCacheWrapper,0.217056006193161,0.0020098751951206587,199.91474588076713,0.4451410753131434,trtllm-gen-native,16,16,1024,1024,64,8,128,128,,,True,torch.float8_e4m3fn,torch.float8_e4m3fn,327,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,True,False,42,Llama-3.1-70B-1k,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --no_cuda_graph --q_dtype fp8_e4m3 --kv_dtype fp8_e4m3 --generate_repro_command --case_tag Llama-3.1-70B-1k -BatchMLAPagedAttentionWrapper,0.022272000089287758,0.0009335132358293922,100.38362884521484,0.5237701128427505,trtllm-gen-native,32,16,1,1024,128,,,,512,64,False,torch.float8_e4m3fn,torch.float8_e4m3fn,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,True,False,42,DeepSeek-R1-1k,True,python3 flashinfer_benchmark.py --routine BatchMLAPagedAttentionWrapper --backends fa2 trtllm-gen-native --page_size 32 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_ckv 512 --head_dim_kpe 64 --random_actual_seq_len -vv --refcheck --no_cuda_graph --q_dtype fp8_e4m3 --kv_dtype fp8_e4m3 --generate_repro_command --case_tag DeepSeek-R1-1k +BatchPrefillWithPagedKVCacheWrapper,0.217056006193161,0.0020098751951206587,199.91474588076713,0.4451410753131434,trtllm-gen-native,16,16,1024,1024,64,8,128,128,,,True,paddle.float8_e4m3fn,paddle.float8_e4m3fn,327,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,True,False,42,Llama-3.1-70B-1k,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --no_cuda_graph --q_dtype fp8_e4m3 --kv_dtype fp8_e4m3 --generate_repro_command --case_tag Llama-3.1-70B-1k +BatchMLAPagedAttentionWrapper,0.022272000089287758,0.0009335132358293922,100.38362884521484,0.5237701128427505,trtllm-gen-native,32,16,1,1024,128,,,,512,64,False,paddle.float8_e4m3fn,paddle.float8_e4m3fn,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,True,False,42,DeepSeek-R1-1k,True,python3 flashinfer_benchmark.py --routine BatchMLAPagedAttentionWrapper --backends fa2 trtllm-gen-native --page_size 32 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_ckv 512 --head_dim_kpe 64 --random_actual_seq_len -vv --refcheck --no_cuda_graph --q_dtype fp8_e4m3 --kv_dtype fp8_e4m3 --generate_repro_command --case_tag DeepSeek-R1-1k gemm_fp8_nt_groupwise,0.5826559960842133,0.03992277071294015,1887.0682446681349,0.46071002067093136,cutlass,,,,,,,,,,,,,,,,8192,4096,16384,,128,MN,torch.bfloat16,2,,,,,,,,,,,,,,,,,,,,,,,,,,True,True,False,42,gemm_fp8_nt_groupwise_sample,True,python3 flashinfer_benchmark.py --routine gemm_fp8_nt_groupwise --m 8192 --n 4096 --k 16384 --mma_sm 2 --no_cuda_graph --refcheck -vv --backend cutlass trtllm --scale_major_mode MN --generate_repro_command --case_tag gemm_fp8_nt_groupwise_sample group_gemm_fp8_nt_groupwise,1.4235039949417114,0.019968493175593464,1544.795984673049,0.3771474571955686,cutlass,,,,,,,,,,,,,,,,8192,4096,16384,2,128,K,torch.bfloat16,2,,,,,,,,,,,,,,,,,,,,,,,,,,True,True,False,42,group_gemm_fp8_nt_groupwise_sample,True,python3 flashinfer_benchmark.py --routine group_gemm_fp8_nt_groupwise --m 8192 --n 4096 --k 16384 --mma_sm 2 --group_size 2 --no_cuda_graph --scale_major_mode K --refcheck -vv --generate_repro_command --case_tag group_gemm_fp8_nt_groupwise_sample -bmm_fp8,0.38998879194259645,0.020104750197258048,2819.341607996366,0.6883158222647378,cudnn,,1,,,,,,,,,,,,,,8192,4096,16384,,,,torch.bfloat16,,,,,,,,,,,,,,,,,,,torch.float8_e4m3fn,,,,,,,,True,False,False,42,bmm_fp8_sample,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --m 8192 --n 4096 --k 16384 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command --case_tag bmm_fp8_sample -bmm_fp8,0.3898000001907349,0.027221875225817732,2820.7070991226083,0.6886491941217305,cublas,,1,,,,,,,,,,,,,,8192,4096,16384,,,,torch.bfloat16,,,,,,,,,,,,,,,,,,,torch.float8_e4m3fn,,,,,,,,True,False,False,42,bmm_fp8_sample,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --m 8192 --n 4096 --k 16384 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command --case_tag bmm_fp8_sample -bmm_fp8,0.7879528045654296,0.05200885894277077,1395.4028990129693,0.3406745358918382,cutlass,,1,,,,,,,,,,,,,,8192,4096,16384,,,,torch.bfloat16,,,,,,,,,,,,,,,,,,,torch.float8_e4m3fn,,,,,,,,True,False,False,42,bmm_fp8_sample,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --m 8192 --n 4096 --k 16384 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command --case_tag bmm_fp8_sample +bmm_fp8,0.38998879194259645,0.020104750197258048,2819.341607996366,0.6883158222647378,cudnn,,1,,,,,,,,,,,,,,8192,4096,16384,,,,torch.bfloat16,,,,,,,,,,,,,,,,,,,paddle.float8_e4m3fn,,,,,,,,True,False,False,42,bmm_fp8_sample,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --m 8192 --n 4096 --k 16384 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command --case_tag bmm_fp8_sample +bmm_fp8,0.3898000001907349,0.027221875225817732,2820.7070991226083,0.6886491941217305,cublas,,1,,,,,,,,,,,,,,8192,4096,16384,,,,torch.bfloat16,,,,,,,,,,,,,,,,,,,paddle.float8_e4m3fn,,,,,,,,True,False,False,42,bmm_fp8_sample,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --m 8192 --n 4096 --k 16384 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command --case_tag bmm_fp8_sample +bmm_fp8,0.7879528045654296,0.05200885894277077,1395.4028990129693,0.3406745358918382,cutlass,,1,,,,,,,,,,,,,,8192,4096,16384,,,,torch.bfloat16,,,,,,,,,,,,,,,,,,,paddle.float8_e4m3fn,,,,,,,,True,False,False,42,bmm_fp8_sample,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --m 8192 --n 4096 --k 16384 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command --case_tag bmm_fp8_sample mm_fp4,0.23255040645599365,0.030370730682184946,4728.057217926491,0.7214442776377091,cudnn,,,,,,,,,,,,,,,,8192,4096,16384,,,,torch.bfloat16,,True,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,42,mm_fp4_sample,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 4096 --k 16384 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --refcheck -vv --generate_repro_command --case_tag mm_fp4_sample mm_fp4,0.20704638957977295,0.007477970763357942,5310.4602790108975,0.8103119322221218,cutlass,,,,,,,,,,,,,,,,8192,4096,16384,,,,torch.bfloat16,,True,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,42,mm_fp4_sample,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 4096 --k 16384 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --refcheck -vv --generate_repro_command --case_tag mm_fp4_sample mm_fp4,0.44398319721221924,0.0008919232676501257,2476.47125990321,0.37787952574206696,trtllm,,,,,,,,,,,,,,,,8192,4096,16384,,,,torch.bfloat16,,True,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,42,mm_fp4_sample,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 4096 --k 16384 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --refcheck -vv --generate_repro_command --case_tag mm_fp4_sample diff --git a/benchmarks/samples/sample_testlist_output.txt b/benchmarks/samples/sample_testlist_output.txt index e7a699b85a..08ed3c691d 100644 --- a/benchmarks/samples/sample_testlist_output.txt +++ b/benchmarks/samples/sample_testlist_output.txt @@ -169,9 +169,9 @@ $ python3 flashinfer_benchmark.py --testlist samples/sample_testlist.txt --outpu [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine bmm_fp8 --m 8192 --n 4096 --k 16384 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command --case_tag bmm_fp8_sample [VVERBOSE] input_fp8.shape = torch.Size([1, 8192, 16384]) -[VVERBOSE] input_fp8.dtype = torch.float8_e4m3fn +[VVERBOSE] input_fp8.dtype = paddle.float8_e4m3fn [VVERBOSE] mat2_fp8.shape = torch.Size([1, 16384, 4096]) -[VVERBOSE] mat2_fp8.dtype = torch.float8_e4m3fn +[VVERBOSE] mat2_fp8.dtype = paddle.float8_e4m3fn [VVERBOSE] input_inv_s = tensor(0.0127, device='cuda:0') [VVERBOSE] input_inv_s.dtype = torch.float32 [VVERBOSE] mat2_inv_s = tensor(0.0127, device='cuda:0') diff --git a/ci/scripts/jenkins/git_skip_ci.py b/ci/scripts/jenkins/git_skip_ci.py index 199b8f9ff9..977b1a7dbf 100644 --- a/ci/scripts/jenkins/git_skip_ci.py +++ b/ci/scripts/jenkins/git_skip_ci.py @@ -1,21 +1,3 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - import argparse import logging import os @@ -33,15 +15,12 @@ ) args = parser.parse_args() init_log() - branch = git(["rev-parse", "--abbrev-ref", "HEAD"]) log = git(["log", "--format=%s", "-1"]) - # Check the PR's title (don't check this until everything else passes first) def check_pr_title(): remote = git(["config", "--get", f"remote.{args.remote}.url"]) user, repo = parse_remote(remote) - if args.pr_title: title = args.pr_title else: diff --git a/ci/scripts/jenkins/git_skip_ci_globs.py b/ci/scripts/jenkins/git_skip_ci_globs.py index 2a71ab1294..66a8cd3979 100644 --- a/ci/scripts/jenkins/git_skip_ci_globs.py +++ b/ci/scripts/jenkins/git_skip_ci_globs.py @@ -1,21 +1,3 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - import argparse import fnmatch from typing import Optional @@ -55,15 +37,11 @@ def match_any(f: str) -> Optional[str]: diff = diff.split("\n") diff = [d.strip() for d in diff] diff = [d for d in diff if d != ""] - print(f"Changed files:\n{diff}") - if len(diff) == 0: print("Found no changed files, skipping CI") exit(0) - print(f"Checking with globs:\n{globs}") - for file in diff: match = match_any(file) if match is None: @@ -71,6 +49,5 @@ def match_any(f: str) -> Optional[str]: exit(1) else: print(f"{file} matched glob {match}") - print("All files matched a glob, skipping CI") exit(0) diff --git a/ci/scripts/jenkins/git_utils.py b/ci/scripts/jenkins/git_utils.py index 221da61b5d..99b9a0b467 100644 --- a/ci/scripts/jenkins/git_utils.py +++ b/ci/scripts/jenkins/git_utils.py @@ -1,21 +1,3 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - import base64 import json import logging @@ -30,7 +12,7 @@ def compress_query(query: str) -> str: query = query.replace("\n", "") - query = re.sub(r"\s+", " ", query) + query = re.sub("\\s+", " ", query) return query @@ -41,14 +23,11 @@ def post(url: str, body: Optional[Any] = None, auth: Optional[Tuple[str, str]] = if auth is not None: auth_str = base64.b64encode(f"{auth[0]}:{auth[1]}".encode()) req.add_header("Authorization", f"Basic {auth_str.decode()}") - if body is None: body = "" - req.add_header("Content-Type", "application/json; charset=utf-8") data = json.dumps(body).encode("utf-8") req.add_header("Content-Length", str(len(data))) - with request.urlopen(req, data) as response: return response.read() @@ -71,9 +50,7 @@ def __init__(self, user, repo, token, test_data=None): self.base = f"https://api.github.com/repos/{user}/{repo}/" def headers(self): - return { - "Authorization": f"Bearer {self.token}", - } + return {"Authorization": f"Bearer {self.token}"} def dry_run(self) -> bool: return self.token == DRY_RUN @@ -84,17 +61,20 @@ def graphql( query = compress_query(query) if variables is None: variables = {} - response = self._request( - self.GRAPHQL_URL, - {"query": query, "variables": variables}, - method="POST", + self.GRAPHQL_URL, {"query": query, "variables": variables}, method="POST" ) if self.dry_run(): return self.testing_response("POST", self.GRAPHQL_URL) - if "data" not in response: - msg = f"Error fetching data with query:\n{query}\n\nvariables:\n{variables}\n\nerror:\n{json.dumps(response, indent=2)}" + msg = f"""Error fetching data with query: +{query} + +variables: +{variables} + +error: +{json.dumps(response, indent=2)}""" raise RuntimeError(msg) return response @@ -114,13 +94,11 @@ def _request( f"Dry run, would have requested a {method} to {full_url} with {body}" ) return self.testing_response(method, full_url) - logging.info(f"Requesting {method} to {full_url} with {body}") req = request.Request(full_url, headers=self.headers(), method=method.upper()) req.add_header("Content-Type", "application/json; charset=utf-8") data = json.dumps(body).encode("utf-8") req.add_header("Content-Length", str(len(data))) - try: with request.urlopen(req, data) as response: content = response.read() @@ -128,13 +106,11 @@ def _request( msg = str(e) error_data = e.read().decode() raise RuntimeError(f"Error response: {msg}\n{error_data}") from e - logging.info(f"Got response from {full_url}: {content}") try: response = json.loads(content) except json.decoder.JSONDecodeError: return content - return response def put(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: @@ -174,18 +150,15 @@ def parse_remote(remote: str) -> Tuple[str, str]: Get a GitHub (user, repo) pair out of a git remote """ if remote.startswith("https://"): - # Parse HTTP remote parts = remote.split("/") if len(parts) < 2: raise RuntimeError(f"Unable to parse remote '{remote}'") user, repo = parts[-2], parts[-1].replace(".git", "") else: - # Parse SSH remote - m = re.search(r":(.*)/(.*)\.git", remote) + m = re.search(":(.*)/(.*)\\.git", remote) if m is None or len(m.groups()) != 2: raise RuntimeError(f"Unable to parse remote '{remote}'") user, repo = m.groups() - user = os.getenv("DEBUG_USER", user) repo = os.getenv("DEBUG_REPO", repo) return user, repo @@ -201,14 +174,12 @@ def git(command, **kwargs): def find_ccs(body: str) -> List[str]: - matches = re.findall(r"(cc( @[-A-Za-z0-9]+)+)", body, flags=re.MULTILINE) + matches = re.findall("(cc( @[-A-Za-z0-9]+)+)", body, flags=re.MULTILINE) matches = [full for full, last in matches] - reviewers = set() for match in matches: if match.startswith("cc "): match = match.replace("cc ", "") users = [x.strip() for x in match.split("@")] reviewers.update(users) - return [x for x in reviewers if x != ""] diff --git a/custom_backend.py b/custom_backend.py index cd799f9232..80a0611fff 100644 --- a/custom_backend.py +++ b/custom_backend.py @@ -7,7 +7,6 @@ _data_dir = _root / "flashinfer" / "data" _aot_ops_dir = _root / "aot-ops" _aot_ops_package_dir = _root / "build" / "aot-ops-package-dir" - _requires_for_aot = ["torch", "ninja", "numpy", "requests"] @@ -39,11 +38,8 @@ def ln(source: str, target: str) -> None: def _prepare_for_wheel(): - # Remove data directory if _data_dir.exists(): shutil.rmtree(_data_dir) - - # Link AOT ops directory to "aot-ops" _rm_aot_ops_package_dir() if not _aot_ops_dir.exists(): _aot_ops_dir.mkdir() @@ -55,7 +51,6 @@ def _prepare_for_wheel(): def _prepare_for_editable(): _create_data_dir() - _rm_aot_ops_package_dir() _aot_ops_dir.mkdir(parents=True, exist_ok=True) _aot_ops_package_dir.parent.mkdir(parents=True, exist_ok=True) @@ -63,11 +58,8 @@ def _prepare_for_editable(): def _prepare_for_sdist(): - # Remove data directory if _data_dir.exists(): shutil.rmtree(_data_dir) - - # Create an empty directory for AOT ops _rm_aot_ops_package_dir() _aot_ops_package_dir.parent.mkdir(parents=True, exist_ok=True) _aot_ops_package_dir.mkdir(parents=True) diff --git a/docs/conf.py b/docs/conf.py index ce09d510fb..6ff6af5819 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -3,15 +3,6 @@ from pathlib import Path from typing import Any, List -# import tlcpack_sphinx_addon -# Configuration file for the Sphinx documentation builder. -# -# For the full list of built-in configuration values, see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Project information ----------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information - root = Path(__file__).parents[1].resolve() sys.path.insert(0, str(root)) os.environ["BUILD_DOC"] = "1" @@ -24,18 +15,12 @@ "einops", "mpi4py", ] - project = "FlashInfer" author = "FlashInfer Contributors" copyright = f"2023-2025, {author}" - package_version = (root / "version.txt").read_text().strip() version = package_version release = package_version - -# -- General configuration --------------------------------------------------- -# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration - extensions = [ "sphinx_tabs.tabs", "sphinx.ext.autodoc", @@ -43,31 +28,16 @@ "sphinx.ext.autosummary", "sphinx.ext.mathjax", ] - autodoc_default_flags = ["members"] autosummary_generate = True - source_suffix = [".rst"] - language = "en" - exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] - -# The name of the Pygments (syntax highlighting) style to use. pygments_style = "sphinx" - -# A list of ignored prefixes for module index sorting. -# If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False - -# -- Options for HTML output ---------------------------------------------- - -html_theme = "furo" # "sphinx_rtd_theme" - +html_theme = "furo" templates_path: List[Any] = [] - html_static_path = ["_static"] - html_theme_options = { "logo_only": True, "light_logo": "FlashInfer-white-background.png", diff --git a/docs/wrap_run_llm.py b/docs/wrap_run_llm.py index d6a711882c..1b7a367164 100644 --- a/docs/wrap_run_llm.py +++ b/docs/wrap_run_llm.py @@ -1,22 +1,19 @@ +import os + """ HTML post-processing script to insert RunLLM widget into documentation. Based on: https://github.com/sgl-project/sglang/blob/499f5e620c243b6a9980b63f7aa54d096a9a3ddd/docs/wrap_run_llm.py Copyright (c) 2023 SGLang Project (Apache 2.0 License) """ - -import os import re def insert_runllm_widget(html_content): - # RunLLM Widget script to be inserted for FlashInfer widget_script = """ """ - - # Find the closing body tag and insert the widget script before it - return re.sub(r"", f"{widget_script}\n", html_content) + return re.sub("", f"{widget_script}\n", html_content) def process_html_files(build_dir): @@ -25,30 +22,19 @@ def process_html_files(build_dir): for file in files: if file.endswith(".html"): file_path = os.path.join(root, file) - - # Read the HTML file with open(file_path, "r", encoding="utf-8") as f: content = f.read() - - # Insert the RunLLM widget modified_content = insert_runllm_widget(content) - - # Write back the modified content with open(file_path, "w", encoding="utf-8") as f: f.write(modified_content) - processed_count += 1 - print(f"Processed {processed_count} HTML files") def main(): - # Get the build directory path build_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "_build", "html" ) - - # Process all HTML files if os.path.exists(build_dir): print(f"Processing HTML files in: {build_dir}") process_html_files(build_dir) diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index f3d8621254..2e0bee8eaa 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -13,65 +13,51 @@ See the License for the specific language governing permissions and limitations under the License. """ - try: from ._build_meta import __version__ as __version__ except ModuleNotFoundError: __version__ = "0.0.0+unknown" - from . import jit as jit from .activation import gelu_and_mul as gelu_and_mul from .activation import gelu_tanh_and_mul as gelu_tanh_and_mul from .activation import silu_and_mul as silu_and_mul from .attention import BatchAttention as BatchAttention from .autotuner import autotune as autotune -from .cascade import ( - BatchDecodeWithSharedPrefixPagedKVCacheWrapper as BatchDecodeWithSharedPrefixPagedKVCacheWrapper, -) -from .cascade import ( - BatchPrefillWithSharedPrefixPagedKVCacheWrapper as BatchPrefillWithSharedPrefixPagedKVCacheWrapper, -) -from .cascade import ( - MultiLevelCascadeAttentionWrapper as MultiLevelCascadeAttentionWrapper, -) +from .cascade import \ + BatchDecodeWithSharedPrefixPagedKVCacheWrapper as \ + BatchDecodeWithSharedPrefixPagedKVCacheWrapper +from .cascade import \ + BatchPrefillWithSharedPrefixPagedKVCacheWrapper as \ + BatchPrefillWithSharedPrefixPagedKVCacheWrapper +from .cascade import \ + MultiLevelCascadeAttentionWrapper as MultiLevelCascadeAttentionWrapper from .cascade import merge_state as merge_state from .cascade import merge_state_in_place as merge_state_in_place from .cascade import merge_states as merge_states -from .decode import ( - BatchDecodeMlaWithPagedKVCacheWrapper as BatchDecodeMlaWithPagedKVCacheWrapper, -) -from .decode import ( - BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper, -) -from .decode import ( - CUDAGraphBatchDecodeWithPagedKVCacheWrapper as CUDAGraphBatchDecodeWithPagedKVCacheWrapper, -) -from .decode import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache +from .decode import \ + BatchDecodeMlaWithPagedKVCacheWrapper as \ + BatchDecodeMlaWithPagedKVCacheWrapper +from .decode import \ + BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper +from .decode import \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper as \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper +from .decode import \ + cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache from .decode import single_decode_with_kv_cache as single_decode_with_kv_cache -from .fp4_quantization import ( - SfLayout, - block_scale_interleave, - nvfp4_block_scale_interleave, - e2m1_and_ufp8sf_scale_to_float, - fp4_quantize, - mxfp4_dequantize_host, - mxfp4_dequantize, - mxfp4_quantize, - nvfp4_quantize, - shuffle_matrix_a, - shuffle_matrix_sf_a, -) +from .fp4_quantization import (SfLayout, block_scale_interleave, + e2m1_and_ufp8sf_scale_to_float, fp4_quantize, + mxfp4_dequantize, mxfp4_dequantize_host, + mxfp4_quantize, nvfp4_block_scale_interleave, + nvfp4_quantize, shuffle_matrix_a, + shuffle_matrix_sf_a) from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize -from .fused_moe import ( - RoutingMethodType, - GatedActType, - cutlass_fused_moe, - reorder_rows_for_gated_act_gemm, - trtllm_fp4_block_scale_moe, - trtllm_fp4_block_scale_routed_moe, - trtllm_fp8_block_scale_moe, - trtllm_fp8_per_tensor_scale_moe, -) +from .fused_moe import (GatedActType, RoutingMethodType, cutlass_fused_moe, + reorder_rows_for_gated_act_gemm, + trtllm_fp4_block_scale_moe, + trtllm_fp4_block_scale_routed_moe, + trtllm_fp8_block_scale_moe, + trtllm_fp8_per_tensor_scale_moe) from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper from .gemm import bmm_fp8 as bmm_fp8 from .gemm import mm_fp4 as mm_fp4 @@ -85,32 +71,32 @@ from .page import get_batch_indices_positions as get_batch_indices_positions from .page import get_seq_lens as get_seq_lens from .pod import PODWithPagedKVCacheWrapper as PODWithPagedKVCacheWrapper -from .prefill import ( - BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper, -) -from .prefill import ( - BatchPrefillWithRaggedKVCacheWrapper as BatchPrefillWithRaggedKVCacheWrapper, -) -from .prefill import single_prefill_with_kv_cache as single_prefill_with_kv_cache -from .prefill import ( - single_prefill_with_kv_cache_return_lse as single_prefill_with_kv_cache_return_lse, -) +from .prefill import \ + BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper +from .prefill import \ + BatchPrefillWithRaggedKVCacheWrapper as \ + BatchPrefillWithRaggedKVCacheWrapper +from .prefill import \ + single_prefill_with_kv_cache as single_prefill_with_kv_cache +from .prefill import \ + single_prefill_with_kv_cache_return_lse as \ + single_prefill_with_kv_cache_return_lse from .quantization import packbits as packbits from .quantization import segment_packbits as segment_packbits from .rope import apply_llama31_rope as apply_llama31_rope from .rope import apply_llama31_rope_inplace as apply_llama31_rope_inplace from .rope import apply_llama31_rope_pos_ids as apply_llama31_rope_pos_ids -from .rope import ( - apply_llama31_rope_pos_ids_inplace as apply_llama31_rope_pos_ids_inplace, -) +from .rope import \ + apply_llama31_rope_pos_ids_inplace as apply_llama31_rope_pos_ids_inplace from .rope import apply_rope as apply_rope from .rope import apply_rope_inplace as apply_rope_inplace from .rope import apply_rope_pos_ids as apply_rope_pos_ids from .rope import apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace -from .rope import apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache -from .rope import ( - apply_rope_with_cos_sin_cache_inplace as apply_rope_with_cos_sin_cache_inplace, -) +from .rope import \ + apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache +from .rope import \ + apply_rope_with_cos_sin_cache_inplace as \ + apply_rope_with_cos_sin_cache_inplace from .sampling import chain_speculative_sampling as chain_speculative_sampling from .sampling import min_p_sampling_from_probs as min_p_sampling_from_probs from .sampling import sampling_from_logits as sampling_from_logits @@ -119,14 +105,13 @@ from .sampling import top_k_mask_logits as top_k_mask_logits from .sampling import top_k_renorm_probs as top_k_renorm_probs from .sampling import top_k_sampling_from_probs as top_k_sampling_from_probs -from .sampling import ( - top_k_top_p_sampling_from_logits as top_k_top_p_sampling_from_logits, -) -from .sampling import top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs +from .sampling import \ + top_k_top_p_sampling_from_logits as top_k_top_p_sampling_from_logits +from .sampling import \ + top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs from .sampling import top_p_renorm_probs as top_p_renorm_probs from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper -from .sparse import ( - VariableBlockSparseAttentionWrapper as VariableBlockSparseAttentionWrapper, -) +from .sparse import \ + VariableBlockSparseAttentionWrapper as VariableBlockSparseAttentionWrapper from .utils import next_positive_power_of_2 as next_positive_power_of_2 diff --git a/flashinfer/__main__.py b/flashinfer/__main__.py index 08401d8802..3c59aee96a 100644 --- a/flashinfer/__main__.py +++ b/flashinfer/__main__.py @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ - -# flashinfer-cli import argparse from .artifacts import download_artifacts @@ -24,9 +22,7 @@ parser.add_argument( "--download-cubin", action="store_true", help="Download artifacts" ) - args = parser.parse_args() - if args.download_cubin: if download_artifacts(): print("✅ All cubin download tasks completed successfully.") diff --git a/flashinfer/activation.py b/flashinfer/activation.py index a8cacf3280..3b81018940 100644 --- a/flashinfer/activation.py +++ b/flashinfer/activation.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,38 +15,32 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools from types import SimpleNamespace from typing import Optional -import torch - from .jit import JitSpec from .jit import gen_act_and_mul_module as gen_act_and_mul_module_impl from .utils import device_support_pdl, register_custom_op, register_fake_op -silu_def_cu_str = r""" +silu_def_cu_str = """ __device__ __forceinline__ float silu(const float& val) { return val / (1.0f + __expf(-val)); } """ - -gelu_def_cu_str = r""" +gelu_def_cu_str = """ __device__ __forceinline__ float gelu(const float& val) { constexpr float kAlpha = M_SQRT1_2; return val * 0.5f * (1.0f + ::erf(val * kAlpha)); } """ - -gelu_def_tanh_cu_str = r""" +gelu_def_tanh_cu_str = """ __device__ __forceinline__ float gelu_tanh(const float& val) { const float cdf = 0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val)))); return val * cdf; } """ - act_func_def_str = { "silu": silu_def_cu_str, "gelu": gelu_def_cu_str, @@ -59,43 +55,40 @@ def gen_act_and_mul_module(act_func_name: str) -> JitSpec: @functools.cache def get_act_and_mul_module(act_func_name: str): module = gen_act_and_mul_module(act_func_name).build_and_load() - - # torch library for act_and_mul fname = f"{act_func_name}_and_mul" fn = getattr(module, fname).default @register_custom_op(f"flashinfer::{fname}", mutates_args=("out",)) def _act_and_mul( - out: torch.Tensor, input: torch.Tensor, enable_pdl: Optional[bool] = None + out: paddle.Tensor, input: paddle.Tensor, enable_pdl: Optional[bool] = None ) -> None: if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) fn(out, input, enable_pdl) @register_fake_op(f"flashinfer::{fname}") def _fake_act_and_mul( - out: torch.Tensor, input: torch.Tensor, enable_pdl: Optional[bool] = None + out: paddle.Tensor, input: paddle.Tensor, enable_pdl: Optional[bool] = None ) -> None: pass - # Register the module return SimpleNamespace(**{fname: _act_and_mul}) -def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: +def _check_shape(input: paddle.Tensor, output: paddle.Tensor) -> None: assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}" - assert input.shape[:-1] == output.shape[:-1], ( - f"{input.shape[:-1]} != {output.shape[:-1]}" - ) - assert input.shape[-1] == 2 * output.shape[-1], ( - f"{input.shape[-1]} != {2 * output.shape[-1]}" - ) + assert ( + tuple(input.shape)[:-1] == tuple(output.shape)[:-1] + ), f"{tuple(input.shape)[:-1]} != {tuple(output.shape)[:-1]}" + assert ( + tuple(input.shape)[-1] == 2 * tuple(output.shape)[-1] + ), f"{tuple(input.shape)[-1]} != {2 * tuple(output.shape)[-1]}" def silu_and_mul( - input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None -) -> torch.Tensor: - r"""Fused SiLU and Mul operation. + input: paddle.Tensor, out: paddle.Tensor = None, enable_pdl: Optional[bool] = None +) -> paddle.Tensor: + """Fused SiLU and Mul operation. ``silu(input[..., :hidden_size]) * input[..., hidden_size:]`` @@ -117,29 +110,24 @@ def silu_and_mul( Output tensor, shape (..., hidden_size). """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) - if input.shape[-1] * input.dtype.itemsize % 16 != 0: + enable_pdl = device_support_pdl(input.place) + if tuple(input.shape)[-1] * input.dtype.itemsize % 16 != 0: raise ValueError("The pointers must be multiple of 16 bytes.") if out is not None: _check_shape(input, out) else: - out = torch.empty( - input.shape[:-1] + (input.shape[-1] // 2,), - device=input.device, + out = paddle.empty( + shape=tuple(input.shape)[:-1] + (tuple(input.shape)[-1] // 2,), dtype=input.dtype, ) - get_act_and_mul_module("silu").silu_and_mul( - out, - input, - enable_pdl, - ) + get_act_and_mul_module("silu").silu_and_mul(out, input, enable_pdl) return out def gelu_tanh_and_mul( - input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None -) -> torch.Tensor: - r"""Fused GeLU Tanh and Mul operation. + input: paddle.Tensor, out: paddle.Tensor = None, enable_pdl: Optional[bool] = None +) -> paddle.Tensor: + """Fused GeLU Tanh and Mul operation. ``gelu(tanh(input[..., :hidden_size])) * input[..., hidden_size:]`` @@ -161,15 +149,14 @@ def gelu_tanh_and_mul( Output tensor, shape (..., hidden_size). """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) - if input.shape[-1] * input.dtype.itemsize % 16 != 0: + enable_pdl = device_support_pdl(input.place) + if tuple(input.shape)[-1] * input.dtype.itemsize % 16 != 0: raise ValueError("The pointers must be multiple of 16 bytes.") if out is not None: _check_shape(input, out) else: - out = torch.empty( - input.shape[:-1] + (input.shape[-1] // 2,), - device=input.device, + out = paddle.empty( + shape=tuple(input.shape)[:-1] + (tuple(input.shape)[-1] // 2,), dtype=input.dtype, ) get_act_and_mul_module("gelu_tanh").gelu_tanh_and_mul(out, input, enable_pdl) @@ -177,9 +164,9 @@ def gelu_tanh_and_mul( def gelu_and_mul( - input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None -) -> torch.Tensor: - r"""Fused GeLU and Mul operation. + input: paddle.Tensor, out: paddle.Tensor = None, enable_pdl: Optional[bool] = None +) -> paddle.Tensor: + """Fused GeLU and Mul operation. ``gelu(input[..., :hidden_size]) * input[..., hidden_size:]`` @@ -201,15 +188,14 @@ def gelu_and_mul( Output tensor, shape (..., hidden_size). """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) - if input.shape[-1] * input.dtype.itemsize % 16 != 0: + enable_pdl = device_support_pdl(input.place) + if tuple(input.shape)[-1] * input.dtype.itemsize % 16 != 0: raise ValueError("The pointers must be multiple of 16 bytes.") if out is not None: _check_shape(input, out) else: - out = torch.empty( - input.shape[:-1] + (input.shape[-1] // 2,), - device=input.device, + out = paddle.empty( + shape=tuple(input.shape)[:-1] + (tuple(input.shape)[-1] // 2,), dtype=input.dtype, ) get_act_and_mul_module("gelu").gelu_and_mul(out, input, enable_pdl) diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 4a65e3e355..d65fb5b4b8 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -1,36 +1,29 @@ +import sys + + import argparse import os import shutil from itertools import product from pathlib import Path -from typing import List, Tuple, Iterator +from typing import Iterator, List, Tuple -import torch -import torch.version -from torch.utils.cpp_extension import _get_cuda_arch_flags +import paddle +from flashinfer.paddle_utils import * from .activation import act_func_def_str, gen_act_and_mul_module from .cascade import gen_cascade_module -from .fp4_quantization import ( - gen_fp4_quantization_sm100_module, - gen_fp4_quantization_sm90_module, -) -from .fused_moe import ( - gen_cutlass_fused_moe_sm100_module, - gen_cutlass_fused_moe_sm90_module, -) +from .fp4_quantization import (gen_fp4_quantization_sm90_module, + gen_fp4_quantization_sm100_module) +from .fused_moe import (gen_cutlass_fused_moe_sm90_module, + gen_cutlass_fused_moe_sm100_module) from .gemm import gen_gemm_module, gen_gemm_sm90_module, gen_gemm_sm100_module from .jit import JitSpec, build_jit_specs from .jit import env as jit_env -from .jit import ( - gen_batch_decode_module, - gen_batch_mla_module, - gen_batch_prefill_module, - gen_fmha_cutlass_sm100a_module, - gen_jit_spec, - gen_single_decode_module, - gen_single_prefill_module, -) +from .jit import (gen_batch_decode_module, gen_batch_mla_module, + gen_batch_prefill_module, gen_fmha_cutlass_sm100a_module, + gen_jit_spec, gen_single_decode_module, + gen_single_prefill_module) from .mla import gen_mla_module from .norm import gen_norm_module from .page import gen_page_module @@ -42,18 +35,17 @@ def gen_fa2( - dtype_qo: torch.dtype, - dtype_kv: torch.dtype, + dtype_qo: paddle.dtype, + dtype_kv: paddle.dtype, head_dim_qk: int, head_dim_vo: int, use_sliding_window: bool, use_logits_soft_cap: bool, ) -> Iterator[JitSpec]: - if dtype_qo.itemsize == dtype_kv.itemsize and dtype_qo != dtype_kv: + if dtype_qo.element_size() == dtype_kv.element_size() and dtype_qo != dtype_kv: + return + if dtype_qo.element_size() == 1: return - if dtype_qo.itemsize == 1: - return # fp8 tensor cores not supported in fa2 - yield gen_single_prefill_module( backend="fa2", dtype_q=dtype_qo, @@ -66,13 +58,12 @@ def gen_fa2( use_logits_soft_cap=use_logits_soft_cap, use_fp16_qk_reduction=False, ) - yield gen_batch_prefill_module( backend="fa2", dtype_q=dtype_qo, dtype_kv=dtype_kv, dtype_o=dtype_qo, - dtype_idx=torch.int32, + dtype_idx="int32", head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, pos_encoding_mode=0, @@ -80,7 +71,6 @@ def gen_fa2( use_logits_soft_cap=use_logits_soft_cap, use_fp16_qk_reduction=False, ) - yield gen_single_decode_module( dtype_q=dtype_qo, dtype_kv=dtype_kv, @@ -91,12 +81,11 @@ def gen_fa2( use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, ) - yield gen_batch_decode_module( dtype_q=dtype_qo, dtype_kv=dtype_kv, dtype_o=dtype_qo, - dtype_idx=torch.int32, + dtype_idx="int32", head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, pos_encoding_mode=0, @@ -106,30 +95,28 @@ def gen_fa2( def gen_fa3( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, head_dim_qk: int, head_dim_vo: int, use_sliding_window: bool, use_logits_soft_cap: bool, ) -> Iterator[JitSpec]: if dtype_q != dtype_kv: - return # fa3 template do not support mixed precision - if dtype_q.itemsize == 2: + return + if dtype_q.element_size() == 2: if dtype_q != dtype_o: - return # for fp16, dtype_o must be the same as dtype_q/dtype_kv - - if dtype_kv.itemsize == 1: + return + if dtype_kv.element_size() == 1: if head_dim_qk == 192 or head_dim_qk == 64: - return # (192, 128) & (64, 64) not supported for fp8 yet. - + return yield gen_batch_prefill_module( backend="fa3", dtype_q=dtype_q, dtype_kv=dtype_kv, dtype_o=dtype_o, - dtype_idx=torch.int32, + dtype_idx="int32", head_dim_qk=head_dim_qk, head_dim_vo=head_dim_vo, pos_encoding_mode=0, @@ -140,8 +127,8 @@ def gen_fa3( def gen_attention( - f16_dtype_: List[torch.dtype], - f8_dtype_: List[torch.dtype], + f16_dtype_: List[paddle.dtype], + f8_dtype_: List[paddle.dtype], fa2_head_dim_: List[Tuple[int, int]], fa3_head_dim_: List[Tuple[int, int]], use_sliding_window_: List[bool], @@ -153,8 +140,6 @@ def gen_attention( ) -> Iterator[JitSpec]: head_dim_ckv = 512 head_dim_kpe = 64 - - # FA2 MHA / MQA / GQA for ( (head_dim_qk, head_dim_vo), dtype_qo, @@ -176,8 +161,6 @@ def gen_attention( use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, ) - - # FA3 MHA / MQA / GQA if has_sm90: for ( (head_dim_qk, head_dim_vo), @@ -201,17 +184,9 @@ def gen_attention( use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, ) - - # Gemma if add_gemma: - for ( - dtype_qo, - dtype_kv, - (use_sliding_window, use_logits_soft_cap), - ) in product( - f16_dtype_, - f16_dtype_ + f8_dtype_, - [(True, True)], + for dtype_qo, dtype_kv, (use_sliding_window, use_logits_soft_cap) in product( + f16_dtype_, f16_dtype_ + f8_dtype_, [(True, True)] ): yield from gen_fa2( dtype_qo=dtype_qo, @@ -226,11 +201,7 @@ def gen_attention( dtype_qkv, dtype_o, (use_sliding_window, use_logits_soft_cap), - ) in product( - f16_dtype_ + f8_dtype_, - f16_dtype_, - [(True, True)], - ): + ) in product(f16_dtype_ + f8_dtype_, f16_dtype_, [(True, True)]): yield from gen_fa3( dtype_q=dtype_qkv, dtype_kv=dtype_qkv, @@ -240,8 +211,6 @@ def gen_attention( use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, ) - - # OAI OSS if add_oai_oss: from .jit.attention import gen_batch_prefill_attention_sink_module @@ -253,30 +222,24 @@ def gen_attention( dtype_q=dtype, dtype_kv=dtype, dtype_o=dtype, - dtype_idx=torch.int32, + dtype_idx="int32", head_dim_qk=64, head_dim_vo=64, pos_encoding_mode=0, use_sliding_window=use_swa, ) - - # fmha_cutlass_sm100a - # NOTE: currently there's only one uri. if has_sm100: yield gen_fmha_cutlass_sm100a_module( - dtype_q=torch.bfloat16, - dtype_kv=torch.bfloat16, - dtype_o=torch.bfloat16, - dtype_idx=torch.int32, + dtype_q="bfloat16", + dtype_kv="bfloat16", + dtype_o="bfloat16", + dtype_idx="int32", head_dim_qk=128, head_dim_vo=128, pos_encoding_mode=0, use_sliding_window=False, use_logits_soft_cap=False, ) - - # MLA - # NOTE: fp8 kv not supported in MLA mla_backend_ = ["fa2"] + (["fa3"] if has_sm90 else []) for dtype_qo in f16_dtype_: for backend in mla_backend_: @@ -285,20 +248,18 @@ def gen_attention( dtype_q=dtype_qo, dtype_kv=dtype_qo, dtype_o=dtype_qo, - dtype_idx=torch.int32, + dtype_idx="int32", head_dim_ckv=head_dim_ckv, head_dim_kpe=head_dim_kpe, use_profiler=False, ) - - # MLA SM100 if has_sm100: yield gen_mla_module() def gen_all_modules( - f16_dtype_: List[torch.dtype], - f8_dtype_: List[torch.dtype], + f16_dtype_: List[paddle.dtype], + f8_dtype_: List[paddle.dtype], fa2_head_dim_: List[Tuple[int, int]], fa3_head_dim_: List[Tuple[int, int]], use_sliding_window_: List[bool], @@ -313,7 +274,6 @@ def gen_all_modules( add_misc: bool, ) -> List[JitSpec]: jit_specs: List[JitSpec] = [] - jit_specs += list( gen_attention( f16_dtype_, @@ -328,11 +288,9 @@ def gen_all_modules( add_oai_oss, ) ) - if add_act: for act_name in act_func_def_str: jit_specs.append(gen_act_and_mul_module(act_name)) - if add_moe: jit_specs.append(gen_gemm_module()) if has_sm90: @@ -343,7 +301,6 @@ def gen_all_modules( jit_specs.append(gen_fp4_quantization_sm100_module()) jit_specs.append(gen_cutlass_fused_moe_sm100_module()) jit_specs.append(gen_gemm_sm100_module()) - if add_comm: from .comm import gen_trtllm_comm_module, gen_vllm_comm_module from .comm.nvshmem import gen_nvshmem_module @@ -352,7 +309,6 @@ def gen_all_modules( if has_sm100: jit_specs.append(gen_trtllm_comm_module()) jit_specs.append(gen_vllm_comm_module()) - if add_misc: jit_specs += [ gen_cascade_module(), @@ -364,8 +320,6 @@ def gen_all_modules( ] if has_sm90: jit_specs.append(get_trtllm_utils_spec()) - - # dedup names = set() ret: List[JitSpec] = [] for jit_spec in jit_specs: @@ -375,10 +329,7 @@ def gen_all_modules( return ret -def copy_built_kernels( - jit_specs: List[JitSpec], - out_dir: Path, -) -> None: +def copy_built_kernels(jit_specs: List[JitSpec], out_dir: Path) -> None: if out_dir.exists(): shutil.rmtree(out_dir) out_dir.mkdir(parents=True, exist_ok=False) @@ -407,16 +358,8 @@ def main(): parser = argparse.ArgumentParser( description="Ahead-of-Time (AOT) build all modules" ) - parser.add_argument( - "--out-dir", - type=Path, - help="Output directory", - ) - parser.add_argument( - "--build-dir", - type=Path, - help="Build directory", - ) + parser.add_argument("--out-dir", type=Path, help="Output directory") + parser.add_argument("--build-dir", type=Path, help="Build directory") parser.add_argument( "--fa2-head-dim", nargs="*", @@ -440,15 +383,9 @@ def main(): help="8-bit data type", ) parser.add_argument( - "--use-sliding-window", - nargs="*", - help="Use sliding window attention", - ) - parser.add_argument( - "--use-logits-soft-cap", - nargs="*", - help="Use logits soft cap", + "--use-sliding-window", nargs="*", help="Use sliding window attention" ) + parser.add_argument("--use-logits-soft-cap", nargs="*", help="Use logits soft cap") parser.add_argument( "--add-comm", type=parse_bool, @@ -464,62 +401,25 @@ def main(): type=parse_bool, help="Add kernels for OAI OSS Model (head_dim=64, use_sliding_window)", ) - parser.add_argument( - "--add-moe", - type=parse_bool, - help="Add MoE kernels", - ) - parser.add_argument( - "--add-act", - type=parse_bool, - help="Add activation kernels", - ) - parser.add_argument( - "--add-misc", - type=parse_bool, - help="Add miscellaneous kernels", - ) + parser.add_argument("--add-moe", type=parse_bool, help="Add MoE kernels") + parser.add_argument("--add-act", type=parse_bool, help="Add activation kernels") + parser.add_argument("--add-misc", type=parse_bool, help="Add miscellaneous kernels") args = parser.parse_args() - - # Default values project_root = Path(__file__).resolve().parents[1] out_dir = project_root / "aot-ops" build_dir = project_root / "build" / "aot" - fa2_head_dim_ = [ - (64, 64), - (128, 128), - # (256, 256), - ] - fa3_head_dim_ = [ - (192, 128), - (128, 128), - # (64, 64), - # (256, 256), - ] - f16_dtype_ = [ - torch.float16, - torch.bfloat16, - ] - f8_dtype_ = [ - torch.float8_e4m3fn, - # torch.float8_e5m2, - ] - use_sliding_window_ = [ - False, - # True, - ] - use_logits_soft_cap_ = [ - False, - # True, - ] + fa2_head_dim_ = [(64, 64), (128, 128)] + fa3_head_dim_ = [(192, 128), (128, 128)] + f16_dtype_ = ["float16", "bfloat16"] + f8_dtype_ = [paddle.float8_e4m3fn] + use_sliding_window_ = [False] + use_logits_soft_cap_ = [False] add_comm = False add_gemma = False add_oai_oss = True add_moe = False add_act = False add_misc = True - - # Override if args.out_dir: out_dir = Path(args.out_dir) if args.build_dir: @@ -548,23 +448,19 @@ def main(): add_act = bool(args.add_act) if args.add_misc is not None: add_misc = bool(args.add_misc) - - # Cuda Arch if "TORCH_CUDA_ARCH_LIST" not in os.environ: raise RuntimeError("Please explicitly set env var TORCH_CUDA_ARCH_LIST.") - gencode_flags = _get_cuda_arch_flags() +>>>>>> gencode_flags = torch.utils.cpp_extension._get_cuda_arch_flags() def has_sm(compute: str, version: str) -> bool: if not any(compute in flag for flag in gencode_flags): return False - if torch.version.cuda is None: +>>>>>> if torch.version.cuda is None: return True - return version_at_least(torch.version.cuda, version) +>>>>>> return version_at_least(torch.version.cuda, version) has_sm90 = has_sm("compute_90", "12.3") has_sm100 = has_sm("compute_100", "12.8") - - # Update data dir jit_env.FLASHINFER_CSRC_DIR = project_root / "csrc" jit_env.FLASHINFER_INCLUDE_DIR = project_root / "include" jit_env.CUTLASS_INCLUDE_DIRS = [ @@ -572,15 +468,11 @@ def has_sm(compute: str, version: str) -> bool: project_root / "3rdparty" / "cutlass" / "tools" / "util" / "include", ] jit_env.SPDLOG_INCLUDE_DIR = project_root / "3rdparty" / "spdlog" / "include" - - # Update workdir jit_env.FLASHINFER_WORKSPACE_DIR = build_dir jit_env.FLASHINFER_JIT_DIR = build_dir / "cached_ops" jit_env.FLASHINFER_GEN_SRC_DIR = build_dir / "generated" jit_env.FLASHINFER_JIT_DIR.mkdir(parents=True, exist_ok=True) jit_env.FLASHINFER_GEN_SRC_DIR.mkdir(parents=True, exist_ok=True) - - # Print summary print("AOT build summary:") print(" out_dir:", out_dir) print(" build_dir:", build_dir) @@ -599,15 +491,11 @@ def has_sm(compute: str, version: str) -> bool: print(" add_moe:", add_moe) print(" add_act:", add_act) print(" add_misc:", add_misc) - - # Generate JIT specs print("Generating JIT specs...") jit_specs = [ gen_jit_spec( "logging", - [ - jit_env.FLASHINFER_CSRC_DIR / "logging.cc", - ], + [jit_env.FLASHINFER_CSRC_DIR / "logging.cc"], extra_include_paths=[ jit_env.SPDLOG_INCLUDE_DIR, jit_env.FLASHINFER_INCLUDE_DIR, @@ -631,11 +519,7 @@ def has_sm(compute: str, version: str) -> bool: add_misc, ) print("Total ops:", len(jit_specs)) - - # Build build_jit_specs(jit_specs, verbose=True, skip_prebuilt=False) - - # Copy built kernels copy_built_kernels(jit_specs, out_dir) print("AOT kernels saved to:", out_dir) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index a8810239ec..2fd3fa6205 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -1,3 +1,5 @@ +import os + """ Copyright (c) 2025 by FlashInfer team. @@ -13,13 +15,11 @@ See the License for the specific language governing permissions and limitations under the License. """ - -import os import re import time from concurrent.futures import ThreadPoolExecutor, as_completed -import requests # type: ignore[import-untyped] +import requests from .jit.core import logger from .jit.cubin_loader import FLASHINFER_CUBINS_REPOSITORY, get_cubin @@ -30,15 +30,13 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10): try: response = requests.get(source, timeout=timeout) response.raise_for_status() - hrefs = re.findall(r'\', response.text) + hrefs = re.findall('\\', response.text) files = [(h[9:-8], ".cubin") for h in hrefs] return files - except requests.exceptions.RequestException as e: logger.warning( f"Fetching available files {source}: attempt {attempt} failed: {e}" ) - if attempt < retries: logger.info(f"Retrying in {delay} seconds...") time.sleep(delay) @@ -106,5 +104,4 @@ def download_artifacts() -> bool: os.environ.pop("FLASHINFER_CUBIN_CHECKSUM_DISABLED") else: os.environ["FLASHINFER_CUBIN_CHECKSUM_DISABLED"] = env_backup - return all_success diff --git a/flashinfer/attention.py b/flashinfer/attention.py index a43798b26d..39373e4fd7 100644 --- a/flashinfer/attention.py +++ b/flashinfer/attention.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,21 +19,13 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools import math from typing import Optional, Tuple, Union -import torch - from .jit import gen_batch_attention_module -from .utils import ( - MaskMode, - PosEncodingMode, - TensorLayout, - _check_kv_layout, - _unpack_paged_kv_cache, -) +from .utils import (MaskMode, PosEncodingMode, TensorLayout, _check_kv_layout, + _unpack_paged_kv_cache) @functools.cache @@ -36,37 +34,23 @@ def get_holistic_attention_module(*args): class BatchAttention: - def __init__( - self, - kv_layout: str = "NHD", - device: str = "cuda", - ): + def __init__(self, kv_layout: str = "NHD", device: str = "cuda"): _check_kv_layout(kv_layout) self._kv_layout = kv_layout - - self.float_workspace_buffer = torch.empty( - 384 * 1024 * 1024, - dtype=torch.uint8, - device=torch.device(device), - ) - self.int_workspace_buffer = torch.empty( - 8 * 1024 * 1024, - dtype=torch.uint8, - device=torch.device(device), - ) - self.page_locked_int_workspace_buffer = torch.empty( - 8 * 1024 * 1024, - dtype=torch.uint8, - device=torch.device("cpu"), - pin_memory=True, + self.float_workspace_buffer = paddle.empty( + shape=384 * 1024 * 1024, dtype="uint8" ) + self.int_workspace_buffer = paddle.empty(shape=8 * 1024 * 1024, dtype="uint8") + self.page_locked_int_workspace_buffer = paddle.empty( + shape=8 * 1024 * 1024, dtype="uint8" + ).pin_memory() def plan( self, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, - kv_indices: torch.Tensor, - kv_len_arr: torch.Tensor, + qo_indptr: paddle.Tensor, + kv_indptr: paddle.Tensor, + kv_indices: paddle.Tensor, + kv_len_arr: paddle.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim_qk: int, @@ -75,15 +59,13 @@ def plan( causal: bool = False, sm_scale: float = None, logits_soft_cap: Optional[float] = None, - q_data_type: torch.dtype = torch.bfloat16, - kv_data_type: torch.dtype = torch.bfloat16, + q_data_type: paddle.dtype = "bfloat16", + kv_data_type: paddle.dtype = "bfloat16", use_profiler: bool = False, ) -> None: if logits_soft_cap is None: logits_soft_cap = 0.0 self._logits_soft_cap = logits_soft_cap - - # get jit module get_module_args = ( q_data_type, kv_data_type, @@ -93,16 +75,14 @@ def plan( head_dim_vo, PosEncodingMode["NONE"].value, logits_soft_cap > 0.0, - use_profiler, # different compiler path + use_profiler, ) self.module = get_holistic_attention_module(*get_module_args) - - qo_indptr_host = qo_indptr.to(torch.device("cpu"), non_blocking=True) - kv_indptr_host = kv_indptr.to(torch.device("cpu"), non_blocking=True) - kv_len_arr_host = kv_len_arr.to(torch.device("cpu"), non_blocking=True) - torch.cuda.synchronize() - - batch_size = kv_len_arr.shape[0] + qo_indptr_host = qo_indptr.to(device2str("cpu"), blocking=not True) + kv_indptr_host = kv_indptr.to(device2str("cpu"), blocking=not True) + kv_len_arr_host = kv_len_arr.to(device2str("cpu"), blocking=not True) + paddle.device.synchronize() + batch_size = tuple(kv_len_arr.shape)[0] self._page_size = page_size self._sm_scale = sm_scale self._mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value @@ -111,9 +91,6 @@ def plan( self._page_size = page_size self._sm_scale = sm_scale self._use_profiler = use_profiler - - # No addtional buf allocated for CUDA graph tensor - # Allocate outside FlashInfer self._kv_indices = kv_indices self._plan_info = self.module.plan( self.float_workspace_buffer, @@ -131,13 +108,13 @@ def plan( def run( self, - q: torch.Tensor, - kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + q: paddle.Tensor, + kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, logits_soft_cap: float = 0.0, - profiler_buffer: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + profiler_buffer: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: if profiler_buffer is None: if self._use_profiler: raise ValueError( @@ -147,22 +124,17 @@ def run( raise ValueError( "logits_soft_cap used in kernel run but not provided in plan(). This will cause template deduction error." ) - k_cache, v_cache = _unpack_paged_kv_cache(kv_cache, self._kv_layout) if out is None: - out = torch.empty_like(q) + out = paddle.empty_like(x=q) if lse is None: - # lse shape: [batch_size, num_qo_heads] - lse = torch.empty( - q.shape[0], q.shape[1], device=q.device, dtype=torch.float32 + lse = paddle.empty( + shape=[tuple(q.shape)[0], tuple(q.shape)[1]], dtype="float32" ) - head_dim_qk = q.shape[2] + head_dim_qk = tuple(q.shape)[2] if self._sm_scale is None: self._sm_scale = 1.0 / math.sqrt(head_dim_qk) - - # profiler_buffer is optional profiler_args = (profiler_buffer,) if self._use_profiler else () - self.module.run( self.float_workspace_buffer, self.int_workspace_buffer, @@ -180,9 +152,6 @@ def run( self._page_size, self._sm_scale, logits_soft_cap, - # ADDITIONAL_FUNC_PARAMS - # PROFILER_FUNC_PARAMS - *profiler_args, + *profiler_args ) - return out, lse diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index 32bf52d113..8f15cb8d2a 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -1,3 +1,6 @@ +import sys + + import contextlib import copy import importlib @@ -7,24 +10,20 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from functools import lru_cache -from typing import Any, Callable, Dict, List, Set, Tuple, Union, Optional +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union -import torch +import paddle +from flashinfer.paddle_utils import * -# from tensorrt_llm.bindings.internal.runtime import delay_kernel -# from tensorrt_llm.logger import logger from flashinfer.tllm_utils import delay_kernel from .jit.core import logger -# This version should be updated whenever the nvfp4_cutlass backend is changed, -# such as when new kernels or configs are added. In such cases, the tuning configs -# should also be updated. Currently, this process is manual, but it should be automated in the future. _nvfp4_cutlass_version = "0.1" def get_config_path(is_module: bool): - dev_name = torch.cuda.get_device_name(0).replace(" ", "_") + dev_name = paddle.device.cuda.get_device_name(device=0).replace(" ", "_") cutlass_ver = _nvfp4_cutlass_version.replace(".", "_") config_name = f"v{cutlass_ver}_trtllm_fused_moe_{dev_name}" if is_module: @@ -58,22 +57,17 @@ class DynamicTensorSpec: tensor_initializers: List[Callable] = field(default_factory=lambda: None) def __post_init__(self): - # Set default tensor_initializers if not provided if self.tensor_initializers is None: self.tensor_initializers = [ - lambda shapes, dtype, device: torch.randn(shapes, device=device).to( - dtype - ) + (lambda shapes, dtype, device: paddle.randn(shape=shapes).to(dtype)) for _ in range(len(self.input_idx)) ] def __hash__(self) -> int: - # FIXME: currently not hasing tensor_initializers return hash( ( self.input_idx, self.dim_idx, - # For gen_tuning_buckets, only hash if it's a tuple, otherwise hash its id self.gen_tuning_buckets if isinstance(self.gen_tuning_buckets, tuple) else id(self.gen_tuning_buckets), @@ -176,25 +170,23 @@ def get_hash_key(self): def get_opt_shapes(self): """Only the opt shapes are considered as hash key""" - # TODO: remove duplicate shape generation opt_shapes = [] for t in self.shapes: opt_shapes.append(tuple([d._opt() for d in t])) return tuple(opt_shapes) -# TODO: can/shall we use the torch builtin FakeTensor class? @dataclass class FakeTensor: - dtype: torch.dtype - device: torch.device + dtype: paddle.dtype + device: str shape: List[Dim] class TunableRunner(ABC): @abstractmethod def get_valid_tactics( - self, inputs: List[torch.Tensor], profile: OptimizationProfile + self, inputs: List[paddle.Tensor], profile: OptimizationProfile ) -> List[int]: """One tactic corresponding to one cuda kernel normally, but how to interpret the meaning of tactic is pure internal details of the runner. @@ -219,10 +211,10 @@ def __call__(self, inputs, **kwargs): @abstractmethod def forward( self, - inputs: List[torch.Tensor], + inputs: List[paddle.Tensor], tactic: int = -1, do_preparation: bool = False, - **kwargs, # all others are keyword args only + **kwargs, ) -> Any: """Forward pass for tunable runners. @@ -292,14 +284,13 @@ def __str__(self) -> str: stats_str += f" {op}:\n" for profile in sorted(profiles, key=str): stats_str += f" - Config: {profile}\n" - if self.tuned_op_total_configs: stats_str += "Tuned operations:\n" for op in sorted(self.tuned_op_total_configs.keys()): total = self.tuned_op_total_configs[op] successful = self.tuned_op_successful_configs.get(op, 0) failed = len(self.failed_profiling_count.get(op, set())) - success_rate = (successful / total * 100) if total > 0 else 0 + success_rate = successful / total * 100 if total > 0 else 0 stats_str += f" {op}:\n" stats_str += f" - Total configs tried: {total}\n" stats_str += f" - Successful configs: {successful}\n" @@ -309,7 +300,6 @@ def __str__(self) -> str: for failed_key in self.failed_profiling_count[op]: stats_str += f" - {failed_key}\n" stats_str += f" - Success rate: {success_rate:.1f}%\n" - return stats_str @@ -352,10 +342,7 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000): self.stream_delay_micro_secs = stream_delay_micro_secs self.profiling_cache = {} self.is_tuning_mode = False - - # Add statistics tracking self.stats = AutoTunerStatistics() - self.profiling_debug = True @classmethod @@ -368,7 +355,7 @@ def search_cache( self, custom_op: str, runners: List[TunableRunner], - input_shapes: Tuple[torch.Size], + input_shapes: Tuple[list], tuning_config: TuningConfig, ) -> Tuple[bool, int, int, OptimizationProfile]: """Search for cached profiling results matching the current configuration. @@ -394,7 +381,6 @@ def search_cache( return output elif cache_key in self.profiling_cache: return True, *self.profiling_cache[cache_key] - return False, 0, -1, None def choose_one( @@ -402,7 +388,7 @@ def choose_one( custom_op: str, runners: List[TunableRunner], tuning_config: TuningConfig, - inputs: List[torch.Tensor], + inputs: List[paddle.Tensor], **kwargs, ) -> Tuple[TunableRunner, int]: """Choose the best runner and tactic combination through performance profiling. @@ -426,20 +412,12 @@ def choose_one( Although runners[0] with tactic=-1 is always treated as the fallback runner. Runner authors are suggested to provide a fallback implementation for each runner to avoid potential issues. """ - input_shapes = tuple(self._get_input_sizes(inputs)) - - # Early return if it's not tuning, use cache found one or fallback one if not self.is_tuning_mode: is_cache_hit, runner_id, tactic, stored_profile = self.search_cache( custom_op, runners, input_shapes, tuning_config ) runner = runners[runner_id] - # TODO: check the stored runner and tactic can implement this shape here - # Should not directly try (runner, tactic) here, or it will hurt a lot of inference perf. - - # Record the cache miss config. - # Expect no cache miss in inference. Thus, any cache miss should be recorded. if not is_cache_hit: logger.debug( f"[AutoTunner]: Using fallback tactic for {custom_op} with input shapes {input_shapes}" @@ -448,16 +426,12 @@ def choose_one( f"[AutoTunner]: Generated key{AutoTuner._get_cache_key(custom_op, runners[0], input_shapes, tuning_config)}" ) return runner, tactic - assert len(runners) > 0, "At least one runner is required" - assert all([isinstance(r, TunableRunner) for r in runners]), ( - "All Given runners must be subclass of TunableRunner" - ) - + assert all( + [isinstance(r, TunableRunner) for r in runners] + ), "All Given runners must be subclass of TunableRunner" profiles = self._generate_optimization_profiles(tuning_config, inputs) - # Record the total configs to try self.stats.tuned_op_total_configs[custom_op] = len(profiles) - for p in profiles: tensors = self._prepare_input_tensors(p, inputs) is_cache_hit, runner_id, tactic, _ = self.search_cache( @@ -465,10 +439,8 @@ def choose_one( ) if not is_cache_hit: min_time = float("inf") - # Initialize runner and tactic as None in case of no valid tactic or runners are found runner_id, tactic = None, None for r_id, r in enumerate(runners): - # TODO: use FakeTensor here. valid_tactics = r.get_valid_tactics(tensors, p) runner_arg_names = { p.name for p in inspect.signature(r.forward).parameters.values() @@ -482,12 +454,9 @@ def choose_one( ) except Exception as e: shapes = self._get_input_sizes(tensors) - logger.error( f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}" ) - - # Record the failed profiling combinations if custom_op not in self.stats.failed_profiling_count: self.stats.failed_profiling_count[custom_op] = set() self.stats.failed_profiling_count[custom_op].add( @@ -495,46 +464,35 @@ def choose_one( custom_op, r, p.get_opt_shapes(), tuning_config ) ) - - # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics - # or some runtime error occurs during profiling. time_measured = float("inf") if time_measured < min_time: min_time = time_measured runner_id, tactic = r_id, tac if runner_id is not None: - # At least one valid (runner, tactic) pair is found cache_key = AutoTuner._get_cache_key( custom_op, runners[runner_id], p.get_opt_shapes(), tuning_config ) - # inspect call stack - self.profiling_cache[cache_key] = (runner_id, tactic, p) + self.profiling_cache[cache_key] = runner_id, tactic, p self.stats.tuned_op_successful_configs[custom_op] = ( self.stats.tuned_op_successful_configs.get(custom_op, 0) + 1 ) logger.debug( f"[Autotuner]: profiling chosen runner: {runners[runner_id]} {tactic} for {cache_key}" ) - - # Get the best runner and tactic from cache - # If no valid tactic is found, the fallback runner and tactic will be used _, runner_id, tactic, _ = self.search_cache( custom_op, runners, input_shapes, tuning_config ) - return runners[runner_id], tactic - def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]: - # Handle None tensors for optional inputs and non-Tensor scalar values + def _get_input_sizes(self, inputs: List[paddle.Tensor]) -> List[list]: sizes = [ - input.size() if isinstance(input, torch.Tensor) else torch.Size((0,)) + (tuple(input.shape) if isinstance(input, paddle.Tensor) else tuple((0,))) for input in inputs ] - return sizes def _profile_single_kernel( - self, runner: TunableRunner, inputs: List[torch.Tensor], tactic: int, **kwargs + self, runner: TunableRunner, inputs: List[paddle.Tensor], tactic: int, **kwargs ) -> float: """Profile a single kernel implementation for performance measurement. @@ -551,37 +509,28 @@ def _profile_single_kernel( to get an average execution time. Stream synchronization and delays are used to ensure accurate timing. """ - stream = torch.cuda.current_stream() - # warm up, no timing + stream = paddle.device.current_stream() for _ in range(self.warmup): runner(inputs, tactic=tactic, **kwargs) stream.synchronize() - - # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling. - # TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops) - # Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity. if self.stream_delay_micro_secs > 0: delay_kernel(self.stream_delay_micro_secs) - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - + start = paddle.device.cuda.Event(enable_timing=True) + end = paddle.device.cuda.Event(enable_timing=True) start.record(stream=stream) for _ in range(self.repeat): runner(inputs, tactic=tactic, **kwargs) end.record(stream=stream) stream.synchronize() - avg_time = start.elapsed_time(end) / self.repeat - shapes = self._get_input_sizes(inputs) logger.debug( f"[Autotuner]: profiling {runner} {tactic}, shapes={shapes}, avg_time {avg_time}" ) - return avg_time def _generate_optimization_profiles( - self, tuning_config: TuningConfig, inputs: List[torch.Tensor] + self, tuning_config: TuningConfig, inputs: List[paddle.Tensor] ) -> List[OptimizationProfile]: """Generate optimization profiles for autotuning. @@ -596,41 +545,31 @@ def _generate_optimization_profiles( This method performs a cartesian product of all possible dimension combinations specified in dynamic_tensor_specs. """ - # every dimension created from the concrete input tensor shape - # generate some dynamic dimension description based on the dynamic_tensors - - # Zero handles the case where a TRTLLM op has optional or scalar inputs. base_profile = OptimizationProfile( [ ( - [StaticDim(x) for x in t.size()] - if isinstance(t, torch.Tensor) + [StaticDim(x) for x in tuple(t.shape)] + if isinstance(t, paddle.Tensor) else [StaticDim(0)] ) for t in inputs ], [None] * len(inputs), ) - generated_profiles: List[OptimizationProfile] = [] - dynamic_dims: List[Tuple[Any, ...]] = [] - for spec in tuning_config.dynamic_tensor_specs: assert inspect.isfunction(spec.gen_tuning_buckets) or isinstance( spec.gen_tuning_buckets, (list, tuple) - ), ( - "The given dynamic dimension must provide a opt value generation function or a list of opt values" - ) - assert len(spec.input_idx) == len(spec.dim_idx), ( - f"The number of input indices and dimension indices must be the same, got {len(spec.input_idx)} and {len(spec.dim_idx)}" - ) - assert len(spec.tensor_initializers) == len(spec.input_idx), ( - f"The number of tensor initializers and input indices must be the same, got {len(spec.tensor_initializers)} and {len(spec.input_idx)}" - ) + ), "The given dynamic dimension must provide a opt value generation function or a list of opt values" + assert len(spec.input_idx) == len( + spec.dim_idx + ), f"The number of input indices and dimension indices must be the same, got {len(spec.input_idx)} and {len(spec.dim_idx)}" + assert len(spec.tensor_initializers) == len( + spec.input_idx + ), f"The number of tensor initializers and input indices must be the same, got {len(spec.tensor_initializers)} and {len(spec.input_idx)}" for i, idx in enumerate(spec.input_idx): base_profile.tensor_initializers[idx] = spec.tensor_initializers[i] - if inspect.isfunction(spec.gen_tuning_buckets): opt_shapes = spec.gen_tuning_buckets( base_profile.shapes[spec.input_idx[0]][spec.dim_idx[0]]._opt() @@ -644,8 +583,6 @@ def _generate_optimization_profiles( dynamic_dims.append( (spec.input_idx, spec.dim_idx, opt_shapes_max, opt_shapes) ) - - # grid search, do cartesian product for all the dynamic axis dim_grids = itertools.product(*[d[-1] for d in dynamic_dims]) for opt_point in dim_grids: p = copy.deepcopy(base_profile) @@ -653,22 +590,19 @@ def _generate_optimization_profiles( dynamic_dims ): opt_value = opt_point[pos] - # TODO: fix me, how to set the min and max? min_value = opt_value max_value = opt_shapes_max[opt_value] for i in range(len(input_idx)): p.shapes[input_idx[i]][dim_idx[i]] = DynamicDim( min_value, opt_value, max_value ) - - # Adjust the profile to satisfy the constraints for constraint_spec in tuning_config.constraint_specs: min_value = opt_value = max_value = constraint_spec.infer_shape( p.get_opt_shapes() ) - p.shapes[constraint_spec.input_idx][constraint_spec.dim_idx] = ( - DynamicDim(min_value, opt_value, max_value) - ) + p.shapes[constraint_spec.input_idx][ + constraint_spec.dim_idx + ] = DynamicDim(min_value, opt_value, max_value) generated_profiles.append(p) logger.debug(f"[Autotuner]: generated profile: {p}") return generated_profiles @@ -676,7 +610,7 @@ def _generate_optimization_profiles( @classmethod @lru_cache(maxsize=None) def _find_nearest_profile( - cls, shapes: Tuple[torch.Size], tuning_config: TuningConfig + cls, shapes: Tuple[list], tuning_config: TuningConfig ) -> Tuple: """Find the nearest optimization profile for given inputs User can define their own nearest profile generation method to reduce the host overhead. @@ -691,15 +625,12 @@ def _find_nearest_profile( - profile: Tuple of input tensor shapes """ base_profile = list(list(shape) for shape in shapes) - for spec in tuning_config.dynamic_tensor_specs: - base_profile[spec.input_idx[0]][spec.dim_idx[0]] = ( - spec.map_to_tuning_buckets( - base_profile[spec.input_idx[0]][spec.dim_idx[0]] - ) + base_profile[spec.input_idx[0]][ + spec.dim_idx[0] + ] = spec.map_to_tuning_buckets( + base_profile[spec.input_idx[0]][spec.dim_idx[0]] ) - - # associated dimensions dependent on other free dynamic dimensions, so assign -1 in the profile for constraint_spec in tuning_config.constraint_specs: base_profile[constraint_spec.input_idx][constraint_spec.dim_idx] = -1 return tuple(tuple(shape) for shape in base_profile) @@ -709,7 +640,7 @@ def _get_cache_key( cls, custom_op: str, runner: TunableRunner, - input_shapes: Tuple[torch.Size], + input_shapes: Tuple[list], tuning_config: TuningConfig, ) -> Tuple: return ( @@ -720,8 +651,8 @@ def _get_cache_key( ) def _create_tensor_like( - self, origin_tensor: torch.Tensor, dims: List[Dim], initializer: Callable - ) -> torch.Tensor: + self, origin_tensor: paddle.Tensor, dims: List[Dim], initializer: Callable + ) -> paddle.Tensor: """Create a new tensor matching the properties of the original tensor. Args: @@ -736,30 +667,27 @@ def _create_tensor_like( but with dimensions specified by the dims parameter. """ dtype = origin_tensor.dtype - device = origin_tensor.device + device = origin_tensor.place shapes = [] for d in dims: if isinstance(d, StaticDim): shapes.append(d.val) else: - # TODO: how to make sure the created Tensor has the min/max info assert isinstance(d, DynamicDim) shapes.append(d.opt) return initializer(shapes, dtype, device) def _prepare_input_tensors( - self, profile: OptimizationProfile, inputs: List[torch.Tensor] - ) -> List[torch.Tensor]: - default_initializer = lambda shapes, dtype, device: torch.rand( - shapes, device=device + self, profile: OptimizationProfile, inputs: List[paddle.Tensor] + ) -> List[paddle.Tensor]: + default_initializer = lambda shapes, dtype, device: paddle.rand( + shape=shapes ).to(dtype) tensors = [] for i, p in enumerate(profile.shapes): if any(isinstance(d, DynamicDim) for d in p): tensor = self._create_tensor_like( - inputs[i], - p, - profile.tensor_initializers[i] or default_initializer, + inputs[i], p, profile.tensor_initializers[i] or default_initializer ) else: tensor = inputs[i] diff --git a/flashinfer/cascade.py b/flashinfer/cascade.py index 6d3246cd0e..639ac0f312 100644 --- a/flashinfer/cascade.py +++ b/flashinfer/cascade.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2023 by FlashInfer team. @@ -13,17 +15,15 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools from typing import List, Optional, Tuple, Union -import torch - from .decode import BatchDecodeWithPagedKVCacheWrapper from .jit import JitSpec from .jit import env as jit_env from .jit import gen_jit_spec -from .prefill import BatchPrefillWithPagedKVCacheWrapper, single_prefill_with_kv_cache +from .prefill import (BatchPrefillWithPagedKVCacheWrapper, + single_prefill_with_kv_cache) from .utils import register_custom_op, register_fake_op @@ -44,9 +44,9 @@ def get_cascade_module(): @register_custom_op("flashinfer::merge_state", mutates_args=()) def merge_state( - v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Merge the attention output ``V`` and the logsumexp value ``S`` from the two + v_a: paddle.Tensor, s_a: paddle.Tensor, v_b: paddle.Tensor, s_b: paddle.Tensor +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Merge the attention output ``V`` and the logsumexp value ``S`` from the two KV-segments. Check :ref:`our tutorial ` on the mathematical details. @@ -91,32 +91,32 @@ def merge_state( >>> s_merged.shape torch.Size([2048, 32]) """ - s_a = s_a.to(torch.float32) - s_b = s_b.to(torch.float32) - v_merged = torch.empty_like(v_a) - s_merged = torch.empty_like(s_a) + s_a = s_a.to("float32") + s_b = s_b.to("float32") + v_merged = paddle.empty_like(x=v_a) + s_merged = paddle.empty_like(x=s_a) get_cascade_module().merge_state(v_a, s_a, v_b, s_b, v_merged, s_merged) return v_merged, s_merged @register_fake_op("flashinfer::merge_state") def _fake_merge_state( - v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - v = torch.empty_like(v_a) - s = torch.empty_like(s_a) + v_a: paddle.Tensor, s_a: paddle.Tensor, v_b: paddle.Tensor, s_b: paddle.Tensor +) -> Tuple[paddle.Tensor, paddle.Tensor]: + v = paddle.empty_like(x=v_a) + s = paddle.empty_like(x=s_a) return v, s @register_custom_op("flashinfer::merge_state_in_place", mutates_args=("v", "s")) def merge_state_in_place( - v: torch.Tensor, - s: torch.Tensor, - v_other: torch.Tensor, - s_other: torch.Tensor, - mask: Optional[torch.Tensor] = None, + v: paddle.Tensor, + s: paddle.Tensor, + v_other: paddle.Tensor, + s_other: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, ) -> None: - r"""Merge the self-attention state ``(v, s)`` with another state + """Merge the self-attention state ``(v, s)`` with another state ``(v_other, s_other)`` in-place. Parameters @@ -152,25 +152,27 @@ def merge_state_in_place( >>> s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") >>> flashinfer.merge_state_in_place(v, s, v_other, s_other) """ - s = s.to(torch.float32) - s_other = s_other.to(torch.float32) + s = s.to("float32") + s_other = s_other.to("float32") get_cascade_module().merge_state_in_place(v, s, v_other, s_other, mask) @register_fake_op("flashinfer::merge_state_in_place") def _fake_merge_state_in_place( - v: torch.Tensor, - s: torch.Tensor, - v_other: torch.Tensor, - s_other: torch.Tensor, - mask: Optional[torch.Tensor] = None, + v: paddle.Tensor, + s: paddle.Tensor, + v_other: paddle.Tensor, + s_other: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, ) -> None: pass @register_custom_op("flashinfer::merge_states", mutates_args=()) -def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Merge multiple attention states (v, s). +def merge_states( + v: paddle.Tensor, s: paddle.Tensor +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Merge multiple attention states (v, s). Parameters ---------- @@ -206,27 +208,27 @@ def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch. >>> s_merged.shape torch.Size([2048, 32]) """ - device = v.device - s = s.to(torch.float32) - seq_len, _, num_heads, head_dim = v.size() - v_merged = torch.empty(seq_len, num_heads, head_dim, dtype=v.dtype, device=device) - s_merged = torch.empty(seq_len, num_heads, dtype=torch.float32, device=device) + device = v.place + s = s.to("float32") + seq_len, _, num_heads, head_dim = tuple(v.shape) + v_merged = paddle.empty(shape=[seq_len, num_heads, head_dim], dtype=v.dtype) + s_merged = paddle.empty(shape=[seq_len, num_heads], dtype="float32") get_cascade_module().merge_states(v, s, v_merged, s_merged) return v_merged, s_merged @register_fake_op("flashinfer::merge_states") def _fake_merge_states( - v: torch.Tensor, s: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - seq_len, _, num_heads, head_dim = v.size() - v_merged = torch.empty(seq_len, num_heads, head_dim, dtype=v.dtype) - s_merged = torch.empty(seq_len, num_heads, dtype=torch.float32) + v: paddle.Tensor, s: paddle.Tensor +) -> Tuple[paddle.Tensor, paddle.Tensor]: + seq_len, _, num_heads, head_dim = tuple(v.shape) + v_merged = paddle.empty(shape=[seq_len, num_heads, head_dim], dtype=v.dtype) + s_merged = paddle.empty(shape=[seq_len, num_heads], dtype="float32") return v_merged, s_merged class MultiLevelCascadeAttentionWrapper: - r"""Attention wrapper for memory efficient multi-level cascade inference, this API assumes all + """Attention wrapper for memory efficient multi-level cascade inference, this API assumes all levels KV-Cache are stored in a unified paged table. Please check :ref:`cascade-inference-data-layout` for data layout in cascade inference. @@ -302,15 +304,15 @@ class MultiLevelCascadeAttentionWrapper: def __init__( self, num_levels, - float_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, - qo_indptr_buf_arr: Optional[List[torch.Tensor]] = None, - paged_kv_indptr_buf_arr: Optional[List[torch.Tensor]] = None, - paged_kv_indices_buf_arr: Optional[List[torch.Tensor]] = None, - paged_kv_last_page_len_buf_arr: Optional[List[torch.Tensor]] = None, + qo_indptr_buf_arr: Optional[List[paddle.Tensor]] = None, + paged_kv_indptr_buf_arr: Optional[List[paddle.Tensor]] = None, + paged_kv_indices_buf_arr: Optional[List[paddle.Tensor]] = None, + paged_kv_last_page_len_buf_arr: Optional[List[paddle.Tensor]] = None, ) -> None: - r"""Constructor of :class:`MultiLevelCascadeAttentionWrapper`. + """Constructor of :class:`MultiLevelCascadeAttentionWrapper`. Parameters ---------- @@ -351,12 +353,7 @@ def __init__( paged_kv_indices_buf=paged_kv_indices_buf, paged_kv_last_page_len_buf=paged_kv_last_page_len_buf, ) - for ( - qo_indptr_buf, - paged_kv_indptr_buf, - paged_kv_indices_buf, - paged_kv_last_page_len_buf, - ) in zip( + for qo_indptr_buf, paged_kv_indptr_buf, paged_kv_indices_buf, paged_kv_last_page_len_buf in zip( qo_indptr_buf_arr, paged_kv_indptr_buf_arr, paged_kv_indices_buf_arr, @@ -377,10 +374,10 @@ def is_cuda_graph_enabled(self) -> bool: def reset_workspace_buffer( self, - float_workspace_buffer: torch.Tensor, - int_workspace_buffers: List[torch.Tensor], + float_workspace_buffer: paddle.Tensor, + int_workspace_buffers: List[paddle.Tensor], ) -> None: - r"""Reset the workspace buffer. + """Reset the workspace buffer. Parameters ---------- @@ -399,10 +396,10 @@ def reset_workspace_buffer( def plan( self, - qo_indptr_arr: List[torch.Tensor], - paged_kv_indptr_arr: List[torch.Tensor], - paged_kv_indices_arr: List[torch.Tensor], - paged_kv_last_page_len: List[torch.Tensor], + qo_indptr_arr: List[paddle.Tensor], + paged_kv_indptr_arr: List[paddle.Tensor], + paged_kv_indices_arr: List[paddle.Tensor], + paged_kv_last_page_len: List[paddle.Tensor], num_qo_heads: int, num_kv_heads: int, head_dim: int, @@ -416,9 +413,9 @@ def plan( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, q_data_type: str = "float16", - kv_data_type: Optional[Union[str, torch.dtype]] = None, + kv_data_type: Optional[Union[str, paddle.dtype]] = None, ): - r"""Create auxiliary data structures for multi-level cascade attention for multiple + """Create auxiliary data structures for multi-level cascade attention for multiple forward calls within the same decode step. Please check :ref:`cascade-inference-data-layout` for data layout in cascade inference. @@ -463,7 +460,7 @@ def plan( The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: - :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, + :math:`\\texttt{logits_soft_cap} \\times \\mathrm{tanh}(x / \\texttt{logits_soft_cap})`, where :math:`x` is the input logits. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to @@ -516,12 +513,8 @@ def plan( begin_forward = plan - def run( - self, - q: torch.Tensor, - paged_kv_cache: torch.Tensor, - ): - r"""Compute multi-level cascade attention. + def run(self, q: paddle.Tensor, paged_kv_cache: paddle.Tensor): + """Compute multi-level cascade attention. Parameters ---------- @@ -542,21 +535,18 @@ def run( ``paged_kv_cache[:, 1]`` is the value-cache. """ out, lse = self._batch_prefill_wrappers[-1].run( - q, - paged_kv_cache, - return_lse=True, + q, paged_kv_cache, return_lse=True ) for wrapper in self._batch_prefill_wrappers[:-1]: out_i, lse_i = wrapper.run(q, paged_kv_cache, return_lse=True) merge_state_in_place(out, lse, out_i, lse_i) - return out forward = run class BatchDecodeWithSharedPrefixPagedKVCacheWrapper: - r"""Wrapper class for decode attention with shared-prefix paged kv-cache for batch + """Wrapper class for decode attention with shared-prefix paged kv-cache for batch of requests. The shared-prefix KV-Cache was stored in a standalone tensors, and the unique KV-Cache of each request was stored in a paged KV-Cache data structure. @@ -640,7 +630,7 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper: """ def __init__( - self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD" + self, float_workspace_buffer: paddle.Tensor, kv_layout: str = "NHD" ) -> None: self._batch_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( float_workspace_buffer, kv_layout @@ -648,9 +638,9 @@ def __init__( self._kv_layout = kv_layout def reset_workspace_buffer( - self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + self, float_workspace_buffer: paddle.Tensor, int_workspace_buffer ) -> None: - r"""Reset the workspace buffer. + """Reset the workspace buffer. Parameters ---------- @@ -668,16 +658,16 @@ def reset_workspace_buffer( def begin_forward( self, - unique_kv_indptr: torch.Tensor, - unique_kv_indices: torch.Tensor, - unique_kv_last_page_len: torch.Tensor, + unique_kv_indptr: paddle.Tensor, + unique_kv_indices: paddle.Tensor, + unique_kv_last_page_len: paddle.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, data_type: str = "float16", ) -> None: - r"""Plan shared-prefix batch decode attention for given problem specification. + """Plan shared-prefix batch decode attention for given problem specification. Parameters ---------- @@ -729,12 +719,12 @@ def begin_forward( def forward( self, - q: torch.Tensor, - k_shared: torch.Tensor, - v_shared: torch.Tensor, - unique_kv_cache: torch.Tensor, - ) -> torch.Tensor: - r"""Compute batch decode attention between queries and shared-prefix paged + q: paddle.Tensor, + k_shared: paddle.Tensor, + v_shared: paddle.Tensor, + unique_kv_cache: paddle.Tensor, + ) -> paddle.Tensor: + """Compute batch decode attention between queries and shared-prefix paged kv-cache. Parameters @@ -783,20 +773,18 @@ def forward( return_lse=True, ) V_unique, S_unique = self._batch_decode_wrapper.forward_return_lse( - q, - unique_kv_cache, - pos_encoding_mode="NONE", + q, unique_kv_cache, pos_encoding_mode="NONE" ) merge_state_in_place(V_shared, S_shared, V_unique, S_unique) return V_shared def end_forward(self) -> None: - r"""Warning: this function is deprecated and has no effect""" + """Warning: this function is deprecated and has no effect""" pass class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: - r"""Wrapper class for prefill/append attention with shared-prefix paged kv-cache for + """Wrapper class for prefill/append attention with shared-prefix paged kv-cache for batch of requests. Check :ref:`our tutorial` for paged kv-cache layout. @@ -887,9 +875,9 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: """ def __init__( - self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD" + self, float_workspace_buffer: paddle.Tensor, kv_layout: str = "NHD" ) -> None: - r"""Constructor of :class:`BatchDecodeWithSharedPrefixPagedKVCacheWrapper`. + """Constructor of :class:`BatchDecodeWithSharedPrefixPagedKVCacheWrapper`. Parameters ---------- @@ -906,9 +894,9 @@ def __init__( self._kv_layout = kv_layout def reset_workspace_buffer( - self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor + self, float_workspace_buffer: paddle.Tensor, int_workspace_buffer: paddle.Tensor ) -> None: - r"""Reset the workspace buffer. + """Reset the workspace buffer. Parameters ---------- @@ -926,16 +914,16 @@ def reset_workspace_buffer( def begin_forward( self, - qo_indptr: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, + qo_indptr: paddle.Tensor, + paged_kv_indptr: paddle.Tensor, + paged_kv_indices: paddle.Tensor, + paged_kv_last_page_len: paddle.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, ) -> None: - r"""Create auxiliary data structures for shared-prefix batch prefill/append + """Create auxiliary data structures for shared-prefix batch prefill/append attention for multiple forward calls within the same prefill/append step. Parameters @@ -981,17 +969,17 @@ def begin_forward( def forward( self, - q: torch.Tensor, - k_shared: torch.Tensor, - v_shared: torch.Tensor, - unique_kv_cache: torch.Tensor, + q: paddle.Tensor, + k_shared: paddle.Tensor, + v_shared: paddle.Tensor, + unique_kv_cache: paddle.Tensor, causal: bool = False, use_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - ) -> torch.Tensor: - r"""Compute batch prefill/append attention between query and shared-prefix paged + ) -> paddle.Tensor: + """Compute batch prefill/append attention between query and shared-prefix paged kv-cache. Parameters @@ -1071,5 +1059,5 @@ def forward( return V_shared def end_forward(self) -> None: - r"""Warning: this function is deprecated and has no effect""" + """Warning: this function is deprecated and has no effect""" pass diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index f7ae3754ac..8abbec6a83 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -6,37 +6,37 @@ from .trtllm_ar import AllReduceStrategyConfig as AllReduceStrategyConfig from .trtllm_ar import AllReduceStrategyType as AllReduceStrategyType from .trtllm_ar import QuantizationSFLayout as QuantizationSFLayout -from .trtllm_ar import ( - compute_fp4_swizzled_layout_sf_size as compute_fp4_swizzled_layout_sf_size, -) +from .trtllm_ar import \ + compute_fp4_swizzled_layout_sf_size as compute_fp4_swizzled_layout_sf_size from .trtllm_ar import gen_trtllm_comm_module as gen_trtllm_comm_module from .trtllm_ar import trtllm_allreduce_fusion as trtllm_allreduce_fusion -from .trtllm_ar import ( - trtllm_create_ipc_workspace_for_all_reduce as trtllm_create_ipc_workspace_for_all_reduce, -) -from .trtllm_ar import ( - trtllm_create_ipc_workspace_for_all_reduce_fusion as trtllm_create_ipc_workspace_for_all_reduce_fusion, -) +from .trtllm_ar import \ + trtllm_create_ipc_workspace_for_all_reduce as \ + trtllm_create_ipc_workspace_for_all_reduce +from .trtllm_ar import \ + trtllm_create_ipc_workspace_for_all_reduce_fusion as \ + trtllm_create_ipc_workspace_for_all_reduce_fusion from .trtllm_ar import trtllm_custom_all_reduce as trtllm_custom_all_reduce -from .trtllm_ar import ( - trtllm_destroy_ipc_workspace_for_all_reduce as trtllm_destroy_ipc_workspace_for_all_reduce, -) -from .trtllm_ar import ( - trtllm_destroy_ipc_workspace_for_all_reduce_fusion as trtllm_destroy_ipc_workspace_for_all_reduce_fusion, -) +from .trtllm_ar import \ + trtllm_destroy_ipc_workspace_for_all_reduce as \ + trtllm_destroy_ipc_workspace_for_all_reduce +from .trtllm_ar import \ + trtllm_destroy_ipc_workspace_for_all_reduce_fusion as \ + trtllm_destroy_ipc_workspace_for_all_reduce_fusion from .trtllm_ar import trtllm_lamport_initialize as trtllm_lamport_initialize -from .trtllm_ar import trtllm_lamport_initialize_all as trtllm_lamport_initialize_all -from .trtllm_ar import trtllm_moe_allreduce_fusion as trtllm_moe_allreduce_fusion -from .trtllm_ar import ( - trtllm_moe_finalize_allreduce_fusion as trtllm_moe_finalize_allreduce_fusion, -) +from .trtllm_ar import \ + trtllm_lamport_initialize_all as trtllm_lamport_initialize_all +from .trtllm_ar import \ + trtllm_moe_allreduce_fusion as trtllm_moe_allreduce_fusion +from .trtllm_ar import \ + trtllm_moe_finalize_allreduce_fusion as \ + trtllm_moe_finalize_allreduce_fusion from .vllm_ar import all_reduce as vllm_all_reduce from .vllm_ar import dispose as vllm_dispose from .vllm_ar import gen_vllm_comm_module as gen_vllm_comm_module -from .vllm_ar import get_graph_buffer_ipc_meta as vllm_get_graph_buffer_ipc_meta +from .vllm_ar import \ + get_graph_buffer_ipc_meta as vllm_get_graph_buffer_ipc_meta from .vllm_ar import init_custom_ar as vllm_init_custom_ar from .vllm_ar import meta_size as vllm_meta_size from .vllm_ar import register_buffer as vllm_register_buffer from .vllm_ar import register_graph_buffers as vllm_register_graph_buffers - -# from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo diff --git a/flashinfer/comm/cuda_ipc.py b/flashinfer/comm/cuda_ipc.py index e85c9f26e8..ca20601c5a 100644 --- a/flashinfer/comm/cuda_ipc.py +++ b/flashinfer/comm/cuda_ipc.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,18 +15,10 @@ See the License for the specific language governing permissions and limitations under the License. """ - import ctypes from dataclasses import dataclass from typing import Any, Dict, List, Optional -import torch.distributed as dist -from torch.distributed import ProcessGroup - -# NOTE(Zihao): we should use cuda-python instead of ctypes cuda runtime bindings. -# However, cuda-python's API is not stable yet, so we use ctypes bindings instead. -# which is copied from vllm codebase. - cudaError_t = ctypes.c_int cudaMemcpyKind = ctypes.c_int @@ -46,7 +40,7 @@ def find_loaded_library(lib_name) -> Optional[str]: the file `/proc/self/maps` contains the memory maps of the process, which includes the shared libraries loaded by the process. We can use this file to find the path of the a loaded library. - """ # noqa + """ found = False with open("/proc/self/maps") as f: for line in f: @@ -54,16 +48,13 @@ def find_loaded_library(lib_name) -> Optional[str]: found = True break if not found: - # the library is not loaded in the current process return None - # if lib_name is libcudart, we need to match a line with: - # address /path/to/libcudart-hash.so.11.0 start = line.index("/") path = line[start:].strip() filename = path.split("/")[-1] - assert filename.rpartition(".so")[0].startswith(lib_name), ( - f"Unexpected filename: {filename} for library {lib_name}" - ) + assert filename.rpartition(".so")[0].startswith( + lib_name + ), f"Unexpected filename: {filename} for library {lib_name}" return path @@ -71,52 +62,36 @@ class CudaRTLibrary: """CudaRTLibrary""" exported_functions = [ - # ​cudaError_t cudaSetDevice ( int device ) Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), - # cudaError_t cudaDeviceSynchronize ( void ) Function("cudaDeviceSynchronize", cudaError_t, []), - # ​cudaError_t cudaDeviceReset ( void ) Function("cudaDeviceReset", cudaError_t, []), - # const char* cudaGetErrorString ( cudaError_t error ) Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), - # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) Function( "cudaMalloc", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t], ), - # ​cudaError_t cudaFree ( void* devPtr ) Function("cudaFree", cudaError_t, [ctypes.c_void_p]), - # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) Function( "cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t] ), - # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa Function( "cudaMemcpy", cudaError_t, [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind], ), - # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa Function( "cudaIpcGetMemHandle", cudaError_t, [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p], ), - # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa Function( "cudaIpcOpenMemHandle", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint], ), ] - - # class attribute to store the mapping from the path to the library - # to avoid loading the same library multiple times path_to_library_cache: Dict[str, Any] = {} - - # class attribute to store the mapping from library path - # to the corresponding dictionary path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} def __init__(self, so_file: Optional[str] = None): @@ -127,7 +102,6 @@ def __init__(self, so_file: Optional[str] = None): lib = ctypes.CDLL(so_file) CudaRTLibrary.path_to_library_cache[so_file] = lib self.lib = CudaRTLibrary.path_to_library_cache[so_file] - if so_file not in CudaRTLibrary.path_to_dict_mapping: _funcs = {} for func in CudaRTLibrary.exported_functions: @@ -195,7 +169,7 @@ def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: def create_shared_buffer( - size_in_bytes: int, group: Optional[ProcessGroup] = None +>>>>>> size_in_bytes: int, group: Optional[torch.distributed.ProcessGroup] = None ) -> List[int]: """ Creates a shared buffer and returns a list of pointers @@ -204,34 +178,34 @@ def create_shared_buffer( pointer = cudart.cudaMalloc(size_in_bytes) handle = cudart.cudaIpcGetMemHandle(pointer) if group is None: - group = dist.group.WORLD - world_size = dist.get_world_size(group=group) - rank = dist.get_rank(group=group) +>>>>>> group = torch.distributed.group.WORLD + world_size = paddle.distributed.get_world_size(group=group) + rank = paddle.distributed.get_rank(group=group) handles = [None] * world_size - dist.all_gather_object(handles, handle, group=group) + handles = [] + paddle.distributed.all_gather_object(object_list=handles, obj=handle, group=group) handles = [None] * world_size - dist.all_gather_object(handles, handle, group=group) - + handles = [] + paddle.distributed.all_gather_object(object_list=handles, obj=handle, group=group) pointers: List[int] = [] for i, h in enumerate(handles): if i == rank: pointers.append(pointer.value) else: pointers.append(cudart.cudaIpcOpenMemHandle(h).value) - - dist.barrier(group=group) + paddle.distributed.barrier(group=group) return pointers def free_shared_buffer( - pointers: List[int], group: Optional[ProcessGroup] = None +>>>>>> pointers: List[int], group: Optional[torch.distributed.ProcessGroup] = None ) -> None: """ Frees a shared buffer. """ if group is None: - group = dist.group.WORLD - rank = dist.get_rank(group=group) +>>>>>> group = torch.distributed.group.WORLD + rank = paddle.distributed.get_rank(group=group) if pointers and len(pointers) > rank and pointers[rank] is not None: cudart.cudaFree(ctypes.c_void_p(pointers[rank])) - dist.barrier(group=group) + paddle.distributed.barrier(group=group) diff --git a/flashinfer/comm/dlpack_utils.py b/flashinfer/comm/dlpack_utils.py index f04fa748e0..46bbc731f3 100644 --- a/flashinfer/comm/dlpack_utils.py +++ b/flashinfer/comm/dlpack_utils.py @@ -1,71 +1,33 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import ctypes -from ctypes import ( - CFUNCTYPE, - POINTER, - c_int, - c_int64, - c_size_t, - c_uint8, - c_uint16, - c_void_p, - pointer, -) - -import torch - - -# Define data structures required for DLPack +from ctypes import (CFUNCTYPE, POINTER, c_int, c_int64, c_size_t, c_uint8, + c_uint16, c_void_p, pointer) + +import paddle + + class DLDataType(ctypes.Structure): - _fields_ = [ - ("code", c_uint8), # Data type code, e.g., 2 for float - ("bits", c_uint8), # Number of bits per element, e.g., 32 - ("lanes", c_uint16), # Number of lanes, usually 1 - ] + _fields_ = [("code", c_uint8), ("bits", c_uint8), ("lanes", c_uint16)] class DLDevice(ctypes.Structure): - _fields_ = [ - ("device_type", c_int), # Device type, typically 2 for GPU - ("device_id", c_int), # Device ID, usually 0 for default GPU - ] + _fields_ = [("device_type", c_int), ("device_id", c_int)] class DLTensor(ctypes.Structure): _fields_ = [ - ("data", c_void_p), # Data pointer - ("device", DLDevice), # Device information - ("ndim", c_int), # Number of dimensions - ("dtype", DLDataType), # Data type - ("shape", POINTER(c_int64)), # Pointer to array of dimension sizes - ( - "strides", - POINTER(c_int64), - ), # Pointer to strides array (can be NULL for default contiguous layout) - ("byte_offset", c_size_t), # Byte offset (usually 0) + ("data", c_void_p), + ("device", DLDevice), + ("ndim", c_int), + ("dtype", DLDataType), + ("shape", POINTER(c_int64)), + ("strides", POINTER(c_int64)), + ("byte_offset", c_size_t), ] -# Deleter type for DLManagedTensor -DLManagedTensorDeleter = CFUNCTYPE( - None, POINTER(ctypes.c_void_p) -) # Not used directly here +DLManagedTensorDeleter = CFUNCTYPE(None, POINTER(ctypes.c_void_p)) -# Define DLManagedTensor structure, with deleter prototype void(*deleter)(DLManagedTensor*) class DLManagedTensor(ctypes.Structure): pass @@ -77,14 +39,11 @@ class DLManagedTensor(ctypes.Structure): ] -# A no-op deleter that doesn't perform any operation @CFUNCTYPE(None, POINTER(DLManagedTensor)) def no_op_deleter(dmt_ptr): - # You can also call cudaFree here if you want to free memory when the tensor's lifecycle ends pass -# Wrapper class to prevent Python garbage collection of DLPack-related objects class CapsuleWrapper: """ A wrapper class that holds references to the PyCapsule and its associated data. @@ -103,13 +62,9 @@ def __init__(self, capsule, shape_array, managed_tensor): shape_array: The array containing tensor shape information managed_tensor: The DLManagedTensor instance that the capsule points to """ - self.capsule = ( - capsule # The main PyCapsule object that can be passed to other libraries - ) - self._shape_array = shape_array # Keep reference to prevent garbage collection - self._managed_tensor = ( - managed_tensor # Keep reference to prevent garbage collection - ) + self.capsule = capsule + self._shape_array = shape_array + self._managed_tensor = managed_tensor def create_dlpack_capsule( @@ -129,61 +84,48 @@ def create_dlpack_capsule( """ bits_per_elements = 0 dldata_type_code = 0 - # refer to https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h#L160 if torch_dtype in [ - torch.float8_e5m2, - torch.float8_e4m3fn, - torch.bfloat16, - torch.float16, - torch.float32, - torch.float64, +>>>>>> paddle.float8_e5m2, + paddle.float8_e4m3fn, + "bfloat16", + "float16", + "float32", + "float64", ]: - bits_per_elements = torch.finfo(torch_dtype).bits + bits_per_elements = paddle.finfo(dtype=torch_dtype).bits dldata_type_code = 2 - elif torch_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: - bits_per_elements = torch.iinfo(torch_dtype).bits + elif torch_dtype in ["int8", "int16", "int32", "int64"]: + bits_per_elements = paddle.iinfo(dtype=torch_dtype).bits dldata_type_code = 0 - elif torch_dtype in [torch.uint8, torch.uint16, torch.uint32, torch.uint64]: - bits_per_elements = torch.iinfo(torch_dtype).bits +>>>>>> elif torch_dtype in ["uint8", torch.uint16, torch.uint32, torch.uint64]: + bits_per_elements = paddle.iinfo(dtype=torch_dtype).bits dldata_type_code = 1 else: raise NotImplementedError(torch_dtype) bytes_per_element = bits_per_elements // 8 - # Allocate space for shape (constructing a one-dimensional tensor here) - ShapeArrayType = c_int64 * 2 # 1 dimension + ShapeArrayType = c_int64 * 2 shape_array = ShapeArrayType(num_segments, segment_size // bytes_per_element) stride_array = ShapeArrayType(segment_stride // bytes_per_element, 1) - # Set device information: GPU (device_type=2) and device_id=dev_id (modify as needed) device = DLDevice(device_type=2, device_id=dev_id) - # Set data type dtype = DLDataType(code=dldata_type_code, bits=bits_per_elements, lanes=1) - # Construct DLTensor dltensor = DLTensor() dltensor.data = c_void_p(ptr) - dltensor.device = device + dltensor.place = device dltensor.ndim = 2 dltensor.dtype = dtype - dltensor.shape = ctypes.cast(shape_array, POINTER(c_int64)) + tuple(dltensor.shape) = ctypes.cast(shape_array, POINTER(c_int64)) dltensor.strides = ctypes.cast(stride_array, POINTER(c_int64)) dltensor.byte_offset = 0 - # Construct DLManagedTensor and set deleter to no-op (you can also call cudaFree here) managed_tensor = DLManagedTensor() managed_tensor.dl_tensor = dltensor managed_tensor.manager_ctx = None managed_tensor.deleter = no_op_deleter - # Note: Must ensure that shape_array and managed_tensor are not garbage collected by Python, - # A simple way is to attach them to the capsule object. - # Call PyCapsule_New to create capsule PyCapsule_New = ctypes.pythonapi.PyCapsule_New PyCapsule_New.restype = c_void_p PyCapsule_New.argtypes = [c_void_p, ctypes.c_char_p, c_void_p] - # Allocate managed_tensor on the heap (note that pointer returns a pointer) managed_tensor_ptr = pointer(managed_tensor) - # The capsule name must be "dltensor", as required by the DLPack specification capsule_ptr = PyCapsule_New(managed_tensor_ptr, b"dltensor", None) - # Convert capsule_ptr to Python object capsule = ctypes.cast(capsule_ptr, ctypes.py_object).value - # To prevent shape_array and managed_tensor from being collected, we attach them as attributes to the capsule capsule_wrapper = CapsuleWrapper(capsule, shape_array, managed_tensor) return capsule_wrapper @@ -193,7 +135,7 @@ def pack_strided_memory( segment_size: int, segment_stride: int, num_segments: int, - dtype: torch.dtype, + dtype: paddle.dtype, dev_id, ): """ @@ -214,10 +156,9 @@ def pack_strided_memory( This function creates a new DLPack capsule each time it's called, even with the same pointer. Each capsule is consumed only once. """ - # Create a new capsule each time capsule_wrapper = create_dlpack_capsule( ptr, segment_size, segment_stride, num_segments, dtype, dev_id ) - torch_tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule) + torch_tensor = paddle.utils.dlpack.from_dlpack(dlpack=capsule_wrapper.capsule) torch_tensor._capsule_wrapper = capsule_wrapper return torch_tensor diff --git a/flashinfer/comm/mapping.py b/flashinfer/comm/mapping.py index eca1481f0e..c43662154d 100644 --- a/flashinfer/comm/mapping.py +++ b/flashinfer/comm/mapping.py @@ -1,21 +1,6 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Code imported from TensorRT-LLM/tensorrt_llm/mapping.py from typing import List -import torch +import paddle class Mapping(object): @@ -124,72 +109,56 @@ def __init__( cp_config=None, tp_size=1, pp_size=1, - moe_cluster_size=-1, # -1 means no moe - moe_tp_size=-1, # -1 means no moe - moe_ep_size=-1, # -1 means no moe + moe_cluster_size=-1, + moe_tp_size=-1, + moe_ep_size=-1, attn_tp_size=-1, attn_cp_size=-1, auto_parallel=False, enable_attention_dp=False, ): - # set default values for non-moe cases - # or where only one MOE parallelism size is specified if moe_cluster_size == -1: moe_cluster_size = 1 - if moe_tp_size == -1 and moe_ep_size == -1: moe_tp_size = tp_size // moe_cluster_size moe_ep_size = 1 - elif moe_tp_size == -1: moe_tp_size = tp_size // (moe_ep_size * moe_cluster_size) - elif moe_ep_size == -1: moe_ep_size = tp_size // (moe_tp_size * moe_cluster_size) - if attn_tp_size == -1 and attn_cp_size == -1: - # fallback to ulysses attn_tp_size = tp_size * cp_size attn_cp_size = 1 - elif attn_tp_size == -1: attn_tp_size = cp_size * tp_size // attn_cp_size - elif attn_cp_size == -1: attn_cp_size = cp_size * tp_size // attn_tp_size - if attn_cp_size != 1: raise ValueError( f"attn_cp_size must be 1 for now, but got {attn_tp_size}, {attn_cp_size}." ) - if auto_parallel: if tp_size != 1 or pp_size != 1 or tp_size != 1: raise ValueError( f"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got {tp_size}, {pp_size}, {cp_size}." ) - else: - if tp_size * pp_size * cp_size != world_size: - raise ValueError( - f"world_size must equal to tp_size * pp_size * cp_size, but got {world_size} != {tp_size} * {pp_size} * {cp_size}." - ) - + elif tp_size * pp_size * cp_size != world_size: + raise ValueError( + f"world_size must equal to tp_size * pp_size * cp_size, but got {world_size} != {tp_size} * {pp_size} * {cp_size}." + ) moe_tp_ep_size = moe_tp_size * moe_ep_size moe_tp_cluster_ep_size = moe_tp_ep_size * moe_cluster_size if moe_tp_cluster_ep_size != tp_size: raise ValueError( f"tp_size must equal to moe_tp_size * moe_ep_size * moe_cluster_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size} * {moe_cluster_size}" ) - attn_tp_cp_size = attn_tp_size * attn_cp_size if attn_tp_cp_size != tp_size * cp_size: raise ValueError( f"tp_size * cp_size must equal to attn_tp_size * attn_cp_size, but got {tp_size} * {cp_size} != {attn_tp_size} * {attn_cp_size}" ) - if moe_ep_size != 1 and cp_size > 1: raise NotImplementedError("CP don't support MoE tp/ep yet") - self.tp_size = tp_size self.cp_size = cp_size self.cp_config = cp_config if cp_config is not None else {} @@ -210,24 +179,17 @@ def __init__( self.moe_cluster_groups = [] self.moe_tp_groups = [] self.moe_ep_groups = [] - if moe_cluster_size > 1: assert moe_ep_size == 1 - - # init pp group for i in range(tp_size * cp_size): ranks = range(i, world_size, tp_size * cp_size) self.pp_groups.append(list(ranks)) - - # init cp group for i in range(pp_size): for j in range(tp_size): ranks = range( i * tp_size * cp_size + j, (i + 1) * tp_size * cp_size + j, tp_size ) self.cp_groups.append(list(ranks)) - - # init tp group for i in range(pp_size): for j in range(cp_size): ranks = range( @@ -235,8 +197,6 @@ def __init__( i * tp_size * cp_size + (j + 1) * tp_size, ) self.tp_groups.append(list(ranks)) - - # init moe tp group for i in range(pp_size): for j in range(moe_cluster_size * moe_ep_size): ranks = range( @@ -245,8 +205,6 @@ def __init__( moe_cluster_size * moe_ep_size, ) self.moe_tp_groups.append(list(ranks)) - - # init moe cluster group for i in range(pp_size): for j in range(moe_tp_size): ranks = range( @@ -255,8 +213,6 @@ def __init__( + (j + 1) * moe_cluster_size * moe_ep_size, ) self.moe_cluster_groups.append(list(ranks)) - - # init moe ep group for i in range(pp_size): for j in range(moe_tp_size): for k in range(moe_cluster_size): @@ -273,7 +229,6 @@ def __init__( def __eq__(self, other): if not isinstance(other, Mapping): return NotImplemented - return ( self.world_size == other.world_size and self.rank == other.rank @@ -313,7 +268,6 @@ def rank(self): @rank.setter def rank(self, rank: int): - # TODO(qijun): skip check for enable_attention_dp temporarily, will support attention_dp_size if not self.enable_attention_dp: if not isinstance(rank, int) or rank < 0 and rank >= self.world_size: raise ValueError( @@ -440,10 +394,9 @@ def has_moe_ep(self): return self.moe_ep_size > 1 def pp_layers(self, num_layers: int) -> List[int]: - # If num_layers % pp_size = n != 0, first n ranks get one extra layer - return torch.tensor_split(torch.arange(num_layers), self.pp_size)[ - self.pp_rank - ].tolist() + return paddle.tensor_split( + x=paddle.arange(end=num_layers), num_or_indices=self.pp_size + )[self.pp_rank].tolist() def ep_experts(self, num_experts: int) -> List[int]: assert self.cp_size == 1 diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index e495825def..a4fa3aa55f 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -1,18 +1,6 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Code imported from TensorRT-LLM/tensorrt_llm/_mnnvl_utils.py +import sys + + import ctypes import logging import os @@ -20,32 +8,28 @@ import sys from typing import Any, Dict, List, Optional -import torch +import paddle from cuda import cuda +from flashinfer.paddle_utils import * from ..cuda_utils import checkCudaErrors from .dlpack_utils import create_dlpack_capsule, pack_strided_memory from .mapping import Mapping IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1" - -# mpi4py only exports MPI_COMM_TYPE_SHARED, so we define OMPI_COMM_TYPE_HOST here OMPI_COMM_TYPE_HOST = 9 - -# Constants from C++ header -SIGNAL_PAD_SIZE = 2048 # kSIGNAL_PAD_SIZE from header - +SIGNAL_PAD_SIZE = 2048 MNNVL_DEBUG = False def round_up(val: int, gran: int) -> int: """Efficient implementation assuming gran is a power of 2""" - return (val + gran - 1) & ~(gran - 1) + return val + gran - 1 & ~(gran - 1) def create_tensor_from_cuda_memory( - ptr: int, shape: tuple, dtype: torch.dtype, device_id: int -) -> torch.Tensor: + ptr: int, shape: tuple, dtype: paddle.dtype, device_id: int +) -> paddle.Tensor: """ Create a PyTorch tensor from a CUDA memory pointer using DLPack. @@ -58,24 +42,15 @@ def create_tensor_from_cuda_memory( Returns: PyTorch tensor that wraps the CUDA memory """ - # Calculate total size in elements numel = 1 for dim in shape: numel *= dim - - # Get element size in bytes - element_size = torch.tensor([], dtype=dtype).element_size() - - # Create DLPack capsule for contiguous memory (stride = element_size, num_segments = numel) + element_size = paddle.to_tensor(data=[], dtype=dtype).element_size() capsule_wrapper = create_dlpack_capsule( ptr, element_size, element_size, numel, dtype, device_id ) - - # Convert to tensor and reshape - tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule) - tensor._capsule_wrapper = capsule_wrapper # Keep reference to prevent GC - - # Reshape to desired shape + tensor = paddle.utils.dlpack.from_dlpack(dlpack=capsule_wrapper.capsule) + tensor._capsule_wrapper = capsule_wrapper return tensor.view(shape) @@ -92,16 +67,10 @@ def test_cuda_memory_access(ptr: int, size: int, device_id: int) -> bool: True if memory is accessible, False otherwise """ try: - # Test with a small 4-byte read/write test_size = min(4, size) host_data = bytearray(test_size) - - # Try to copy from device to host checkCudaErrors(cuda.cuMemcpyDtoH(host_data, ptr, test_size)) - - # Try to copy back from host to device checkCudaErrors(cuda.cuMemcpyHtoD(ptr, host_data, test_size)) - print(f"DEBUG: Memory access test PASSED for ptr=0x{ptr:x}") return True except Exception as e: @@ -115,24 +84,19 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]: """ if not host_ptr_array: return None - ArrayType = ctypes.c_uint64 * len(host_ptr_array) c_array = ArrayType(*host_ptr_array) size_in_bytes = ctypes.sizeof(c_array) - device_ptr: cuda.CUdeviceptr = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes)) checkCudaErrors( cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes) ) - # c_array should be freed by GC - return device_ptr if IS_BUILDING_DOCS: - # Mock classes for building docs - class MpiComm: # type: ignore[no-redef] + class MpiComm: @classmethod def set_mpi_comm(cls, new_comm): pass @@ -140,24 +104,15 @@ def set_mpi_comm(cls, new_comm): def __getattr__(self, name): return None - class MnnvlMemory: # type: ignore[no-redef] + class MnnvlMemory: initialized: bool = False - current_mem_offset: int = 0 - current_rank_stride: int = 0 # stride for ranks and also address space size. + current_rank_stride: int = 0 current_start_address: int = 0 - - # allocation granularity allocation_granularity: int = 0 - - # fabric address page size (512 MB) fabric_page_size: int = 1 << 29 - - # MPI communicator comm = None - dev_id: int = None - allocated_map: Dict[int, Any] = {} address_refcnt: Dict[int, Any] = {} @@ -210,7 +165,7 @@ def supports_mnnvl() -> bool: import pynvml from mpi4py import MPI - class MpiComm: # type: ignore[no-redef] + class MpiComm: _comm: MPI.Intracomm = MPI.COMM_WORLD @classmethod @@ -220,24 +175,15 @@ def set_mpi_comm(cls, new_comm: MPI.Intracomm): def __getattr__(self, name): return getattr(self._comm, name) - class MnnvlMemory: # type: ignore[no-redef] + class MnnvlMemory: initialized: bool = False - current_mem_offset: int = 0 - current_rank_stride: int = 0 # stride for ranks and also address space size. + current_rank_stride: int = 0 current_start_address: int = 0 - - # allocation granularity allocation_granularity: int = 0 - - # fabric address page size (512 MB) fabric_page_size: int = 1 << 29 - - # MPI communicator comm = None - dev_id: int = None - allocated_map: Dict[int, Any] = {} address_refcnt: Dict[int, Any] = {} @@ -266,9 +212,7 @@ def as_torch_strided_tensor(self, dtype): @staticmethod def initialize(): if not MnnvlMemory.initialized: - # use a dummy torch CUDA tensor to trigger CUDA context initialization - _ = torch.empty(1, device="cuda") - # ensure nvml is initialized. + _ = paddle.empty(shape=[1]) try: pynvml.nvmlDeviceGetCount() except pynvml.NVMLError_Uninitialized: @@ -341,31 +285,28 @@ def open_mnnvl_memory(mapping: Mapping, size: int): dev_id = int(dev) if MnnvlMemory.dev_id is None: MnnvlMemory.dev_id = dev_id - assert dev_id == MnnvlMemory.dev_id, ( - f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}" - ) + assert ( + dev_id == MnnvlMemory.dev_id + ), f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}" comm = MnnvlMemory.get_comm(mapping) comm_rank = comm.Get_rank() comm_size = comm.Get_size() all_rank_allocate_sizes = comm.allgather(size) assert len(all_rank_allocate_sizes) == comm_size - assert all(x == size for x in all_rank_allocate_sizes), ( - "Not all rank allocating same size." - ) + assert all( + x == size for x in all_rank_allocate_sizes + ), "Not all rank allocating same size." granularity = MnnvlMemory.get_allocation_granularity(dev_id) aligned_size = (size + granularity - 1) // granularity * granularity - if ( MnnvlMemory.current_mem_offset + aligned_size > MnnvlMemory.current_rank_stride ): MnnvlMemory.new_mnnvl_memory_address(mapping, aligned_size) - assert ( MnnvlMemory.current_mem_offset + aligned_size <= MnnvlMemory.current_rank_stride ) - allocation_prop = MnnvlMemory.get_allocation_prop(dev_id) allocated_mem_handle = checkCudaErrors( cuda.cuMemCreate(aligned_size, allocation_prop, flags=0) @@ -378,15 +319,10 @@ def open_mnnvl_memory(mapping: Mapping, size: int): ) ) all_handles_data = comm.allgather(exported_fabric_handle.data) - # all_handles_data like b'\x00\x00\x00 \x00\x00\x00\x00\x8f\xec\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x1d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' # noqa: E501 - # can use buf = memoryview(data) to import if using plain buffer for data. - madesc = cuda.CUmemAccessDesc() madesc.location = allocation_prop.location madesc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE - mem_handles = [None] * comm_size - for i, remote_handle_data in enumerate(all_handles_data): rank_ptr = ( MnnvlMemory.current_start_address @@ -394,7 +330,6 @@ def open_mnnvl_memory(mapping: Mapping, size: int): + MnnvlMemory.current_mem_offset ) if i == comm_rank: - # Local memory mapping mem_handles[i] = allocated_mem_handle checkCudaErrors( cuda.cuMemMap( @@ -402,7 +337,6 @@ def open_mnnvl_memory(mapping: Mapping, size: int): ) ) else: - # Fabric memory mapping imported_mem_handle = checkCudaErrors( cuda.cuMemImportFromShareableHandle( remote_handle_data, @@ -413,11 +347,9 @@ def open_mnnvl_memory(mapping: Mapping, size: int): checkCudaErrors( cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0) ) - checkCudaErrors( cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1) ) - ptr = MnnvlMemory.current_start_address + MnnvlMemory.current_mem_offset stride = MnnvlMemory.current_rank_stride MnnvlMemory.allocated_map[ptr] = ( @@ -431,7 +363,6 @@ def open_mnnvl_memory(mapping: Mapping, size: int): MnnvlMemory.address_refcnt[MnnvlMemory.current_start_address] = ( MnnvlMemory.address_refcnt.get(MnnvlMemory.current_start_address, 0) + 1 ) - MnnvlMemory.current_mem_offset += aligned_size return ptr, stride @@ -452,7 +383,6 @@ def close_mnnvl_memory(ptr: int): checkCudaErrors(cuda.cuMemUnmap(rank_ptr, aligned_size)) checkCudaErrors(cuda.cuMemRelease(mem_handles[i])) MnnvlMemory.address_refcnt[start_address] -= 1 - if MnnvlMemory.address_refcnt[start_address] == 0: MnnvlMemory.address_refcnt.pop(start_address) device_ptr = cuda.CUdeviceptr(start_address) @@ -466,7 +396,7 @@ def close_mnnvl_memory(ptr: int): @staticmethod def support_nvlink(need_all_up: bool = True): - dev_id = torch.cuda.current_device() +>>>>>> dev_id = torch.cuda.current_device() handle = pynvml.nvmlDeviceGetHandleByIndex(dev_id) link_count = pynvml.NVML_NVLINK_MAX_LINKS active_links = 0 @@ -490,10 +420,6 @@ def support_nvlink(need_all_up: bool = True): @staticmethod def supports_mnnvl() -> bool: - # TODO: - # We check if it is an aarch64 platform and has all NVLink up now. - # But it is not equivalent to MNNVL support. - # May need better support check. arch = platform.machine().lower() if "aarch64" not in arch: return False @@ -512,23 +438,15 @@ def __init__( is_multi_node: bool = True, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) - primary_ctx = checkCudaErrors(cuda.cuDevicePrimaryCtxRetain(cu_device)) checkCudaErrors(cuda.cuCtxSetCurrent(primary_ctx)) - - # Set CUDA device - # Check if cuda.cudart is available and import accordingly from flashinfer.utils import has_cuda_cudart if has_cuda_cudart(): - # cuda-python <= 12.9 import cuda.cudart as cudart else: - # cuda-python >= 13.0 import cuda.bindings.runtime as cudart - checkCudaErrors(cudart.cudaSetDevice(device_idx)) - self.is_multi_node = is_multi_node self.device_idx = device_idx self.group_size = group_size @@ -536,23 +454,15 @@ def __init__( self.buf_size = buf_size self.signal_pad_offset = 0 self.allocation_size = 0 - - # CUDA memory handles and pointers - self.mc_ptr = 0 # CUdeviceptr mMcPtr - self.uc_ptrs: List[int] = [] # std::vector mUcPtrs - self.signal_pads: List[int] = [] # mSignalPads - self.signal_pads_dev = 0 # std::vector mSignalPadsDev + self.mc_ptr = 0 + self.uc_ptrs: List[int] = [] + self.signal_pads: List[int] = [] + self.signal_pads_dev = 0 self.uc_ptrs_dev = 0 - self.mc_handle = 0 # CUmemGenericAllocationHandle mMcHandle - self.uc_handles: List[ - int - ] = [] # std::vector mUcHandles - - # Signal pad constants + self.mc_handle = 0 + self.uc_handles: List[int] = [] self.SIGNAL_PAD_ALIGNMENT = 16 self.SIGNAL_PAD_SIZE = SIGNAL_PAD_SIZE - - # Check if device supports multicasting multicast_supported = checkCudaErrors( cuda.cuDeviceGetAttribute( cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, @@ -563,18 +473,11 @@ def __init__( raise RuntimeError( "[McastDeviceMemory] Device does not support multicasting." ) - - # Calculate signal pad offset with alignment (matching C++ exactly) self.signal_pad_offset = round_up(buf_size, self.SIGNAL_PAD_ALIGNMENT) - logging.info( - f"[McastDeviceMemory] Rank: {group_rank}, Group size: {group_size}, " - f"mnNvlink: {is_multi_node}, device_idx: {device_idx}, " - f"Signal pad offset: {self.signal_pad_offset}" + f"[McastDeviceMemory] Rank: {group_rank}, Group size: {group_size}, mnNvlink: {is_multi_node}, device_idx: {device_idx}, Signal pad offset: {self.signal_pad_offset}" ) - if self.is_multi_node: - # Check if fabric handle is supported fabric_handle_supported = checkCudaErrors( cuda.cuDeviceGetAttribute( cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, @@ -585,13 +488,9 @@ def __init__( raise RuntimeError( "[McastDeviceMemory] Device does not support fabric handle." ) - self._alloc_mn_mcast_mem(buf_size) else: - # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem raise NotImplementedError("Single-node NVLS allocation not implemented yet") - - # Initialize signal pads self.signal_pads = [0] * self.group_size for i in range(self.group_size): self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset @@ -599,47 +498,31 @@ def __init__( checkCudaErrors( cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE) ) - - # Create device pointers self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads) self.uc_ptrs_dev = alloc_and_copy_to_cuda(self.uc_ptrs) def __del__(self): """Destructor - cleanup allocated memory""" - - # Check if we're in a valid state for cleanup if not hasattr(self, "is_multi_node"): return - if not self.is_multi_node: return - - # Skip cleanup during Python finalization to avoid segfaults - # Especially cause the CUDA context could be destroyed at this point. if sys.is_finalizing(): return - - # Verify CUDA context is still valid try: cuda.cuCtxGetCurrent() except Exception as e: print(f"Destructor: CUDA context invalid, skipping cleanup: {e}") return - - # Free device pointers if self.signal_pads_dev: checkCudaErrors(cuda.cuMemFree(self.signal_pads_dev)) if self.uc_ptrs_dev: checkCudaErrors(cuda.cuMemFree(self.uc_ptrs_dev)) - - # Unmap UC regions and release their handles if hasattr(self, "uc_handles") and self.uc_handles: for rank in range(self.group_size): if self.uc_handles[rank] != 0: try: - # Release the handle checkCudaErrors(cuda.cuMemRelease(self.uc_handles[rank])) - # Unmap the vmem if rank < len(self.uc_ptrs) and self.uc_ptrs[rank]: checkCudaErrors( cuda.cuMemUnmap( @@ -650,14 +533,10 @@ def __del__(self): print( f"Destructor: Failed to release UC handle for rank {rank}: {e}" ) - - # Free the UC address space if hasattr(self, "uc_base_ptr") and self.uc_base_ptr: checkCudaErrors( cuda.cuMemAddressFree(self.uc_base_ptr, self.total_uc_size) ) - - # Release MC handle if hasattr(self, "mc_handle") and self.mc_handle and self.mc_handle != 0: try: checkCudaErrors(cuda.cuMemUnmap(self.mc_ptr, self.allocation_size)) @@ -688,16 +567,11 @@ def get_unicast_ptr(self, rank: int) -> int: """Get the raw unicast pointer to a given rank""" if rank >= len(self.uc_ptrs): raise ValueError(f"Rank {rank} out of range (0-{len(self.uc_ptrs) - 1})") - data_ptr = self.uc_ptrs[rank] - # Note: In C++, this would call tensorrt_llm::common::registerMcastDevMemBuffer - # For Python port, we skip this registration for now return data_ptr def get_multicast_ptr(self) -> int: """Get the raw multicast pointer""" - # Note: In C++, this would call tensorrt_llm::common::registerMcastDevMemBuffer - # For Python port, we skip this registration for now return int(self.mc_ptr) def get_rank(self) -> int: @@ -710,24 +584,16 @@ def get_world_size(self) -> int: def _alloc_mn_mcast_mem(self, buf_size: int): """Allocate multi-node multicast memory using MNNVL""" - - # Verify CUDA context try: current_device = checkCudaErrors(cuda.cuCtxGetDevice()) - if int(current_device) != self.device_idx: print( f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}" ) except Exception as e: print(f"Error checking CUDA context: {e}") - - # Get MPI communicator comm = MpiComm() - - # Set up allocation properties handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC - allocation_prop = cuda.CUmemAllocationProp() allocation_prop.requestedHandleTypes = handle_type allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED @@ -736,47 +602,31 @@ def _alloc_mn_mcast_mem(self, buf_size: int): cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE ) allocation_prop.location.id = self.device_idx - allocation_prop.allocFlags.gpuDirectRDMACapable = 1 - - # Get allocation granularity alloc_granularity = checkCudaErrors( cuda.cuMemGetAllocationGranularity( allocation_prop, cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM, ) ) - - # mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, alloc_granularity); self.allocation_size = round_up( buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity ) - - # Set up multicast properties mc_prop = cuda.CUmulticastObjectProp() mc_prop.numDevices = self.group_size mc_prop.size = self.allocation_size mc_prop.handleTypes = handle_type - - # Get multicast granularity mc_granularity = checkCudaErrors( cuda.cuMulticastGetGranularity( mc_prop, cuda.CUmulticastGranularity_flags.CU_MULTICAST_GRANULARITY_RECOMMENDED, ) ) - self.allocation_size = round_up(self.allocation_size, mc_granularity) - - # Initialize UC handles list self.uc_handles = [0] * self.group_size - - # Allocate local GPU memory self.uc_handles[self.group_rank] = checkCudaErrors( cuda.cuMemCreate(self.allocation_size, allocation_prop, 0) ) - - # Export local handle to fabric handle my_fabric_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.uc_handles[self.group_rank], @@ -784,12 +634,8 @@ def _alloc_mn_mcast_mem(self, buf_size: int): 0, ) ) - - # All-gather fabric handles all_fabric_handles = comm.allgather(my_fabric_handle.data) cuda.cuCtxSynchronize() - - # Import remote handles for p in range(self.group_size): if p != self.group_rank: self.uc_handles[p] = checkCudaErrors( @@ -798,13 +644,8 @@ def _alloc_mn_mcast_mem(self, buf_size: int): cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, ) ) - - # Initialize multicasting if self.group_rank == 0: - # Create multicast object self.mc_handle = checkCudaErrors(cuda.cuMulticastCreate(mc_prop)) - - # Export multicast handle mc_fabric_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.mc_handle, @@ -814,14 +655,10 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) else: mc_fabric_handle = None - - # Broadcast multicast handle mc_fabric_handle_data = comm.bcast( mc_fabric_handle.data if mc_fabric_handle else None, root=0 ) - # Sync device to ensure broadcast is complete cuda.cuCtxSynchronize() - # Import multicast handle for non-root ranks if self.group_rank != 0: self.mc_handle = checkCudaErrors( cuda.cuMemImportFromShareableHandle( @@ -829,29 +666,19 @@ def _alloc_mn_mcast_mem(self, buf_size: int): cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, ) ) - - # Add device to multicast checkCudaErrors(cuda.cuMulticastAddDevice(self.mc_handle, self.device_idx)) - - # Bind memory addresses self.uc_ptrs = [0] * self.group_size - - # Reserve address space for UC pointers total_uc_size = self.allocation_size * self.group_size self.total_uc_size = total_uc_size uc_base_ptr = checkCudaErrors( cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0) ) - self.uc_base_ptr = uc_base_ptr # Store for cleanup - - # Set up memory access descriptor + self.uc_base_ptr = uc_base_ptr access_desc = cuda.CUmemAccessDesc() access_desc.location = cuda.CUmemLocation() access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE access_desc.location.id = self.device_idx access_desc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE - - # Map UC memory for i in range(self.group_size): offset = self.allocation_size * i self.uc_ptrs[i] = int(uc_base_ptr) + offset @@ -860,13 +687,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int): self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0 ) ) - - # Set memory access permissions checkCudaErrors( cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1) ) - - # Bind MC pointer self.mc_ptr = checkCudaErrors( cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0) ) @@ -876,34 +699,29 @@ def _alloc_mn_mcast_mem(self, buf_size: int): checkCudaErrors( cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1) ) - - # Bind memory to multicast checkCudaErrors( cuda.cuMulticastBindMem( self.mc_handle, - 0, # mcOffset + 0, self.uc_handles[self.group_rank], - 0, # memOffset + 0, self.allocation_size, - 0, # flags + 0, ) ) - def lamport_initialize(self, rank: int, dtype: torch.dtype): - if dtype == torch.bfloat16 or dtype == torch.float16: - neg_zero = 0x8000 + def lamport_initialize(self, rank: int, dtype: paddle.dtype): + if dtype == "bfloat16" or dtype == "float16": + neg_zero = 32768 dsize = 2 memset_func = cuda.cuMemsetD16 - elif dtype == torch.float32: - neg_zero = 0x80000000 + elif dtype == "float32": + neg_zero = 2147483648 dsize = 4 memset_func = cuda.cuMemsetD32 else: raise ValueError(f"Unsupported dtype: {dtype}") - - # Calculate number of elements that fit in allocation_size num_elements = self.allocation_size // dsize - checkCudaErrors( memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements) ) @@ -922,7 +740,7 @@ def __init__( buf_size: int, group_size: int, group_rank: int, - device: torch.device, + device: str, mn_nvlink: bool = True, ): """ @@ -941,12 +759,12 @@ def __init__( self.buf_size = buf_size self.local_device = device - def lamport_initialize(self, rank: int, dtype: torch.dtype): + def lamport_initialize(self, rank: int, dtype: paddle.dtype): self.mcast_device_memory.lamport_initialize(rank, dtype) def get_mc_buffer( - self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 - ) -> torch.Tensor: + self, sizes: tuple, dtype: paddle.dtype, storage_offset: int = 0 + ) -> paddle.Tensor: """ Returns a PyTorch tensor view of the multicast buffer portion. diff --git a/flashinfer/comm/nvshmem.py b/flashinfer/comm/nvshmem.py index 358eee3016..5c289e92bf 100644 --- a/flashinfer/comm/nvshmem.py +++ b/flashinfer/comm/nvshmem.py @@ -4,7 +4,7 @@ import shlex from typing import Sequence -import torch +import paddle from ..jit import JitSpec from ..jit import env as jit_env @@ -18,7 +18,6 @@ def gen_nvshmem_module() -> JitSpec: + ["-lnvshmem_device"] + shlex.split(os.environ.get("NVSHMEM_LDFLAGS", "")) ) - return gen_jit_spec( "nvshmem", [jit_env.FLASHINFER_CSRC_DIR / "nvshmem_binding.cu"], @@ -30,10 +29,8 @@ def gen_nvshmem_module() -> JitSpec: @functools.cache def get_nvshmem_module(): - # Try to find libnvshmem_host.so first, fallback to libnvshmem_host.so.3 lib_dirs = jit_env.get_nvshmem_lib_dirs() lib_path = None - lib_names = ["libnvshmem_host.so", "libnvshmem_host.so.3"] for lib_dir in lib_dirs: for lib_name in lib_names: @@ -43,19 +40,16 @@ def get_nvshmem_module(): break if lib_path is not None: break - if lib_path is None: raise FileNotFoundError( f"Could not find libnvshmem_host.so or libnvshmem_host.so.3 in {lib_dirs}" ) - ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) module = gen_nvshmem_module().build_and_load() - return module -def get_unique_id() -> torch.Tensor: +def get_unique_id() -> paddle.Tensor: return get_nvshmem_module().nvshmem_get_unique_id() @@ -63,22 +57,22 @@ def unique_id_size() -> int: return get_nvshmem_module().nvshmem_unique_id_size() -def alloc_empty_unique_id() -> torch.Tensor: - return torch.zeros(unique_id_size(), dtype=torch.uint8, device="cpu") +def alloc_empty_unique_id() -> paddle.Tensor: + return paddle.zeros(shape=unique_id_size(), dtype="uint8") -def init(uid: torch.Tensor, rank: int, world_size: int) -> int: +def init(uid: paddle.Tensor, rank: int, world_size: int) -> int: status = get_nvshmem_module().nvshmem_init(uid, rank, world_size) - torch.cuda.synchronize() + paddle.device.synchronize() return status -def alltoall(dest: torch.Tensor, source: torch.Tensor) -> None: +def alltoall(dest: paddle.Tensor, source: paddle.Tensor) -> None: return get_nvshmem_module().nvshmem_alltoall(dest, source) def finalize() -> None: - torch.cuda.synchronize() + paddle.device.synchronize() get_nvshmem_module().nvshmem_finalize() @@ -90,11 +84,7 @@ def n_pes() -> int: return get_nvshmem_module().nvshmem_n_pes() -def malloc( - shape: Sequence[int], - dtype: torch.dtype, - device: torch.device, -) -> torch.Tensor: +def malloc(shape: Sequence[int], dtype: paddle.dtype, device: str) -> paddle.Tensor: """Allocates memory using NVSHMEM collective malloc operation. This is a collective operation that requires participation by all PEs (Processing Elements). @@ -114,7 +104,6 @@ def malloc( Reference: https://docs.nvidia.com/nvshmem/api/gen/api/memory.html#nvshmem-malloc-nvshmem-free-nvshmem-align """ - return get_nvshmem_module().nvshmem_malloc(shape, dtype, device) diff --git a/flashinfer/comm/nvshmem_allreduce.py b/flashinfer/comm/nvshmem_allreduce.py index 57784cad01..336daaa680 100644 --- a/flashinfer/comm/nvshmem_allreduce.py +++ b/flashinfer/comm/nvshmem_allreduce.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2023 by FlashInfer team. @@ -13,12 +15,8 @@ See the License for the specific language governing permissions and limitations under the License. """ - from typing import Optional -import torch -from torch.distributed import ProcessGroup - from .nvshmem import get_nvshmem_module @@ -48,9 +46,9 @@ def __init__( local_rank: int, world_size: int, max_buffer_elements: int, - dtype: torch.dtype, - device: torch.device, - group: Optional[ProcessGroup] = None, + dtype: paddle.dtype, + device: str, +>>>>>> group: Optional[torch.distributed.ProcessGroup] = None, should_init: bool = True, ): self.local_rank = local_rank @@ -60,12 +58,9 @@ def __init__( self.max_buffer_elements = max_buffer_elements self.group = group self.nvshmem_module = get_nvshmem_module() - self.should_init = should_init if self.should_init: self.init_nvshmem() - - # assert PE and world size match my_pe = self.nvshmem_module.nvshmem_my_pe() n_pes = self.nvshmem_module.nvshmem_n_pes() if my_pe != local_rank: @@ -78,56 +73,40 @@ def __init__( f"WARNING: Rank {local_rank}: World size mismatch! Expected {world_size}, got {n_pes}", flush=True, ) - - # allocate memory in nvshmem symm heap self.symm_buffer_input = self.nvshmem_module.nvshmem_malloc( - [max_buffer_elements], - self.dtype, - self.device, + [max_buffer_elements], self.dtype, self.device ) self.symm_buffer_output = self.nvshmem_module.nvshmem_malloc( - [max_buffer_elements], - self.dtype, - self.device, + [max_buffer_elements], self.dtype, self.device ) - torch.distributed.barrier(self.group) + paddle.distributed.barrier(group=self.group) def init_nvshmem(self): - torch.zeros( - self.nvshmem_module.nvshmem_unique_id_size(), - dtype=torch.uint8, - device="cpu", - ) + paddle.zeros(shape=self.nvshmem_module.nvshmem_unique_id_size(), dtype="uint8") if self.local_rank == 0: uid = self.nvshmem_module.nvshmem_get_unique_id() else: - uid = torch.zeros( - self.nvshmem_module.nvshmem_unique_id_size(), - dtype=torch.uint8, - device="cpu", + uid = paddle.zeros( + shape=self.nvshmem_module.nvshmem_unique_id_size(), dtype="uint8" ) - torch.distributed.broadcast(uid, src=0) - torch.distributed.barrier(self.group) + paddle.distributed.broadcast(tensor=uid, src=0) + paddle.distributed.barrier(group=self.group) init_status = self.nvshmem_module.nvshmem_init( uid, self.local_rank, self.world_size ) - torch.cuda.synchronize() + paddle.device.synchronize() if init_status != 0: raise RuntimeError("Failed to initialize nvshmem") - def all_reduce(self, inp: torch.Tensor, out: torch.Tensor) -> None: + def all_reduce(self, inp: paddle.Tensor, out: paddle.Tensor) -> None: self.nvshmem_module.nvshmem_allreduce_on_stream_with_copy( - self.symm_buffer_output, - self.symm_buffer_input, - out, - inp, - inp.numel(), + self.symm_buffer_output, self.symm_buffer_input, out, inp, inp.size ) def shutdown(self): del self.symm_buffer_input del self.symm_buffer_output - torch.distributed.barrier(self.group) - torch.cuda.synchronize() + paddle.distributed.barrier(group=self.group) + paddle.device.synchronize() if self.should_init: self.nvshmem_module.nvshmem_finalize() diff --git a/flashinfer/comm/trtllm_alltoall.py b/flashinfer/comm/trtllm_alltoall.py index 595a84990e..1ae5664103 100644 --- a/flashinfer/comm/trtllm_alltoall.py +++ b/flashinfer/comm/trtllm_alltoall.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,14 +19,11 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools from dataclasses import dataclass from types import SimpleNamespace from typing import Optional, Tuple -import torch - from ..jit import JitSpec from ..jit import env as jit_env from ..jit import gen_jit_spec @@ -30,59 +33,45 @@ def gen_comm_alltoall_module() -> JitSpec: - return gen_jit_spec( - "comm", - [ - jit_env.FLASHINFER_CSRC_DIR / "trtllm_alltoall.cu", - ], - ) + return gen_jit_spec("comm", [jit_env.FLASHINFER_CSRC_DIR / "trtllm_alltoall.cu"]) @functools.cache def get_comm_alltoall_module(): module = gen_comm_alltoall_module().build_and_load() - @register_custom_op( - "flashinfer::moe_comm_prepare_indices", - mutates_args=[], - ) + @register_custom_op("flashinfer::moe_comm_prepare_indices", mutates_args=[]) def moe_comm_prepare_indices( - gathered_target_rank_ids: torch.Tensor, - real_rank_token_count_cum_sum: Optional[torch.Tensor], + gathered_target_rank_ids: paddle.Tensor, + real_rank_token_count_cum_sum: Optional[paddle.Tensor], max_token_count_per_rank: int, expert_count: int, top_k: int, ep_rank: int, ep_size: int, ) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, ]: - device = gathered_target_rank_ids.device + device = gathered_target_rank_ids.place max_send_ranks_per_token = max(top_k, ep_size) - local_gather_indices = torch.empty( - (max_token_count_per_rank * ep_size), device=device, dtype=torch.int - ) - send_rank_count_cum_sum = torch.empty( - (ep_size,), device=device, dtype=torch.int + local_gather_indices = paddle.empty( + shape=max_token_count_per_rank * ep_size, dtype="int32" ) - send_rank_local_indices = torch.empty( - (max_token_count_per_rank * max_send_ranks_per_token), - device=device, - dtype=torch.int, + send_rank_count_cum_sum = paddle.empty(shape=(ep_size,), dtype="int32") + send_rank_local_indices = paddle.empty( + shape=max_token_count_per_rank * max_send_ranks_per_token, dtype="int32" ) - recv_rank_count_cum_sum = torch.empty((ep_size), device=device, dtype=torch.int) - recv_rank_local_indices = torch.empty( - (max_token_count_per_rank * ep_size), device=device, dtype=torch.int + recv_rank_count_cum_sum = paddle.empty(shape=ep_size, dtype="int32") + recv_rank_local_indices = paddle.empty( + shape=max_token_count_per_rank * ep_size, dtype="int32" ) - backward_recv_rank_local_indice = torch.empty( - (max_token_count_per_rank * max_send_ranks_per_token), - device=device, - dtype=torch.int, + backward_recv_rank_local_indice = paddle.empty( + shape=max_token_count_per_rank * max_send_ranks_per_token, dtype="int32" ) module.moe_comm_prepare_indices( gathered_target_rank_ids, @@ -113,12 +102,12 @@ def moe_comm_prepare_indices( mutates_args=["local_expert_ids", "local_scales"], ) def moe_local_gather( - recv_rank_cum_sum: torch.Tensor, - local_gather_indices: torch.Tensor, - gathered_expert_ids: torch.Tensor, - gathered_scales: torch.Tensor, - local_expert_ids: torch.Tensor, - local_scales: torch.Tensor, + recv_rank_cum_sum: paddle.Tensor, + local_gather_indices: paddle.Tensor, + gathered_expert_ids: paddle.Tensor, + gathered_scales: paddle.Tensor, + local_expert_ids: paddle.Tensor, + local_scales: paddle.Tensor, max_token_count_per_rank: int, expert_count: int, top_k: int, @@ -139,18 +128,15 @@ def moe_local_gather( ep_size, ) - @register_custom_op( - "flashinfer::moe_comm", - mutates_args=["output"], - ) + @register_custom_op("flashinfer::moe_comm", mutates_args=["output"]) def moe_comm( - input: torch.Tensor, - send_rank_cum_sum: torch.Tensor, - send_indices: torch.Tensor, - output: torch.Tensor, - recv_rank_cum_sum: torch.Tensor, - recv_indices: torch.Tensor, - all_workspaces: torch.Tensor, + input: paddle.Tensor, + send_rank_cum_sum: paddle.Tensor, + send_indices: paddle.Tensor, + output: paddle.Tensor, + recv_rank_cum_sum: paddle.Tensor, + recv_indices: paddle.Tensor, + all_workspaces: paddle.Tensor, ep_rank: int, ep_size: int, ) -> None: @@ -166,22 +152,14 @@ def moe_comm( ep_size, ) - @register_custom_op( - "flashinfer::set_moe_max_usable_sm_count", - mutates_args=[], - ) - def set_moe_max_usable_sm_count( - max_sm_count: int, - ) -> None: + @register_custom_op("flashinfer::set_moe_max_usable_sm_count", mutates_args=[]) + def set_moe_max_usable_sm_count(max_sm_count: int) -> None: module.set_moe_max_usable_sm_count(max_sm_count) @register_custom_op( - "flashinfer::get_moe_commworkspace_size_per_rank", - mutates_args=[], + "flashinfer::get_moe_commworkspace_size_per_rank", mutates_args=[] ) - def get_moe_commworkspace_size_per_rank( - ep_size: int, - ) -> int: + def get_moe_commworkspace_size_per_rank(ep_size: int) -> int: return module.get_moe_commworkspace_size_per_rank(ep_size) return SimpleNamespace( @@ -194,15 +172,20 @@ def get_moe_commworkspace_size_per_rank( def moe_comm_prepare_indices( - gathered_target_rank_ids: torch.Tensor, - real_rank_token_count_cum_sum: Optional[torch.Tensor], + gathered_target_rank_ids: paddle.Tensor, + real_rank_token_count_cum_sum: Optional[paddle.Tensor], max_token_count_per_rank: int, expert_count: int, top_k: int, ep_rank: int, ep_size: int, ) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, ]: return get_comm_alltoall_module().moe_comm_prepare_indices( gathered_target_rank_ids, @@ -216,12 +199,12 @@ def moe_comm_prepare_indices( def moe_local_gather( - recv_rank_cum_sum: torch.Tensor, - local_gather_indices: torch.Tensor, - gathered_expert_ids: torch.Tensor, - gathered_scales: torch.Tensor, - local_expert_ids: torch.Tensor, - local_scales: torch.Tensor, + recv_rank_cum_sum: paddle.Tensor, + local_gather_indices: paddle.Tensor, + gathered_expert_ids: paddle.Tensor, + gathered_scales: paddle.Tensor, + local_expert_ids: paddle.Tensor, + local_scales: paddle.Tensor, max_token_count_per_rank: int, expert_count: int, top_k: int, @@ -244,13 +227,13 @@ def moe_local_gather( def moe_comm( - input: torch.Tensor, - send_rank_cum_sum: torch.Tensor, - send_indices: torch.Tensor, - output: torch.Tensor, - recv_rank_cum_sum: torch.Tensor, - recv_indices: torch.Tensor, - all_workspaces: torch.Tensor, + input: paddle.Tensor, + send_rank_cum_sum: paddle.Tensor, + send_indices: paddle.Tensor, + output: paddle.Tensor, + recv_rank_cum_sum: paddle.Tensor, + recv_indices: paddle.Tensor, + all_workspaces: paddle.Tensor, ep_rank: int, ep_size: int, ) -> None: @@ -267,32 +250,28 @@ def moe_comm( ) -def set_moe_max_usable_sm_count( - max_sm_count: int, -) -> None: +def set_moe_max_usable_sm_count(max_sm_count: int) -> None: get_comm_alltoall_module().set_moe_max_usable_sm_count(max_sm_count) -def get_moe_commworkspace_size_per_rank( - ep_size: int, -) -> int: +def get_moe_commworkspace_size_per_rank(ep_size: int) -> int: return get_comm_alltoall_module().get_moe_commworkspace_size_per_rank(ep_size) @dataclass class MoEAlltoallInfo: - local_gather_indices: torch.Tensor - send_rank_count_cumsum: torch.Tensor - send_rank_local_indices: torch.Tensor - recv_rank_count_cumsum: torch.Tensor - recv_rank_local_indices: torch.Tensor - backward_recv_rank_local_indices: torch.Tensor + local_gather_indices: paddle.Tensor + send_rank_count_cumsum: paddle.Tensor + send_rank_local_indices: paddle.Tensor + recv_rank_count_cumsum: paddle.Tensor + recv_rank_local_indices: paddle.Tensor + backward_recv_rank_local_indices: paddle.Tensor local_token_allocation_count: int class MnnvlMoe: moe_workspace: MnnvlMemory = None - moe_workspace_tensor: torch.Tensor = None + moe_workspace_tensor: paddle.Tensor = None moe_mapping: Mapping = None @staticmethod @@ -300,32 +279,31 @@ def get_moe_workspaces(mapping: Mapping): if MnnvlMoe.moe_workspace is not None: assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now" return MnnvlMoe.moe_workspace_tensor - MnnvlMoe.moe_mapping = mapping workspace_size_per_rank = get_moe_commworkspace_size_per_rank(mapping.tp_size) MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank) MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor( - torch.uint64 +>>>>>> torch.uint64 ) return MnnvlMoe.moe_workspace_tensor @staticmethod def compute_target_rank_id( - token_selected_experts: torch.Tensor, expert_count: int, ep_size: int + token_selected_experts: paddle.Tensor, expert_count: int, ep_size: int ): - assert expert_count % ep_size == 0, ( - "expert_count should be divisible by ep_size" - ) + assert ( + expert_count % ep_size == 0 + ), "expert_count should be divisible by ep_size" expert_per_rank = expert_count // ep_size token_target_rank_ids = token_selected_experts // expert_per_rank return token_target_rank_ids @staticmethod def mnnvl_moe_alltoallv_prepare( - gathered_target_rank_ids: torch.Tensor, - real_rank_token_count_cumsum: torch.Tensor, - gathered_expert_ids: torch.Tensor, - gathered_scales: torch.Tensor, + gathered_target_rank_ids: paddle.Tensor, + real_rank_token_count_cumsum: paddle.Tensor, + gathered_expert_ids: paddle.Tensor, + gathered_scales: paddle.Tensor, max_token_count_per_rank: int, expert_count: int, top_k: int, @@ -348,22 +326,13 @@ def mnnvl_moe_alltoallv_prepare( ep_rank, ep_size, ) - local_token_allocation_count = max_token_count_per_rank * ep_size - - local_expert_ids = torch.empty( - local_token_allocation_count, - top_k, - dtype=torch.int32, - device=torch.device("cuda"), + local_expert_ids = paddle.empty( + shape=[local_token_allocation_count, top_k], dtype="int32" ) - local_scales = torch.empty( - local_token_allocation_count, - top_k, - dtype=torch.float32, - device=torch.device("cuda"), + local_scales = paddle.empty( + shape=[local_token_allocation_count, top_k], dtype="float32" ) - moe_local_gather( recv_rank_count_cumsum, local_gather_indices, @@ -377,7 +346,6 @@ def mnnvl_moe_alltoallv_prepare( ep_rank, ep_size, ) - alltoall_info = MoEAlltoallInfo( local_gather_indices, send_rank_count_cumsum, @@ -391,18 +359,16 @@ def mnnvl_moe_alltoallv_prepare( @staticmethod def mnnvl_moe_alltoallv( - x: torch.Tensor, + x: paddle.Tensor, alltoall_info: MoEAlltoallInfo, - workspace: torch.Tensor, + workspace: paddle.Tensor, ep_rank: int, ep_size: int, ): assert x.dim() == 2, "only 2D tensor supported, please reshape." - output_tensor = torch.empty( - alltoall_info.local_token_allocation_count, - x.shape[1], + output_tensor = paddle.empty( + shape=[alltoall_info.local_token_allocation_count, tuple(x.shape)[1]], dtype=x.dtype, - device=torch.device("cuda"), ) moe_comm( x, @@ -419,17 +385,17 @@ def mnnvl_moe_alltoallv( @staticmethod def mnnvl_moe_alltoallv_combine( - x: torch.Tensor, + x: paddle.Tensor, alltoall_info: MoEAlltoallInfo, - workspace: torch.Tensor, + workspace: paddle.Tensor, ep_rank: int, ep_size: int, top_k: int, token_count: int, ): assert x.dim() == 2, "2D tensor supported, please reshape." - output_tensor = torch.zeros( - token_count * top_k, x.shape[1], dtype=x.dtype, device=torch.device("cuda") + output_tensor = paddle.zeros( + shape=[token_count * top_k, tuple(x.shape)[1]], dtype=x.dtype ) moe_comm( x, @@ -442,6 +408,8 @@ def mnnvl_moe_alltoallv_combine( ep_rank, ep_size, ) - return torch.sum( - output_tensor.reshape(token_count, top_k, x.shape[1]), dim=1, keepdim=False + return paddle.sum( + x=output_tensor.reshape(token_count, top_k, tuple(x.shape)[1]), + axis=1, + keepdim=False, ) diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 0f4e438c33..9fe8cdce63 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,18 +19,12 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools import logging from ctypes import c_void_p, cast from types import SimpleNamespace from typing import List, Optional, Tuple, Union -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.utils.cpp_extension import _get_cuda_arch_flags - from ..jit import JitSpec from ..jit import env as jit_env from ..jit import gen_jit_spec, sm100a_nvcc_flags @@ -33,7 +33,6 @@ class AllReduceStrategyType: - # NOTE: for trtllm_custom_all_reduce NCCL = 0 MIN_LATENCY = 1 UB = 2 @@ -44,13 +43,11 @@ class AllReduceStrategyType: class AllReduceStrategyConfig: - # NOTE: for trtllm_custom_all_reduce USE_MEMCPY = 1 << 0 PUSH_MODE = 1 << 1 class AllReduceFusionOp: - # NOTE: for trtllm_custom_all_reduce NONE = 0 RESIDUAL_RMS_NORM = 1 LAST_PROCESS_FOR_UB = 2 @@ -64,44 +61,25 @@ class AllReduceFusionOp: class AllReduceFusionPattern: - # NOTE: for trtllm_allreduce_fusion - # Basic all-reduce pattern kAllReduce = 0 - # All-reduce followed by residual add and RMS norm kARResidualRMSNorm = 1 - # All-reduce followed by residual add, RMS norm and FP8 quantization kARResidualRMSNormFP8Quant = 2 - # All-reduce followed by residual add, RMS norm and FP4 quantization kARResidualRMSNormFP4Quant = 3 - # All-reduce followed by residual add, RMS norm and FP8 quantization, with norm output kARResidualRMSNormOutFP8Quant = 4 - # All-reduce followed by residual add, RMS norm and FP4 quantization, with norm output kARResidualRMSNormOutFP4Quant = 5 class QuantizationSFLayout: - # Block scale factors are stored in swizzled layout for cutlass FP4 kernel. Scale factor - # blocks are organized in 512-byte blocks in global memory, with each block having 128x4 FP8 - # values. The SF matrix dimensions are therefore padded - rows to the nearest multiple of 128 and - # columns to the nearest multiple of 4. - # - # The scale factor block rows map to data block rows in an interleaved pattern: - # For a scale factor row 'i', it maps to data block row: (i % 4) * 32 + (i / 4) - # Column 'j' in the scale factor block corresponds to scaling the j-th block in the data tensor. - # - # Please refer to https://nvbugs/4165523 for more details about the swizzled layout. SWIZZLED_128x4 = 0 SWIZZLED_8x4 = 1 - # Block scale factors are stored in linear layout (row-major). This is used in some trtllm-gen - # kernels standard. LINEAR = 2 def gen_trtllm_comm_module() -> JitSpec: - gencode_flags = _get_cuda_arch_flags() +>>>>>> gencode_flags = torch.utils.cpp_extension._get_cuda_arch_flags() has_sm100 = any( "compute_100" in flag for flag in gencode_flags - ) and version_at_least(torch.version.cuda, "12.8") +>>>>>> ) and version_at_least(torch.version.cuda, "12.8") return gen_jit_spec( "trtllm_comm", [ @@ -121,7 +99,7 @@ def get_trtllm_comm_module(): "flashinfer::trtllm_lamport_initialize", mutates_args=["buffer"] ) def trtllm_lamport_initialize( - buffer_ptr: int, size: int, dtype: torch.dtype + buffer_ptr: int, size: int, dtype: paddle.dtype ) -> None: module.trtllm_lamport_initialize(buffer_ptr, size, dtype) @@ -134,7 +112,7 @@ def trtllm_lamport_initialize_all( buffer_1_ptr: int, buffer_2_ptr: int, size: int, - dtype: torch.dtype, + dtype: paddle.dtype, ) -> None: module.trtllm_lamport_initialize_all( buffer_0_ptr, buffer_1_ptr, buffer_2_ptr, size, dtype @@ -168,8 +146,8 @@ def trtllm_lamport_initialize_all( ], ) def trtllm_custom_all_reduce( - inp: torch.Tensor, - out: torch.Tensor, + inp: paddle.Tensor, + out: paddle.Tensor, tp_size: int, tp_rank: int, token_num: int, @@ -178,18 +156,18 @@ def trtllm_custom_all_reduce( config_code: AllReduceStrategyConfig, launch_with_pdl: bool, flag_value: int, - peer_comm_buffer_ptrs: torch.Tensor, - peer_barrier_ptrs_in: torch.Tensor, - peer_barrier_ptrs_out: torch.Tensor, - bias: Optional[torch.Tensor], - residual: Optional[torch.Tensor], - weight: Optional[torch.Tensor], - weight_pre_residual_norm: Optional[torch.Tensor], + peer_comm_buffer_ptrs: paddle.Tensor, + peer_barrier_ptrs_in: paddle.Tensor, + peer_barrier_ptrs_out: paddle.Tensor, + bias: Optional[paddle.Tensor], + residual: Optional[paddle.Tensor], + weight: Optional[paddle.Tensor], + weight_pre_residual_norm: Optional[paddle.Tensor], eps: Optional[float], - intermediate_buffer: Optional[torch.Tensor], - lamport_peer_comm_buffer_ptrs_0: Optional[torch.Tensor], - lamport_peer_comm_buffer_ptrs_1: Optional[torch.Tensor], - lamport_peer_comm_buffer_ptrs_2: Optional[torch.Tensor], + intermediate_buffer: Optional[paddle.Tensor], + lamport_peer_comm_buffer_ptrs_0: Optional[paddle.Tensor], + lamport_peer_comm_buffer_ptrs_1: Optional[paddle.Tensor], + lamport_peer_comm_buffer_ptrs_2: Optional[paddle.Tensor], ) -> None: module.trtllm_custom_all_reduce( inp, @@ -243,26 +221,26 @@ def trtllm_custom_all_reduce( ], ) def trtllm_allreduce_fusion( - allreduce_in: torch.Tensor, + allreduce_in: paddle.Tensor, world_size: int, world_rank: int, token_num: int, hidden_dim: int, - workspace_ptrs: torch.Tensor, + workspace_ptrs: paddle.Tensor, launch_with_pdl: bool, use_oneshot: bool, trigger_completion_at_end: bool, fp32_acc: bool, pattern_code: AllReduceFusionPattern, - allreduce_out: Optional[torch.Tensor], - residual_in: Optional[torch.Tensor], - residual_out: Optional[torch.Tensor], - norm_out: Optional[torch.Tensor], - quant_out: Optional[torch.Tensor], - scale_out: Optional[torch.Tensor], - rms_gamma: Optional[torch.Tensor], + allreduce_out: Optional[paddle.Tensor], + residual_in: Optional[paddle.Tensor], + residual_out: Optional[paddle.Tensor], + norm_out: Optional[paddle.Tensor], + quant_out: Optional[paddle.Tensor], + scale_out: Optional[paddle.Tensor], + rms_gamma: Optional[paddle.Tensor], rms_eps: Optional[float], - scale_factor: Optional[Union[torch.Tensor, float]], + scale_factor: Optional[Union[paddle.Tensor, float]], layout_code: Optional[QuantizationSFLayout], ) -> None: module.trtllm_allreduce_fusion( @@ -320,22 +298,22 @@ def trtllm_moe_allreduce_fusion( world_rank: int, token_num: int, hidden_dim: int, - workspace_ptrs: torch.Tensor, + workspace_ptrs: paddle.Tensor, launch_with_pdl: bool, - residual_in: torch.Tensor, - rms_gamma: torch.Tensor, + residual_in: paddle.Tensor, + rms_gamma: paddle.Tensor, rms_eps: float, scale_factor: float, moe_reduction_device_num_experts: int, - moe_reduction_scale_input: torch.Tensor, - moe_reduction_active_experts_token_input: torch.Tensor, - moe_reduction_token_input: torch.Tensor, + moe_reduction_scale_input: paddle.Tensor, + moe_reduction_active_experts_token_input: paddle.Tensor, + moe_reduction_token_input: paddle.Tensor, layout_code: Optional[QuantizationSFLayout], - moe_allreduce_out: Optional[torch.Tensor], - residual_out: Optional[torch.Tensor], - norm_out: Optional[torch.Tensor], - quant_out: Optional[torch.Tensor], - scale_out: Optional[torch.Tensor], + moe_allreduce_out: Optional[paddle.Tensor], + residual_out: Optional[paddle.Tensor], + norm_out: Optional[paddle.Tensor], + quant_out: Optional[paddle.Tensor], + scale_out: Optional[paddle.Tensor], ) -> None: module.trtllm_moe_allreduce_fusion( world_size, @@ -365,19 +343,19 @@ def trtllm_moe_allreduce_fusion( mutates_args=["residual_out", "norm_out"], ) def trtllm_moe_finalize_allreduce_fusion( - allreduce_in: torch.Tensor, - residual_in: torch.Tensor, - norm_weight: torch.Tensor, - expanded_idx_to_permuted_idx: torch.Tensor, - norm_out: torch.Tensor, - residual_out: torch.Tensor, + allreduce_in: paddle.Tensor, + residual_in: paddle.Tensor, + norm_weight: paddle.Tensor, + expanded_idx_to_permuted_idx: paddle.Tensor, + norm_out: paddle.Tensor, + residual_out: paddle.Tensor, launch_with_pdl: bool, - workspace: torch.Tensor, + workspace: paddle.Tensor, world_rank: int, world_size: int, eps: float, - shared_expert_output: Optional[torch.Tensor], - expert_scale_factor: Optional[torch.Tensor], + shared_expert_output: Optional[paddle.Tensor], + expert_scale_factor: Optional[paddle.Tensor], ) -> None: module.trtllm_moe_finalize_allreduce_fusion( allreduce_in, @@ -405,8 +383,6 @@ def trtllm_moe_finalize_allreduce_fusion( ) -# NOTE(Yingyi): The customAllReduce and allReduceFusion require different buffer size -# since allreduceFusion kernels are an improved implementation OneShotMaxToken = 128 MAX_ALL_REDUCE_BLOCKS = 24 LamportTokenNumThreshold = 16 @@ -417,7 +393,7 @@ def trtllm_create_ipc_workspace_for_all_reduce( tp_size: int, max_token_num: int, hidden_dim, - group: Optional[ProcessGroup] = None, +>>>>>> group: Optional[torch.distributed.ProcessGroup] = None, ) -> List[List[int]]: """ Parameters: @@ -454,14 +430,11 @@ def trtllm_create_ipc_workspace_for_all_reduce( Reference: trtllm, cpp/tests/unit_tests/kernels/allReduce/allReduceKernelTest.cu, Workspace init """ - buffer_size = tp_size * max_token_num * hidden_dim * 4 FLAG_SIZE = (MAX_ALL_REDUCE_BLOCKS + 1) * 4 flag_size = FLAG_SIZE * tp_size * 2 lamport_buffer_size = tp_size * LamportTokenNumThreshold * tp_size * hidden_dim * 2 - ipc_handles = list() - for size in [ buffer_size, buffer_size, @@ -471,29 +444,24 @@ def trtllm_create_ipc_workspace_for_all_reduce( lamport_buffer_size, lamport_buffer_size, ]: - # all sizes should be aligned to 1LU << 21 bytes (2MB) aligned_size = round_up(size, 1 << 21) ipc_handles.append(create_shared_buffer(aligned_size, group)) - print( f"rank {rank} allocated ipc_handles: {[[hex(handle) for handle in sublist] for sublist in ipc_handles]}" ) - trtllm_lamport_initialize_all( ipc_handles[4][rank], ipc_handles[5][rank], ipc_handles[6][rank], lamport_buffer_size // 2, - torch.float16, + "float16", ) - - dist.barrier(group=group) # must sync after create_workspace - + paddle.distributed.barrier(group=group) return ipc_handles def trtllm_destroy_ipc_workspace_for_all_reduce( - workspace: List[List[int]], group: Optional[ProcessGroup] = None +>>>>>> workspace: List[List[int]], group: Optional[torch.distributed.ProcessGroup] = None ) -> None: """ Note: @@ -502,14 +470,12 @@ def trtllm_destroy_ipc_workspace_for_all_reduce( The workspace should be destroyed after calling trtllm_custom_all_reduce. The workspace can be reused for multiple all reduce calls under the same configuration. """ - for ipc_handle in workspace: free_shared_buffer(ipc_handle, group) BarrierFlagCount = 256 - -MAX_COMM_SIZE = 2147483647 & ~((1 << 21) - 1) # MAX_INT32 rounded down to 2MB +MAX_COMM_SIZE = 2147483647 & ~((1 << 21) - 1) def trtllm_create_ipc_workspace_for_all_reduce_fusion( @@ -518,8 +484,8 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( max_token_num: int, hidden_dim, use_fp32_lamport: bool = False, - group: Optional[ProcessGroup] = None, -) -> Tuple[List[List[int]], torch.Tensor]: +>>>>>> group: Optional[torch.distributed.ProcessGroup] = None, +) -> Tuple[List[List[int]], paddle.Tensor]: """ Parameters: - tp_rank: the rank of the current process. @@ -544,11 +510,8 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( Reference: trtllm, cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu, Workspace init """ - buffer_size = tp_size * max_token_num * hidden_dim * 2 flag_size = tp_size * BarrierFlagCount * 4 - # lamport_comm_size = tp_size * max(max_token_num, OneShotMaxToken) * hidden_dim * 2 - # enable larger workspace for cases > OneShotMaxToken lamport_comm_size = ( tp_size * max_token_num * hidden_dim * 2 if not use_fp32_lamport @@ -559,42 +522,27 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( f"warning: lamport_comm_size {lamport_comm_size} is greater than MAX_COMM_SIZE {MAX_COMM_SIZE}, set to MAX_COMM_SIZE" ) lamport_comm_size = MAX_COMM_SIZE - lamport_buffer_size = lamport_comm_size * 3 - - # we should init 3 buffers for all reduce fusion: - # [buffer_size, flag_size, lamport_buffer_size] - ipc_handles: List[List[int]] = list() for size in [buffer_size, flag_size, lamport_buffer_size]: - # todo(review): confirm we need this alignment - # all sizes should be aligned to 1LU << 21 bytes (2MB) aligned_size = round_up(size, 1 << 21) ipc_handles.append(create_shared_buffer(aligned_size, group)) - print( f"rank {tp_rank} allocated ipc_handles: {[[hex(handle) for handle in sublist] for sublist in ipc_handles]}" ) - - # Initialize lamport buffer aligned_lamport_buffer_size = round_up(lamport_buffer_size, 1 << 21) if use_fp32_lamport: trtllm_lamport_initialize( - ipc_handles[2][tp_rank], aligned_lamport_buffer_size // 4, torch.float32 + ipc_handles[2][tp_rank], aligned_lamport_buffer_size // 4, "float32" ) else: trtllm_lamport_initialize( - ipc_handles[2][tp_rank], aligned_lamport_buffer_size // 2, torch.float16 + ipc_handles[2][tp_rank], aligned_lamport_buffer_size // 2, "float16" ) - - # initialize workspace workspace = list() - # add ipc handles to workspace for ipc_handle in ipc_handles: for rank in range(tp_size): workspace.append(ipc_handle[rank]) - - # add flags to workspace """ NOTE: The flags are for the lamport communication states. @@ -604,34 +552,25 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( lamport triple buffer offset: kernel_flag_ptr[3] = lamport_comm_size; lamport clear size: kernel_flag_ptr[4] = 0; """ - # malloc cuda memory of int32_t * 5 flag_ptr = cudart.cudaMalloc(5 * 4) - # initialize the flag to [0,0,0,lamport_comm_size,0] cudart.cudaMemset(flag_ptr, 0, 5 * 4) - # Set flag_ptr[3] = lamport_comm_size lamport_comm_size_bytes = lamport_comm_size.to_bytes(4, byteorder="little") cudart.cudaMemcpy( c_void_p(flag_ptr.value + 3 * 4), cast(lamport_comm_size_bytes, c_void_p), 4 ) print("set flag_ptr[3] = lamport_comm_size: ", lamport_comm_size) - # add flag_ptr to workspace workspace.append(flag_ptr.value) - for i in range(len(workspace)): print(f"Rank {tp_rank} workspace[{i}] {hex(workspace[i])}") - - # Store workspace pointers in device tensor - workspace_tensor = torch.tensor( - workspace, dtype=torch.int64, device=torch.device("cuda") + workspace_tensor = paddle.to_tensor( + data=workspace, dtype="int64", place=device2str("gpu") ) - - dist.barrier(group=group) # must sync after create_workspace - + paddle.distributed.barrier(group=group) return ipc_handles, workspace_tensor def trtllm_destroy_ipc_workspace_for_all_reduce_fusion( - workspace: List[List[int]], group: Optional[ProcessGroup] = None +>>>>>> workspace: List[List[int]], group: Optional[torch.distributed.ProcessGroup] = None ) -> None: """ Parameters: @@ -644,12 +583,10 @@ def trtllm_destroy_ipc_workspace_for_all_reduce_fusion( The workspace should be destroyed after calling trtllm_custom_all_reduce_fusion. The workspace can be reused for multiple all reduce fusion calls under the same configuration. """ - for ipc_handle in workspace: free_shared_buffer(ipc_handle, group) -# allReduce fused quant utils def compute_fp4_swizzled_layout_sf_size(total_row, total_column): """ Helper function to compute the padded size of the fp4 swizzled layout. @@ -660,14 +597,14 @@ def compute_fp4_swizzled_layout_sf_size(total_row, total_column): """ def pad_up(x, y): - return ((x + y - 1) // y) * y + return (x + y - 1) // y * y padded_row = pad_up(total_row, 128) padded_column = pad_up(total_column, 4) return padded_row * padded_column -def trtllm_lamport_initialize(buffer_ptr: int, size: int, dtype: torch.dtype) -> None: +def trtllm_lamport_initialize(buffer_ptr: int, size: int, dtype: paddle.dtype) -> None: get_trtllm_comm_module().trtllm_lamport_initialize(buffer_ptr, size, dtype) @@ -676,7 +613,7 @@ def trtllm_lamport_initialize_all( buffer_1_ptr: int, buffer_2_ptr: int, size: int, - dtype: torch.dtype, + dtype: paddle.dtype, ) -> None: """ Initialize 3 lamport buffers by negative zero. @@ -688,15 +625,14 @@ def trtllm_lamport_initialize_all( - size: the size of the buffer. - dtype: the data type of the buffer. """ - get_trtllm_comm_module().trtllm_lamport_initialize_all( buffer_0_ptr, buffer_1_ptr, buffer_2_ptr, size, dtype ) def trtllm_custom_all_reduce( - inp: torch.Tensor, - out: torch.Tensor, + inp: paddle.Tensor, + out: paddle.Tensor, tp_size: int, tp_rank: int, token_num: int, @@ -705,18 +641,18 @@ def trtllm_custom_all_reduce( config_code: AllReduceStrategyConfig, launch_with_pdl: bool, flag_value: int, - peer_comm_buffer_ptrs: torch.Tensor, - peer_barrier_ptrs_in: torch.Tensor, - peer_barrier_ptrs_out: torch.Tensor, - bias: Optional[torch.Tensor], - residual: Optional[torch.Tensor], - weight: Optional[torch.Tensor], - weight_pre_residual_norm: Optional[torch.Tensor], + peer_comm_buffer_ptrs: paddle.Tensor, + peer_barrier_ptrs_in: paddle.Tensor, + peer_barrier_ptrs_out: paddle.Tensor, + bias: Optional[paddle.Tensor], + residual: Optional[paddle.Tensor], + weight: Optional[paddle.Tensor], + weight_pre_residual_norm: Optional[paddle.Tensor], eps: Optional[float], - intermediate_buffer: Optional[torch.Tensor], - lamport_peer_comm_buffer_ptrs_0: Optional[torch.Tensor], - lamport_peer_comm_buffer_ptrs_1: Optional[torch.Tensor], - lamport_peer_comm_buffer_ptrs_2: Optional[torch.Tensor], + intermediate_buffer: Optional[paddle.Tensor], + lamport_peer_comm_buffer_ptrs_0: Optional[paddle.Tensor], + lamport_peer_comm_buffer_ptrs_1: Optional[paddle.Tensor], + lamport_peer_comm_buffer_ptrs_2: Optional[paddle.Tensor], ) -> None: """ Parameters: @@ -743,7 +679,6 @@ def trtllm_custom_all_reduce( - lamport_peer_comm_buffer_ptrs_1: the lamport peer communication buffer pointers 1. - lamport_peer_comm_buffer_ptrs_2: the lamport peer communication buffer pointers 2. """ - get_trtllm_comm_module().trtllm_custom_all_reduce( inp, out, @@ -771,26 +706,26 @@ def trtllm_custom_all_reduce( def trtllm_allreduce_fusion( - allreduce_in: torch.Tensor, + allreduce_in: paddle.Tensor, world_size: int, world_rank: int, token_num: int, hidden_dim: int, - workspace_ptrs: torch.Tensor, + workspace_ptrs: paddle.Tensor, launch_with_pdl: bool, trigger_completion_at_end: bool, fp32_acc: bool, pattern_code: AllReduceFusionPattern, use_oneshot: Optional[bool], - allreduce_out: Optional[torch.Tensor], - residual_in: Optional[torch.Tensor], - residual_out: Optional[torch.Tensor], - norm_out: Optional[torch.Tensor], - quant_out: Optional[torch.Tensor], - scale_out: Optional[torch.Tensor], - rms_gamma: Optional[torch.Tensor], + allreduce_out: Optional[paddle.Tensor], + residual_in: Optional[paddle.Tensor], + residual_out: Optional[paddle.Tensor], + norm_out: Optional[paddle.Tensor], + quant_out: Optional[paddle.Tensor], + scale_out: Optional[paddle.Tensor], + rms_gamma: Optional[paddle.Tensor], rms_eps: Optional[float], - scale_factor: Optional[Union[torch.Tensor, float]], + scale_factor: Optional[Union[paddle.Tensor, float]], layout_code: Optional[QuantizationSFLayout], ) -> None: """ @@ -821,30 +756,26 @@ def trtllm_allreduce_fusion( Regarding the `use_oneshot` parameter, you could force to use the one-shot strategy based on your use case. Otherwise, it would be enabled if token_num is less than the one-shot max token number (currently 128) for min-latency mode. """ - if use_oneshot is None: use_oneshot = token_num <= 128 - if not use_oneshot: assert token_num > world_size, "sequence length should be larger than tp_size" - required_lamport_comm_size = ( token_num * hidden_dim * 2 * world_size - if allreduce_in.dtype != torch.float32 + if allreduce_in.dtype != "float32" else token_num * hidden_dim * 4 * world_size ) - if required_lamport_comm_size > MAX_COMM_SIZE and use_oneshot: logging.warning( f"required_lamport_comm_size {required_lamport_comm_size} is greater than MAX_COMM_SIZE {MAX_COMM_SIZE}. Cannot use oneshot in this case." ) use_oneshot = False if scale_factor is not None: - if isinstance(scale_factor, torch.Tensor): - scale_factor = scale_factor.to(torch.float32) + if isinstance(scale_factor, paddle.Tensor): + scale_factor = scale_factor.to("float32") else: - scale_factor = torch.tensor( - [scale_factor], dtype=torch.float32, device=allreduce_in.device + scale_factor = paddle.to_tensor( + data=[scale_factor], dtype="float32", place=allreduce_in.place ) get_trtllm_comm_module().trtllm_allreduce_fusion( allreduce_in=allreduce_in, @@ -876,22 +807,22 @@ def trtllm_moe_allreduce_fusion( world_rank: int, token_num: int, hidden_dim: int, - workspace_ptrs: torch.Tensor, + workspace_ptrs: paddle.Tensor, launch_with_pdl: bool, - residual_in: torch.Tensor, - rms_gamma: torch.Tensor, + residual_in: paddle.Tensor, + rms_gamma: paddle.Tensor, rms_eps: float, scale_factor: float, moe_reduction_device_num_experts: int, - moe_reduction_scale_input: torch.Tensor, - moe_reduction_active_experts_token_input: torch.Tensor, - moe_reduction_token_input: torch.Tensor, + moe_reduction_scale_input: paddle.Tensor, + moe_reduction_active_experts_token_input: paddle.Tensor, + moe_reduction_token_input: paddle.Tensor, layout_code: Optional[QuantizationSFLayout], - moe_allreduce_out: Optional[torch.Tensor], - residual_out: Optional[torch.Tensor], - norm_out: Optional[torch.Tensor], - quant_out: Optional[torch.Tensor], - scale_out: Optional[torch.Tensor], + moe_allreduce_out: Optional[paddle.Tensor], + residual_out: Optional[paddle.Tensor], + norm_out: Optional[paddle.Tensor], + quant_out: Optional[paddle.Tensor], + scale_out: Optional[paddle.Tensor], ) -> None: """ Parameters: @@ -916,15 +847,11 @@ def trtllm_moe_allreduce_fusion( - quant_out: the quant output tensor. [token_num // 4, hidden_dim], fp16/bf16 -> fp4 - scale_out: the scale output tensor. Initialization referece: tests/test_trtllm_moe_allreduce_fusion.py """ - - required_lamport_comm_size = moe_reduction_token_input.numel() * 2 * world_size - - # Note: only one-shot is supported for moe allreduce fusion. + required_lamport_comm_size = moe_reduction_token_input.size * 2 * world_size if required_lamport_comm_size > MAX_COMM_SIZE: raise ValueError( f"required_lamport_comm_size {required_lamport_comm_size} is greater than MAX_COMM_SIZE {MAX_COMM_SIZE}. Cannot use oneshot in this case." ) - get_trtllm_comm_module().trtllm_moe_allreduce_fusion( world_size=world_size, world_rank=world_rank, @@ -950,19 +877,19 @@ def trtllm_moe_allreduce_fusion( def trtllm_moe_finalize_allreduce_fusion( - allreduce_in: torch.Tensor, - residual_in: torch.Tensor, - norm_weight: torch.Tensor, - expanded_idx_to_permuted_idx: torch.Tensor, - norm_out: torch.Tensor, - residual_out: torch.Tensor, - workspace_ptrs: torch.Tensor, + allreduce_in: paddle.Tensor, + residual_in: paddle.Tensor, + norm_weight: paddle.Tensor, + expanded_idx_to_permuted_idx: paddle.Tensor, + norm_out: paddle.Tensor, + residual_out: paddle.Tensor, + workspace_ptrs: paddle.Tensor, launch_with_pdl: bool, world_rank: int, world_size: int, eps: float, - shared_expert_output: Optional[torch.Tensor], - expert_scale_factor: Optional[torch.Tensor], + shared_expert_output: Optional[paddle.Tensor], + expert_scale_factor: Optional[paddle.Tensor], ) -> None: """ Parameters: @@ -980,15 +907,11 @@ def trtllm_moe_finalize_allreduce_fusion( - shared_expert_output: the shared expert output tensor. [token_num, hidden_dim] - expert_scale_factor: the expert scale factor tensor. [token_num, top_k] """ - - required_lamport_comm_size = allreduce_in.numel() * 2 * world_size - - # Note: only one-shot is supported for moe allreduce fusion. + required_lamport_comm_size = allreduce_in.size * 2 * world_size if required_lamport_comm_size > MAX_COMM_SIZE: raise ValueError( f"required_lamport_comm_size {required_lamport_comm_size} is greater than MAX_COMM_SIZE {MAX_COMM_SIZE}. Cannot use oneshot in this case." ) - get_trtllm_comm_module().trtllm_moe_finalize_allreduce_fusion( allreduce_in=allreduce_in, residual_in=residual_in, diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 2e004703f0..28754a66a5 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -1,16 +1,20 @@ +import sys + + +import os + +import paddle +from flashinfer.paddle_utils import * + """ MNNVL (Multi-Node NVLink) communication operations for FlashInfer. """ - import functools import math -import os from types import SimpleNamespace from typing import Optional, Tuple -import torch - from flashinfer.comm.mapping import Mapping from ..jit import JitSpec @@ -29,10 +33,7 @@ def mpi_barrier(): def gen_trtllm_mnnvl_comm_module() -> JitSpec: return gen_jit_spec( - "trtllm_mnnvl_comm", - [ - jit_env.FLASHINFER_CSRC_DIR / "trtllm_mnnvl_allreduce.cu", - ], + "trtllm_mnnvl_comm", [jit_env.FLASHINFER_CSRC_DIR / "trtllm_mnnvl_allreduce.cu"] ) @@ -56,16 +57,16 @@ def get_trtllm_mnnvl_comm_module(): ], ) def trtllm_mnnvl_all_reduce( - inp: torch.Tensor, - multicast_buffer_ptr: int, # Pointer address as integer - buffer_ptrs_dev: int, # Pointer address as integer - buffer_mnnvl: torch.Tensor, - buffer_flags_mnnvl: torch.Tensor, + inp: paddle.Tensor, + multicast_buffer_ptr: int, + buffer_ptrs_dev: int, + buffer_mnnvl: paddle.Tensor, + buffer_flags_mnnvl: paddle.Tensor, nranks: int, rank: int, wait_for_results: bool, launch_with_pdl: bool, - out: Optional[torch.Tensor], + out: Optional[paddle.Tensor], ) -> None: module.trtllm_mnnvl_all_reduce( inp, @@ -95,12 +96,12 @@ def trtllm_mnnvl_all_reduce( ) def trtllm_mnnvl_rmsnorm( mcast_buffer_input: int, - prenorm_output: torch.Tensor, - normed_output: torch.Tensor, - gamma: torch.Tensor, + prenorm_output: paddle.Tensor, + normed_output: paddle.Tensor, + gamma: paddle.Tensor, epsilon: float, - residual: torch.Tensor, - buffer_flags: torch.Tensor, + residual: paddle.Tensor, + buffer_flags: paddle.Tensor, launch_with_pdl: bool, ) -> None: """Performs MNNVL TwoShot RMSNorm on the communication buffer. @@ -133,8 +134,8 @@ def trtllm_mnnvl_rmsnorm( def get_allreduce_mnnvl_workspace( - mapping: Mapping, dtype: torch.dtype -) -> Tuple[McastGPUBuffer, torch.Tensor, int]: + mapping: Mapping, dtype: paddle.dtype +) -> Tuple[McastGPUBuffer, paddle.Tensor, int]: """Get workspace buffers needed for multi-node NVLink all-reduce operation. This function allocates and initializes the workspace buffers required for performing @@ -157,59 +158,42 @@ def get_allreduce_mnnvl_workspace( - int: Maximum number of elements that can fit in buffer """ force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" - - # buffer shape: [3, 2, buffer_tokens, hidden_dim] - stride = 3 * 2 * dtype.itemsize - # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 - # max_num_elements must be a multiple of 286720 + stride = 3 * 2 * dtype.element_size() lcm_hidden_dim = 286720 - TARGET_WORKSPACE_SIZE_BYTES = 12_000_000 + TARGET_WORKSPACE_SIZE_BYTES = 12000000 buffer_size_in_bytes = math.ceil( TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) ) * (lcm_hidden_dim * stride) max_num_elements = buffer_size_in_bytes // stride - mcast_buffer = McastGPUBuffer( buffer_size_in_bytes, mapping.tp_size, mapping.tp_rank, - torch.device("cuda", mapping.local_rank), + device2str("cuda", mapping.local_rank), mapping.is_multi_node() or force_mn, ) - - # Initialize the unicast buffer with -0.0 mcast_buffer.lamport_initialize(mapping.tp_rank, dtype) - - # CPU barrier since we assume this should not be called in cuda graph - torch.cuda.synchronize() + paddle.device.synchronize() mpi_barrier() - - # This is a buffer to maintain the state of this allreduce Op - # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter] - buffer_flags = torch.tensor( - [0, 2, max_num_elements, 0, 0], - dtype=torch.uint32, - device=torch.device("cuda", mapping.local_rank), - ) - - return ( - mcast_buffer, - buffer_flags, - max_num_elements, + buffer_flags = paddle.to_tensor( + data=[0, 2, max_num_elements, 0, 0], +>>>>>> dtype=torch.uint32, + place=device2str("gpu", mapping.local_rank), ) + return mcast_buffer, buffer_flags, max_num_elements def trtllm_mnnvl_all_reduce( - inp: torch.Tensor, - multicast_buffer_ptr: int, # Pointer address as integer - buffer_ptrs_dev: int, # Pointer address as integer + inp: paddle.Tensor, + multicast_buffer_ptr: int, + buffer_ptrs_dev: int, buffer_M: int, - buffer_flags_mnnvl: torch.Tensor, + buffer_flags_mnnvl: paddle.Tensor, nranks: int, rank: int, wait_for_results: bool, launch_with_pdl: bool, - out: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, ) -> None: """Perform a multi-node NVLink all-reduce operation across multiple GPUs. @@ -250,19 +234,19 @@ def trtllm_mnnvl_all_reduce( def trtllm_mnnvl_fused_allreduce_rmsnorm( - prenorm_output: torch.Tensor, - normed_output: torch.Tensor, - shard_input: torch.Tensor, - multicast_buffer_ptr: int, # Pointer address as integer - buffer_ptrs_dev: int, # Pointer address as integer - unicast_ptr: int, # Local unicast buffer pointer + prenorm_output: paddle.Tensor, + normed_output: paddle.Tensor, + shard_input: paddle.Tensor, + multicast_buffer_ptr: int, + buffer_ptrs_dev: int, + unicast_ptr: int, buffer_M: int, - buffer_flags_mnnvl: torch.Tensor, + buffer_flags_mnnvl: paddle.Tensor, nranks: int, rank: int, - gamma: torch.Tensor, + gamma: paddle.Tensor, epsilon: float, - residual: torch.Tensor, + residual: paddle.Tensor, launch_with_pdl: bool, ) -> None: """Performs MNNVL TwoShot Allreduce + RMSNorm. @@ -288,7 +272,6 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( launch_with_pdl: Whether to launch with PDL """ - # allreduce_result = Σ(shard_input across all ranks) trtllm_mnnvl_all_reduce( shard_input, multicast_buffer_ptr, @@ -297,14 +280,10 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( buffer_flags_mnnvl, nranks, rank, - False, # No need to wait to write AR results here as we are not writing them + False, launch_with_pdl, - None, # out parameter - None since wait_for_results=False + None, ) - - # prenorm_output = AllReduce(shard_input) + residual - # rms = sqrt(mean(prenorm_output²) + epsilon) - # normed_output = (prenorm_output / rms) * gamma get_trtllm_mnnvl_comm_module().trtllm_mnnvl_rmsnorm( unicast_ptr, prenorm_output, diff --git a/flashinfer/comm/vllm_ar.py b/flashinfer/comm/vllm_ar.py index 147a35cc19..ccd99ad024 100644 --- a/flashinfer/comm/vllm_ar.py +++ b/flashinfer/comm/vllm_ar.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,13 +15,10 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools from types import SimpleNamespace from typing import List, Tuple -import torch - from ..jit import JitSpec from ..jit import env as jit_env from ..jit import gen_jit_spec @@ -28,10 +27,7 @@ def gen_vllm_comm_module() -> JitSpec: return gen_jit_spec( - "vllm_comm", - [ - jit_env.FLASHINFER_CSRC_DIR / "vllm_custom_all_reduce.cu", - ], + "vllm_comm", [jit_env.FLASHINFER_CSRC_DIR / "vllm_custom_all_reduce.cu"] ) @@ -39,13 +35,12 @@ def gen_vllm_comm_module() -> JitSpec: def get_vllm_comm_module(): module = gen_vllm_comm_module().build_and_load() - # torch library for all @register_custom_op( "flashinfer::init_custom_ar", mutates_args=["ipc_ptrs", "rank_data", "rank", "full_nvlink"], ) def init_custom_ar( - ipc_ptrs: List[int], rank_data: torch.Tensor, rank: int, full_nvlink: bool + ipc_ptrs: List[int], rank_data: paddle.Tensor, rank: int, full_nvlink: bool ) -> int: return module.init_custom_ar(ipc_ptrs, rank_data, rank, full_nvlink) @@ -61,11 +56,10 @@ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: "flashinfer::register_buffer", mutates_args=["fa", "fake_ipc_ptrs"] ) def register_buffer(fa: int, fake_ipc_ptrs: List[int]) -> None: - return module.register_buffer(fa, fake_ipc_ptrs) + return module.register_buffer(name=fa, tensor=fake_ipc_ptrs) @register_custom_op( - "flashinfer::register_graph_buffers", - mutates_args=["fa", "handles", "offsets"], + "flashinfer::register_graph_buffers", mutates_args=["fa", "handles", "offsets"] ) def register_graph_buffers( fa: int, handles: List[List[int]], offsets: List[List[int]] @@ -82,8 +76,8 @@ def meta_size() -> int: ) def all_reduce( fa: int, - inp: torch.Tensor, - out: torch.Tensor, + inp: paddle.Tensor, + out: paddle.Tensor, reg_buffer: int, reg_buffer_sz_bytes: int, num_ctas: int, @@ -102,7 +96,7 @@ def all_reduce( def init_custom_ar( - ipc_tensors: List[int], rank_data: torch.Tensor, rank: int, full_nvlink: bool + ipc_tensors: List[int], rank_data: paddle.Tensor, rank: int, full_nvlink: bool ) -> int: return get_vllm_comm_module().init_custom_ar( ipc_tensors, rank_data, rank, full_nvlink @@ -115,8 +109,8 @@ def dispose(fa: int) -> None: def all_reduce( fa: int, - inp: torch.Tensor, - out: torch.Tensor, + inp: paddle.Tensor, + out: paddle.Tensor, reg_buffer: int, reg_buffer_sz_bytes: int, num_ctas: int, @@ -142,7 +136,7 @@ def get_graph_buffer_ipc_meta(fa) -> Tuple[List[int], List[int]]: def register_buffer(fa: int, fake_ipc_ptrs: List[int]) -> None: - return get_vllm_comm_module().register_buffer(fa, fake_ipc_ptrs) + return get_vllm_comm_module().register_buffer(name=fa, tensor=fake_ipc_ptrs) def register_graph_buffers( diff --git a/flashinfer/cuda_utils.py b/flashinfer/cuda_utils.py index 472d660739..22a6653295 100644 --- a/flashinfer/cuda_utils.py +++ b/flashinfer/cuda_utils.py @@ -13,21 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. """ - from flashinfer.utils import has_cuda_cudart -# Check if cuda.cudart module is available and import accordingly if has_cuda_cudart(): - # cuda-python <= 12.9 (has cuda.cudart) import cuda.bindings.driver as driver import cuda.bindings.runtime as runtime import cuda.cudart as cudart import cuda.nvrtc as nvrtc else: - # cuda-python >= 13.0 (no cuda.cudart, use runtime as cudart) from cuda.bindings import driver, nvrtc, runtime - cudart = runtime # Alias runtime as cudart for compatibility + cudart = runtime def _cudaGetErrorEnum(error): diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py index c2763d6c7d..7a46f85cfa 100644 --- a/flashinfer/cudnn/decode.py +++ b/flashinfer/cudnn/decode.py @@ -1,7 +1,11 @@ +import sys + + from enum import Enum from typing import Optional -import torch +import paddle +from flashinfer.paddle_utils import * from ..jit import get_cudnn_fmha_gen_module @@ -12,12 +16,10 @@ except ImportError: cudnn = None CUDNN_AVAILABLE = False - -# Global cudnn handle. need to make it per device in future _cudnn_handle = None -def _create_cudnn_handle(stream: torch.cuda.Stream): +def _create_cudnn_handle(stream: paddle.device.Stream): global _cudnn_handle if _cudnn_handle is None: _cudnn_handle = cudnn.create_handle() @@ -25,49 +27,38 @@ def _create_cudnn_handle(stream: torch.cuda.Stream): return _cudnn_handle -# Tensor ids class UIDs(Enum): RESERVED_INVALID_UID = 0 - - Q_UID = 1 # Query tensor - K_UID = 2 # Key cache tensor - V_UID = 3 # Value cache tensor - - ACTUAL_SEQ_LENS_Q_UID = 100 # Actual sequence lengths for query tensor - ACTUAL_SEQ_LENS_KV_UID = 101 # Actual sequence lengths for key/value tensor - - BLOCK_TABLES_UID = 200 # Block tables tensor - BLOCK_TABLES_K_UID = 201 # Block tables tensor for key - BLOCK_TABLES_V_UID = 202 # Block tables tensor for value - - RAGGED_Q_UID = 50 # Ragged query tensor - RAGGED_O_UID = 51 # Ragged output tensor - RAGGED_STATS_UID = 52 # Ragged stats tensor - - O_UID = 1000 # Output tensor - STATS_UID = 1001 # Stats tensor + Q_UID = 1 + K_UID = 2 + V_UID = 3 + ACTUAL_SEQ_LENS_Q_UID = 100 + ACTUAL_SEQ_LENS_KV_UID = 101 + BLOCK_TABLES_UID = 200 + BLOCK_TABLES_K_UID = 201 + BLOCK_TABLES_V_UID = 202 + RAGGED_Q_UID = 50 + RAGGED_O_UID = 51 + RAGGED_STATS_UID = 52 + O_UID = 1000 + STATS_UID = 1001 def _sdpa_decode_key_fn( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, + q: paddle.Tensor, + k_cache: paddle.Tensor, + v_cache: paddle.Tensor, scale: float, *, max_sequence_kv: int, block_size: Optional[int] = 1, - actual_seq_lens_q: Optional[torch.Tensor] = None, - actual_seq_lens_kv: Optional[torch.Tensor] = None, - block_tables: Optional[torch.Tensor] = None, - batch_offsets_q: Optional[torch.Tensor] = None, - batch_offsets_o: Optional[torch.Tensor] = None, + actual_seq_lens_q: Optional[paddle.Tensor] = None, + actual_seq_lens_kv: Optional[paddle.Tensor] = None, + block_tables: Optional[paddle.Tensor] = None, + batch_offsets_q: Optional[paddle.Tensor] = None, + batch_offsets_o: Optional[paddle.Tensor] = None, ): - return ( - "decode", - max_sequence_kv, - tuple(q.shape), - tuple(k_cache.shape), - ) + return "decode", max_sequence_kv, tuple(tuple(q.shape)), tuple(tuple(k_cache.shape)) if CUDNN_AVAILABLE: @@ -75,44 +66,38 @@ def _sdpa_decode_key_fn( @cudnn.jit(heur_modes=[cudnn.heur_mode.A]) @cudnn.graph_cache(key_fn=_sdpa_decode_key_fn) def _build_decode_graph( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, + q: paddle.Tensor, + k_cache: paddle.Tensor, + v_cache: paddle.Tensor, scale: float, *, max_sequence_kv: int, block_size: Optional[int] = 1, - actual_seq_lens_q: Optional[torch.Tensor] = None, - actual_seq_lens_kv: Optional[torch.Tensor] = None, - block_tables: Optional[torch.Tensor] = None, - batch_offsets_q: Optional[torch.Tensor] = None, - batch_offsets_o: Optional[torch.Tensor] = None, + actual_seq_lens_q: Optional[paddle.Tensor] = None, + actual_seq_lens_kv: Optional[paddle.Tensor] = None, + block_tables: Optional[paddle.Tensor] = None, + batch_offsets_q: Optional[paddle.Tensor] = None, + batch_offsets_o: Optional[paddle.Tensor] = None, ): - handle = _create_cudnn_handle(torch.cuda.current_stream()) - - # WAR: override batch offsets for now, as it leads to a poor performance + handle = _create_cudnn_handle(paddle.device.current_stream()) batch_offsets_q = None batch_offsets_o = None - with cudnn.graph(handle) as (g, _): if q.dim() == 3: s_qo = 1 - b, h_qo, d_qk = q.shape[0], q.shape[1], q.shape[2] + b, h_qo, d_qk = tuple(q.shape)[0], tuple(q.shape)[1], tuple(q.shape)[2] elif q.dim() == 4: b, h_qo, s_qo, d_qk = ( - q.shape[0], - q.shape[1], - q.shape[2], - q.shape[3], + tuple(q.shape)[0], + tuple(q.shape)[1], + tuple(q.shape)[2], + tuple(q.shape)[3], ) else: raise ValueError(f"q must have 3 or 4 dimensions, got {q.dim()}") - assert s_qo == 1, "q must have a sequence length of 1" assert k_cache.dim() == 4, "k_cache must have 4 dimensions" - - d_vo = v_cache.shape[3] - + d_vo = tuple(v_cache.shape)[3] cudnn_q = g.tensor( name="q", dim=(b, h_qo, s_qo, d_qk), @@ -123,46 +108,38 @@ def _build_decode_graph( ragged_q = g.tensor_like(batch_offsets_q) ragged_q.set_uid(UIDs.RAGGED_Q_UID.value) cudnn_q.set_ragged_offset(ragged_q) - cudnn_k_cache = g.tensor_like(k_cache) cudnn_v_cache = g.tensor_like(v_cache) - cudnn_q.set_uid(UIDs.Q_UID.value) cudnn_k_cache.set_uid(UIDs.K_UID.value) cudnn_v_cache.set_uid(UIDs.V_UID.value) - if block_tables is not None: nd_block_tables = block_tables.reshape( - block_tables.shape[0], 1, block_tables.shape[1], 1 + tuple(block_tables.shape)[0], 1, tuple(block_tables.shape)[1], 1 ) cudnn_k_block_tables = g.tensor_like(nd_block_tables) cudnn_k_block_tables.set_uid(UIDs.BLOCK_TABLES_K_UID.value) - cudnn_v_block_tables = g.tensor_like(nd_block_tables) cudnn_v_block_tables.set_uid(UIDs.BLOCK_TABLES_V_UID.value) - if actual_seq_lens_q is not None: cudnn_actual_seq_lens_q = g.tensor_like(actual_seq_lens_q) cudnn_actual_seq_lens_q.set_uid(UIDs.ACTUAL_SEQ_LENS_Q_UID.value) - if actual_seq_lens_kv is not None: cudnn_actual_seq_lens_kv = g.tensor_like(actual_seq_lens_kv) cudnn_actual_seq_lens_kv.set_uid(UIDs.ACTUAL_SEQ_LENS_KV_UID.value) cudnn_actual_seq_lens_kv.set_is_pass_by_value(False) - padding_mask = actual_seq_lens_kv is not None - O, _ = g.sdpa( name="sdpa", q=cudnn_q, k=cudnn_k_cache, v=cudnn_v_cache, - seq_len_q=( - cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None - ), - seq_len_kv=( - cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None - ), + seq_len_q=cudnn_actual_seq_lens_q + if actual_seq_lens_q is not None + else None, + seq_len_kv=cudnn_actual_seq_lens_kv + if actual_seq_lens_kv is not None + else None, use_padding_mask=padding_mask, is_inference=True, attn_scale=scale, @@ -171,46 +148,41 @@ def _build_decode_graph( paged_attention_max_seq_len_kv=max_sequence_kv, compute_data_type=cudnn.data_type.FLOAT, ) - if batch_offsets_o is not None: ragged_o = g.tensor_like(batch_offsets_o) ragged_o.set_uid(UIDs.RAGGED_O_UID.value) O.set_ragged_offset(ragged_o) - O.set_uid(UIDs.O_UID.value).set_output(True).set_dim( [b, h_qo, s_qo, d_vo] ).set_stride([d_vo * h_qo, d_vo, d_vo * h_qo, 1]).set_data_type( cudnn.data_type.BFLOAT16 ) - tensors_to_return = [cudnn_q, cudnn_k_cache, cudnn_v_cache, O] - if actual_seq_lens_q is not None: tensors_to_return.append(cudnn_actual_seq_lens_q) if actual_seq_lens_kv is not None: tensors_to_return.append(cudnn_actual_seq_lens_kv) - return g, tensors_to_return def _batch_decode_with_kv_cache( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, + q: paddle.Tensor, + k_cache: paddle.Tensor, + v_cache: paddle.Tensor, scale: float, - workspace_buffer: torch.Tensor, + workspace_buffer: paddle.Tensor, *, max_sequence_kv: int, - actual_seq_lens_q: Optional[torch.Tensor] = None, - actual_seq_lens_kv: Optional[torch.Tensor] = None, - block_tables: Optional[torch.Tensor] = None, + actual_seq_lens_q: Optional[paddle.Tensor] = None, + actual_seq_lens_kv: Optional[paddle.Tensor] = None, + block_tables: Optional[paddle.Tensor] = None, block_size: Optional[int] = 1, - batch_offsets_q: Optional[torch.Tensor] = None, - batch_offsets_o: Optional[torch.Tensor] = None, - batch_offsets_k: Optional[torch.Tensor] = None, - batch_offsets_v: Optional[torch.Tensor] = None, - out: torch.Tensor, -) -> torch.Tensor: + batch_offsets_q: Optional[paddle.Tensor] = None, + batch_offsets_o: Optional[paddle.Tensor] = None, + batch_offsets_k: Optional[paddle.Tensor] = None, + batch_offsets_v: Optional[paddle.Tensor] = None, + out: paddle.Tensor, +) -> paddle.Tensor: graph, tensors = _build_decode_graph( q=q, k_cache=k_cache, @@ -224,9 +196,7 @@ def _batch_decode_with_kv_cache( batch_offsets_q=batch_offsets_q if batch_offsets_q is not None else None, batch_offsets_o=batch_offsets_q if batch_offsets_q is not None else None, ) - - handle_ = _create_cudnn_handle(torch.cuda.current_stream()) - + handle_ = _create_cudnn_handle(paddle.device.current_stream()) var_map = { UIDs.Q_UID.value: q, UIDs.K_UID.value: k_cache, @@ -237,38 +207,34 @@ def _batch_decode_with_kv_cache( var_map[UIDs.ACTUAL_SEQ_LENS_Q_UID.value] = actual_seq_lens_q if actual_seq_lens_kv is not None: var_map[UIDs.ACTUAL_SEQ_LENS_KV_UID.value] = actual_seq_lens_kv - if batch_offsets_q is not None: var_map[UIDs.RAGGED_Q_UID.value] = batch_offsets_q if batch_offsets_o is not None: var_map[UIDs.RAGGED_O_UID.value] = batch_offsets_o - if block_tables is not None: var_map[UIDs.BLOCK_TABLES_K_UID.value] = block_tables var_map[UIDs.BLOCK_TABLES_V_UID.value] = block_tables - graph.execute(var_map, workspace=workspace_buffer, handle=handle_) - return out def cudnn_batch_decode_with_kv_cache( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, + q: paddle.Tensor, + k_cache: paddle.Tensor, + v_cache: paddle.Tensor, scale: float, - workspace_buffer: torch.Tensor, + workspace_buffer: paddle.Tensor, *, max_sequence_kv: int, - actual_seq_lens_kv: Optional[torch.Tensor] = None, - block_tables: Optional[torch.Tensor] = None, + actual_seq_lens_kv: Optional[paddle.Tensor] = None, + block_tables: Optional[paddle.Tensor] = None, is_cuda_graph_compatible: bool = False, - batch_offsets_q: Optional[torch.Tensor] = None, - batch_offsets_o: Optional[torch.Tensor] = None, - batch_offsets_k: Optional[torch.Tensor] = None, - batch_offsets_v: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, -) -> torch.Tensor: + batch_offsets_q: Optional[paddle.Tensor] = None, + batch_offsets_o: Optional[paddle.Tensor] = None, + batch_offsets_k: Optional[paddle.Tensor] = None, + batch_offsets_v: Optional[paddle.Tensor] = None, + out: Optional[paddle.Tensor] = None, +) -> paddle.Tensor: """Performs batched decode attention with paged KV cache using cuDNN. Args: @@ -296,17 +262,13 @@ def cudnn_batch_decode_with_kv_cache( All tensors must be contiguous and on the same CUDA device Query and KV heads can have different sizes (num_heads_qo >= num_heads_kv) """ - - bs = q.shape[0] - h_qo = q.shape[1] - d_vo = v_cache.shape[3] - + bs = tuple(q.shape)[0] + h_qo = tuple(q.shape)[1] + d_vo = tuple(v_cache.shape)[3] if out is None: - out = torch.empty(bs, h_qo, d_vo, device=q.device, dtype=q.dtype) - + out = paddle.empty(shape=[bs, h_qo, d_vo], dtype=q.dtype) if not CUDNN_AVAILABLE: - actual_seq_lens_kv_gpu = actual_seq_lens_kv.to(q.device, non_blocking=True) - + actual_seq_lens_kv_gpu = actual_seq_lens_kv.to(q.place, blocking=not True) run_func = get_cudnn_fmha_gen_module().decode run_func( max_sequence_kv, @@ -324,11 +286,8 @@ def cudnn_batch_decode_with_kv_cache( is_cuda_graph_compatible, ) else: - actual_seq_lens_q = torch.ones( - (bs, 1, 1, 1), device=q.device, dtype=torch.int32 - ) - block_size = k_cache.shape[2] - + actual_seq_lens_q = paddle.ones(shape=(bs, 1, 1, 1), dtype="int32") + block_size = tuple(k_cache.shape)[2] _batch_decode_with_kv_cache( q=q, k_cache=k_cache, @@ -344,5 +303,4 @@ def cudnn_batch_decode_with_kv_cache( block_size=block_size, out=out, ) - return out diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py index 5c37564cc8..d6887c9d3e 100644 --- a/flashinfer/cudnn/prefill.py +++ b/flashinfer/cudnn/prefill.py @@ -1,7 +1,11 @@ +import sys + + from enum import Enum from typing import Optional -import torch +import paddle +from flashinfer.paddle_utils import * from ..jit import get_cudnn_fmha_gen_module @@ -12,12 +16,10 @@ except Exception: cudnn = None CUDNN_AVAILABLE = False - -# Global cudnn handle. need to make it per device in future _cudnn_handle = None -def _create_cudnn_handle(stream: torch.cuda.Stream): +def _create_cudnn_handle(stream: paddle.device.Stream): global _cudnn_handle if _cudnn_handle is None: _cudnn_handle = cudnn.create_handle() @@ -25,68 +27,58 @@ def _create_cudnn_handle(stream: torch.cuda.Stream): return _cudnn_handle -# Tensor ids class UIDs(Enum): RESERVED_INVALID_UID = 0 - - Q_UID = 1 # Query tensor - K_UID = 2 # Key cache tensor - V_UID = 3 # Value cache tensor - - ACTUAL_SEQ_LENS_Q_UID = 100 # Actual sequence lengths for query tensor - ACTUAL_SEQ_LENS_KV_UID = 101 # Actual sequence lengths for key/value tensor - - BLOCK_TABLES_UID = 200 # Block tables tensor - BLOCK_TABLES_K_UID = 201 # Block tables tensor for key - BLOCK_TABLES_V_UID = 202 # Block tables tensor for value - - RAGGED_Q_UID = 50 # Ragged query tensor - RAGGED_O_UID = 51 # Ragged output tensor - RAGGED_STATS_UID = 52 # Ragged stats tensor - RAGGED_K_UID = 53 # Ragged key tensor - RAGGED_V_UID = 54 # Ragged value tensor - - O_UID = 1000 # Output tensor - STATS_UID = 1001 # Stats tensor + Q_UID = 1 + K_UID = 2 + V_UID = 3 + ACTUAL_SEQ_LENS_Q_UID = 100 + ACTUAL_SEQ_LENS_KV_UID = 101 + BLOCK_TABLES_UID = 200 + BLOCK_TABLES_K_UID = 201 + BLOCK_TABLES_V_UID = 202 + RAGGED_Q_UID = 50 + RAGGED_O_UID = 51 + RAGGED_STATS_UID = 52 + RAGGED_K_UID = 53 + RAGGED_V_UID = 54 + O_UID = 1000 + STATS_UID = 1001 def _sdpa_prefill_key_fn( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, + q: paddle.Tensor, + k_cache: paddle.Tensor, + v_cache: paddle.Tensor, scale: float, *, max_token_seq_q: Optional[int] = None, max_sequence_kv: Optional[int] = None, - actual_seq_lens_q: Optional[torch.Tensor] = None, - actual_seq_lens_kv: torch.Tensor, - block_tables: Optional[torch.Tensor] = None, + actual_seq_lens_q: Optional[paddle.Tensor] = None, + actual_seq_lens_kv: paddle.Tensor, + block_tables: Optional[paddle.Tensor] = None, page_size: Optional[int] = None, bottom_right_causal_mask: Optional[bool] = None, return_lse: Optional[bool] = False, - batch_offsets_q: Optional[torch.Tensor] = None, - batch_offsets_o: Optional[torch.Tensor] = None, - batch_offsets_k: Optional[torch.Tensor] = None, - batch_offsets_v: Optional[torch.Tensor] = None, - batch_offsets_stats: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + batch_offsets_q: Optional[paddle.Tensor] = None, + batch_offsets_o: Optional[paddle.Tensor] = None, + batch_offsets_k: Optional[paddle.Tensor] = None, + batch_offsets_v: Optional[paddle.Tensor] = None, + batch_offsets_stats: Optional[paddle.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, ): - graph_b = actual_seq_lens_q.shape[0] - + graph_b = tuple(actual_seq_lens_q.shape)[0] if q.dim() == 3: - h_qo, d_qk = q.shape[1], q.shape[2] + h_qo, d_qk = tuple(q.shape)[1], tuple(q.shape)[2] elif q.dim() == 4: - h_qo, d_qk = q.shape[1], q.shape[3] - + h_qo, d_qk = tuple(q.shape)[1], tuple(q.shape)[3] if v_cache.dim() == 3: - h_kv, d_vo = k_cache.shape[1], k_cache.shape[2] + h_kv, d_vo = tuple(k_cache.shape)[1], tuple(k_cache.shape)[2] elif k_cache.dim() == 4: - h_kv, d_vo = k_cache.shape[1], k_cache.shape[3] - + h_kv, d_vo = tuple(k_cache.shape)[1], tuple(k_cache.shape)[3] if block_tables is not None: - page_size = k_cache.shape[2] - + page_size = tuple(k_cache.shape)[2] key = ( graph_b, q.dim(), @@ -110,66 +102,60 @@ def _sdpa_prefill_key_fn( @cudnn.jit(heur_modes=[cudnn.heur_mode.A]) @cudnn.graph_cache(key_fn=_sdpa_prefill_key_fn) def _build_prefill_graph( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, + q: paddle.Tensor, + k_cache: paddle.Tensor, + v_cache: paddle.Tensor, scale: float, *, max_token_seq_q: Optional[int] = None, max_sequence_kv: Optional[int] = None, - actual_seq_lens_q: Optional[torch.Tensor] = None, - actual_seq_lens_kv: Optional[torch.Tensor] = None, - block_tables: Optional[torch.Tensor] = None, + actual_seq_lens_q: Optional[paddle.Tensor] = None, + actual_seq_lens_kv: Optional[paddle.Tensor] = None, + block_tables: Optional[paddle.Tensor] = None, bottom_right_causal_mask: Optional[bool] = True, return_lse: Optional[bool] = False, - batch_offsets_q: Optional[torch.Tensor] = None, - batch_offsets_o: Optional[torch.Tensor] = None, - batch_offsets_k: Optional[torch.Tensor] = None, - batch_offsets_v: Optional[torch.Tensor] = None, - batch_offsets_stats: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + batch_offsets_q: Optional[paddle.Tensor] = None, + batch_offsets_o: Optional[paddle.Tensor] = None, + batch_offsets_k: Optional[paddle.Tensor] = None, + batch_offsets_v: Optional[paddle.Tensor] = None, + batch_offsets_stats: Optional[paddle.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, ): - handle = _create_cudnn_handle(torch.cuda.current_stream(q.device)) - - graph_b = actual_seq_lens_q.shape[0] + handle = _create_cudnn_handle( + paddle.device.current_stream(device=device2str(q.place)) + ) + graph_b = tuple(actual_seq_lens_q.shape)[0] graph_s_qo = max_token_seq_q graph_s_kv = max_sequence_kv - with cudnn.graph(handle) as (g, _): - # Create tensors from the input tensors if q.dim() == 3: - h_qo, d_qk = q.shape[1], q.shape[2] + h_qo, d_qk = tuple(q.shape)[1], tuple(q.shape)[2] elif q.dim() == 4: - h_qo, d_qk = q.shape[2], q.shape[3] + h_qo, d_qk = tuple(q.shape)[2], tuple(q.shape)[3] else: - raise ValueError(f"Invalid query tensor shape: {q.shape}") - + raise ValueError(f"Invalid query tensor shape: {tuple(q.shape)}") cudnn_q = g.tensor( name="q", dim=(graph_b, h_qo, graph_s_qo, d_qk), stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1), data_type=cudnn.data_type.BFLOAT16, ) - if batch_offsets_q is not None: ragged_q = g.tensor_like(batch_offsets_q) ragged_q.set_uid(UIDs.RAGGED_Q_UID.value) cudnn_q.set_ragged_offset(ragged_q) - if v_cache.dim() == 3: - assert block_tables is None, ( - "block_tables needs 4 dimensions of kv cache" - ) - h_kv, d_vo = v_cache.shape[1], v_cache.shape[2] + assert ( + block_tables is None + ), "block_tables needs 4 dimensions of kv cache" + h_kv, d_vo = tuple(v_cache.shape)[1], tuple(v_cache.shape)[2] elif v_cache.dim() == 4: - h_kv, d_vo = ( - v_cache.shape[1], - v_cache.shape[3], - ) + h_kv, d_vo = tuple(v_cache.shape)[1], tuple(v_cache.shape)[3] else: - raise ValueError(f"Invalid kv cache tensor shape: {k_cache.shape}") - + raise ValueError( + f"Invalid kv cache tensor shape: {tuple(k_cache.shape)}" + ) if k_cache.dim() == 3: cudnn_k_cache = g.tensor( name="k_cache", @@ -177,151 +163,136 @@ def _build_prefill_graph( stride=(h_kv * d_qk * graph_s_kv, d_qk, d_qk * h_kv, 1), data_type=cudnn.data_type.BFLOAT16, ) - if batch_offsets_k is not None: ragged_k = g.tensor_like(batch_offsets_k) ragged_k.set_uid(UIDs.RAGGED_K_UID.value) cudnn_k_cache.set_ragged_offset(ragged_k) - cudnn_v_cache = g.tensor( name="v_cache", dim=(graph_b, h_kv, graph_s_kv, d_vo), stride=(h_kv * d_vo * graph_s_kv, d_vo, d_vo * h_kv, 1), data_type=cudnn.data_type.BFLOAT16, ) - if batch_offsets_v is not None: ragged_v = g.tensor_like(batch_offsets_v) ragged_v.set_uid(UIDs.RAGGED_V_UID.value) cudnn_v_cache.set_ragged_offset(ragged_v) - elif k_cache.dim() == 4: cudnn_k_cache = g.tensor( name="k_cache", - dim=k_cache.shape, - stride=k_cache.stride(), + dim=tuple(k_cache.shape), + stride=k_cache.get_strides(), data_type=cudnn.data_type.BFLOAT16, ) - cudnn_v_cache = g.tensor( name="v_cache", - dim=v_cache.shape, - stride=v_cache.stride(), + dim=tuple(v_cache.shape), + stride=v_cache.get_strides(), data_type=cudnn.data_type.BFLOAT16, ) - cudnn_q.set_uid(UIDs.Q_UID.value) cudnn_k_cache.set_uid(UIDs.K_UID.value) cudnn_v_cache.set_uid(UIDs.V_UID.value) - if block_tables is not None: nd_block_tables = block_tables.reshape( - block_tables.shape[0], 1, block_tables.shape[1], 1 + tuple(block_tables.shape)[0], 1, tuple(block_tables.shape)[1], 1 ) cudnn_k_block_tables = g.tensor_like(nd_block_tables) cudnn_k_block_tables.set_uid(UIDs.BLOCK_TABLES_K_UID.value) - cudnn_v_block_tables = g.tensor_like(nd_block_tables) cudnn_v_block_tables.set_uid(UIDs.BLOCK_TABLES_V_UID.value) - if actual_seq_lens_q is not None: cudnn_actual_seq_lens_q = g.tensor_like(actual_seq_lens_q) cudnn_actual_seq_lens_q.set_name("actual_seq_lens_q") cudnn_actual_seq_lens_q.set_uid(UIDs.ACTUAL_SEQ_LENS_Q_UID.value) - if actual_seq_lens_kv is not None: cudnn_actual_seq_lens_kv = g.tensor_like(actual_seq_lens_kv) cudnn_actual_seq_lens_kv.set_name("actual_seq_lens_kv") cudnn_actual_seq_lens_kv.set_uid(UIDs.ACTUAL_SEQ_LENS_KV_UID.value) - padding_mask = ( actual_seq_lens_q is not None and actual_seq_lens_kv is not None ) - O, Stats = g.sdpa( name="sdpa", q=cudnn_q, k=cudnn_k_cache, v=cudnn_v_cache, - seq_len_q=( - cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None - ), - seq_len_kv=( - cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None - ), + seq_len_q=cudnn_actual_seq_lens_q + if actual_seq_lens_q is not None + else None, + seq_len_kv=cudnn_actual_seq_lens_kv + if actual_seq_lens_kv is not None + else None, use_padding_mask=padding_mask, attn_scale=scale, generate_stats=return_lse, use_causal_mask_bottom_right=bottom_right_causal_mask, - paged_attention_k_table=( - cudnn_k_block_tables if block_tables is not None else None - ), - paged_attention_v_table=( - cudnn_v_block_tables if block_tables is not None else None - ), - paged_attention_max_seq_len_kv=( - graph_s_kv if block_tables is not None else None - ), + paged_attention_k_table=cudnn_k_block_tables + if block_tables is not None + else None, + paged_attention_v_table=cudnn_v_block_tables + if block_tables is not None + else None, + paged_attention_max_seq_len_kv=graph_s_kv + if block_tables is not None + else None, compute_data_type=cudnn.data_type.FLOAT, ) - if batch_offsets_o is not None: ragged_o = g.tensor_like(batch_offsets_o) ragged_o.set_uid(UIDs.RAGGED_O_UID.value) O.set_ragged_offset(ragged_o) - if batch_offsets_stats is not None: ragged_stats = g.tensor_like(batch_offsets_stats) ragged_stats.set_uid(UIDs.RAGGED_STATS_UID.value) Stats.set_ragged_offset(ragged_stats) - O.set_uid(UIDs.O_UID.value).set_output(True).set_dim( [graph_b, h_qo, graph_s_qo, d_vo] ).set_stride( [graph_s_qo * d_vo * h_qo, d_vo, d_vo * h_qo, 1] - ).set_data_type(cudnn.data_type.BFLOAT16) - + ).set_data_type( + cudnn.data_type.BFLOAT16 + ) if return_lse: Stats.set_uid(UIDs.STATS_UID.value).set_output( return_lse ).set_data_type(cudnn.data_type.FLOAT).set_dim( [graph_b, h_qo, graph_s_qo, 1] - ).set_stride([graph_s_qo * h_qo, 1, h_qo, 1]) - + ).set_stride( + [graph_s_qo * h_qo, 1, h_qo, 1] + ) tensors_to_return = [cudnn_q, cudnn_k_cache, cudnn_v_cache, O] if return_lse: tensors_to_return.append(Stats) - if actual_seq_lens_q is not None: tensors_to_return.append(cudnn_actual_seq_lens_q) if actual_seq_lens_kv is not None: tensors_to_return.append(cudnn_actual_seq_lens_kv) - return g, tensors_to_return def _batch_prefill_with_kv_cache( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, + q: paddle.Tensor, + k_cache: paddle.Tensor, + v_cache: paddle.Tensor, scale: float, - workspace_buffer: torch.Tensor, + workspace_buffer: paddle.Tensor, *, max_token_per_sequence: int, max_sequence_kv: int, - actual_seq_lens_q: torch.Tensor, - actual_seq_lens_kv: torch.Tensor, - block_tables: Optional[torch.Tensor] = None, + actual_seq_lens_q: paddle.Tensor, + actual_seq_lens_kv: paddle.Tensor, + block_tables: Optional[paddle.Tensor] = None, causal: bool, return_lse: bool, - batch_offsets_q: Optional[torch.Tensor] = None, - batch_offsets_o: Optional[torch.Tensor] = None, - batch_offsets_k: Optional[torch.Tensor] = None, - batch_offsets_v: Optional[torch.Tensor] = None, - batch_offsets_stats: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor]: + batch_offsets_q: Optional[paddle.Tensor] = None, + batch_offsets_o: Optional[paddle.Tensor] = None, + batch_offsets_k: Optional[paddle.Tensor] = None, + batch_offsets_v: Optional[paddle.Tensor] = None, + batch_offsets_stats: Optional[paddle.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, +) -> tuple[paddle.Tensor, paddle.Tensor]: graph, tensors = _build_prefill_graph( q=q, k_cache=k_cache, @@ -342,41 +313,35 @@ def _batch_prefill_with_kv_cache( out=out, lse=lse, ) - var_map = { UIDs.Q_UID.value: q, UIDs.K_UID.value: k_cache, UIDs.V_UID.value: v_cache, UIDs.O_UID.value: out, } - if actual_seq_lens_q is not None: var_map[UIDs.ACTUAL_SEQ_LENS_Q_UID.value] = actual_seq_lens_q if actual_seq_lens_kv is not None: var_map[UIDs.ACTUAL_SEQ_LENS_KV_UID.value] = actual_seq_lens_kv - if batch_offsets_q is not None: var_map[UIDs.RAGGED_Q_UID.value] = batch_offsets_q if batch_offsets_o is not None: var_map[UIDs.RAGGED_O_UID.value] = batch_offsets_o - if batch_offsets_k is not None: var_map[UIDs.RAGGED_K_UID.value] = batch_offsets_k if batch_offsets_v is not None: var_map[UIDs.RAGGED_V_UID.value] = batch_offsets_v - if block_tables is not None: var_map[UIDs.BLOCK_TABLES_K_UID.value] = block_tables var_map[UIDs.BLOCK_TABLES_V_UID.value] = block_tables - if return_lse: var_map[UIDs.STATS_UID.value] = lse if batch_offsets_stats is not None: var_map[UIDs.RAGGED_STATS_UID.value] = batch_offsets_stats - - handle = _create_cudnn_handle(torch.cuda.current_stream(q.device)) + handle = _create_cudnn_handle( + paddle.device.current_stream(device=device2str(q.place)) + ) graph.execute(var_map, workspace=workspace_buffer, handle=handle) - if return_lse: return out, lse else: @@ -384,29 +349,29 @@ def _batch_prefill_with_kv_cache( def cudnn_batch_prefill_with_kv_cache( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, + q: paddle.Tensor, + k_cache: paddle.Tensor, + v_cache: paddle.Tensor, scale: float, - workspace_buffer: torch.Tensor, + workspace_buffer: paddle.Tensor, *, max_token_per_sequence: int, max_sequence_kv: int, - actual_seq_lens_q: torch.Tensor, - actual_seq_lens_kv: torch.Tensor, - block_tables: Optional[torch.Tensor] = None, + actual_seq_lens_q: paddle.Tensor, + actual_seq_lens_kv: paddle.Tensor, + block_tables: Optional[paddle.Tensor] = None, causal: bool, return_lse: bool, - batch_offsets_q: Optional[torch.Tensor] = None, - batch_offsets_o: Optional[torch.Tensor] = None, - batch_offsets_k: Optional[torch.Tensor] = None, - batch_offsets_v: Optional[torch.Tensor] = None, - batch_offsets_stats: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + batch_offsets_q: Optional[paddle.Tensor] = None, + batch_offsets_o: Optional[paddle.Tensor] = None, + batch_offsets_k: Optional[paddle.Tensor] = None, + batch_offsets_v: Optional[paddle.Tensor] = None, + batch_offsets_stats: Optional[paddle.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, is_cuda_graph_compatible: bool = False, backend: Optional[str] = None, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[paddle.Tensor, Optional[paddle.Tensor]]: """Performs batched prefill attention with paged KV cache using cuDNN. Args: @@ -440,40 +405,32 @@ def cudnn_batch_prefill_with_kv_cache( Head dimension of query and key must be 128 or 192 Head dimension of value and output must be 128 """ - - num_tokens = q.shape[0] - - num_sequences = actual_seq_lens_q.shape[0] - + num_tokens = tuple(q.shape)[0] + num_sequences = tuple(actual_seq_lens_q.shape)[0] if q.dim() == 3: - h_qo, d_qk = q.shape[1], q.shape[2] + h_qo, d_qk = tuple(q.shape)[1], tuple(q.shape)[2] elif q.dim() == 4: - h_qo, d_qk = q.shape[1], q.shape[3] - + h_qo, d_qk = tuple(q.shape)[1], tuple(q.shape)[3] if v_cache.dim() == 3: - d_vo = v_cache.shape[2] + d_vo = tuple(v_cache.shape)[2] elif v_cache.dim() == 4: - d_vo = v_cache.shape[3] - + d_vo = tuple(v_cache.shape)[3] if return_lse: if lse is None: - lse = torch.empty( - num_sequences, - max_token_per_sequence, - h_qo, - device=q.device, - dtype=torch.float32, + lse = paddle.empty( + shape=[num_sequences, max_token_per_sequence, h_qo], dtype="float32" ) - - if lse is not None and lse.shape != (num_sequences, max_token_per_sequence, h_qo): + if lse is not None and tuple(lse.shape) != ( + num_sequences, + max_token_per_sequence, + h_qo, + ): raise ValueError( "lse must have shape (num_sequences, max_token_per_sequence, h_qo)" ) - if out is None: - out_shape = (num_tokens, h_qo, d_vo) - out = torch.empty(out_shape, device=q.device, dtype=q.dtype) - + out_shape = num_tokens, h_qo, d_vo + out = paddle.empty(shape=out_shape, dtype=q.dtype) if CUDNN_AVAILABLE and backend != "cubin": return _batch_prefill_with_kv_cache( q=q, @@ -498,32 +455,28 @@ def cudnn_batch_prefill_with_kv_cache( ) else: assert return_lse, "Currently only supports return_lse = True" - - assert (d_qk == 192 and block_tables is None) or ( - d_qk == 128 and block_tables is not None - ), ( - "Currently only supports if d_qk = 192 and block_tables is None or d_qk = 128 and block_tables is not None" - ) - + assert ( + d_qk == 192 + and block_tables is None + or d_qk == 128 + and block_tables is not None + ), "Currently only supports if d_qk = 192 and block_tables is None or d_qk = 128 and block_tables is not None" if max_sequence_kv is None: max_sequence_kv = max_token_per_sequence - - actual_seq_lens_q_gpu = actual_seq_lens_q.to(q.device, non_blocking=True) - - actual_seq_lens_kv_gpu = actual_seq_lens_kv.to(q.device, non_blocking=True) - + actual_seq_lens_q_gpu = actual_seq_lens_q.to(q.place, blocking=not True) + actual_seq_lens_kv_gpu = actual_seq_lens_kv.to(q.place, blocking=not True) run_func = get_cudnn_fmha_gen_module().prefill run_func( num_sequences, - max_token_per_sequence, # max_s_qo - max_sequence_kv, # max_s_kv + max_token_per_sequence, + max_sequence_kv, q, k_cache, v_cache, scale, workspace_buffer, - actual_seq_lens_q, # actual_seq_lens_q - actual_seq_lens_kv, # actual_seq_lens_kv + actual_seq_lens_q, + actual_seq_lens_kv, actual_seq_lens_q_gpu, actual_seq_lens_kv_gpu, block_tables, @@ -537,5 +490,4 @@ def cudnn_batch_prefill_with_kv_cache( None, is_cuda_graph_compatible, ) - return out, lse diff --git a/flashinfer/cute_dsl/blockscaled_gemm.py b/flashinfer/cute_dsl/blockscaled_gemm.py index f9c8bf6bc0..30489aa22f 100644 --- a/flashinfer/cute_dsl/blockscaled_gemm.py +++ b/flashinfer/cute_dsl/blockscaled_gemm.py @@ -1,32 +1,8 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause +import sys -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. - -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. - -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from typing import Optional, Tuple, Type, Union +import functools +from typing import Callable, List, Optional, Tuple, Type, Union import cuda.bindings.driver as cuda import cutlass @@ -36,21 +12,16 @@ import cutlass.utils as utils import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.utils.blockscaled_layout as blockscaled_utils -import torch -import functools +import paddle from cutlass._mlir import ir from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass.cute.runtime import from_dlpack, make_ptr -from cutlass.cutlass_dsl import ( - Int32, - Integer, - dsl_user_op, - extract_mlir_values, - new_from_mlir_values, -) +from cutlass.cutlass_dsl import (Int32, Integer, dsl_user_op, + extract_mlir_values, new_from_mlir_values) from cutlass.utils.static_persistent_tile_scheduler import WorkTileInfo -from .utils import get_cutlass_dtype, cutlass_to_torch_dtype, get_num_sm -from typing import Callable, List +from flashinfer.paddle_utils import * + +from .utils import cutlass_to_torch_dtype, get_cutlass_dtype, get_num_sm class MaskedSchedulerParams: @@ -66,18 +37,15 @@ def __init__( ): if cluster_shape_mnk[2] != 1: raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}") - gc = cute.zipped_divide(c, tiler=c_tiler) - problem_shape_ntile_mnl = gc[(0, (None, None, None))].shape + problem_shape_ntile_mnl = tuple(gc[0, (None, None, None)].shape) self.masked_m = masked_m self.c = c self.c_tiler = c_tiler self.problem_shape_ntile_mnl = problem_shape_ntile_mnl - # cluster_shape_mnk is kept for reconstruction self._cluster_shape_mnk = cluster_shape_mnk self.cluster_shape_mn = cluster_shape_mnk[:2] self._loc = loc - self.problem_layout_ncluster_mnl = cute.make_layout( cute.ceil_div( self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip @@ -88,12 +56,7 @@ def __init__( def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [ - self.masked_m, - self.c, - self.c_tiler, - self._cluster_shape_mnk, - ]: + for obj in [self.masked_m, self.c, self.c_tiler, self._cluster_shape_mnk]: obj_values = extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -107,15 +70,14 @@ def __new_from_mlir_values__(self, values): ): obj_list.append(new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return MaskedSchedulerParams(*(tuple(obj_list)), loc=self._loc) + return MaskedSchedulerParams(*tuple(obj_list), loc=self._loc) @dsl_user_op def get_grid_shape( self, max_active_clusters: Int32, *, loc=None, ip=None ) -> Tuple[Integer, Integer, Integer]: num_persistent_clusters = max_active_clusters - - return (*self.cluster_shape_mn, num_persistent_clusters) + return *self.cluster_shape_mn, num_persistent_clusters class MaskedScheduler: @@ -174,7 +136,6 @@ def __new_from_mlir_values__(self, values: list[ir.Value]) -> "MaskedScheduler": new_num_tiles_executed, ) - # called by host @dsl_user_op @staticmethod def create( @@ -186,27 +147,18 @@ def create( ip=None, ): params = params - - # Calculate the number of persistent clusters by dividing the total grid size - # by the number of CTAs per cluster num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size( params.cluster_shape_mn, loc=loc, ip=ip ) - bidx, bidy, bidz = block_idx - - # Initialize workload index equals to the cluster index in the grid current_work_linear_idx = Int32(bidz) current_batch_idx = Int32(0) accum_tile_m = Int32(0) - - # CTA id in the cluster cta_id_in_cluster = ( Int32(bidx % params.cluster_shape_mn[0]), Int32(bidy % params.cluster_shape_mn[1]), Int32(0), ) - # Initialize number of tiles executed to zero num_tiles_executed = Int32(0) return MaskedScheduler( params, @@ -218,48 +170,34 @@ def create( num_tiles_executed, ) - # called by host @staticmethod def get_grid_shape( - params: MaskedSchedulerParams, - max_active_clusters: Int32, - *, - loc=None, - ip=None, + params: MaskedSchedulerParams, max_active_clusters: Int32, *, loc=None, ip=None ) -> Tuple[Integer, Integer, Integer]: return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip) - # private method @cute.jit def _get_current_work_for_linear_idx( - self, - current_work_linear_idx: Int32, + self, current_work_linear_idx: Int32 ) -> WorkTileInfo: - # is_valid = current_work_linear_idx < cute.size( - # self.params.problem_layout_ncluster_mnl, loc=loc, ip=ip - # ) num_tiles_n = self.params.problem_shape_ntile_mnl[1] accum_tile_m = self._accum_tile_m batch_idx = self._current_batch_idx - while ( - ( - accum_tile_m - + cute.ceil_div(self.params.masked_m[batch_idx], self.params.c_tiler[0]) - ) - * num_tiles_n - <= current_work_linear_idx - and batch_idx < self.params.masked_m.shape[0] - ): + accum_tile_m + + cute.ceil_div(self.params.masked_m[batch_idx], self.params.c_tiler[0]) + ) * num_tiles_n <= current_work_linear_idx and batch_idx < tuple( + self.params.masked_m.shape + )[ + 0 + ]: accum_tile_m += cute.ceil_div( self.params.masked_m[batch_idx], self.params.c_tiler[0] ) batch_idx += Int32(1) - self._accum_tile_m = accum_tile_m self._current_batch_idx = batch_idx - - is_valid = self._current_batch_idx < self.params.masked_m.shape[0] + is_valid = self._current_batch_idx < tuple(self.params.masked_m.shape)[0] if is_valid: is_valid = ( self._accum_tile_m @@ -268,17 +206,11 @@ def _get_current_work_for_linear_idx( self.params.c_tiler[0], ) ) * num_tiles_n > current_work_linear_idx - - # cur_cluster_coord = self.params.problem_layout_ncluster_mnl.get_hier_coord( - # current_work_linear_idx, loc=loc, ip=ip - # ) cur_cluster_coord = ( current_work_linear_idx // num_tiles_n - self._accum_tile_m, current_work_linear_idx % num_tiles_n, self._current_batch_idx, ) - - # cur_tile_coord is a tuple of i32 values cur_tile_coord = tuple( Int32(x) * Int32(z) + Int32(y) for x, y, z in zip( @@ -287,14 +219,11 @@ def _get_current_work_for_linear_idx( (*self.params.cluster_shape_mn, Int32(1)), ) ) - return WorkTileInfo(cur_tile_coord, is_valid) @dsl_user_op def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: - return self._get_current_work_for_linear_idx( - self._current_work_linear_idx, - ) + return self._get_current_work_for_linear_idx(self._current_work_linear_idx) @dsl_user_op def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: @@ -353,22 +282,13 @@ def num_tiles_executed(self) -> Int32: .. code-block:: bash - python examples/blackwell/dense_blockscaled_gemm_persistent.py \ - --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ - --c_dtype Float16 \ - --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ - --mnkl 8192,8192,1024,1 + python examples/blackwell/dense_blockscaled_gemm_persistent.py --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 --c_dtype Float16 --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 --mnkl 8192,8192,1024,1 To collect performance with NCU profiler: .. code-block:: bash - ncu python examples/blackwell/dense_blockscaled_gemm_persistent.py \ - --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ - --c_dtype Float16 \ - --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ - --mnkl 8192,8192,1024,1 \ - --warmup_iterations 1 --iterations 10 --skip_ref_check + ncu python examples/blackwell/dense_blockscaled_gemm_persistent.py --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 --c_dtype Float16 --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 --mnkl 8192,8192,1024,1 --warmup_iterations 1 --iterations 10 --skip_ref_check Constraints: @@ -451,32 +371,21 @@ def __init__( :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. :type cluster_shape_mn: Tuple[int, int] """ - self.acc_dtype = cutlass.Float32 self.sf_vec_size = sf_vec_size self.use_2cta_instrs = mma_tiler_mn[0] == 256 self.cluster_shape_mn = cluster_shape_mn - # K dimension is deferred in _setup_attributes - self.mma_tiler = (*mma_tiler_mn, 1) - + self.mma_tiler = *mma_tiler_mn, 1 self.cta_group = ( tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE ) - self.occupancy = 1 - # Set specialized warp ids - self.epilog_warp_id = ( - 0, - 1, - 2, - 3, - ) + self.epilog_warp_id = 0, 1, 2, 3 self.mma_warp_id = 4 self.tma_warp_id = 5 self.threads_per_cta = 32 * len( (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) ) - # Set barrier id for cta sync, epilogue sync and tmem ptr sync self.cta_sync_bar_id = 0 self.epilog_sync_bar_id = 1 self.tmem_ptr_sync_bar_id = 2 @@ -498,21 +407,17 @@ def _setup_attributes(self): - Computing A/B/SFA/SFB/C shared memory layout - Computing tensor memory allocation columns """ - # Compute mma instruction shapes mma_inst_bits_k = 256 - # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) self.mma_inst_shape_mnk = ( self.mma_tiler[0], self.mma_tiler[1], mma_inst_bits_k // self.a_dtype.width, ) - # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) self.mma_inst_shape_mnk_sfb = ( self.mma_inst_shape_mnk[0] // (2 if self.use_2cta_instrs else 1), cute.round_up(self.mma_inst_shape_mnk[1], 128), self.mma_inst_shape_mnk[2], ) - tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( self.a_dtype, self.a_major_mode, @@ -522,7 +427,6 @@ def _setup_attributes(self): self.cta_group, self.mma_inst_shape_mnk[:2], ) - tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( self.a_dtype, self.a_major_mode, @@ -532,8 +436,6 @@ def _setup_attributes(self): cute.nvgpu.tcgen05.CtaGroup.ONE, self.mma_inst_shape_mnk_sfb[:2], ) - - # Compute mma/cluster/tile shapes mma_inst_tile_k = 4 self.mma_tiler = ( self.mma_inst_shape_mnk[0], @@ -550,34 +452,23 @@ def _setup_attributes(self): self.mma_tiler[1], self.mma_tiler[2], ) - - # Compute cluster layout self.cluster_layout_vmnk = cute.tiled_divide( - cute.make_layout((*self.cluster_shape_mn, 1)), - (tiled_mma.thr_id.shape,), + cute.make_layout((*self.cluster_shape_mn, 1)), (tiled_mma.thr_id.shape,) ) self.cluster_layout_sfb_vmnk = cute.tiled_divide( - cute.make_layout((*self.cluster_shape_mn, 1)), - (tiled_mma_sfb.thr_id.shape,), + cute.make_layout((*self.cluster_shape_mn, 1)), (tiled_mma_sfb.thr_id.shape,) + ) + self.num_mcast_ctas_a = cute.size(tuple(self.cluster_layout_vmnk.shape)[2]) + self.num_mcast_ctas_b = cute.size(tuple(self.cluster_layout_vmnk.shape)[1]) + self.num_mcast_ctas_sfb = cute.size( + tuple(self.cluster_layout_sfb_vmnk.shape)[1] ) - - # Compute number of multicast CTAs for A/B - self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) - self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) - self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) self.is_a_mcast = self.num_mcast_ctas_a > 1 self.is_b_mcast = self.num_mcast_ctas_b > 1 self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 - - # Compute epilogue subtile self.epi_tile = sm100_utils.compute_epilogue_tile_shape( - self.cta_tile_shape_mnk, - self.use_2cta_instrs, - self.c_layout, - self.c_dtype, + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.c_dtype ) - - # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( tiled_mma, self.mma_tiler, @@ -593,37 +484,20 @@ def _setup_attributes(self): self.smem_capacity, self.occupancy, ) - - # Compute A/B/SFA/SFB/C shared memory layout self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( - tiled_mma, - self.mma_tiler, - self.a_dtype, - self.num_ab_stage, + tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage ) self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( - tiled_mma, - self.mma_tiler, - self.b_dtype, - self.num_ab_stage, + tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage ) self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( - tiled_mma, - self.mma_tiler, - self.sf_vec_size, - self.num_ab_stage, + tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage ) self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( - tiled_mma, - self.mma_tiler, - self.sf_vec_size, - self.num_ab_stage, + tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage ) self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( - self.c_dtype, - self.c_layout, - self.epi_tile, - self.num_c_stage, + self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage ) @cute.jit @@ -666,7 +540,6 @@ def __call__( :type alpha_tensor: cute.Tensor :raises TypeError: If input data types are incompatible with the MMA instruction. """ - # Setup static attributes before smem/grid/tma computation self.a_dtype: Type[cutlass.Numeric] = a_tensor.element_type self.b_dtype: Type[cutlass.Numeric] = b_tensor.element_type self.sf_dtype: Type[cutlass.Numeric] = sfa_tensor.element_type @@ -674,27 +547,17 @@ def __call__( self.a_major_mode = utils.LayoutEnum.from_tensor(a_tensor).mma_major_mode() self.b_major_mode = utils.LayoutEnum.from_tensor(b_tensor).mma_major_mode() self.c_layout = utils.LayoutEnum.from_tensor(c_tensor) - - # Check if input data types are compatible with MMA instruction if cutlass.const_expr(self.a_dtype != self.b_dtype): raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") - - # Setup attributes that dependent on gemm inputs self._setup_attributes() - - # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout - # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( - a_tensor.shape, self.sf_vec_size + tuple(a_tensor.shape), self.sf_vec_size ) sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) - - # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( - b_tensor.shape, self.sf_vec_size + tuple(b_tensor.shape), self.sf_vec_size ) sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) - tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( self.a_dtype, self.a_major_mode, @@ -704,7 +567,6 @@ def __call__( self.cta_group, self.mma_inst_shape_mnk[:2], ) - tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( self.a_dtype, self.a_major_mode, @@ -715,8 +577,6 @@ def __call__( self.mma_inst_shape_mnk_sfb[:2], ) atom_thr_size = cute.size(tiled_mma.thr_id.shape) - - # Setup TMA load for A a_op = sm100_utils.cluster_shape_to_tma_atom_A( self.cluster_shape_mn, tiled_mma.thr_id ) @@ -727,10 +587,8 @@ def __call__( a_smem_layout, self.mma_tiler, tiled_mma, - self.cluster_layout_vmnk.shape, + tuple(self.cluster_layout_vmnk.shape), ) - - # Setup TMA load for B b_op = sm100_utils.cluster_shape_to_tma_atom_B( self.cluster_shape_mn, tiled_mma.thr_id ) @@ -741,10 +599,8 @@ def __call__( b_smem_layout, self.mma_tiler, tiled_mma, - self.cluster_layout_vmnk.shape, + tuple(self.cluster_layout_vmnk.shape), ) - - # Setup TMA load for SFA sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( self.cluster_shape_mn, tiled_mma.thr_id ) @@ -757,11 +613,9 @@ def __call__( sfa_smem_layout, self.mma_tiler, tiled_mma, - self.cluster_layout_vmnk.shape, + tuple(self.cluster_layout_vmnk.shape), internal_type=cutlass.Int16, ) - - # Setup TMA load for SFB sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( self.cluster_shape_mn, tiled_mma.thr_id ) @@ -774,10 +628,9 @@ def __call__( sfb_smem_layout, self.mma_tiler_sfb, tiled_mma_sfb, - self.cluster_layout_sfb_vmnk.shape, + tuple(self.cluster_layout_sfb_vmnk.shape), internal_type=cutlass.Int16, ) - a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) @@ -785,28 +638,19 @@ def __call__( self.num_tma_load_bytes = ( a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size ) * atom_thr_size - - # Setup TMA store for C epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( - cpasync.CopyBulkTensorTileS2GOp(), - c_tensor, - epi_smem_layout, - self.epi_tile, + cpasync.CopyBulkTensorTileS2GOp(), c_tensor, epi_smem_layout, self.epi_tile ) - - # Compute grid size self.tile_sched_params, grid = self._compute_grid( - masked_m_tensor, # add masked layout + masked_m_tensor, c_tensor, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters, ) - self.buffer_align_bytes = 1024 - # Define shared storage for kernel @cute.struct class SharedStorage: ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] @@ -815,36 +659,30 @@ class SharedStorage: acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] tmem_dealloc_mbar_ptr: cutlass.Int64 tmem_holding_buf: cutlass.Int32 - # (EPI_TILE_M, EPI_TILE_N, STAGE) sC: cute.struct.Align[ cute.struct.MemRange[ - self.c_dtype, - cute.cosize(self.c_smem_layout_staged.outer), + self.c_dtype, cute.cosize(self.c_smem_layout_staged.outer) ], self.buffer_align_bytes, ] - # (MMA, MMA_M, MMA_K, STAGE) sA: cute.struct.Align[ cute.struct.MemRange[ self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) ], self.buffer_align_bytes, ] - # (MMA, MMA_N, MMA_K, STAGE) sB: cute.struct.Align[ cute.struct.MemRange[ self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) ], self.buffer_align_bytes, ] - # (MMA, MMA_M, MMA_K, STAGE) sSFA: cute.struct.Align[ cute.struct.MemRange[ self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged) ], self.buffer_align_bytes, ] - # (MMA, MMA_N, MMA_K, STAGE) sSFB: cute.struct.Align[ cute.struct.MemRange[ self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged) @@ -853,8 +691,6 @@ class SharedStorage: ] self.shared_storage = SharedStorage - - # Launch the kernel synchronously self.kernel( tiled_mma, tiled_mma_sfb, @@ -882,12 +718,11 @@ class SharedStorage: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=(*self.cluster_shape_mn, 1), - smem=self.shared_storage.size_in_bytes(), # type: ignore[attr-defined] + smem=self.shared_storage.size_in_bytes(), stream=stream, ) return - # GPU device kernel @cute.kernel def kernel( self, @@ -919,23 +754,13 @@ def kernel( """ warp_idx = cute.arch.warp_idx() warp_idx = cute.arch.make_warp_uniform(warp_idx) - - # - # Prefetch tma desc - # if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_atom_a) cpasync.prefetch_descriptor(tma_atom_b) cpasync.prefetch_descriptor(tma_atom_sfa) cpasync.prefetch_descriptor(tma_atom_sfb) cpasync.prefetch_descriptor(tma_atom_c) - use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 - - # - # Setup cta/thread coordinates - # - # Coords inside cluster bidx, bidy, bidz = cute.arch.block_idx() mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) is_leader_cta = mma_tile_coord_v == 0 @@ -948,19 +773,11 @@ def kernel( block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( cta_rank_in_cluster ) - # Coord inside cta tidx, _, _ = cute.arch.thread_idx() - - # - # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier - # smem = utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr tmem_holding_buf = storage.tmem_holding_buf - - # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 ab_pipeline_consumer_group = pipeline.CooperativeGroup( @@ -974,8 +791,6 @@ def kernel( tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cluster_layout_vmnk, ) - - # Initialize acc_pipeline (barrier) and states acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_acc_consumer_threads = len(self.epilog_warp_id) * ( 2 if use_2cta_instrs else 1 @@ -990,8 +805,6 @@ def kernel( consumer_group=acc_pipeline_consumer_group, cta_layout_vmnk=cluster_layout_vmnk, ) - - # Tensor memory dealloc barrier init if use_2cta_instrs: if warp_idx == self.tma_warp_id: num_tmem_dealloc_threads = 32 @@ -1000,34 +813,19 @@ def kernel( tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads ) cute.arch.mbarrier_init_fence() - - # Cluster arrive after barrier init if cute.size(self.cluster_shape_mn) > 1: cute.arch.cluster_arrive_relaxed() - - # - # Setup smem tensor A/B/SFA/SFB/C - # - # (EPI_TILE_M, EPI_TILE_N, STAGE) sC = storage.sC.get_tensor( c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner ) - # (MMA, MMA_M, MMA_K, STAGE) sA = storage.sA.get_tensor( a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner ) - # (MMA, MMA_N, MMA_K, STAGE) sB = storage.sB.get_tensor( b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner ) - # (MMA, MMA_M, MMA_K, STAGE) sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) - # (MMA, MMA_N, MMA_K, STAGE) sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) - - # - # Compute multicast mask for A/B/SFA/SFB buffer full - # a_full_mcast_mask = None b_full_mcast_mask = None sfa_full_mcast_mask = None @@ -1045,57 +843,32 @@ def kernel( sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 ) - - # - # Local_tile partition global tensors - # - # (bM, bK, RestM, RestK, RestL) gA_mkl = cute.local_tile( mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) ) - # (bN, bK, RestN, RestK, RestL) gB_nkl = cute.local_tile( mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) ) - # (bM, bK, RestM, RestK, RestL) gSFA_mkl = cute.local_tile( mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) ) - # (bN, bK, RestN, RestK, RestL) gSFB_nkl = cute.local_tile( mSFB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) ) - # (bM, bN, RestM, RestN, RestL) gC_mnl = cute.local_tile( mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) ) k_block_cnt = cute.size(gA_mkl, mode=[3]) - - # - # Partition global tensor for TiledMMA_A/B/C - # thr_mma = tiled_mma.get_slice(mma_tile_coord_v) thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) - # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) tCgA = thr_mma.partition_A(gA_mkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) tCgB = thr_mma.partition_B(gB_nkl) - # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) tCgSFA = thr_mma.partition_A(gSFA_mkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) - # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) tCgC = thr_mma.partition_C(gC_mnl) - - # - # Partition global/shared tensor for TMA load A/B - # - # TMA load A partition_S/D a_cta_layout = cute.make_layout( cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape ) - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestM, RestK, RestL) tAsA, tAgA = cpasync.tma_partition( tma_atom_a, block_in_cluster_coord_vmnk[2], @@ -1103,12 +876,9 @@ def kernel( cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3), ) - # TMA load B partition_S/D b_cta_layout = cute.make_layout( cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape ) - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) tBsB, tBgB = cpasync.tma_partition( tma_atom_b, block_in_cluster_coord_vmnk[1], @@ -1116,11 +886,7 @@ def kernel( cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3), ) - - # TMA load SFA partition_S/D sfa_cta_layout = a_cta_layout - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestM, RestK, RestL) tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition( tma_atom_sfa, block_in_cluster_coord_vmnk[2], @@ -1130,13 +896,9 @@ def kernel( ) tAsSFA = cute.filter_zeros(tAsSFA) tAgSFA = cute.filter_zeros(tAgSFA) - - # TMA load SFB partition_S/D sfb_cta_layout = cute.make_layout( cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape ) - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( tma_atom_sfb, block_in_cluster_coord_sfb_vmnk[1], @@ -1146,173 +908,108 @@ def kernel( ) tBsSFB = cute.filter_zeros(tBsSFB) tBgSFB = cute.filter_zeros(tBgSFB) - - # - # Partition shared/tensor memory tensor for TiledMMA_A/B/C - # - # (MMA, MMA_M, MMA_K, STAGE) tCrA = tiled_mma.make_fragment_A(sA) - # (MMA, MMA_N, MMA_K, STAGE) tCrB = tiled_mma.make_fragment_B(sB) - # (MMA, MMA_M, MMA_N) acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) - # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_fake = tiled_mma.make_fragment_C( cute.append(acc_shape, self.num_acc_stage) ) - - # - # Cluster wait before tensor memory alloc - # if cute.size(self.cluster_shape_mn) > 1: cute.arch.cluster_wait() else: cute.arch.barrier( barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta ) - - # - # Specialized TMA load warp - # if warp_idx == self.tma_warp_id: - # - # Persistent tile scheduling loop - # tile_sched = MaskedScheduler.create( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() - ab_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.num_ab_stage ) - while work_tile.is_valid_tile: - # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx mma_tile_coord_mnl = ( cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), cur_tile_coord[1], cur_tile_coord[2], ) - - # - # Slice to per mma tile index - # - # ((atom_v, rest_v), RestK) tAgA_slice = tAgA[ - (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2] ] - # ((atom_v, rest_v), RestK) tBgB_slice = tBgB[ - (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2] ] - - # ((atom_v, rest_v), RestK) tAgSFA_slice = tAgSFA[ - (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2] ] - # ((atom_v, rest_v), RestK) tBgSFB_slice = tBgSFB[ - (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2] ] - - # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt ab_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) if ab_producer_state.count < k_block_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire( ab_producer_state ) - # - # Tma load loop - # - for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1): # noqa: B007 - # Conditionally wait for AB buffer empty + for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1): ab_pipeline.producer_acquire( ab_producer_state, peek_ab_empty_status ) - - # TMA load A/B/SFA/SFB cute.copy( tma_atom_a, - tAgA_slice[(None, ab_producer_state.count)], - tAsA[(None, ab_producer_state.index)], + tAgA_slice[None, ab_producer_state.count], + tAsA[None, ab_producer_state.index], tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), mcast_mask=a_full_mcast_mask, ) cute.copy( tma_atom_b, - tBgB_slice[(None, ab_producer_state.count)], - tBsB[(None, ab_producer_state.index)], + tBgB_slice[None, ab_producer_state.count], + tBsB[None, ab_producer_state.index], tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), mcast_mask=b_full_mcast_mask, ) cute.copy( tma_atom_sfa, - tAgSFA_slice[(None, ab_producer_state.count)], - tAsSFA[(None, ab_producer_state.index)], + tAgSFA_slice[None, ab_producer_state.count], + tAsSFA[None, ab_producer_state.index], tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), mcast_mask=sfa_full_mcast_mask, ) cute.copy( tma_atom_sfb, - tBgSFB_slice[(None, ab_producer_state.count)], - tBsSFB[(None, ab_producer_state.index)], + tBgSFB_slice[None, ab_producer_state.count], + tBsSFB[None, ab_producer_state.index], tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), mcast_mask=sfb_full_mcast_mask, ) - - # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1 ab_producer_state.advance() peek_ab_empty_status = cutlass.Boolean(1) if ab_producer_state.count < k_block_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire( ab_producer_state ) - - # - # Advance to next tile - # tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() - - # - # Wait A/B buffer empty - # ab_pipeline.producer_tail(ab_producer_state) - - # - # Specialized MMA warp - # if warp_idx == self.mma_warp_id: - # - # Bar sync for retrieve tensor memory ptr from shared mem - # tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) cute.arch.barrier( barrier_id=self.tmem_ptr_sync_bar_id, number_of_threads=tmem_ptr_read_threads, ) - - # - # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor - # - # Make accumulator tmem tensor acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf, ) - # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) - - # Make SFA tmem tensor sfa_tmem_ptr = cute.recast_ptr( acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base), dtype=self.sf_dtype, ) - # (MMA, MMA_M, MMA_K) tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( tiled_mma, self.mma_tiler, @@ -1320,15 +1017,12 @@ def kernel( cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), ) tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) - - # Make SFB tmem tensor sfb_tmem_ptr = cute.recast_ptr( acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA), dtype=self.sf_dtype, ) - # (MMA, MMA_N, MMA_K) tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( tiled_mma, self.mma_tiler, @@ -1336,74 +1030,48 @@ def kernel( cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), ) tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) - # - # Partition for S2T copy of SFA/SFB - # - tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = ( - self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) - ) - tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = ( - self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) - ) - - # - # Persistent tile scheduling loop - # + ( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t, + tCtSFA_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t, + tCtSFB_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) tile_sched = MaskedScheduler.create( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() - ab_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.num_ab_stage ) acc_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.num_acc_stage ) - while work_tile.is_valid_tile: - # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx mma_tile_coord_mnl = ( cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), cur_tile_coord[1], cur_tile_coord[2], ) - - # Set tensor memory buffer for current tile - # (MMA, MMA_M, MMA_N) - tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] - - # Peek (try_wait) AB buffer full for k_block = 0 + tCtAcc = tCtAcc_base[None, None, None, acc_producer_state.index] ab_consumer_state.reset_count() peek_ab_full_status = cutlass.Boolean(1) if ab_consumer_state.count < k_block_cnt and is_leader_cta: peek_ab_full_status = ab_pipeline.consumer_try_wait( ab_consumer_state ) - - # - # Wait for accumulator buffer empty - # if is_leader_cta: acc_pipeline.producer_acquire(acc_producer_state) - - # - # Reset the ACCUMULATE field for each tile - # tiled_mma.set(tcgen05.Field.ACCUMULATE, False) - - # - # Mma mainloop - # - for k_block in cutlass.range_constexpr(k_block_cnt): # noqa: B007 + for k_block in cutlass.range_constexpr(k_block_cnt): if is_leader_cta: - # Conditionally wait for AB buffer full ab_pipeline.consumer_wait( ab_consumer_state, peek_ab_full_status ) - - # Copy SFA/SFB from smem to tmem s2t_stage_coord = ( None, None, @@ -1423,8 +1091,6 @@ def kernel( tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t, ) - - # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB num_kphases = cute.size(tCrA, mode=[2]) for kphase_idx in cutlass.range(num_kphases, unroll_full=True): kphase_coord = ( @@ -1433,18 +1099,13 @@ def kernel( kphase_idx, ab_consumer_state.index, ) - - # Set SFA/SFB tensor to tiled_mma - sf_kphase_coord = (None, None, kphase_idx) + sf_kphase_coord = None, None, kphase_idx tiled_mma.set( - tcgen05.Field.SFA, - tCtSFA[sf_kphase_coord].iterator, + tcgen05.Field.SFA, tCtSFA[sf_kphase_coord].iterator ) tiled_mma.set( - tcgen05.Field.SFB, - tCtSFB[sf_kphase_coord].iterator, + tcgen05.Field.SFB, tCtSFB[sf_kphase_coord].iterator ) - cute.gemm( tiled_mma, tCtAcc, @@ -1452,14 +1113,8 @@ def kernel( tCrB[kphase_coord], tCtAcc, ) - - # Enable accumulate on tCtAcc after first kphase tiled_mma.set(tcgen05.Field.ACCUMULATE, True) - - # Async arrive AB buffer empty ab_pipeline.consumer_release(ab_consumer_state) - - # Peek (try_wait) AB buffer full for k_block = k_block + 1 ab_consumer_state.advance() peek_ab_full_status = cutlass.Boolean(1) if ab_consumer_state.count < k_block_cnt: @@ -1467,169 +1122,92 @@ def kernel( peek_ab_full_status = ab_pipeline.consumer_try_wait( ab_consumer_state ) - - # - # Async arrive accumulator buffer full - # if is_leader_cta: acc_pipeline.producer_commit(acc_producer_state) acc_producer_state.advance() - - # - # Advance to next tile - # tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() - - # - # Wait for accumulator buffer empty - # acc_pipeline.producer_tail(acc_producer_state) - # - # Specialized epilogue warps - # if warp_idx < self.mma_warp_id: - # - # Alloc tensor memory buffer - # if warp_idx == self.epilog_warp_id[0]: cute.arch.alloc_tmem( self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs, ) - - # - # Bar sync for retrieve tensor memory ptr from shared memory - # tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) cute.arch.barrier( barrier_id=self.tmem_ptr_sync_bar_id, number_of_threads=tmem_ptr_read_threads, ) - - # - # Retrieving tensor memory ptr and make accumulator tensor - # acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf, ) - # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) - - # - # Partition for epilogue - # epi_tidx = tidx - tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( - self.epilog_tmem_copy_and_partition( - epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs - ) + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs ) - - tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tTR_rC = cute.make_fragment(tuple(tTR_rAcc.shape), self.c_dtype) tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( tiled_copy_t2r, tTR_rC, epi_tidx, sC ) - tma_atom_c, bSG_sC, bSG_gC_partitioned = ( - self.epilog_gmem_copy_and_partition( - epi_tidx, tma_atom_c, tCgC, epi_tile, sC - ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC ) - - # - # Persistent tile scheduling loop - # tile_sched = MaskedScheduler.create( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) work_tile = tile_sched.initial_work_tile_info() - acc_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.num_acc_stage ) - - # Threads/warps participating in tma store pipeline c_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 32 * len(self.epilog_warp_id), ) c_pipeline = pipeline.PipelineTmaStore.create( - num_stages=self.num_c_stage, - producer_group=c_producer_group, + num_stages=self.num_c_stage, producer_group=c_producer_group ) - while work_tile.is_valid_tile: - # Get tile coord from tile scheduler cur_tile_coord = work_tile.tile_idx mma_tile_coord_mnl = ( cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), cur_tile_coord[1], cur_tile_coord[2], ) - - # - # Slice to per mma tile index - # - # ((ATOM_V, REST_V), EPI_M, EPI_N) - bSG_gC = bSG_gC_partitioned[ - ( - None, - None, - None, - *mma_tile_coord_mnl, - ) - ] - - # Set tensor memory buffer for current tile - # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + bSG_gC = bSG_gC_partitioned[None, None, None, *mma_tile_coord_mnl] tTR_tAcc = tTR_tAcc_base[ - (None, None, None, None, None, acc_consumer_state.index) + None, None, None, None, None, acc_consumer_state.index ] - - # - # Wait for accumulator buffer full - # acc_pipeline.consumer_wait(acc_consumer_state) - tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) - - # - # Store accumulator to global memory in subtiles - # - subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + subtile_cnt = cute.size(tuple(tTR_tAcc.shape), mode=[3]) num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt for subtile_idx in cutlass.range(subtile_cnt): - # - # Load accumulator from tensor memory buffer to register - # - tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + tTR_tAcc_mn = tTR_tAcc[None, None, None, subtile_idx] cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) - - # - # Convert to C type - # acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() if cutlass.const_expr(alpha is not None): acc_vec = acc_vec * alpha[work_tile.tile_idx[2]] - acc_vec = acc_vec.to(self.c_dtype) tRS_rC.store(acc_vec) - - # - # Store C to shared memory - # c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage cute.copy( - tiled_copy_r2s, - tRS_rC, - tRS_sC[(None, None, None, c_buffer)], + tiled_copy_r2s, tRS_rC, tRS_sC[None, None, None, c_buffer] ) - # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta, @@ -1639,40 +1217,23 @@ def kernel( barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads, ) - - # - # TMA store C to global memory - # if warp_idx == self.epilog_warp_id[0]: cute.copy( tma_atom_c, - bSG_sC[(None, c_buffer)], - bSG_gC[(None, subtile_idx)], + bSG_sC[None, c_buffer], + bSG_gC[None, subtile_idx], ) - # Fence and barrier to make sure shared memory store is visible to TMA store c_pipeline.producer_commit() c_pipeline.producer_acquire() cute.arch.barrier( barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads, ) - - # - # Async arrive accumulator buffer empty - # with cute.arch.elect_one(): acc_pipeline.consumer_release(acc_consumer_state) acc_consumer_state.advance() - - # - # Advance to next tile - # tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() - - # - # Dealloc the tensor memory buffer - # if warp_idx == self.epilog_warp_id[0]: cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) epilog_threads = 32 * len(self.epilog_warp_id) @@ -1688,15 +1249,10 @@ def kernel( cute.arch.dealloc_tmem( acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs ) - # - # Wait for C store complete - # c_pipeline.producer_tail() def mainloop_s2t_copy_and_partition( - self, - sSF: cute.Tensor, - tSF: cute.Tensor, + self, sSF: cute.Tensor, tSF: cute.Tensor ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: """ Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). @@ -1712,28 +1268,18 @@ def mainloop_s2t_copy_and_partition( - tSF_compact_s2t: The partitioned scale factor tensor in tmem :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] """ - # (MMA, MMA_MN, MMA_K, STAGE) tCsSF_compact = cute.filter_zeros(sSF) - # (MMA, MMA_MN, MMA_K) tCtSF_compact = cute.filter_zeros(tSF) - - # Make S2T CopyAtom and tiledCopy copy_atom_s2t = cute.make_copy_atom( - tcgen05.Cp4x32x128bOp(self.cta_group), - self.sf_dtype, + tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype ) tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) thr_copy_s2t = tiled_copy_s2t.get_slice(0) - - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( tiled_copy_s2t, tCsSF_compact_s2t_ ) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) - return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t def epilog_tmem_copy_and_partition( @@ -1764,7 +1310,6 @@ def epilog_tmem_copy_and_partition( - tTR_rAcc: The accumulated tensor in register used to hold t2r results :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] """ - # Make tiledCopy for tensor memory load copy_atom_t2r = sm100_utils.get_tmem_load_op( self.cta_tile_shape_mnk, self.c_layout, @@ -1773,29 +1318,18 @@ def epilog_tmem_copy_and_partition( epi_tile, use_2cta_instrs, ) - # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) - tAcc_epi = cute.flat_divide( - tAcc[((None, None), 0, 0, None)], - epi_tile, - ) - # (EPI_TILE_M, EPI_TILE_N) + tAcc_epi = cute.flat_divide(tAcc[(None, None), 0, 0, None], epi_tile) tiled_copy_t2r = tcgen05.make_tmem_copy( - copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + copy_atom_t2r, tAcc_epi[None, None, 0, 0, 0] ) - thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) - # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) - - # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) gC_mnl_epi = cute.flat_divide( - gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + gC_mnl[(None, None), 0, 0, None, None, None], epi_tile ) - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) - # (T2R, T2R_M, T2R_N) tTR_rAcc = cute.make_fragment( - tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + tuple(tTR_gC[None, None, None, 0, 0, 0, 0, 0].shape), self.acc_dtype ) return tiled_copy_t2r, tTR_tAcc, tTR_rAcc @@ -1829,10 +1363,8 @@ def epilog_smem_copy_and_partition( self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r ) tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) - # (R2S, R2S_M, R2S_N, PIPE_D) thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) tRS_sC = thr_copy_r2s.partition_D(sC) - # (R2S, R2S_M, R2S_N) tRS_rC = tiled_copy_r2s.retile(tTR_rC) return tiled_copy_r2s, tRS_rC, tRS_sC @@ -1864,16 +1396,12 @@ def epilog_gmem_copy_and_partition( - bSG_gC: The partitioned global tensor C :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] """ - # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) gC_epi = cute.flat_divide( - gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + gC_mnl[(None, None), 0, 0, None, None, None], epi_tile ) - tma_atom_c = atom sC_for_tma_partition = cute.group_modes(sC, 0, 2) gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) - # ((ATOM_V, REST_V), EPI_M, EPI_N) - # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) bSG_sC, bSG_gC = cpasync.tma_partition( tma_atom_c, 0, @@ -1932,45 +1460,23 @@ def _compute_stages( (ACC stages, A/B operand stages, C stages) :rtype: tuple[int, int, int] """ - # ACC stages num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 - - # Default C stages num_c_stage = 2 - - # Calculate smem layout and size for one stage of A, B, SFA, SFB and C a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( - tiled_mma, - mma_tiler_mnk, - a_dtype, - 1, # a tmp 1 stage is provided + tiled_mma, mma_tiler_mnk, a_dtype, 1 ) b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( - tiled_mma, - mma_tiler_mnk, - b_dtype, - 1, # a tmp 1 stage is provided + tiled_mma, mma_tiler_mnk, b_dtype, 1 ) sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - 1, # a tmp 1 stage is provided + tiled_mma, mma_tiler_mnk, sf_vec_size, 1 ) sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - 1, # a tmp 1 stage is provided + tiled_mma, mma_tiler_mnk, sf_vec_size, 1 ) - c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( - c_dtype, - c_layout, - epi_tile, - 1, + c_dtype, c_layout, epi_tile, 1 ) - ab_bytes_per_stage = ( cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) @@ -1980,24 +1486,14 @@ def _compute_stages( mbar_helpers_bytes = 1024 c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) c_bytes = c_bytes_per_stage * num_c_stage - - # Calculate A/B/SFA/SFB stages: - # Start with total smem per CTA (capacity / occupancy) - # Subtract reserved bytes and initial C stages bytes - # Divide remaining by bytes needed per A/B/SFA/SFB stage num_ab_stage = ( smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) ) // ab_bytes_per_stage - - # Refine epilogue stages: - # Calculate remaining smem after allocating for A/B/SFA/SFB stages and reserved bytes - # Add remaining unused smem to epilogue num_c_stage += ( smem_capacity - occupancy * ab_bytes_per_stage * num_ab_stage - occupancy * (mbar_helpers_bytes + c_bytes) ) // (occupancy * c_bytes_per_stage) - return num_acc_stage, num_ab_stage, num_c_stage @staticmethod @@ -2025,13 +1521,11 @@ def _compute_grid( :rtype: Tuple[MaskedSchedulerParams, tuple[int, int, int]] """ c_tiler = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) - cluster_shape_mnl = (*cluster_shape_mn, 1) - + cluster_shape_mnl = *cluster_shape_mn, 1 tile_sched_params = MaskedSchedulerParams( masked_m_tensor, c, c_tiler, cluster_shape_mnl ) grid = MaskedScheduler.get_grid_shape(tile_sched_params, max_active_clusters) - return tile_sched_params, grid @staticmethod @@ -2057,30 +1551,20 @@ def is_valid_dtypes_and_scale_factor_vec_size( :rtype: bool """ is_valid = True - - # Check valid ab_dtype if ab_dtype not in { cutlass.Float4E2M1FN, cutlass.Float8E5M2, cutlass.Float8E4M3FN, }: is_valid = False - - # Check valid sf_vec_size if sf_vec_size not in {16, 32}: is_valid = False - - # Check valid sf_dtype if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: is_valid = False - - # Check valid sf_dtype and sf_vec_size combinations if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: is_valid = False if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16: is_valid = False - - # Check valid c_dtype if c_dtype not in { cutlass.Float32, cutlass.Float16, @@ -2089,7 +1573,6 @@ def is_valid_dtypes_and_scale_factor_vec_size( cutlass.Float8E4M3FN, }: is_valid = False - return is_valid @staticmethod @@ -2118,15 +1601,13 @@ def is_valid_layouts( :rtype: bool """ is_valid = True - if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): is_valid = False return is_valid @staticmethod def is_valid_mma_tiler_and_cluster_shape( - mma_tiler_mn: Tuple[int, int], - cluster_shape_mn: Tuple[int, int], + mma_tiler_mn: Tuple[int, int], cluster_shape_mn: Tuple[int, int] ) -> bool: """ Check if the mma tiler and cluster shape are valid @@ -2140,22 +1621,17 @@ def is_valid_mma_tiler_and_cluster_shape( :rtype: bool """ is_valid = True - # Skip invalid mma tile shape if mma_tiler_mn[0] not in [128, 256]: is_valid = False if mma_tiler_mn[1] not in [128, 256]: is_valid = False - # Skip illegal cluster shape if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: is_valid = False - # Skip invalid cluster shape - is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + is_power_of_2 = lambda x: x > 0 and x & x - 1 == 0 if ( cluster_shape_mn[0] * cluster_shape_mn[1] > 16 or cluster_shape_mn[0] <= 0 or cluster_shape_mn[1] <= 0 - # Special cluster shape check for scale factor multicasts. - # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. or cluster_shape_mn[0] > 4 or cluster_shape_mn[1] > 4 or not is_power_of_2(cluster_shape_mn[0]) @@ -2267,22 +1743,18 @@ def can_implement( :rtype: bool """ can_implement = True - # Skip unsupported types if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size( ab_dtype, sf_dtype, sf_vec_size, c_dtype ): can_implement = False - # Skip unsupported layouts if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_layouts( ab_dtype, c_dtype, a_major, b_major, c_major ): can_implement = False - # Skip invalid mma tile shape and cluster shape if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( mma_tiler_mn, cluster_shape_mn ): can_implement = False - # Skip illegal problem shape for load/store alignment if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_tensor_alignment( m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major ): @@ -2292,12 +1764,9 @@ def can_implement( @cute.jit def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - sf_ref_tensor: cute.Tensor, - sf_mma_tensor: cute.Tensor, + sf_ref_tensor: cute.Tensor, sf_mma_tensor: cute.Tensor ): """Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout""" - # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l) - # group to ((32, 4, rest_m), (4, rest_k), l) sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) for i in cutlass.range(cute.size(sf_ref_tensor)): @@ -2306,25 +1775,19 @@ def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( @cute.jit -def cvt_sf_MKL_to_M32x4xrm_K4xrk_L_mma_spec( - sf_mma_tensor: cute.Tensor, -): +def cvt_sf_MKL_to_M32x4xrm_K4xrk_L_mma_spec(sf_mma_tensor: cute.Tensor): """Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout""" - # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l) - # group to ((32, 4, rest_m), (4, rest_k), l) sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) -# Create scale factor tensor SFA/SFB def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype, device): def ceil_div(a, b): return (a + b - 1) // b sf_k = ceil_div(k, sf_vec_size) - ref_shape = (l, mn, sf_k) - - atom_m = (32, 4) + ref_shape = l, mn, sf_k + atom_m = 32, 4 atom_k = 4 mma_shape = ( l, @@ -2334,67 +1797,40 @@ def ceil_div(a, b): atom_m[1], atom_k, ) - - ref_permute_order = (1, 2, 0) - mma_permute_order = (3, 4, 1, 5, 2, 0) - - # Create f32 ref torch tensor + ref_permute_order = 1, 2, 0 + mma_permute_order = 3, 4, 1, 5, 2, 0 ref_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( ref_shape, - torch.float32, + "float32", permute_order=ref_permute_order, init_type=cutlass_torch.TensorInitType.RANDOM, - init_config=cutlass_torch.RandomInitConfig( - min_val=1, - max_val=3, - ), + init_config=cutlass_torch.RandomInitConfig(min_val=1, max_val=3), ) - - # Create f32 cute torch tensor cute_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( mma_shape, - torch.float32, + "float32", permute_order=mma_permute_order, init_type=cutlass_torch.TensorInitType.RANDOM, - init_config=cutlass_torch.RandomInitConfig( - min_val=0, - max_val=1, - ), + init_config=cutlass_torch.RandomInitConfig(min_val=0, max_val=1), ) - - # convert ref f32 tensor to cute f32 tensor cvt_sf_MKL_to_M32x4xrm_K4xrk_L( - from_dlpack(ref_f32_torch_tensor_cpu), - from_dlpack(cute_f32_torch_tensor_cpu), + from_dlpack(ref_f32_torch_tensor_cpu), from_dlpack(cute_f32_torch_tensor_cpu) ) - cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.to(device, non_blocking=True) - - # reshape makes memory contiguous + cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.to(device, blocking=not True) ref_f32_torch_tensor_cpu = ( - ref_f32_torch_tensor_cpu.permute(2, 0, 1) - .unsqueeze(-1) - .expand(l, mn, sf_k, sf_vec_size) + ref_f32_torch_tensor_cpu.transpose(perm=[2, 0, 1]) + .unsqueeze(axis=-1) + .expand(shape=[l, mn, sf_k, sf_vec_size]) .reshape(l, mn, sf_k * sf_vec_size) - .permute(*ref_permute_order) + .transpose(perm=ref_permute_order) ) - # prune to mkl for reference check. ref_f32_torch_tensor_cpu = ref_f32_torch_tensor_cpu[:, :k, :] - ref_f32_torch_tensor = ref_f32_torch_tensor_cpu.to(device, non_blocking=True) - - # Create dtype cute torch tensor (cpu) + ref_f32_torch_tensor = ref_f32_torch_tensor_cpu.to(device, blocking=not True) cute_tensor, cute_torch_tensor = cutlass_torch.cute_tensor_like( - cute_f32_torch_tensor_cpu, - dtype, - is_dynamic_layout=True, - assumed_align=16, + cute_f32_torch_tensor_cpu, dtype, is_dynamic_layout=True, assumed_align=16 ) - - # Convert f32 cute tensor to dtype cute tensor cute_tensor = cutlass_torch.convert_cute_tensor( - cute_f32_torch_tensor, - cute_tensor, - dtype, - is_dynamic_layout=True, + cute_f32_torch_tensor, cute_tensor, dtype, is_dynamic_layout=True ) return ref_f32_torch_tensor, cute_tensor, cute_torch_tensor @@ -2409,10 +1845,10 @@ def __init__( a_major: str, b_major: str, c_major: str, - ab_dtype: torch.dtype, - sf_dtype: torch.dtype, - c_dtype: torch.dtype, - alpha_dtype: torch.dtype, + ab_dtype: paddle.dtype, + sf_dtype: paddle.dtype, + c_dtype: paddle.dtype, + alpha_dtype: paddle.dtype, sf_vec_size: int, mma_tiler_mn: Tuple[int, int], cluster_shape_mn: Tuple[int, int], @@ -2432,7 +1868,6 @@ def __init__( self._sf_vec_size = sf_vec_size self._mma_tiler_mn = mma_tiler_mn self._cluster_shape_mn = cluster_shape_mn - if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( ab_dtype, sf_dtype, @@ -2451,8 +1886,6 @@ def __init__( raise TypeError( f"MaskedBatchedMatmulCuteDSL: Unsupported with {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" ) - - # Compute max active clusters on current device hardware_info = cutlass.utils.HardwareInfo() self._max_active_clusters = min( hardware_info.get_max_active_clusters( @@ -2469,7 +1902,7 @@ def __call__( sfa_ptr: cute.Pointer, sfb_ptr: cute.Pointer, c_ptr: cute.Pointer, - masked_mptr: cute.Pointer, + masked_m_ptr: cute.Pointer, alpha_ptr: cute.Pointer, current_stream: cuda.CUstream, ): @@ -2495,13 +1928,11 @@ def __call__( ), ) - # calculate sf_tensor shape and order def ceil_div(a, b): return (a + b - 1) // b sf_k = ceil_div(self._k, self._sf_vec_size) - - atom_m = (32, 4) + atom_m = 32, 4 atom_k = 4 mma_shape_a = ( self._l, @@ -2519,40 +1950,27 @@ def ceil_div(a, b): atom_m[1], atom_k, ) - mma_permute_order = (3, 4, 1, 5, 2, 0) - + mma_permute_order = 3, 4, 1, 5, 2, 0 sfa_tensor = cute.make_tensor( sfa_ptr, - layout=cute.make_ordered_layout( - mma_shape_a, - order=mma_permute_order, - ), + layout=cute.make_ordered_layout(mma_shape_a, order=mma_permute_order), ) sfb_tensor = cute.make_tensor( sfb_ptr, - layout=cute.make_ordered_layout( - mma_shape_b, - order=mma_permute_order, - ), + layout=cute.make_ordered_layout(mma_shape_b, order=mma_permute_order), ) cvt_sf_MKL_to_M32x4xrm_K4xrk_L_mma_spec(sfa_tensor) cvt_sf_MKL_to_M32x4xrm_K4xrk_L_mma_spec(sfb_tensor) - masked_m_tensor = cute.make_tensor( - masked_mptr, - layout=cute.make_ordered_layout((self._l,), order=(0,)), + masked_m_ptr, layout=cute.make_ordered_layout((self._l,), order=(0,)) ) - - # Use const_expr for compile-time conditional alpha_tensor = ( cute.make_tensor( - alpha_ptr, - layout=cute.make_ordered_layout((self._l,), order=(0,)), + alpha_ptr, layout=cute.make_ordered_layout((self._l,), order=(0,)) ) if cutlass.const_expr(alpha_ptr is not None) else None ) - Sm100BlockScaledPersistentDenseGemmKernel( sf_vec_size=self._sf_vec_size, mma_tiler_mn=self._mma_tiler_mn, @@ -2589,7 +2007,7 @@ def get_cute_dsl_compiled_masked_gemm_kernel( sm_count: int, ) -> Callable: def get_cute_pointers( - input_tensors: Optional[List[torch.tensor]], + input_tensors: Optional[List[paddle.to_tensor]], ) -> List[cute.Pointer]: if input_tensors is None: ( @@ -2600,7 +2018,7 @@ def get_cute_pointers( c_data_ptr, masked_m_data_ptr, alpha_data_ptr, - ) = [16 for _ in range(7)] + ) = [(16) for _ in range(7)] else: ( a_tensor_gpu, @@ -2628,54 +2046,25 @@ def get_cute_pointers( masked_m_tensor_gpu.data_ptr(), alpha_tensor_gpu.data_ptr() if alpha_tensor_gpu is not None else None, ) - - a_ptr = make_ptr( - ab_dtype, - a_data_ptr, - cute.AddressSpace.gmem, - assumed_align=16, - ) - b_ptr = make_ptr( - ab_dtype, - b_data_ptr, - cute.AddressSpace.gmem, - assumed_align=16, - ) + a_ptr = make_ptr(ab_dtype, a_data_ptr, cute.AddressSpace.gmem, assumed_align=16) + b_ptr = make_ptr(ab_dtype, b_data_ptr, cute.AddressSpace.gmem, assumed_align=16) sfa_ptr = make_ptr( - sf_dtype, - sfa_data_ptr, - cute.AddressSpace.gmem, - assumed_align=16, + sf_dtype, sfa_data_ptr, cute.AddressSpace.gmem, assumed_align=16 ) sfb_ptr = make_ptr( - sf_dtype, - sfb_data_ptr, - cute.AddressSpace.gmem, - assumed_align=16, - ) - c_ptr = make_ptr( - c_dtype, - c_data_ptr, - cute.AddressSpace.gmem, - assumed_align=16, + sf_dtype, sfb_data_ptr, cute.AddressSpace.gmem, assumed_align=16 ) + c_ptr = make_ptr(c_dtype, c_data_ptr, cute.AddressSpace.gmem, assumed_align=16) masked_m_ptr = make_ptr( - cutlass.Int32, - masked_m_data_ptr, - cute.AddressSpace.gmem, - assumed_align=16, + cutlass.Int32, masked_m_data_ptr, cute.AddressSpace.gmem, assumed_align=16 ) alpha_ptr = ( make_ptr( - alpha_dtype, - alpha_data_ptr, - cute.AddressSpace.gmem, - assumed_align=16, + alpha_dtype, alpha_data_ptr, cute.AddressSpace.gmem, assumed_align=16 ) if alpha_data_ptr is not None and alpha_dtype is not None else None ) - return [a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, masked_m_ptr, alpha_ptr] kernel = cute.compile( @@ -2701,25 +2090,19 @@ def get_cute_pointers( ) def tensor_api( - a_tensor_gpu: torch.Tensor, - b_tensor_gpu: torch.Tensor, - sfa_tensor_gpu: torch.Tensor, - sfb_tensor_gpu: torch.Tensor, - masked_m_tensor_gpu: torch.Tensor, - c_tensor_gpu: Optional[torch.Tensor] = None, - alpha_tensor_gpu: Optional[torch.Tensor] = None, + a_tensor_gpu: paddle.Tensor, + b_tensor_gpu: paddle.Tensor, + sfa_tensor_gpu: paddle.Tensor, + sfb_tensor_gpu: paddle.Tensor, + masked_m_tensor_gpu: paddle.Tensor, + c_tensor_gpu: Optional[paddle.Tensor] = None, + alpha_tensor_gpu: Optional[paddle.Tensor] = None, ): if c_tensor_gpu is None: - # fp4 gemm output is not supported - c_tensor_gpu = torch.empty( - (l, m, n), - dtype=cutlass_to_torch_dtype(c_dtype), - device="cuda", + c_tensor_gpu = paddle.empty( + shape=(l, m, n), dtype=cutlass_to_torch_dtype(c_dtype) ) - - # fp4 or fp8 torch tensor to cute tensor current_stream = cutlass_torch.current_stream() - nonlocal kernel kernel( *get_cute_pointers( @@ -2735,17 +2118,16 @@ def tensor_api( ), current_stream, ) - return c_tensor_gpu return tensor_api def grouped_gemm_nt_masked( - lhs: Tuple[torch.Tensor, torch.Tensor], - rhs: Tuple[torch.Tensor, torch.Tensor], - out: torch.Tensor, - masked_m: torch.Tensor, + lhs: Tuple[paddle.Tensor, paddle.Tensor], + rhs: Tuple[paddle.Tensor, paddle.Tensor], + out: paddle.Tensor, + masked_m: paddle.Tensor, *, ab_dtype: str, sf_dtype: str, @@ -2787,27 +2169,19 @@ def grouped_gemm_nt_masked( - If alpha is provided, each batch output is multiplied by its corresponding alpha value. out = alpha * (A @ B). - The result is written to c_tensor. """ - a_torch, sfa_torch = lhs b_torch, sfb_torch = rhs c_torch = out - - m, k, l = a_torch.shape - n, _, _ = b_torch.shape - + m, k, l = tuple(a_torch.shape) + n, _, _ = tuple(b_torch.shape) if ab_dtype == "float4_e2m1fn": - # todo(yingyi): update mnk based on a_major and b_major, and support more major. - # Note: only support deepgemm-like shape for now k = k * 2 - mma_tiler_mn = kwargs.get("mma_tiler_mm", (128, 128)) cluster_shape_mn = kwargs.get("cluster_shape_mm", (1, 1)) if sm_count is None: - sm_count = get_num_sm(a_torch.device) - + sm_count = get_num_sm(a_torch.place) alpha = kwargs.get("alpha") alpha_dtype = kwargs.get("alpha_dtype") - return get_cute_dsl_compiled_masked_gemm_kernel( m=m, n=n, diff --git a/flashinfer/cute_dsl/utils.py b/flashinfer/cute_dsl/utils.py index 4b035ed6c6..8c5ae08125 100644 --- a/flashinfer/cute_dsl/utils.py +++ b/flashinfer/cute_dsl/utils.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,11 +19,10 @@ See the License for the specific language governing permissions and limitations under the License. """ +import functools +import importlib.util import cutlass -import torch -import importlib.util -import functools def is_cute_dsl_available() -> bool: @@ -45,25 +50,24 @@ def cutlass_to_torch_dtype(cutlass_dtype): Return the corresponding torch.dtype per the given DSL type """ torch_dtype = getattr(torch, cutlass_dtype.__name__.lower(), None) - torch_type_map = { - cutlass.TFloat32: torch.float32, - cutlass.Float32: torch.float32, - cutlass.Float16: torch.float16, - cutlass.BFloat16: torch.bfloat16, - cutlass.Float8E5M2: torch.float8_e5m2, - cutlass.Float8E4M3FN: torch.float8_e4m3fn, - cutlass.Float8E4M3B11FNUZ: torch.float8_e4m3fnuz, + cutlass.TFloat32: "float32", + cutlass.Float32: "float32", + cutlass.Float16: "float16", + cutlass.BFloat16: "bfloat16", +>>>>>> cutlass.Float8E5M2: paddle.float8_e5m2, + cutlass.Float8E4M3FN: paddle.float8_e4m3fn, + cutlass.Float8E4M3B11FNUZ: paddle.float8_e4m3fnuz, } if torch_dtype is None: torch_dtype = torch_type_map.get(cutlass_dtype) - if torch_dtype is None: raise TypeError(f"{cutlass_dtype} is not supported by torch") return torch_dtype @functools.cache -def get_num_sm(device: torch.device) -> int: - # get the compute capability of the device, which would be cached - return torch.cuda.get_device_properties(device).multi_processor_count +def get_num_sm(device: str) -> int: + return paddle.device.cuda.get_device_properties( + device=device2str(device) + ).multi_processor_count diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 5178ccfbae..0106e1333f 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2023 by FlashInfer team. @@ -13,55 +15,30 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools import math from types import SimpleNamespace from typing import Any, List, Literal, Optional, Tuple, Union, overload -import torch - -from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache -from .jit import ( - gen_batch_decode_mla_module, - gen_batch_decode_module, - gen_customize_batch_decode_module, - gen_customize_batch_prefill_module, - gen_single_decode_module, - get_batch_decode_uri, - get_batch_prefill_uri, - get_single_decode_uri, - setup_cubin_loader, - trtllm_gen_fmha_module, -) +from .cudnn import \ + cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache +from .jit import (gen_batch_decode_mla_module, gen_batch_decode_module, + gen_customize_batch_decode_module, + gen_customize_batch_prefill_module, gen_single_decode_module, + get_batch_decode_uri, get_batch_prefill_uri, + get_single_decode_uri, setup_cubin_loader, + trtllm_gen_fmha_module) from .page import get_seq_lens -from .prefill import ( - get_batch_prefill_jit_module, - get_batch_prefill_module, - get_single_prefill_module, -) -from .utils import ( - FP4Tensor, - MaskMode, - PosEncodingMode, - TensorLayout, - _check_cached_qkv_data_type, - _check_kv_layout, - _check_pos_encoding_mode, - check_shape_dtype_device, - _get_cache_alibi_slopes_buf, - _get_cache_buf, - _get_range_buf, - _unpack_paged_kv_cache, - canonicalize_torch_dtype, - device_support_pdl, - get_device_sm_count, - is_float8, - register_custom_op, - register_fake_op, - ceil_div, - round_up, -) +from .prefill import (get_batch_prefill_jit_module, get_batch_prefill_module, + get_single_prefill_module) +from .utils import (FP4Tensor, MaskMode, PosEncodingMode, TensorLayout, + _check_cached_qkv_data_type, _check_kv_layout, + _check_pos_encoding_mode, _get_cache_alibi_slopes_buf, + _get_cache_buf, _get_range_buf, _unpack_paged_kv_cache, + canonicalize_torch_dtype, ceil_div, + check_shape_dtype_device, device_support_pdl, + get_device_sm_count, is_float8, register_custom_op, + register_fake_op, round_up) @functools.cache @@ -70,17 +47,15 @@ def get_single_decode_module(*args): module = gen_single_decode_module(*args).build_and_load() run_func = module.run.default - # torch library for single_decode_with_kv_cache - @register_custom_op(f"flashinfer::{uri}_run", mutates_args=("tmp", "o")) def run_single_decode( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - tmp: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], - alibi_slopes: Optional[torch.Tensor], + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + tmp: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], + alibi_slopes: Optional[paddle.Tensor], kv_layout_code: int, window_left: int, logits_soft_cap: float, @@ -100,19 +75,19 @@ def run_single_decode( alibi_slopes, logits_soft_cap, sm_scale, - 1.0 / rope_scale, # rope_rcp_scale - 1.0 / rope_theta, # rope_rcp_theta + 1.0 / rope_scale, + 1.0 / rope_theta, ) @register_fake_op(f"flashinfer::{uri}_run") def _fake_run_single_decode( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - tmp: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], - alibi_slopes: Optional[torch.Tensor], + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + tmp: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], + alibi_slopes: Optional[paddle.Tensor], kv_layout_code: int, window_left: int, logits_soft_cap: float, @@ -122,7 +97,6 @@ def _fake_run_single_decode( ) -> None: pass - # Register the module. return SimpleNamespace(run=run_single_decode) @@ -143,17 +117,17 @@ def get_batch_decode_jit_module(module_name: str, jit_module: Any): ), ) def run_batch_decode( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - paged_k_cache: Optional[torch.Tensor], - paged_v_cache: Optional[torch.Tensor], - paged_kv_indptr: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + paged_k_cache: Optional[paddle.Tensor], + paged_v_cache: Optional[paddle.Tensor], + paged_kv_indptr: paddle.Tensor, + paged_kv_indices: paddle.Tensor, + paged_kv_last_page_len: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], kv_layout_code: int, window_left: int, enable_pdl: bool, @@ -179,17 +153,17 @@ def run_batch_decode( @register_fake_op(f"flashinfer::{module_name}_run") def _fake_run_batch_decode( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - paged_k_cache: Optional[torch.Tensor], - paged_v_cache: Optional[torch.Tensor], - paged_kv_indptr: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + paged_k_cache: Optional[paddle.Tensor], + paged_v_cache: Optional[paddle.Tensor], + paged_kv_indptr: paddle.Tensor, + paged_kv_indices: paddle.Tensor, + paged_kv_last_page_len: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], kv_layout_code: int, window_left: int, enable_pdl: bool, @@ -197,10 +171,7 @@ def _fake_run_batch_decode( ) -> None: pass - return SimpleNamespace( - plan=plan_func, - run=run_batch_decode, - ) + return SimpleNamespace(plan=plan_func, run=run_batch_decode) @functools.cache @@ -210,8 +181,6 @@ def get_batch_decode_module(*args): plan_func = mod.plan.default run_func = mod.run.default - # torch library for batch_decode_with_paged_kv_cache_run - @register_custom_op( f"flashinfer::{uri}_run", mutates_args=( @@ -224,21 +193,21 @@ def get_batch_decode_module(*args): ), ) def run_batch_decode( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - paged_k_cache: Optional[torch.Tensor], - paged_v_cache: Optional[torch.Tensor], - paged_kv_indptr: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + paged_k_cache: Optional[paddle.Tensor], + paged_v_cache: Optional[paddle.Tensor], + paged_kv_indptr: paddle.Tensor, + paged_kv_indices: paddle.Tensor, + paged_kv_last_page_len: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], kv_layout_code: int, window_left: int, enable_pdl: bool, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: Optional[paddle.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, @@ -262,27 +231,27 @@ def run_batch_decode( alibi_slopes, logits_soft_cap, sm_scale, - 1.0 / rope_scale, # rope_rcp_scale - 1.0 / rope_theta, # rope_rcp_theta + 1.0 / rope_scale, + 1.0 / rope_theta, ) @register_fake_op(f"flashinfer::{uri}_run") def _fake_run_batch_decode( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - paged_k_cache: Optional[torch.Tensor], - paged_v_cache: Optional[torch.Tensor], - paged_kv_indptr: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + paged_k_cache: Optional[paddle.Tensor], + paged_v_cache: Optional[paddle.Tensor], + paged_kv_indptr: paddle.Tensor, + paged_kv_indices: paddle.Tensor, + paged_kv_last_page_len: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], kv_layout_code: int, window_left: int, enable_pdl: bool, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: Optional[paddle.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, @@ -290,14 +259,7 @@ def _fake_run_batch_decode( ) -> None: pass - # Register the module. - # - # Note that plan is not part of model logic. It should not be included in - # Cuda Graph or torch.compile. So, we don't provide a torch library for plan. - return SimpleNamespace( - plan=plan_func, - run=run_batch_decode, - ) + return SimpleNamespace(plan=plan_func, run=run_batch_decode) @functools.cache @@ -310,31 +272,23 @@ def get_trtllm_gen_fmha_module(): def single_decode_with_kv_cache_with_jit_module( jit_module: Any, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, *args, kv_layout: str = "NHD", window_left: int = -1, return_lse: bool = False, ): - device = q.device + device = q.place tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, device) - o = torch.empty_like(q) + o = paddle.empty_like(x=q) if return_lse: - lse = torch.empty((q.size(0)), dtype=torch.float32, device=device) + lse = paddle.empty(shape=q.shape[0], dtype="float32") else: lse = None jit_module.run.default( - q, - k, - v, - tmp, - o, - lse, - TensorLayout[kv_layout].value, - window_left, - *args, + q, k, v, tmp, o, lse, TensorLayout[kv_layout].value, window_left, *args ) return o @@ -346,9 +300,9 @@ def get_batch_decode_mla_module(*args): @overload def single_decode_with_kv_cache( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", use_tensor_cores: bool = False, @@ -361,14 +315,15 @@ def single_decode_with_kv_cache( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, return_lse: Literal[False] = False, -) -> torch.Tensor: ... +) -> paddle.Tensor: + ... @overload def single_decode_with_kv_cache( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", use_tensor_cores: bool = False, @@ -381,13 +336,14 @@ def single_decode_with_kv_cache( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, return_lse: Literal[True] = True, -) -> Tuple[torch.Tensor, torch.Tensor]: ... +) -> Tuple[paddle.Tensor, paddle.Tensor]: + ... def single_decode_with_kv_cache( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", use_tensor_cores: bool = False, @@ -400,8 +356,8 @@ def single_decode_with_kv_cache( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, return_lse: bool = False, -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - r"""Decode attention with KV Cache for single request, return attention output. +) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Decode attention with KV Cache for single request, return attention output. Parameters ---------- @@ -437,7 +393,7 @@ def single_decode_with_kv_cache( The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: - :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, + :math:`\\texttt{logits_soft_cap} \\times \\mathrm{tanh}(x / \\texttt{logits_soft_cap})`, where :math:`x` is the input logits. sm_scale : Optional[float] The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. @@ -481,8 +437,8 @@ def single_decode_with_kv_cache( """ _check_pos_encoding_mode(pos_encoding_mode) _check_kv_layout(kv_layout) - tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.device) - head_dim = q.shape[-1] + tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.place) + head_dim = tuple(q.shape)[-1] if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: @@ -494,60 +450,58 @@ def single_decode_with_kv_cache( if rope_scale is None: rope_scale = 1.0 if rope_theta is None: - rope_theta = 1e4 - num_qo_heads = q.shape[0] - + rope_theta = 10000.0 + num_qo_heads = tuple(q.shape)[0] lse = None if return_lse: - lse = torch.empty((num_qo_heads,), dtype=torch.float32, device=q.device) - + lse = paddle.empty(shape=(num_qo_heads,), dtype="float32") if use_tensor_cores: - out = torch.empty_like(q.unsqueeze(0)) + out = paddle.empty_like(x=q.unsqueeze(axis=0)) get_single_prefill_module( "fa2", q.dtype, k.dtype, q.dtype, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + head_dim, + head_dim, PosEncodingMode[pos_encoding_mode].value, - window_left != -1, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap - False, # use_fp16_qk_reduction + window_left != -1, + logits_soft_cap > 0, + False, ).run( - q.unsqueeze(0), + q.unsqueeze(axis=0), k, v, tmp, out, - lse.unsqueeze(0) if lse is not None else None, + lse.unsqueeze(axis=0) if lse is not None else None, MaskMode.NON_CAUSAL.value, TensorLayout[kv_layout].value, window_left, - None, # packed_custom_mask - _get_cache_alibi_slopes_buf(num_qo_heads, q.device), + None, + _get_cache_alibi_slopes_buf(num_qo_heads, q.place), logits_soft_cap, sm_scale, - None, # scale_q, not supported yet - None, # scale_k - None, # scale_v + None, + None, + None, rope_scale, rope_theta, ) - out = out.squeeze(0) + out = out.squeeze(axis=0) if return_lse: - lse = lse.squeeze(0) + lse = lse.squeeze(axis=0) else: - out = torch.empty_like(q) + out = paddle.empty_like(x=q) get_single_decode_module( q.dtype, k.dtype, q.dtype, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + head_dim, + head_dim, PosEncodingMode[pos_encoding_mode].value, - window_left != -1, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap + window_left != -1, + logits_soft_cap > 0, ).run( q, k, @@ -555,7 +509,7 @@ def single_decode_with_kv_cache( tmp, out, lse, - _get_cache_alibi_slopes_buf(num_qo_heads, q.device), + _get_cache_alibi_slopes_buf(num_qo_heads, q.place), TensorLayout[kv_layout].value, window_left, logits_soft_cap, @@ -563,10 +517,8 @@ def single_decode_with_kv_cache( rope_scale, rope_theta, ) - if v_scale is not None: - # TODO(Zihao): fused into kernel - if out.itemsize == 1: + if out.element_size() == 1: out = (out.to(float) * v_scale).to(out.dtype) else: out *= v_scale @@ -577,7 +529,7 @@ def single_decode_with_kv_cache( class BatchDecodeWithPagedKVCacheWrapper: - r"""Wrapper class for decode attention with paged kv-cache (first proposed in + """Wrapper class for decode attention with paged kv-cache (first proposed in `vLLM `_) for batch of requests. Check :ref:`our tutorial` for page table layout. @@ -644,17 +596,17 @@ class BatchDecodeWithPagedKVCacheWrapper: def __init__( self, - float_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, use_tensor_cores: bool = False, - paged_kv_indptr_buffer: Optional[torch.Tensor] = None, - paged_kv_indices_buffer: Optional[torch.Tensor] = None, - paged_kv_last_page_len_buffer: Optional[torch.Tensor] = None, + paged_kv_indptr_buffer: Optional[paddle.Tensor] = None, + paged_kv_indices_buffer: Optional[paddle.Tensor] = None, + paged_kv_last_page_len_buffer: Optional[paddle.Tensor] = None, backend: str = "auto", jit_args: Optional[List[Any]] = None, ) -> None: - r"""Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`. + """Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`. Parameters ---------- @@ -701,7 +653,6 @@ def __init__( otherwise, the wrapper will use default attention implementation. """ _check_kv_layout(kv_layout) - if jit_args is not None: if use_tensor_cores: self._jit_module = get_batch_prefill_jit_module( @@ -717,35 +668,28 @@ def __init__( ) else: self._jit_module = None - self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer - self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) - self._pin_memory_int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), - dtype=torch.uint8, - pin_memory=True, - device="cpu", + self.device = float_workspace_buffer.place + self._int_workspace_buffer = paddle.empty( + shape=(8 * 1024 * 1024,), dtype="uint8" ) - self._kv_lens_buffer: Optional[torch.Tensor] = None + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=(8 * 1024 * 1024,), dtype="uint8" + ).pin_memory() + self._kv_lens_buffer: Optional[paddle.Tensor] = None if backend == "trtllm-gen": - self._kv_lens_buffer = torch.empty( - (32768,), dtype=torch.int32, device=self.device - ) - + self._kv_lens_buffer = paddle.empty(shape=(32768,), dtype="int32") if use_cuda_graph: - if not torch.is_tensor(paged_kv_indptr_buffer): + if not paddle.is_tensor(x=paged_kv_indptr_buffer): raise ValueError( "paged_kv_indptr_buffer should be a torch.Tensor in cudagraph mode" ) - if not torch.is_tensor(paged_kv_indices_buffer): + if not paddle.is_tensor(x=paged_kv_indices_buffer): raise ValueError( "paged_kv_indices_buffer should be a torch.Tensor in cudagraph mode" ) - if not torch.is_tensor(paged_kv_last_page_len_buffer): + if not paddle.is_tensor(x=paged_kv_last_page_len_buffer): raise ValueError( "paged_kv_last_page_len_buffer should be a torch.Tensor in cudagraph mode" ) @@ -756,20 +700,15 @@ def __init__( ) else: self._fixed_batch_size = 0 - self._paged_kv_indptr_buf = paged_kv_indptr_buffer self._paged_kv_indices_buf = paged_kv_indices_buffer self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer self._use_tensor_cores = use_tensor_cores or backend == "trtllm-gen" self._use_cuda_graph = use_cuda_graph - if use_tensor_cores: if use_cuda_graph: - # NOTE(Zihao): if once created, no need to update it in plan/run - self._qo_indptr_buf = torch.arange( - self._fixed_batch_size + 1, - dtype=torch.int32, - device=float_workspace_buffer.device, + self._qo_indptr_buf = paddle.arange( + dtype="int32", end=self._fixed_batch_size + 1 ) self._backend = backend @@ -782,9 +721,9 @@ def is_cuda_graph_enabled(self) -> bool: return self._use_cuda_graph def reset_workspace_buffer( - self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor + self, float_workspace_buffer: paddle.Tensor, int_workspace_buffer: paddle.Tensor ) -> None: - r"""Reset the workspace buffer. + """Reset the workspace buffer. Parameters ---------- @@ -798,18 +737,16 @@ def reset_workspace_buffer( """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=tuple(self._int_workspace_buffer.shape), dtype=self._int_workspace_buffer.dtype, - device="cpu", - pin_memory=True, - ) + ).pin_memory() def plan( self, - indptr: torch.Tensor, - indices: torch.Tensor, - last_page_len: torch.Tensor, + indptr: paddle.Tensor, + indices: paddle.Tensor, + last_page_len: paddle.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, @@ -817,17 +754,17 @@ def plan( pos_encoding_mode: str = "NONE", window_left: int = -1, logits_soft_cap: Optional[float] = None, - q_data_type: Optional[Union[str, torch.dtype]] = "float16", - kv_data_type: Optional[Union[str, torch.dtype]] = None, - data_type: Optional[Union[str, torch.dtype]] = None, + q_data_type: Optional[Union[str, paddle.dtype]] = "float16", + kv_data_type: Optional[Union[str, paddle.dtype]] = None, + data_type: Optional[Union[str, paddle.dtype]] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, non_blocking: bool = True, - block_tables: Optional[torch.Tensor] = None, - seq_lens: Optional[torch.Tensor] = None, + block_tables: Optional[paddle.Tensor] = None, + seq_lens: Optional[paddle.Tensor] = None, ) -> None: - r"""Plan batch decode for given problem specification. + """Plan batch decode for given problem specification. Parameters ---------- @@ -857,7 +794,7 @@ def plan( The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: - :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, + :math:`\\texttt{logits_soft_cap} \\times \\mathrm{tanh}(x / \\texttt{logits_soft_cap})`, where :math:`x` is the input logits. q_data_type : Optional[Union[str, torch.dtype]] The data type of the query tensor, defaults torch.float16. @@ -890,13 +827,11 @@ def plan( batch_size = len(last_page_len) if logits_soft_cap is None: logits_soft_cap = 0.0 - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") if self.is_cuda_graph_enabled: if batch_size != self._fixed_batch_size: raise ValueError( - "The batch size should be fixed in cudagraph mode, the runtime batch size {} " - " mismatches the batch size set during initialization {}".format( + "The batch size should be fixed in cudagraph mode, the runtime batch size {} mismatches the batch size set during initialization {}".format( batch_size, self._fixed_batch_size ) ) @@ -904,49 +839,40 @@ def plan( raise ValueError( "The size of indices should be less than or equal to the allocated buffer" ) - self._paged_kv_indptr_buf.copy_(indptr, non_blocking=non_blocking) - self._paged_kv_last_page_len_buf.copy_( - last_page_len, non_blocking=non_blocking - ) - self._paged_kv_indices_buf[: len(indices)].copy_( - indices, non_blocking=(indices.device == self.device) and non_blocking - ) + paddle.assign(indptr, output=self._paged_kv_indptr_buf) + paddle.assign(last_page_len, output=self._paged_kv_last_page_len_buf) + paddle.assign(indices, output=self._paged_kv_indices_buf[: len(indices)]) else: self._paged_kv_indptr_buf = indptr.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) self._paged_kv_indices_buf = indices.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) self._paged_kv_last_page_len_buf = last_page_len.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) self._qo_indptr_buf = qo_indptr_host.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) - indptr_host = indptr.to("cpu") last_page_len_host = last_page_len.to("cpu") - if data_type is not None: if q_data_type is None: q_data_type = data_type if kv_data_type is None: kv_data_type = data_type - q_data_type = canonicalize_torch_dtype(q_data_type) if kv_data_type is None: kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) - self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type self._batch_size = batch_size self._num_qo_heads = num_qo_heads self._num_kv_heads = num_kv_heads - self._block_tables: Optional[torch.Tensor] = block_tables + self._block_tables: Optional[paddle.Tensor] = block_tables self._max_kv_len: Optional[int] = None - if seq_lens is None: kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) else: @@ -955,28 +881,26 @@ def plan( assert self._kv_layout == "HND" assert logits_soft_cap == 0.0 self._max_kv_len = max(kv_lens_arr_host).item() - self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( - kv_lens_arr_host, non_blocking=non_blocking + paddle.assign( + kv_lens_arr_host, output=self._kv_lens_buffer[: len(kv_lens_arr_host)] ) if self._block_tables is None: blocks_per_seq = [ - (seq_len + page_size - 1) // page_size + ((seq_len + page_size - 1) // page_size) for seq_len in kv_lens_arr_host ] max_num_blocks_per_seq = max(blocks_per_seq) - self._block_tables = torch.zeros( - (batch_size, max_num_blocks_per_seq), - dtype=torch.int, - device=self.device, + self._block_tables = paddle.zeros( + shape=(batch_size, max_num_blocks_per_seq), dtype="int32" ) block_id = indptr[0] for i in range(batch_size): num_blocks_needed = blocks_per_seq[i] - self._block_tables[i, :num_blocks_needed] = ( - self._paged_kv_indices_buf[ - block_id : block_id + num_blocks_needed - ] - ) + self._block_tables[ + i, :num_blocks_needed + ] = self._paged_kv_indices_buf[ + block_id : block_id + num_blocks_needed + ] block_id += num_blocks_needed self._cached_module = get_trtllm_gen_decode_module( q_data_type, @@ -986,11 +910,11 @@ def plan( head_dim, head_dim, PosEncodingMode[pos_encoding_mode].value, - window_left >= 0, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap - False, # use_fp16_qk_reduction + window_left >= 0, + logits_soft_cap > 0, + False, ) - self._plan_info = self._cached_module.plan() # None + self._plan_info = self._cached_module.plan() elif self.use_tensor_cores: self._max_kv_len = max(kv_lens_arr_host).item() if self._jit_module is not None: @@ -1002,14 +926,13 @@ def plan( kv_data_type, q_data_type, indptr.dtype, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + head_dim, + head_dim, PosEncodingMode[pos_encoding_mode].value, - window_left != -1, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap - False, # use_fp16_qk_reduction + window_left != -1, + logits_soft_cap > 0, + False, ) - self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, @@ -1017,7 +940,7 @@ def plan( qo_indptr_host, indptr_host, kv_lens_arr_host, - batch_size, # total_num_rows + batch_size, batch_size, num_qo_heads, num_kv_heads, @@ -1025,7 +948,7 @@ def plan( self.is_cuda_graph_enabled, head_dim, head_dim, - False, # causal + False, ) else: if self._jit_module is not None: @@ -1036,13 +959,12 @@ def plan( kv_data_type, q_data_type, indptr.dtype, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + head_dim, + head_dim, PosEncodingMode[pos_encoding_mode].value, - window_left != -1, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap + window_left != -1, + logits_soft_cap > 0, ) - self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, @@ -1057,10 +979,9 @@ def plan( logits_soft_cap, head_dim, head_dim, - torch.empty(0, dtype=q_data_type), - torch.empty(0, dtype=kv_data_type), + paddle.empty(shape=[0], dtype=q_data_type), + paddle.empty(shape=[0], dtype=kv_data_type), ) - self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left self._logits_soft_cap = logits_soft_cap @@ -1072,8 +993,8 @@ def plan( def forward( self, - q: torch.Tensor, - paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + q: paddle.Tensor, + paged_kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], pos_encoding_mode: str = "NONE", q_scale: Optional[float] = None, k_scale: Optional[float] = None, @@ -1083,8 +1004,8 @@ def forward( sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - ) -> torch.Tensor: - r"""Warning: this function is deprecated, please use :meth:`run` instead.""" + ) -> paddle.Tensor: + """Warning: this function is deprecated, please use :meth:`run` instead.""" self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left self._logits_soft_cap = logits_soft_cap @@ -1098,51 +1019,53 @@ def forward( @overload def run( self, - q: torch.Tensor, - paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + q: paddle.Tensor, + paged_kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], *args, q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: Literal[False] = False, enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, - ) -> torch.Tensor: ... + ) -> paddle.Tensor: + ... @overload def run( self, - q: torch.Tensor, - paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + q: paddle.Tensor, + paged_kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], *args, q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: Literal[True] = True, enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: ... + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + ... def run( self, - q: torch.Tensor, - paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + q: paddle.Tensor, + paged_kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], *args, q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: bool = False, enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, - sinks: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - r"""Compute batch decode attention between query and paged kv cache. + sinks: Optional[paddle.Tensor] = None, + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Compute batch decode attention between query and paged kv cache. Parameters ---------- @@ -1188,21 +1111,18 @@ def run( * logsumexp of attention scores, shape: ``[batch_size, num_qo_heads]``. """ if enable_pdl is None: - enable_pdl = device_support_pdl(q.device) + enable_pdl = device_support_pdl(q.place) k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) if self._kv_layout == "NHD": - page_size = k_cache.shape[1] + page_size = tuple(k_cache.shape)[1] else: - page_size = k_cache.shape[2] + page_size = tuple(k_cache.shape)[2] _check_cached_qkv_data_type( q, k_cache, self._cached_q_data_type, self._cached_kv_data_type ) - pos_encoding_mode = self._pos_encoding_mode window_left = self._window_left if window_left is None else window_left if self._backend != "trtllm-gen": - # NOTE(Siyuan): since window_left is appeared in the plan function, we need to make sure it is the same as the one in the plan function. - # Remove this check if the backend supports dynamic window_left. assert window_left == self._window_left logits_soft_cap = self._logits_soft_cap sm_scale = self._sm_scale @@ -1212,7 +1132,7 @@ def run( if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: - head_dim = q.shape[-1] + head_dim = tuple(q.shape)[-1] sm_scale = 1.0 / math.sqrt(head_dim) if q_scale is not None: sm_scale *= q_scale @@ -1221,23 +1141,18 @@ def run( if rope_scale is None: rope_scale = 1.0 if rope_theta is None: - rope_theta = 1e4 - + rope_theta = 10000.0 if return_lse: if lse is None: - lse = torch.empty( - (q.size(0), q.size(1)), dtype=torch.float32, device=q.device - ) + lse = paddle.empty(shape=(q.shape[0], q.shape[1]), dtype="float32") else: check_shape_dtype_device( - lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" + lse, (q.shape[0], q.shape[1]), "float32", q.place, "lse" ) - if out is None: - out = torch.empty_like(q) + out = paddle.empty_like(x=q) else: - check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out") - + check_shape_dtype_device(out, tuple(q.shape), q.dtype, q.place, "out") if self.use_tensor_cores: run_args = [ self._float_workspace_buffer, @@ -1257,25 +1172,24 @@ def run( window_left, enable_pdl, ] - if self._jit_module is not None: run_args.extend(list(args)) else: run_args += [ - None, # packed_custom_mask - None, # mask_indptr_buf - _get_cache_alibi_slopes_buf(q.shape[1], q.device), - None, # maybe_prefix_len_ptr - None, # maybe_token_pos_in_items_ptr - None, # maybe_max_item_len_ptr + None, + None, + _get_cache_alibi_slopes_buf(tuple(q.shape)[1], q.place), + None, + None, + None, logits_soft_cap, sm_scale, - None, # scale_q, not supported yet - None, # scale_k - None, # scale_v + None, + None, + None, rope_scale, rope_theta, - 0, # token_pos_in_items_len + 0, paged_kv_cache, self._num_qo_heads, self._num_kv_heads, @@ -1285,16 +1199,13 @@ def run( self._max_kv_len, sinks, ] - self._cached_module.paged_run(*run_args) else: - # trtllm-gen does not need plan info if self._backend == "trtllm-gen" and self._plan_info is None: plan_info: List[int] = [] else: plan_info = self._plan_info assert plan_info is not None, "plan info is not initialized" - run_args = [ self._float_workspace_buffer, self._int_workspace_buffer, @@ -1311,32 +1222,28 @@ def run( window_left, enable_pdl, ] - if self._jit_module is not None: run_args.extend(list(args)) else: run_args += [ - _get_cache_alibi_slopes_buf(q.shape[1], q.device), + _get_cache_alibi_slopes_buf(tuple(q.shape)[1], q.place), logits_soft_cap, sm_scale, rope_scale, rope_theta, ] - self._cached_module.run(*run_args) if v_scale is not None: - # TODO(Zihao): fused into kernel if is_float8(out): - out = (out.to(torch.float32) * v_scale).to(out.dtype) + out = (out.to("float32") * v_scale).to(out.dtype) else: out *= v_scale - return (out, lse) if return_lse else out def forward_return_lse( self, - q: torch.Tensor, - paged_kv_cache: torch.Tensor, + q: paddle.Tensor, + paged_kv_cache: paddle.Tensor, pos_encoding_mode: str = "NONE", q_scale: Optional[float] = None, k_scale: Optional[float] = None, @@ -1346,8 +1253,8 @@ def forward_return_lse( sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Warning: this function is deprecated, please use :meth:`run_return_lse` instead.""" + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Warning: this function is deprecated, please use :meth:`run_return_lse` instead.""" self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left self._logits_soft_cap = logits_soft_cap @@ -1366,12 +1273,12 @@ def forward_return_lse( run_return_lse = functools.partialmethod(run, return_lse=True) def end_forward(self) -> None: - r"""Warning: this function is deprecated and has no effect.""" + """Warning: this function is deprecated and has no effect.""" pass class CUDAGraphBatchDecodeWithPagedKVCacheWrapper(BatchDecodeWithPagedKVCacheWrapper): - r"""CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first + """CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first proposed in `vLLM `_) for batch of requests. Note that this wrapper may not be as efficient as :class:`BatchDecodeWithPagedKVCacheWrapper` @@ -1391,14 +1298,14 @@ class CUDAGraphBatchDecodeWithPagedKVCacheWrapper(BatchDecodeWithPagedKVCacheWra def __init__( self, - workspace_buffer: torch.Tensor, - indptr_buffer: torch.Tensor, - indices_buffer: torch.Tensor, - last_page_len_buffer: torch.Tensor, + workspace_buffer: paddle.Tensor, + indptr_buffer: paddle.Tensor, + indices_buffer: paddle.Tensor, + last_page_len_buffer: paddle.Tensor, kv_layout: str = "NHD", use_tensor_cores: bool = False, ) -> None: - r"""Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`. + """Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`. Parameters ---------- @@ -1441,21 +1348,21 @@ def __init__( class BatchDecodeMlaWithPagedKVCacheWrapper: - r"""Warning: this class is deprecated and will be removed in a future release. + """Warning: this class is deprecated and will be removed in a future release. Please use :class:`flashinfer.mla.BatchMLAPagedAttentionWrapper` instead, which provides a more efficient and general MLA implementation that supports decode and incremental prefill. """ def __init__( self, - float_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, use_cuda_graph: bool = False, use_tensor_cores: bool = False, - paged_kv_indptr_buffer: Optional[torch.Tensor] = None, - paged_kv_indices_buffer: Optional[torch.Tensor] = None, - paged_kv_last_page_len_buffer: Optional[torch.Tensor] = None, + paged_kv_indptr_buffer: Optional[paddle.Tensor] = None, + paged_kv_indices_buffer: Optional[paddle.Tensor] = None, + paged_kv_last_page_len_buffer: Optional[paddle.Tensor] = None, ) -> None: - r"""Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`. + """Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`. Parameters ---------- @@ -1490,27 +1397,23 @@ def __init__( Only needed when ``use_cuda_graph`` is ``True``. """ self._float_workspace_buffer = float_workspace_buffer - self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) - self._pin_memory_int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), - dtype=torch.uint8, - pin_memory=True, - device="cpu", + self.device = float_workspace_buffer.place + self._int_workspace_buffer = paddle.empty( + shape=(8 * 1024 * 1024,), dtype="uint8" ) - + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=(8 * 1024 * 1024,), dtype="uint8" + ).pin_memory() if use_cuda_graph: - if not torch.is_tensor(paged_kv_indptr_buffer): + if not paddle.is_tensor(x=paged_kv_indptr_buffer): raise ValueError( "paged_kv_indptr_buffer should be a torch.Tensor in cudagraph mode" ) - if not torch.is_tensor(paged_kv_indices_buffer): + if not paddle.is_tensor(x=paged_kv_indices_buffer): raise ValueError( "paged_kv_indices_buffer should be a torch.Tensor in cudagraph mode" ) - if not torch.is_tensor(paged_kv_last_page_len_buffer): + if not paddle.is_tensor(x=paged_kv_last_page_len_buffer): raise ValueError( "paged_kv_last_page_len_buffer should be a torch.Tensor in cudagraph mode" ) @@ -1521,7 +1424,6 @@ def __init__( ) else: self._fixed_batch_size = 0 - self._use_tensor_cores = use_tensor_cores self._paged_kv_indptr_buf = paged_kv_indptr_buffer self._paged_kv_indices_buf = paged_kv_indices_buffer @@ -1537,9 +1439,9 @@ def use_tensor_cores(self) -> bool: return self._use_tensor_cores def reset_workspace_buffer( - self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor + self, float_workspace_buffer: paddle.Tensor, int_workspace_buffer: paddle.Tensor ) -> None: - r"""Reset the workspace buffer. + """Reset the workspace buffer. Parameters ---------- @@ -1553,30 +1455,28 @@ def reset_workspace_buffer( """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=tuple(self._int_workspace_buffer.shape), dtype=self._int_workspace_buffer.dtype, - device="cpu", - pin_memory=True, - ) + ).pin_memory() def plan( self, - indptr: torch.Tensor, - indices: torch.Tensor, - last_page_len: torch.Tensor, + indptr: paddle.Tensor, + indices: paddle.Tensor, + last_page_len: paddle.Tensor, num_qo_heads: int, head_dim_compressed_kv: int, page_size: int, sm_scale: float, window_left: int = -1, logits_soft_cap: Optional[float] = None, - data_type: Union[str, torch.dtype] = "float16", - q_data_type: Optional[Union[str, torch.dtype]] = None, + data_type: Union[str, paddle.dtype] = "float16", + q_data_type: Optional[Union[str, paddle.dtype]] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ) -> None: - r"""Plan batch decode for given problem specification. + """Plan batch decode for given problem specification. Parameters ---------- @@ -1602,7 +1502,7 @@ def plan( The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: - :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, + :math:`\\texttt{logits_soft_cap} \\times \\mathrm{tanh}(x / \\texttt{logits_soft_cap})`, where :math:`x` is the input logits. data_type : Union[str, torch.dtype] The data type of the paged kv cache. Defaults to ``float16``. @@ -1619,12 +1519,10 @@ def plan( batch_size = len(last_page_len) if logits_soft_cap is None: logits_soft_cap = 0.0 - if self.is_cuda_graph_enabled: if batch_size != self._fixed_batch_size: raise ValueError( - "The batch size should be fixed in cudagraph mode, the runtime batch size {} " - " mismatches the batch size set during initialization {}".format( + "The batch size should be fixed in cudagraph mode, the runtime batch size {} mismatches the batch size set during initialization {}".format( batch_size, self._fixed_batch_size ) ) @@ -1632,21 +1530,18 @@ def plan( raise ValueError( "The size of indices should be less than or equal to the allocated buffer" ) - self._paged_kv_indptr_buf.copy_(indptr) + paddle.assign(indptr, output=self._paged_kv_indptr_buf) self._paged_kv_indices_buf[: len(indices)] = indices - self._paged_kv_last_page_len_buf.copy_(last_page_len) + paddle.assign(last_page_len, output=self._paged_kv_last_page_len_buf) else: self._paged_kv_indptr_buf = indptr.to(self.device) self._paged_kv_indices_buf = indices.to(self.device) self._paged_kv_last_page_len_buf = last_page_len.to(self.device) - data_type = canonicalize_torch_dtype(data_type) if not q_data_type: q_data_type = data_type q_data_type = canonicalize_torch_dtype(q_data_type) - indptr_host = indptr.to("cpu") - self._cached_module = get_batch_decode_mla_module( q_data_type, data_type, @@ -1654,8 +1549,8 @@ def plan( indptr.dtype, head_dim_compressed_kv, num_qo_heads, - window_left != -1, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap + window_left != -1, + logits_soft_cap > 0, self._use_tensor_cores, ) self._plan_info = self._cached_module.plan( @@ -1668,7 +1563,6 @@ def plan( page_size, self.is_cuda_graph_enabled, ) - self._sm_scale = sm_scale self._window_left = window_left self._logits_soft_cap = logits_soft_cap @@ -1677,19 +1571,19 @@ def plan( def run( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - paged_ckv_cache: torch.Tensor, - paged_kpe_cache: torch.Tensor, + q_nope: paddle.Tensor, + q_pe: paddle.Tensor, + paged_ckv_cache: paddle.Tensor, + paged_kpe_cache: paddle.Tensor, q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: bool = False, - enable_pdl: bool = False, # fake placeholder (sm80) - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - r"""Compute batch decode attention between query and paged kv cache. + enable_pdl: bool = False, + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Compute batch decode attention between query and paged kv cache. Parameters ---------- @@ -1741,29 +1635,25 @@ def run( if rope_scale is None: rope_scale = 1.0 if rope_theta is None: - rope_theta = 1e4 - + rope_theta = 10000.0 device = self.device if out is None: - out = torch.empty_like(q_nope, device=device) + out = paddle.empty_like(x=q_nope) else: check_shape_dtype_device( - out, q_nope.shape, q_nope.dtype, q_nope.device, "out" + out, tuple(q_nope.shape), q_nope.dtype, q_nope.place, "out" ) - if return_lse: if lse is None: - lse = torch.empty( - (q_nope.size(0), q_nope.size(1)), - dtype=torch.float32, - device=device, + lse = paddle.empty( + shape=(q_nope.shape[0], q_nope.shape[1]), dtype="float32" ) else: check_shape_dtype_device( lse, - (q_nope.size(0), q_nope.size(1)), + (q_nope.shape[0], q_nope.shape[1]), q_nope.dtype, - q_nope.device, + q_nope.place, "lse", ) self._cached_module.run( @@ -1789,7 +1679,6 @@ def run( out = [out, lse] if return_lse else [out] if v_scale is not None: out[0] *= v_scale - return tuple(out) if return_lse else out[0] run_return_lse = functools.partialmethod(run, return_lse=True) @@ -1806,31 +1695,28 @@ def __init__(self) -> None: def _paged_run( self, - query: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - workspace_buffer: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, + query: paddle.Tensor, + k_cache: paddle.Tensor, + v_cache: paddle.Tensor, + workspace_buffer: paddle.Tensor, + block_tables: paddle.Tensor, + seq_lens: paddle.Tensor, max_seq_len: int, - bmm1_scale: float, # todo(Yingyi): add dynamic scale tensor later + bmm1_scale: float, bmm2_scale: float, window_left: int = -1, enable_pdl: bool = None, - out: Optional[torch.Tensor] = None, - sinks: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + out: Optional[paddle.Tensor] = None, + sinks: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: if out is None: - out = torch.empty_like(query) + out = paddle.empty_like(x=query) if self._sm_count is None: - self._sm_count = get_device_sm_count(query.device) - + self._sm_count = get_device_sm_count(query.place) self._op.trtllm_paged_attention_decode( out, - None, # fp4 output not supported in wrapper api yet. - query.unsqueeze( - 1 - ), # [B, 1, H, D], no MTP here so second dim is 1 # todo(Yingyi): add MTP?? + None, + query.unsqueeze(axis=1), k_cache, v_cache, workspace_buffer, @@ -1839,9 +1725,9 @@ def _paged_run( max_seq_len, bmm1_scale, bmm2_scale, - -1, # o_sf_scale - -1, # o_sf_vec_size - 0, # o_sf_start_index + -1, + -1, + 0, window_left, self._sm_count, enable_pdl, @@ -1868,44 +1754,44 @@ def get_trtllm_gen_decode_module(*args): ), ) def paged_run( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - paged_k_cache: torch.Tensor, - paged_v_cache: torch.Tensor, - qo_indptr: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + paged_k_cache: paddle.Tensor, + paged_v_cache: paddle.Tensor, + qo_indptr: paddle.Tensor, + paged_kv_indptr: paddle.Tensor, + paged_kv_indices: paddle.Tensor, + paged_kv_last_page_len: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], mask_mode: int, layout: int, window_left: int, enable_pdl: bool, - maybe_custom_mask: Optional[torch.Tensor], - maybe_mask_indptr: Optional[torch.Tensor], - maybe_alibi_slopes: Optional[torch.Tensor], - maybe_prefix_len_ptr: Optional[torch.Tensor], - maybe_token_pos_in_items_ptr: Optional[torch.Tensor], - maybe_max_item_len_ptr: Optional[torch.Tensor], + maybe_custom_mask: Optional[paddle.Tensor], + maybe_mask_indptr: Optional[paddle.Tensor], + maybe_alibi_slopes: Optional[paddle.Tensor], + maybe_prefix_len_ptr: Optional[paddle.Tensor], + maybe_token_pos_in_items_ptr: Optional[paddle.Tensor], + maybe_max_item_len_ptr: Optional[paddle.Tensor], logits_soft_cap: float, sm_scale: float, - scale_q: Optional[torch.Tensor], - scale_k: Optional[torch.Tensor], - scale_v: Optional[torch.Tensor], + scale_q: Optional[paddle.Tensor], + scale_k: Optional[paddle.Tensor], + scale_v: Optional[paddle.Tensor], rope_scale: float, rope_theta: float, token_pos_in_items_len: int, - paged_kv_cache: Optional[torch.Tensor] = None, + paged_kv_cache: Optional[paddle.Tensor] = None, num_qo_heads: Optional[int] = None, num_kv_heads: Optional[int] = None, - block_tables: Optional[torch.Tensor] = None, - kv_lens_buffer: Optional[torch.Tensor] = None, + block_tables: Optional[paddle.Tensor] = None, + kv_lens_buffer: Optional[paddle.Tensor] = None, page_size: Optional[int] = None, max_kv_len: Optional[int] = None, - sinks: Optional[torch.Tensor] = None, + sinks: Optional[paddle.Tensor] = None, ) -> None: assert maybe_lse is None assert paged_kv_cache is not None @@ -1917,7 +1803,7 @@ def paged_run( assert max_kv_len is not None assert enable_pdl is not None o = module._paged_run( - q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect + q.contiguous(), paged_k_cache, paged_v_cache, int_workspace_buffer, @@ -1925,7 +1811,7 @@ def paged_run( kv_lens_buffer, max_kv_len, sm_scale, - 1.0, # NOTE(Siyuan): update this to expose bmm2 scale + 1.0, window_left, enable_pdl, out=o, @@ -1934,71 +1820,64 @@ def paged_run( @register_fake_op(f"flashinfer::{uri}_paged_run") def _fake_paged_run( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - paged_k_cache: torch.Tensor, - paged_v_cache: torch.Tensor, - qo_indptr: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + paged_k_cache: paddle.Tensor, + paged_v_cache: paddle.Tensor, + qo_indptr: paddle.Tensor, + paged_kv_indptr: paddle.Tensor, + paged_kv_indices: paddle.Tensor, + paged_kv_last_page_len: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], mask_mode: int, layout: int, window_left: int, enable_pdl: bool, - maybe_custom_mask: Optional[torch.Tensor], - maybe_mask_indptr: Optional[torch.Tensor], - maybe_alibi_slopes: Optional[torch.Tensor], - maybe_prefix_len_ptr: Optional[torch.Tensor], - maybe_token_pos_in_items_ptr: Optional[torch.Tensor], - maybe_max_item_len_ptr: Optional[torch.Tensor], + maybe_custom_mask: Optional[paddle.Tensor], + maybe_mask_indptr: Optional[paddle.Tensor], + maybe_alibi_slopes: Optional[paddle.Tensor], + maybe_prefix_len_ptr: Optional[paddle.Tensor], + maybe_token_pos_in_items_ptr: Optional[paddle.Tensor], + maybe_max_item_len_ptr: Optional[paddle.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, rope_theta: float, token_pos_in_items_len: int, - paged_kv_cache: Optional[torch.Tensor] = None, + paged_kv_cache: Optional[paddle.Tensor] = None, num_qo_heads: Optional[int] = None, num_kv_heads: Optional[int] = None, - block_tables: Optional[torch.Tensor] = None, - kv_lens_buffer: Optional[torch.Tensor] = None, + block_tables: Optional[paddle.Tensor] = None, + kv_lens_buffer: Optional[paddle.Tensor] = None, page_size: Optional[int] = None, max_kv_len: Optional[int] = None, - sinks: Optional[torch.Tensor] = None, + sinks: Optional[paddle.Tensor] = None, ) -> None: pass - # Register the module. - # - # Note that plan is not part of model logic. It should not be included in - # Cuda Graph or torch.compile. So, we don't provide a torch library for plan. - return SimpleNamespace( - plan=module._plan, - paged_run=paged_run, - ) + return SimpleNamespace(plan=module._plan, paged_run=paged_run) def trtllm_batch_decode_with_kv_cache( - query: torch.Tensor, - kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - workspace_buffer: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, + query: paddle.Tensor, + kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], + workspace_buffer: paddle.Tensor, + block_tables: paddle.Tensor, + seq_lens: paddle.Tensor, max_seq_len: int, bmm1_scale: float, - bmm2_scale: float, # todo(Yingyi): add dynamic scale tensor later + bmm2_scale: float, window_left: int = -1, - out: Optional[Union[torch.Tensor, FP4Tensor]] = None, - out_dtype: Optional[Union[torch.dtype, str]] = None, + out: Optional[Union[paddle.Tensor, FP4Tensor]] = None, + out_dtype: Optional[Union[paddle.dtype, str]] = None, o_sf_scale: Optional[float] = None, o_sf_vec_size: Optional[int] = None, - sinks: Optional[List[torch.Tensor]] = None, + sinks: Optional[List[paddle.Tensor]] = None, enable_pdl: bool = None, -) -> Union[torch.Tensor, FP4Tensor]: +) -> Union[paddle.Tensor, FP4Tensor]: """ Parameters ---------- @@ -2055,96 +1934,76 @@ def trtllm_batch_decode_with_kv_cache( out : Union[torch.Tensor, FP4Tensor] output torch.Tensor or FP4Tensor. """ - enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl - + enable_pdl = device_support_pdl(query.place) if enable_pdl is None else enable_pdl if isinstance(kv_cache, tuple): k_cache, v_cache = kv_cache + elif tuple(kv_cache.shape)[1] == 1: + k_cache, v_cache = kv_cache, kv_cache else: - if kv_cache.shape[1] == 1: - k_cache, v_cache = kv_cache, kv_cache - else: - assert kv_cache.shape[1] == 2, ( - "When kv_cache is a single tensor, the second dimension must be 1 or 2" - ) - # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) - # it doesn't change underlying storage - k_cache, v_cache = kv_cache.unbind(dim=1) - + assert ( + tuple(kv_cache.shape)[1] == 2 + ), "When kv_cache is a single tensor, the second dimension must be 1 or 2" + k_cache, v_cache = kv_cache.unbind(axis=1) run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode - sm_count = get_device_sm_count(query.device) - - if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)): - assert query.dtype == torch.float8_e4m3fn, ( - "query must be fp8 when out_dtype is nvfp4." - ) + sm_count = get_device_sm_count(query.place) + if out_dtype == "nvfp4" or out_dtype is None and isinstance(out, FP4Tensor): + assert ( + query.dtype == paddle.float8_e4m3fn + ), "query must be fp8 when out_dtype is nvfp4." assert o_sf_scale is not None assert o_sf_vec_size in [None, 16], "only o_sf_vec_size = 16 is supported" o_sf_vec_size = o_sf_vec_size or 16 - - fp4_out_shape = query.shape[:-1] + (ceil_div(query.shape[-1], 2),) - + fp4_out_shape = tuple(query.shape)[:-1] + (ceil_div(tuple(query.shape)[-1], 2),) if isinstance(out, FP4Tensor): - fp4_out_scale_shape = ( - out.scale.shape[0], - round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4), + fp4_out_scale_shape = out.scale.shape[0], round_up( + tuple(query.shape)[1] * tuple(query.shape)[2] // o_sf_vec_size, 4 ) out_scale_factor = out.scale o_sf_start_index = out.scale_start_index out = out.data elif out is None: - fp4_out_scale_shape = ( - round_up(query.shape[0], 128), - round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4), + fp4_out_scale_shape = round_up(tuple(query.shape)[0], 128), round_up( + tuple(query.shape)[1] * tuple(query.shape)[2] // o_sf_vec_size, 4 ) - out_scale_factor = torch.empty( - fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device + out_scale_factor = paddle.empty( + shape=fp4_out_scale_shape, dtype=paddle.float8_e4m3fn ) o_sf_start_index = 0 - out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device) + out = paddle.empty(shape=fp4_out_shape, dtype="uint8") else: raise ValueError(f"Invalid out: {out}") - - assert isinstance(out, torch.Tensor) - - # Use uint8 as the container dtype to compliant with next fp4 gemm. - check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out") - + assert isinstance(out, paddle.Tensor) + check_shape_dtype_device(out, fp4_out_shape, "uint8", query.place, "out") check_shape_dtype_device( out_scale_factor, fp4_out_scale_shape, - torch.float8_e4m3fn, - query.device, + paddle.float8_e4m3fn, + query.place, "out_scale_factor", ) - - # Check o_sf_start_index is valid if ( o_sf_start_index < 0 - or o_sf_start_index + out.shape[0] > out_scale_factor.shape[0] + or o_sf_start_index + tuple(out.shape)[0] > tuple(out_scale_factor.shape)[0] ): raise ValueError( - f"o_sf_start_index is out of the valid range of out_scale_factor. " - f"o_sf_start_index={o_sf_start_index}, out.shape[0]={out.shape[0]}, " - f"out_scale_factor.shape[0]={out_scale_factor.shape[0]}" + f"o_sf_start_index is out of the valid range of out_scale_factor. o_sf_start_index={o_sf_start_index}, out.shape[0]={tuple(out.shape)[0]}, out_scale_factor.shape[0]={tuple(out_scale_factor.shape)[0]}" ) - - elif isinstance(out_dtype, torch.dtype) or out_dtype is None: + elif isinstance(out_dtype, paddle.dtype) or out_dtype is None: assert o_sf_scale is None assert o_sf_vec_size is None out_scale_factor = None o_sf_start_index = 0 out_dtype = out_dtype or query.dtype - out = out if out is not None else torch.empty_like(query, dtype=out_dtype) - if out_dtype not in (query.dtype, torch.float16, torch.bfloat16): + out = out if out is not None else paddle.empty_like(x=query, dtype=out_dtype) + if out_dtype not in (query.dtype, "float16", "bfloat16"): raise ValueError(f"Unsupported out_dtype: {out_dtype}") - check_shape_dtype_device(out, query.shape, out_dtype, query.device, "out") + check_shape_dtype_device(out, tuple(query.shape), out_dtype, query.place, "out") else: raise ValueError(f"Invalid out_dtype: {out_dtype}") - run_func( out, out_scale_factor, - query.unsqueeze(1), # [B, 1, H, D], no MTP here so second dim is 1 + query.unsqueeze(axis=1), k_cache, v_cache, workspace_buffer, @@ -2161,11 +2020,10 @@ def trtllm_batch_decode_with_kv_cache( enable_pdl, sinks, ) - return ( out if out_dtype != "nvfp4" - else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) + else FP4Tensor(out, out_scale_factor, o_sf_start_index, tuple(query.shape)) ) @@ -2188,18 +2046,13 @@ def _check_trtllm_gen_mla_shape( raise ValueError(f"Expected kv_lora_rank == 512, got {kv_lora_rank}") if qk_rope_head_dim != 64: raise ValueError(f"Expected qk_rope_head_dim == 64, got {qk_rope_head_dim}") - - B_q, Q_len, H, D_q = query.shape - D_ckv = kv_cache.shape[3] - # if H != 128: - # raise ValueError(f"Expected 128 heads for query, got {H}") - # todo(Yingyi): should we check num_heads == 128? Is this deepseek only? + B_q, Q_len, H, D_q = tuple(query.shape) + D_ckv = tuple(kv_cache.shape)[3] if D_q != D_ckv or D_q != 576: raise ValueError( f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}" ) - - B_block_table, block_num = page_table.shape + B_block_table, block_num = tuple(page_table.shape) block_size = page_size if B_q != B_block_table: raise ValueError( @@ -2207,28 +2060,28 @@ def _check_trtllm_gen_mla_shape( ) if block_num % (128 / block_size) != 0: raise ValueError( - f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}" + f"Expected block_num % (128 / block_size) == 0, got block_num={block_num!r} and block_size={block_size!r}" ) def trtllm_batch_decode_with_kv_cache_mla( - query: torch.Tensor, - kv_cache: torch.Tensor, - workspace_buffer: torch.Tensor, + query: paddle.Tensor, + kv_cache: paddle.Tensor, + workspace_buffer: paddle.Tensor, qk_nope_head_dim: int, kv_lora_rank: int, qk_rope_head_dim: int, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, + block_tables: paddle.Tensor, + seq_lens: paddle.Tensor, max_seq_len: int, - out: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, bmm1_scale: Optional[float] = 1.0, bmm2_scale: Optional[float] = 1.0, - bmm1_scale_log2_tensor: Optional[torch.Tensor] = None, - bmm2_scale_tensor: Optional[torch.Tensor] = None, - sinks: Optional[List[torch.Tensor]] = None, + bmm1_scale_log2_tensor: Optional[paddle.Tensor] = None, + bmm2_scale_tensor: Optional[paddle.Tensor] = None, + sinks: Optional[List[paddle.Tensor]] = None, enable_pdl: bool = None, -) -> torch.Tensor: +) -> paddle.Tensor: """ Parameters: query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. @@ -2265,16 +2118,12 @@ def trtllm_batch_decode_with_kv_cache_mla( - Currently, only fp8 tensor core operation supports this mode. When both are provided, the dynamic scale factor tensors will be used. """ - enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl + enable_pdl = device_support_pdl(query.place) if enable_pdl is None else enable_pdl run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode - sm_count = get_device_sm_count(query.device) - - block_size = kv_cache.size(-2) - if ( - block_size != 32 and block_size != 64 - ): # todo(Yingyi): add support for more block sizes? + sm_count = get_device_sm_count(query.place) + block_size = kv_cache.shape[-2] + if block_size != 32 and block_size != 64: raise ValueError(f"Supported block_size are 32 and 64, got {block_size}") - _check_trtllm_gen_mla_shape( query, kv_cache, @@ -2284,30 +2133,22 @@ def trtllm_batch_decode_with_kv_cache_mla( block_tables, block_size, ) - if out is None: - out_shape = query.shape[:-1] + (kv_lora_rank,) - out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) + out_shape = tuple(query.shape)[:-1] + (kv_lora_rank,) + out = paddle.empty(shape=out_shape, dtype="bfloat16") else: - batch_size, _, num_q_heads, _ = query.shape + batch_size, _, num_q_heads, _ = tuple(query.shape) check_shape_dtype_device( - out, - [batch_size, num_q_heads, kv_lora_rank], - torch.bfloat16, - query.device, - "out", + out, [batch_size, num_q_heads, kv_lora_rank], "bfloat16", query.place, "out" ) - if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: - # dynamic scale factors - if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn: + if query.dtype != paddle.float8_e4m3fn or kv_cache.dtype != paddle.float8_e4m3fn: raise ValueError( "Dynamic scale factors bmm1_scale_tensor and bmm2_scale_tensor are only supported for fp8 tensor core operation" ) - run_func( out, - None, # fp4 output not supported in wrapper api yet. + None, query, kv_cache, kv_cache, @@ -2317,10 +2158,10 @@ def trtllm_batch_decode_with_kv_cache_mla( max_seq_len, bmm1_scale, bmm2_scale, - -1, # o_sf_scale - -1, # o_sf_vec_size - 0, # o_sf_start_index - -1, # window_left + -1, + -1, + 0, + -1, sm_count, enable_pdl, sinks, diff --git a/flashinfer/deep_gemm.py b/flashinfer/deep_gemm.py index 177eafac00..9f25bb1225 100644 --- a/flashinfer/deep_gemm.py +++ b/flashinfer/deep_gemm.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ MIT License @@ -21,9 +27,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - -# Imported and adapted from DeepGEMM. - import ctypes import enum import functools @@ -32,7 +35,6 @@ from typing import Any, Dict, Optional, Tuple import cuda.bindings.driver as cbd -import torch from .artifacts import ArtifactPath, MetaInfoHash from .cuda_utils import checkCudaErrors @@ -48,9 +50,9 @@ class GemmType(enum.Enum): def __str__(self) -> str: return { - 0: "GemmType::Normal", - 1: "GemmType::GroupedContiguous", - 2: "GemmType::GroupedMasked", + (0): "GemmType::Normal", + (1): "GemmType::GroupedContiguous", + (2): "GemmType::GroupedMasked", }[self.value] @@ -65,7 +67,7 @@ def non_contiguous_dim(self): return -2 if self.value == 0 else -1 def __str__(self) -> str: - return {0: "cute::UMMA::Major::K", 1: "cute::UMMA::Major::MN"}[self.value] + return {(0): "cute::UMMA::Major::K", (1): "cute::UMMA::Major::MN"}[self.value] class MajorTypeCD(enum.Enum): @@ -76,31 +78,27 @@ def non_contiguous_dim(self): return -2 if self.value == 0 else -1 -def major_check(t: torch.Tensor): +def major_check(t: paddle.Tensor): assert t.dim() in (2, 3) if t.dim() == 3: - assert t.stride(0) == t.size(-2) * t.size(-1), ( - "Grouped dimension cannot have abnormal stride" - ) - assert t.stride(-2) == 1 or t.stride(-1) == 1 + assert ( + t.get_strides()[0] == t.shape[-2] * t.shape[-1] + ), "Grouped dimension cannot have abnormal stride" + assert t.get_strides()[-2] == 1 or t.get_strides()[-1] == 1 -def get_major_type_ab(t: torch.Tensor): +def get_major_type_ab(t: paddle.Tensor): major_check(t) - return MajorTypeAB.KMajor if t.stride(-1) == 1 else MajorTypeAB.MNMajor + return MajorTypeAB.KMajor if t.get_strides()[-1] == 1 else MajorTypeAB.MNMajor -def get_major_type_cd(t: torch.Tensor): +def get_major_type_cd(t: paddle.Tensor): major_check(t) - return MajorTypeCD.NMajor if t.stride(-1) == 1 else MajorTypeCD.MMajor + return MajorTypeCD.NMajor if t.get_strides()[-1] == 1 else MajorTypeCD.MMajor -def get_element_size(dtype: torch.dtype): - return { - torch.float8_e4m3fn: 1, - torch.bfloat16: 2, - torch.float: 4, - }[dtype] +def get_element_size(dtype: paddle.dtype): + return {paddle.float8_e4m3fn: 1, "bfloat16": 2, "float32": 4}[dtype] def get_m_alignment_for_contiguous_layout(): @@ -114,91 +112,79 @@ def get_tma_aligned_size(x: int, element_size: int) -> int: return round_up(x, alignment) -def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor: - # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA - assert x.dtype == torch.float and x.dim() in (2, 3) - - # First, convert into UE8M0 `uint8_t` - ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8) - - # Second, make padded packed tensors - mn, k = x.shape[-2], x.shape[-1] +def get_col_major_tma_aligned_packed_tensor(x: paddle.Tensor) -> paddle.Tensor: + assert x.dtype == "float32" and x.dim() in (2, 3) + ue8m0_tensor = (x.view("int32") >> 23).to("uint8") + mn, k = tuple(x.shape)[-2], tuple(x.shape)[-1] remove_dim = False if x.dim() == 2: - x, remove_dim = x.unsqueeze(0), True - b = x.shape[0] + x, remove_dim = x.unsqueeze(axis=0), True + b = tuple(x.shape)[0] aligned_mn = get_tma_aligned_size(mn, 4) aligned_k = round_up(k, 4) - padded = torch.zeros((b, aligned_mn, aligned_k), device=x.device, dtype=torch.uint8) + padded = paddle.zeros(shape=(b, aligned_mn, aligned_k), dtype="uint8") padded[:, :mn, :k] = ue8m0_tensor - padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn, aligned_k // 4) - - # Finally, transpose - transposed = torch.transpose( - torch.empty((b, aligned_k // 4, aligned_mn), device=x.device, dtype=torch.int), - 1, - 2, + padded = padded.view(-1).view(dtype="int32").view(b, aligned_mn, aligned_k // 4) + transposed = paddle.transpose( + x=paddle.empty(shape=(b, aligned_k // 4, aligned_mn), dtype="int32"), + perm=dim2perm( + paddle.empty(shape=(b, aligned_k // 4, aligned_mn), dtype="int32").ndim, + 1, + 2, + ), ) transposed[:, :, :] = padded aligned_x = transposed[:, :mn, :] - return aligned_x.squeeze(0) if remove_dim else aligned_x + return aligned_x.squeeze(axis=0) if remove_dim else aligned_x def check_sf_layout( - sf: torch.Tensor, + sf: paddle.Tensor, mn: int, k: int, gran: Tuple[int, int], num_groups: Optional[int], tma_stride_check: bool = False, - type_check: Optional[torch.dtype] = None, -) -> torch.Tensor: - # Type check + type_check: Optional[paddle.dtype] = None, +) -> paddle.Tensor: if type_check is not None: assert sf.dtype == type_check - - # Always do shape checks - assert sf.dtype in (torch.float, torch.int) + assert sf.dtype in ("float32", "int32") assert sf.dim() == int(num_groups is not None) + 2 if num_groups is not None: - assert sf.size(-3) == num_groups - assert sf.size(-2) == ceil_div(mn, gran[0]) - assert sf.size(-1) == ceil_div(k, gran[1] * (1 if sf.dtype == torch.float else 4)) - - # TMA stride checks: TMA aligned and MN-major + assert sf.shape[-3] == num_groups + assert sf.shape[-2] == ceil_div(mn, gran[0]) + assert sf.shape[-1] == ceil_div(k, gran[1] * (1 if sf.dtype == "float32" else 4)) if tma_stride_check: if num_groups is not None: - assert sf.stride(-3) == sf.stride(-1) * sf.size(-1) - assert sf.stride(-2) == 1 - assert sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()) - + assert sf.get_strides()[-3] == sf.get_strides()[-1] * sf.shape[-1] + assert sf.get_strides()[-2] == 1 + assert sf.get_strides()[-1] == get_tma_aligned_size(mn, sf.element_size()) return sf def transform_sf_into_required_layout( - sf: torch.Tensor, + sf: paddle.Tensor, mn: int, k: int, recipe: Tuple[int, int, int], num_groups: Optional[int] = None, is_sfa: bool = False, ): - gran = (recipe[0 if is_sfa else 1], recipe[2]) - + gran = recipe[0 if is_sfa else 1], recipe[2] should_skip_transform = ( - sf.dtype == torch.int and gran == (1, 128) and get_device_arch() == "100a" - ) or (sf.dtype == torch.int and gran == (128, 128) and get_device_arch() == "100a") - + sf.dtype == "int32" + and gran == (1, 128) + and get_device_arch() == "100a" + or sf.dtype == "int32" + and gran == (128, 128) + and get_device_arch() == "100a" + ) if not should_skip_transform: - # Pre-transform checks check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups) - - # (FP32, 1, 128) on Hopper: transform to TMA-aligned and MN-major - if sf.dtype == torch.float and gran == (1, 128) and get_device_arch() == "90a": + if sf.dtype == "float32" and gran == (1, 128) and get_device_arch() == "90a": raise NotImplementedError - - # (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major - if sf.dtype == torch.float and gran == (1, 128) and get_device_arch() == "100a": + if sf.dtype == "float32" and gran == (1, 128) and get_device_arch() == "100a": sf = get_col_major_tma_aligned_packed_tensor(sf) return check_sf_layout( sf, @@ -207,16 +193,12 @@ def transform_sf_into_required_layout( gran=(1, 128), num_groups=num_groups, tma_stride_check=True, - type_check=torch.int, + type_check="int32", ) - - # (FP32, 128, 128) on Hopper: no need to transform, check shape and whatever-major - if sf.dtype == torch.float and gran == (128, 128) and get_device_arch() == "90a": + if sf.dtype == "float32" and gran == (128, 128) and get_device_arch() == "90a": raise NotImplementedError - - # (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major - if sf.dtype == torch.float and gran == (128, 128) and get_device_arch() == "100a": - sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128) + if sf.dtype == "float32" and gran == (128, 128) and get_device_arch() == "100a": + sf = sf.index_select(axis=-2, index=paddle.arange(end=mn) // 128) sf = get_col_major_tma_aligned_packed_tensor(sf) return check_sf_layout( sf, @@ -225,11 +207,9 @@ def transform_sf_into_required_layout( gran=(1, 128), num_groups=num_groups, tma_stride_check=True, - type_check=torch.int, + type_check="int32", ) - if should_skip_transform: - # TODO: add transpose kernel if SF layout is not satisfied return check_sf_layout( sf, mn=mn, @@ -237,15 +217,16 @@ def transform_sf_into_required_layout( gran=(1, 128), num_groups=num_groups, tma_stride_check=True, - type_check=torch.int, + type_check="int32", ) - - AssertionError(f"Unknown cases: {sf.dtype=}, {gran=}, arch={get_device_arch()}") + AssertionError( + f"Unknown cases: sf.dtype={sf.dtype!r}, gran={gran!r}, arch={get_device_arch()}" + ) @functools.lru_cache(maxsize=None) def get_device_arch(): - major, minor = torch.cuda.get_device_capability() + major, minor = paddle.device.cuda.get_device_capability() suffix = "a" if major >= 9 else "" return f"{major * 10 + minor}{suffix}" @@ -258,22 +239,19 @@ def hash_to_hex(s: str) -> str: @functools.lru_cache(maxsize=None) def must_be_k_major() -> bool: - return { - "90a": True, - "100a": False, - }[get_device_arch()] + return {"90a": True, "100a": False}[get_device_arch()] @functools.lru_cache(maxsize=None) def get_default_recipe( - sfa_dtype: torch.dtype, sfb_dtype: torch.dtype + sfa_dtype: paddle.dtype, sfb_dtype: paddle.dtype ) -> Tuple[int, int, int]: - assert sfa_dtype in (torch.float, torch.int) + assert sfa_dtype in ("float32", "int32") return { - ("90a", torch.float): (1, 128, 128), - ("100a", torch.float): (1, 128, 128), - ("100a", torch.int): (1, 1, 128), - }[(get_device_arch(), sfb_dtype)] + ("90a", "float32"): (1, 128, 128), + ("100a", "float32"): (1, 128, 128), + ("100a", "int32"): (1, 1, 128), + }[get_device_arch(), sfb_dtype] class MulticastConfig: @@ -282,12 +260,10 @@ def __init__(self, num_multicast: int, is_multicast_on_a: bool): self.is_multicast_on_a = is_multicast_on_a def get_ab_load_block_m(self, block_m: int): - # NOTES: this for >= SM100 only assert get_device_arch() != "90a" return block_m // (self.num_multicast if self.is_multicast_on_a else 1) def get_ab_load_block_n(self, block_n: int): - # NOTES: this for >= SM100 only assert get_device_arch() != "90a" return block_n // (1 if self.is_multicast_on_a else self.num_multicast) @@ -303,11 +279,8 @@ def __init__( self.smem_size = smem_size self.swizzle_a_mode = swizzle_a_mode self.swizzle_b_mode = swizzle_b_mode - # NOTES: sometimes the default swizzling pattern maybe not compatible (e.g., FP32 output) self.swizzle_cd_mode = swizzle_cd_mode - # TODO: swizzle SF as well self.swizzle_sf_mode = 0 - assert self.swizzle_a_mode != 0 assert self.swizzle_b_mode != 0 assert self.swizzle_cd_mode > 16 @@ -328,31 +301,28 @@ def is_multicast_legal( def get_swizzle_mode(block_size: int, elem_size: int) -> int: - # `> 0` means interleaving - # 16B actually means non-swizzling (but interleaving) for mode_bytes in (128, 64, 32, 16): - if (block_size * elem_size) % mode_bytes == 0: + if block_size * elem_size % mode_bytes == 0: return mode_bytes AssertionError("Invalid mode") return 0 -def get_sf_aligned_block_sizes(block_m: int, block_n: int, ab_dtype: torch.dtype): +def get_sf_aligned_block_sizes(block_m: int, block_n: int, ab_dtype: paddle.dtype): num_utccp_aligned_elems = 128 assert block_m % num_utccp_aligned_elems == 0 return { - torch.bfloat16: (0, 0), - torch.float8_e4m3fn: ( + "bfloat16": (0, 0), + paddle.float8_e4m3fn: ( round_up(block_m, num_utccp_aligned_elems), round_up(block_n, num_utccp_aligned_elems), ), }[ab_dtype] -def is_tmem_size_legal(block_m: int, block_n: int, ab_dtype: torch.float): - # M waves or epilogue stages (* 2), SFA and SFB +def is_tmem_size_legal(block_m: int, block_n: int, ab_dtype: "float32"): sf_block_m, sf_block_n = get_sf_aligned_block_sizes(block_m, block_n, ab_dtype) - return ((2 * block_n) + (sf_block_m // 32) + (sf_block_n // 32)) <= 512 + return 2 * block_n + sf_block_m // 32 + sf_block_n // 32 <= 512 def get_smem_config( @@ -362,16 +332,14 @@ def get_smem_config( major_a: MajorTypeAB, major_b: MajorTypeAB, major_d: MajorTypeCD, - ab_dtype: torch.dtype, - cd_dtype: torch.dtype, + ab_dtype: paddle.dtype, + cd_dtype: paddle.dtype, num_stages: int, multicast_config: MulticastConfig, ) -> SharedMemoryConfig: assert major_d == MajorTypeCD.NMajor - ab_elem_size = get_element_size(ab_dtype) cd_elem_size = get_element_size(cd_dtype) - load_block_m = multicast_config.get_ab_load_block_m(block_m) load_block_n = multicast_config.get_ab_load_block_n(block_n) swizzle_a_mode = get_swizzle_mode( @@ -383,30 +351,15 @@ def get_smem_config( swizzle_cd_mode = get_swizzle_mode( block_n if major_d == MajorTypeCD.NMajor else block_m, cd_elem_size ) - - # 2 stages of STSM and TMA store - # TODO: consider other layouts layout_ad_m = 128 smem_d = min(block_m, layout_ad_m) * swizzle_cd_mode * 2 - - # A/B shared memory smem_a_per_stage = load_block_m * block_k * ab_elem_size smem_b_per_stage = load_block_n * block_k * ab_elem_size - - # SF shared memory must be aligned to UTCCP - # Each stage must prefetch next 4 stages' SF (including the current) sf_block_m, sf_block_n = get_sf_aligned_block_sizes(block_m, block_n, ab_dtype) smem_scales_a_per_stage = sf_block_m * 4 smem_scales_b_per_stage = sf_block_n * 4 - - # TODO: remove SF barriers for BF16 GEMMs - # TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers, accumulation full barrier - # NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages - # NOTES: cases without accumulation will not use the accumulation full barrier smem_barrier = num_stages * 8 * 3 + 2 * 8 * 2 + 8 smem_tmem_ptr = 4 - - # Sum them up smem_size = 0 smem_size += smem_d smem_size += num_stages * smem_a_per_stage @@ -415,7 +368,6 @@ def get_smem_config( smem_size += num_stages * smem_scales_b_per_stage smem_size += smem_barrier smem_size += smem_tmem_ptr - return SharedMemoryConfig( smem_size, swizzle_a_mode, swizzle_b_mode, swizzle_cd_mode ) @@ -431,48 +383,38 @@ def get_best_configs( major_a: MajorTypeAB, major_b: MajorTypeAB, major_d: MajorTypeCD, - ab_dtype: torch.dtype, - cd_dtype: torch.dtype, + ab_dtype: paddle.dtype, + cd_dtype: paddle.dtype, num_sms: int, ) -> Tuple[int, int, int, int, int, MulticastConfig, SharedMemoryConfig]: - assert ab_dtype == torch.float8_e4m3fn - assert cd_dtype in (torch.bfloat16, torch.float) - - # `BLOCK_M` and `BLOCK_N` are selected according to MMA instructions + assert ab_dtype == paddle.float8_e4m3fn + assert cd_dtype in ("bfloat16", "float32") block_ms: Tuple[int, ...] = None if gemm_type == GemmType.GroupedContiguous: block_ms = (get_m_alignment_for_contiguous_layout(),) else: block_ms = (128,) if major_b == MajorTypeAB.KMajor else (128, 256) - # NOTES: some `% 32 == 16` cases are not compatible with 2-CTA TMA swizzling block_ns = ( tuple(range(16, 257, 16)) if major_b == MajorTypeAB.KMajor else tuple(range(32, 257, 32)) ) - - # `BLOCK_K` is selected in a fixed manner block_k = 128 // get_element_size(ab_dtype) - fix_wave_saturate = lambda x: num_sms if x == 0 else x - get_num_waves = lambda bm, bn: ( - ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) + get_num_waves = ( + lambda bm, bn: ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None ) get_last_wave_util = lambda bm, bn: fix_wave_saturate( - (ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms + ceil_div(m, bm) * ceil_div(n, bn) * num_groups % num_sms ) - - # Decide block sizes by waves - # TODO: move block size search into `common.py` best_block_m, best_block_n = None, None for block_m in block_ms: for block_n in block_ns: success = False - num_waves, best_num_waves = ( - get_num_waves(block_m, block_n), - get_num_waves(best_block_m, best_block_n), + num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves( + best_block_m, best_block_n ) if ( best_block_m is None @@ -481,29 +423,20 @@ def get_best_configs( ): success = True elif num_waves == best_num_waves: - # Check last wave utilization util = get_last_wave_util(block_m, block_n) best_util = get_last_wave_util(best_block_m, best_block_n) success = util > best_util if util == best_util: - # Case 1: same `block_m`, smaller `block_n` (wasted) success |= block_m == best_block_m and block_n < best_block_n - # Case 2: same `block_n`, smaller `block_m` (wasted) success |= block_n == best_block_n and block_m < best_block_m - # Case 3: different for both `block_m` and `block_n`, larger `block_n` is better success |= block_m != best_block_m and block_n > best_block_n success &= is_tmem_size_legal(block_m, block_n, ab_dtype) best_block_m, best_block_n = ( (block_m, block_n) if success else (best_block_m, best_block_n) ) assert best_block_m is not None and best_block_n is not None - - # Decide the number of TMA multicasts and whether broadcast on A best_multicast_config = MulticastConfig(1, True) - - # Try to multicast on the larger block side first is_legal = { - # TODO: support other `tcgen05` layouts "A": False, "B": is_multicast_legal(m, best_block_m, 2, num_sms, True) and gemm_type == GemmType.Normal, @@ -512,10 +445,6 @@ def get_best_configs( if m >= 512 and is_legal[i]: best_multicast_config = MulticastConfig(2, i == "A") break - - # Always pick the longest one - # NOTES: for double B scales, the best number of stages may be reduced - # TODO: move stage search into `common.py` best_num_stages, best_smem_config, sm100_capacity = None, None, 232448 stage_candidates = tuple( filter(lambda s: s <= max(k // 128, 1), (8, 7, 6, 5, 4, 3, 2, 1)) @@ -538,10 +467,6 @@ def get_best_configs( break assert best_smem_config is not None assert best_num_stages is not None - - # Recompute the minimal number of SMs required - # NOTES: less L2 cache usage and less GPU frequency drop - # TODO: move min SM fix into `common.py` num_waves = get_num_waves(best_block_m, best_block_n) num_min_sms = ceil_div( ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves @@ -551,7 +476,6 @@ def get_best_configs( * best_multicast_config.num_multicast ) assert num_min_sms <= num_sms - return ( num_min_sms, best_block_m, @@ -564,34 +488,33 @@ def get_best_configs( tmap_type_map: Dict[Any, str] = { - torch.int8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.int16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, - torch.int32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32, - torch.int64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64, - torch.uint8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, - torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32, - torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64, - torch.float32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32, - torch.float16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, - torch.bfloat16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, - torch.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, - torch.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + "int8": cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + "int16": cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, + "int32": cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32, + "int64": cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64, + "uint8": cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, +>>>>>> torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, +>>>>>> torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32, +>>>>>> torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64, + "float32": cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32, + "float16": cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, + "bfloat16": cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + paddle.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, + paddle.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, +>>>>>> paddle.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, +>>>>>> paddle.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, } - swizzle_type_map = { - 0: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, - 16: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, - 32: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B, - 64: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B, - 128: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, + (0): cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, + (16): cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, + (32): cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B, + (64): cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B, + (128): cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, } def make_tma_xd_desc( - t: torch.Tensor, + t: paddle.Tensor, gmem_dims: Tuple[cbd.cuuint64_t, ...], gmem_strides: Tuple[cbd.cuuint64_t, ...], smem_dims: Tuple[cbd.cuuint32_t, ...], @@ -600,7 +523,6 @@ def make_tma_xd_desc( num_dims = len(gmem_dims) assert len(gmem_strides) == num_dims - 1 assert len(smem_dims) == num_dims - tensor_dtype = tmap_type_map[t.dtype] tensor_map = checkCudaErrors( cbd.cuTensorMapEncodeTiled( @@ -621,7 +543,7 @@ def make_tma_xd_desc( def make_tma_2d_desc( - t: torch.Tensor, + t: paddle.Tensor, gmem_inner_dim: int, gmem_outer_dim: int, smem_inner_dim: int, @@ -629,14 +551,12 @@ def make_tma_2d_desc( gmem_outer_stride: int, swizzle_mode: int, ) -> cbd.CUtensorMap: - # For swizzling pattern, multiple TMAs are required if swizzle_mode != 0: assert swizzle_mode % t.element_size() == 0 smem_inner_dim = swizzle_mode // t.element_size() - - gmem_dims = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim)) + gmem_dims = cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim) gmem_strides = (cbd.cuuint64_t(gmem_outer_stride * t.element_size()),) - smem_dims = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim)) + smem_dims = cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim) return make_tma_xd_desc( t, gmem_dims, gmem_strides, smem_dims, swizzle_type_map[swizzle_mode] ) @@ -644,7 +564,7 @@ def make_tma_2d_desc( def make_tma_a_desc( major_type: MajorTypeAB, - t: torch.Tensor, + t: paddle.Tensor, shape_m: int, shape_k: int, block_m: int, @@ -655,7 +575,6 @@ def make_tma_a_desc( ) -> cbd.CUtensorMap: if num_groups > 1: assert major_type == MajorTypeAB.KMajor - gmem_inner_dim, gmem_outer_dim = (shape_k, shape_m * num_groups)[ :: major_type.shape_direction() ] @@ -673,7 +592,7 @@ def make_tma_a_desc( def make_tma_b_desc( major_type: MajorTypeAB, - t: torch.Tensor, + t: paddle.Tensor, shape_n: int, shape_k: int, block_n: int, @@ -682,11 +601,9 @@ def make_tma_b_desc( num_groups: int, swizzle_mode: int, ) -> cbd.CUtensorMap: - # `num_groups` is always applied into the outer dimensions io_shapes = (shape_k, shape_n)[:: major_type.shape_direction()] - gmem_inner_dim, gmem_outer_dim = (io_shapes[0], io_shapes[1] * num_groups) + gmem_inner_dim, gmem_outer_dim = io_shapes[0], io_shapes[1] * num_groups smem_inner_dim, smem_outer_dim = (block_k, block_n)[:: major_type.shape_direction()] - return make_tma_2d_desc( t, gmem_inner_dim, @@ -700,7 +617,7 @@ def make_tma_b_desc( def make_tma_cd_desc( major_type: MajorTypeCD, - t: torch.Tensor, + t: paddle.Tensor, shape_m: int, shape_n: int, block_m: int, @@ -710,9 +627,6 @@ def make_tma_cd_desc( swizzle_mode: int, ) -> cbd.CUtensorMap: assert major_type == MajorTypeCD.NMajor - - # Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode` - # bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required layout_ad_m = 128 return make_tma_2d_desc( t, @@ -727,7 +641,7 @@ def make_tma_cd_desc( def make_tma_sf_desc( major_type: MajorTypeAB, - t: torch.Tensor, + t: paddle.Tensor, shape_mn: int, shape_k: int, block_mn: int, @@ -736,11 +650,7 @@ def make_tma_sf_desc( swizzle_mode: int, ) -> cbd.CUtensorMap: assert major_type == MajorTypeAB.MNMajor - - # TODO: maybe swizzle SF as well assert swizzle_mode == 0 - - # Make TMA aligned to 16 bytes shape_mn = get_tma_aligned_size(shape_mn, t.element_size()) return make_tma_2d_desc( t, @@ -753,15 +663,12 @@ def make_tma_sf_desc( ) -# Map some common Python types into C types pytypes_to_ctypes = { - True: "true", - False: "false", - torch.bfloat16: "cutlass::bfloat16_t", - torch.float: "float", + (True): "true", + (False): "false", + "bfloat16": "cutlass::bfloat16_t", + "float32": "float", } - - RUNTIME_CACHE = {} @@ -771,13 +678,10 @@ def __init__(self, path: str, symbol: str) -> None: self.lib = None self.kernel = None self.symbol = symbol - # Store a reference to the cleanup function to avoid import issues during shutdown self._cleanup_func = cbd.cuLibraryUnload def __call__(self, **kwargs) -> cbd.CUresult: - # Load CUBIN if self.kernel is None: - # Load CUBIN path = bytes(self.path, encoding="utf-8") self.lib = checkCudaErrors( cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0) @@ -785,8 +689,6 @@ def __call__(self, **kwargs) -> cbd.CUresult: self.kernel = checkCudaErrors( cbd.cuLibraryGetKernel(self.lib, bytes(self.symbol, encoding="utf-8")) ) - - # noinspection PyArgumentList return self.launch(self.kernel, kwargs) def __del__(self) -> None: @@ -794,12 +696,11 @@ def __del__(self) -> None: try: checkCudaErrors(self._cleanup_func(self.lib)) except Exception as e: - # Ignore any errors during shutdown print(f"Failed to delete SM100FP8GemmRuntime: {e}") @staticmethod def generate(kwargs: Dict[str, Any]) -> str: - assert kwargs["CD_DTYPE_T"] in (torch.bfloat16, torch.float) + assert kwargs["CD_DTYPE_T"] in ("bfloat16", "float32") code = f""" #ifdef __CUDACC_RTC__ #include @@ -814,33 +715,32 @@ def generate(kwargs: Dict[str, Any]) -> str: static void __instantiate_kernel() {{ auto ptr = reinterpret_cast(&sm100_fp8_gemm_1d1d_impl< - {kwargs["MAJOR_A"]}, - {kwargs["MAJOR_B"]}, - {kwargs["M"] if "m" in kwargs["COMPILED_DIMS"] else 0}, - {kwargs["N"] if "n" in kwargs["COMPILED_DIMS"] else 0}, - {kwargs["K"] if "k" in kwargs["COMPILED_DIMS"] else 0}, - {kwargs["BLOCK_M"]}, - {kwargs["BLOCK_N"]}, - {kwargs["BLOCK_K"]}, - {kwargs["NUM_GROUPS"]}, - {kwargs["SWIZZLE_A_MODE"]}, - {kwargs["SWIZZLE_B_MODE"]}, - {kwargs["SWIZZLE_CD_MODE"]}, - {kwargs["NUM_STAGES"]}, - {kwargs["NUM_LAST_STAGES"]}, - {kwargs["NUM_NON_EPILOGUE_THREADS"]}, - {kwargs["NUM_EPILOGUE_THREADS"]}, - {kwargs["NUM_MULTICAST"]}, - {pytypes_to_ctypes[kwargs["IS_MULTICAST_ON_A"]]}, - {kwargs["GEMM_TYPE"]}, - {pytypes_to_ctypes[kwargs["WITH_ACCUMULATION"]]}, - {pytypes_to_ctypes[kwargs["CD_DTYPE_T"]]} + {kwargs['MAJOR_A']}, + {kwargs['MAJOR_B']}, + {kwargs['M'] if 'm' in kwargs['COMPILED_DIMS'] else 0}, + {kwargs['N'] if 'n' in kwargs['COMPILED_DIMS'] else 0}, + {kwargs['K'] if 'k' in kwargs['COMPILED_DIMS'] else 0}, + {kwargs['BLOCK_M']}, + {kwargs['BLOCK_N']}, + {kwargs['BLOCK_K']}, + {kwargs['NUM_GROUPS']}, + {kwargs['SWIZZLE_A_MODE']}, + {kwargs['SWIZZLE_B_MODE']}, + {kwargs['SWIZZLE_CD_MODE']}, + {kwargs['NUM_STAGES']}, + {kwargs['NUM_LAST_STAGES']}, + {kwargs['NUM_NON_EPILOGUE_THREADS']}, + {kwargs['NUM_EPILOGUE_THREADS']}, + {kwargs['NUM_MULTICAST']}, + {pytypes_to_ctypes[kwargs['IS_MULTICAST_ON_A']]}, + {kwargs['GEMM_TYPE']}, + {pytypes_to_ctypes[kwargs['WITH_ACCUMULATION']]}, + {pytypes_to_ctypes[kwargs['CD_DTYPE_T']]} >); }}; """ return code - # noinspection PyMethodOverriding @staticmethod def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: checkCudaErrors( @@ -851,7 +751,6 @@ def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: cbd.CUdevice(kwargs["DEVICE_INDEX"]), ) ) - attr_val = cbd.CUlaunchAttributeValue() attr_val.clusterDim.x = kwargs["NUM_MULTICAST"] attr_val.clusterDim.y = 1 @@ -859,7 +758,6 @@ def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: attr = cbd.CUlaunchAttribute() attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION attr.value = attr_val - config = cbd.CUlaunchConfig() config.numAttrs = 1 config.attrs = [attr] @@ -873,7 +771,6 @@ def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: config.blockDimZ = 1 config.sharedMemBytes = kwargs["SMEM_SIZE"] config.hStream = kwargs["STREAM"] - arg_values = ( kwargs["GROUPED_LAYOUT"].data_ptr(), kwargs["M"], @@ -937,31 +834,37 @@ def m_grouped_fp8_gemm_nt_contiguous_static_kwargs_gen( major_b: MajorTypeAB, major_d: MajorTypeCD, compiled_dims: str, - output_dtype: torch.dtype, + output_dtype: paddle.dtype, ): - num_sms = torch.cuda.get_device_properties(device="cuda").multi_processor_count - num_sms, block_m, block_n, block_k, num_stages, multicast_config, smem_config = ( - get_best_configs( - GemmType.GroupedContiguous, - m, - n, - k, - num_groups, - major_a, - major_b, - major_d, - torch.float8_e4m3fn, - output_dtype, - num_sms, - ) + num_sms = paddle.device.cuda.get_device_properties( + device="gpu" + ).multi_processor_count + ( + num_sms, + block_m, + block_n, + block_k, + num_stages, + multicast_config, + smem_config, + ) = get_best_configs( + GemmType.GroupedContiguous, + m, + n, + k, + num_groups, + major_a, + major_b, + major_d, + paddle.float8_e4m3fn, + output_dtype, + num_sms, ) kwargs = { - # Templated or runtime arguments according to the `COMPILED_DIMS` "COMPILED_DIMS": compiled_dims, "M": m, "N": n, "K": aligned_k, - # Templated arguments "GEMM_TYPE": GemmType.GroupedContiguous, "NUM_NON_EPILOGUE_THREADS": 128, "NUM_EPILOGUE_THREADS": 128, @@ -993,34 +896,29 @@ def m_grouped_fp8_gemm_nt_contiguous_static_kwargs_gen( def m_grouped_fp8_gemm_nt_contiguous_kwargs_gen( - a: torch.Tensor, - sfa: torch.Tensor, - b: torch.Tensor, - sfb: torch.Tensor, - d: torch.Tensor, - m_indices: torch.Tensor, + a: paddle.Tensor, + sfa: paddle.Tensor, + b: paddle.Tensor, + sfb: paddle.Tensor, + d: paddle.Tensor, + m_indices: paddle.Tensor, major_a: MajorTypeAB, major_b: MajorTypeAB, compiled_dims: str, ): - m, k = a.shape - num_groups, n, _ = b.shape + m, k = tuple(a.shape) + num_groups, n, _ = tuple(b.shape) major_d = MajorTypeCD.NMajor - - # K must be aligned to 128 aligned_k = round_up(k, 128) ( - ( - num_sms, - block_m, - block_n, - block_k, - num_stages, - multicast_config, - smem_config, - ), - static_kwargs, - ) = m_grouped_fp8_gemm_nt_contiguous_static_kwargs_gen( + num_sms, + block_m, + block_n, + block_k, + num_stages, + multicast_config, + smem_config, + ), static_kwargs = m_grouped_fp8_gemm_nt_contiguous_static_kwargs_gen( m, n, k, @@ -1032,7 +930,6 @@ def m_grouped_fp8_gemm_nt_contiguous_kwargs_gen( compiled_dims, d.dtype, ) - # NOTES: you cannot distinguish groups for A, SFA, and D tensor_map_a = make_tma_a_desc( major_a, a, @@ -1040,7 +937,7 @@ def m_grouped_fp8_gemm_nt_contiguous_kwargs_gen( k, multicast_config.get_ab_load_block_m(block_m), block_k, - a.stride(major_a.non_contiguous_dim()), + a.get_strides()[major_a.non_contiguous_dim()], num_groups=1, swizzle_mode=smem_config.swizzle_a_mode, ) @@ -1051,7 +948,7 @@ def m_grouped_fp8_gemm_nt_contiguous_kwargs_gen( k, multicast_config.get_ab_load_block_n(block_n), block_k, - b.stride(major_b.non_contiguous_dim()), + b.get_strides()[major_b.non_contiguous_dim()], num_groups=num_groups, swizzle_mode=smem_config.swizzle_b_mode, ) @@ -1062,7 +959,7 @@ def m_grouped_fp8_gemm_nt_contiguous_kwargs_gen( n, block_m, block_n, - d.stride(major_d.non_contiguous_dim()), + d.get_strides()[major_d.non_contiguous_dim()], num_groups=1, swizzle_mode=smem_config.swizzle_cd_mode, ) @@ -1088,7 +985,6 @@ def m_grouped_fp8_gemm_nt_contiguous_kwargs_gen( ) all_kwargs = { **static_kwargs, - # Runtime arguments "GROUPED_LAYOUT": m_indices, "NUM_SMS": num_sms, "SMEM_SIZE": smem_config.smem_size, @@ -1098,19 +994,19 @@ def m_grouped_fp8_gemm_nt_contiguous_kwargs_gen( "TENSOR_MAP_SFB": tensor_map_sfb, "TENSOR_MAP_C": tensor_map_d, "TENSOR_MAP_D": tensor_map_d, - "STREAM": torch.cuda.current_stream().cuda_stream, + "STREAM": paddle.device.current_stream().cuda_stream, "DEVICE_INDEX": d.device.index, } return static_kwargs, all_kwargs def m_grouped_fp8_gemm_nt_contiguous_sm100( - a: torch.Tensor, - sfa: torch.Tensor, - b: torch.Tensor, - sfb: torch.Tensor, - d: torch.Tensor, - m_indices: torch.Tensor, + a: paddle.Tensor, + sfa: paddle.Tensor, + b: paddle.Tensor, + sfb: paddle.Tensor, + d: paddle.Tensor, + m_indices: paddle.Tensor, major_a: MajorTypeAB, major_b: MajorTypeAB, compiled_dims: str, @@ -1118,7 +1014,6 @@ def m_grouped_fp8_gemm_nt_contiguous_sm100( static_kwargs, all_kwargs = m_grouped_fp8_gemm_nt_contiguous_kwargs_gen( a, sfa, b, sfb, d, m_indices, major_a, major_b, compiled_dims ) - # Generate, build and run the kernel code = SM100FP8GemmRuntime.generate(static_kwargs) runtime = load("fp8_m_grouped_gemm", code) runtime(**all_kwargs) @@ -1135,34 +1030,39 @@ def m_grouped_fp8_gemm_nt_masked_static_kwargs_gen( major_b: MajorTypeAB, major_d: MajorTypeCD, compiled_dims: str, - output_dtype: torch.dtype, + output_dtype: paddle.dtype, ): - num_sms = torch.cuda.get_device_properties(device="cuda").multi_processor_count - num_sms, block_m, block_n, block_k, num_stages, multicast_config, smem_config = ( - get_best_configs( - GemmType.GroupedMasked, - expected_m, - n, - k, - num_groups, - major_a, - major_b, - major_d, - torch.float8_e4m3fn, - output_dtype, - num_sms, - ) + num_sms = paddle.device.cuda.get_device_properties( + device="gpu" + ).multi_processor_count + ( + num_sms, + block_m, + block_n, + block_k, + num_stages, + multicast_config, + smem_config, + ) = get_best_configs( + GemmType.GroupedMasked, + expected_m, + n, + k, + num_groups, + major_a, + major_b, + major_d, + paddle.float8_e4m3fn, + output_dtype, + num_sms, ) if num_groups > 1: assert m % block_m == 0 - kwargs = { - # Templated or runtime arguments according to the `COMPILED_DIMS` "COMPILED_DIMS": compiled_dims, "M": m, "N": n, "K": aligned_k, - # Templated arguments "GEMM_TYPE": GemmType.GroupedMasked, "NUM_NON_EPILOGUE_THREADS": 128, "NUM_EPILOGUE_THREADS": 128, @@ -1194,35 +1094,30 @@ def m_grouped_fp8_gemm_nt_masked_static_kwargs_gen( def m_grouped_fp8_gemm_nt_masked_kwargs_gen( - a: torch.Tensor, - sfa: torch.Tensor, - b: torch.Tensor, - sfb: torch.Tensor, - d: torch.Tensor, - masked_m: torch.Tensor, + a: paddle.Tensor, + sfa: paddle.Tensor, + b: paddle.Tensor, + sfb: paddle.Tensor, + d: paddle.Tensor, + masked_m: paddle.Tensor, expected_m: int, major_a: MajorTypeAB, major_b: MajorTypeAB, compiled_dims: str, ): - num_groups, m, k = a.shape - _, n, _ = b.shape + num_groups, m, k = tuple(a.shape) + _, n, _ = tuple(b.shape) major_d = MajorTypeCD.NMajor - - # K must be aligned to 128 aligned_k = round_up(k, 128) ( - ( - num_sms, - block_m, - block_n, - block_k, - num_stages, - multicast_config, - smem_config, - ), - static_kwargs, - ) = m_grouped_fp8_gemm_nt_masked_static_kwargs_gen( + num_sms, + block_m, + block_n, + block_k, + num_stages, + multicast_config, + smem_config, + ), static_kwargs = m_grouped_fp8_gemm_nt_masked_static_kwargs_gen( m, n, k, @@ -1235,7 +1130,6 @@ def m_grouped_fp8_gemm_nt_masked_kwargs_gen( compiled_dims, d.dtype, ) - tensor_map_a = make_tma_a_desc( major_a, a, @@ -1243,7 +1137,7 @@ def m_grouped_fp8_gemm_nt_masked_kwargs_gen( k, multicast_config.get_ab_load_block_m(block_m), block_k, - a.stride(major_a.non_contiguous_dim()), + a.get_strides()[major_a.non_contiguous_dim()], num_groups, smem_config.swizzle_a_mode, ) @@ -1254,7 +1148,7 @@ def m_grouped_fp8_gemm_nt_masked_kwargs_gen( k, multicast_config.get_ab_load_block_n(block_n), block_k, - b.stride(major_b.non_contiguous_dim()), + b.get_strides()[major_b.non_contiguous_dim()], num_groups, smem_config.swizzle_b_mode, ) @@ -1265,7 +1159,7 @@ def m_grouped_fp8_gemm_nt_masked_kwargs_gen( n, block_m, block_n, - d.stride(major_d.non_contiguous_dim()), + d.get_strides()[major_d.non_contiguous_dim()], num_groups, smem_config.swizzle_cd_mode, ) @@ -1291,7 +1185,6 @@ def m_grouped_fp8_gemm_nt_masked_kwargs_gen( ) all_kwargs = { **static_kwargs, - # Runtime arguments "GROUPED_LAYOUT": masked_m, "NUM_SMS": num_sms, "SMEM_SIZE": smem_config.smem_size, @@ -1301,19 +1194,19 @@ def m_grouped_fp8_gemm_nt_masked_kwargs_gen( "TENSOR_MAP_SFB": tensor_map_sfb, "TENSOR_MAP_C": tensor_map_d, "TENSOR_MAP_D": tensor_map_d, - "STREAM": torch.cuda.current_stream().cuda_stream, + "STREAM": paddle.device.current_stream().cuda_stream, "DEVICE_INDEX": d.device.index, } return static_kwargs, all_kwargs def m_grouped_fp8_gemm_nt_masked_sm100( - a: torch.Tensor, - sfa: torch.Tensor, - b: torch.Tensor, - sfb: torch.Tensor, - d: torch.Tensor, - masked_m: torch.Tensor, + a: paddle.Tensor, + sfa: paddle.Tensor, + b: paddle.Tensor, + sfb: paddle.Tensor, + d: paddle.Tensor, + masked_m: paddle.Tensor, expected_m: int, major_a: MajorTypeAB, major_b: MajorTypeAB, @@ -1322,60 +1215,46 @@ def m_grouped_fp8_gemm_nt_masked_sm100( static_kwargs, all_kwargs = m_grouped_fp8_gemm_nt_masked_kwargs_gen( a, sfa, b, sfb, d, masked_m, expected_m, major_a, major_b, compiled_dims ) - # Generate, build and run the kernel code = SM100FP8GemmRuntime.generate(static_kwargs) runtime = load("fp8_m_grouped_gemm", code) runtime(**all_kwargs) def m_grouped_fp8_gemm_nt_contiguous( - a_fp8: Tuple[torch.Tensor, torch.Tensor], - b_fp8: Tuple[torch.Tensor, torch.Tensor], - d: torch.Tensor, - m_indices: torch.Tensor, + a_fp8: Tuple[paddle.Tensor, paddle.Tensor], + b_fp8: Tuple[paddle.Tensor, paddle.Tensor], + d: paddle.Tensor, + m_indices: paddle.Tensor, recipe: Optional[Tuple[int, int, int]] = None, compiled_dims: str = "nk", ) -> None: - # Compiled dims can be upper cases compiled_dims = compiled_dims.lower() - - # NOTES: shape must be `[M, K] @ [G, N, K].mT` major_a = get_major_type_ab(a_fp8[0]) major_b = get_major_type_ab(b_fp8[0]) assert major_a == MajorTypeAB.KMajor if must_be_k_major(): assert major_b == MajorTypeAB.KMajor assert m_indices.is_contiguous() - a, sfa = a_fp8 b, sfb = b_fp8 - m, k = a.shape - num_groups, n, k_ = b.shape - m_, n_ = d.shape - m__ = m_indices.numel() - - # Type and shape checks + m, k = tuple(a.shape) + num_groups, n, k_ = tuple(b.shape) + m_, n_ = tuple(d.shape) + m__ = m_indices.size assert m == m_ == m__ and n == n_ and k == k_ assert n > 0 and k > 0 and num_groups > 0 - assert a.dtype == torch.float8_e4m3fn - assert b.dtype == torch.float8_e4m3fn - assert d.dtype == torch.bfloat16 - assert m_indices.dtype == torch.int32 - - # D must be N-major + assert a.dtype == paddle.float8_e4m3fn + assert b.dtype == paddle.float8_e4m3fn + assert d.dtype == "bfloat16" + assert m_indices.dtype == "int32" assert get_major_type_cd(d) == MajorTypeCD.NMajor - - # Do nothing if the problem is empty if m == 0: return - - # Transform SFA and SFB into compute-required layout recipe = get_default_recipe(sfa.dtype, sfb.dtype) if recipe is None else recipe sfa = transform_sf_into_required_layout(sfa, mn=m, k=k, recipe=recipe, is_sfa=True) sfb = transform_sf_into_required_layout( sfb, mn=n, k=k, recipe=recipe, num_groups=num_groups, is_sfa=False ) - impl = { "100a": functools.partial( m_grouped_fp8_gemm_nt_contiguous_sm100, @@ -1388,43 +1267,33 @@ def m_grouped_fp8_gemm_nt_contiguous( def m_grouped_fp8_gemm_nt_masked( - a_fp8: Tuple[torch.Tensor, torch.Tensor], - b_fp8: Tuple[torch.Tensor, torch.Tensor], - d: torch.Tensor, - masked_m: torch.Tensor, + a_fp8: Tuple[paddle.Tensor, paddle.Tensor], + b_fp8: Tuple[paddle.Tensor, paddle.Tensor], + d: paddle.Tensor, + masked_m: paddle.Tensor, expected_m: int, recipe: Optional[Tuple[int, int, int]] = None, compiled_dims: str = "nk", ) -> None: - # Compiled dims can be upper cases compiled_dims = compiled_dims.lower() - - # NOTES: shape must be `[G, M, K] @ [G, N, K].mT` major_a = get_major_type_ab(a_fp8[0]) major_b = get_major_type_ab(b_fp8[0]) assert major_a == major_b == MajorTypeAB.KMajor assert masked_m.is_contiguous() - a, sfa = a_fp8 b, sfb = b_fp8 - num_groups, m, k = a.shape - num_groups_, n, k_ = b.shape - num_groups__, m_, n_ = d.shape - num_groups___ = masked_m.numel() - - # Type and shape checks + num_groups, m, k = tuple(a.shape) + num_groups_, n, k_ = tuple(b.shape) + num_groups__, m_, n_ = tuple(d.shape) + num_groups___ = masked_m.size assert num_groups == num_groups_ == num_groups__ == num_groups___ assert m == m_ and n == n_ and k == k_ assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 - assert a.dtype == torch.float8_e4m3fn - assert b.dtype == torch.float8_e4m3fn - assert d.dtype == torch.bfloat16 - assert masked_m.dtype == torch.int32 - - # D must be N-major + assert a.dtype == paddle.float8_e4m3fn + assert b.dtype == paddle.float8_e4m3fn + assert d.dtype == "bfloat16" + assert masked_m.dtype == "int32" assert get_major_type_cd(d) == MajorTypeCD.NMajor - - # Transform SFA and SFB into compute-required layout recipe = get_default_recipe(sfa.dtype, sfb.dtype) if recipe is None else recipe sfa = transform_sf_into_required_layout( sfa, mn=m, k=k, recipe=recipe, num_groups=num_groups, is_sfa=True @@ -1432,7 +1301,6 @@ def m_grouped_fp8_gemm_nt_masked( sfb = transform_sf_into_required_layout( sfb, mn=n, k=k, recipe=recipe, num_groups=num_groups, is_sfa=False ) - impl = { "100a": functools.partial( m_grouped_fp8_gemm_nt_masked_sm100, @@ -1451,9 +1319,9 @@ def __init__(self, sha256: str): def init_indices(self): indice_path = ArtifactPath.DEEPGEMM + "kernel_map" - assert get_cubin(indice_path, self.sha256, file_extension=".json"), ( - "cubin kernel map file not found, nor downloaded with matched sha256" - ) + assert get_cubin( + indice_path, self.sha256, file_extension=".json" + ), "cubin kernel map file not found, nor downloaded with matched sha256" path = FLASHINFER_CUBIN_DIR / f"{indice_path}.json" assert path.exists() with open(path, "r") as f: diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index c83928efb7..60e157b5d1 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,30 +19,23 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools from enum import Enum from types import SimpleNamespace from typing import List, Optional, Tuple -import torch - from .jit import JitSpec from .jit import env as jit_env -from .jit import gen_jit_spec, sm100a_nvcc_flags, sm90a_nvcc_flags +from .jit import gen_jit_spec, sm90a_nvcc_flags, sm100a_nvcc_flags from .jit.cpp_ext import is_cuda_version_at_least -from .utils import ( - device_support_pdl, - get_shuffle_matrix_a_row_indices, - get_shuffle_matrix_sf_a_row_indices, - register_custom_op, - register_fake_op, -) +from .utils import (device_support_pdl, get_shuffle_matrix_a_row_indices, + get_shuffle_matrix_sf_a_row_indices, register_custom_op, + register_fake_op) def _pad_scale_factors( - unswizzled_sf: torch.Tensor, m: int, n: int, sf_vec_size: int = 16 -) -> torch.Tensor: + unswizzled_sf: paddle.Tensor, m: int, n: int, sf_vec_size: int = 16 +) -> paddle.Tensor: """Pad scale factors tensor to meet alignment requirements. Args: @@ -49,17 +48,19 @@ def _pad_scale_factors( torch.Tensor: Padded scale factors tensor. """ factor = sf_vec_size * 4 - padded_row = ((m + 128 - 1) // 128) * 128 # Next multiple of 128 - padded_col = ((n + factor - 1) // factor) * factor # Next multiple of 64 - - # Pad the input tensor to [padded_row, padded_col // scaling_vector_size] + padded_row = (m + 128 - 1) // 128 * 128 + padded_col = (n + factor - 1) // factor * factor pad_rows = padded_row - m pad_cols = (padded_col - n) // sf_vec_size if pad_rows == 0 and pad_cols == 0: return unswizzled_sf else: - return torch.nn.functional.pad( - unswizzled_sf, (0, pad_cols, 0, pad_rows), mode="constant", value=0 + return paddle.nn.functional.pad( + x=unswizzled_sf, + pad=(0, pad_cols, 0, pad_rows), + mode="constant", + value=0, + pad_from_left_axis=False, ).contiguous() @@ -111,19 +112,16 @@ def get_fp4_quantization_module(backend: str = "100"): else: raise ValueError(f"Invalid backend: {backend}") - @register_custom_op( - "flashinfer::fp4_quantize_sm100", - mutates_args=(""), - ) + @register_custom_op("flashinfer::fp4_quantize_sm100", mutates_args="") def fp4_quantize_sm100( - input: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, + input: paddle.Tensor, + global_scale: Optional[paddle.Tensor] = None, sf_vec_size: int = 16, sf_use_ue8m0: bool = False, is_sf_swizzled_layout: bool = True, is_sf_8x4_layout: bool = False, enable_pdl: Optional[bool] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Quantize input tensor to FP4 format. Args: @@ -142,7 +140,7 @@ def fp4_quantize_sm100( - Scale factors tensor with shape determined by layout and sf_vec_size """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) return module.fp4_quantize( input, global_scale, @@ -155,50 +153,33 @@ def fp4_quantize_sm100( @register_fake_op("flashinfer::fp4_quantize_sm100") def _fake_fp4_quantize_sm100( - input: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, + input: paddle.Tensor, + global_scale: Optional[paddle.Tensor] = None, sf_vec_size: int = 16, sf_use_ue8m0: bool = False, is_sf_swizzled_layout: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor]: - m, k = input.shape - return ( - input.new_empty([m, k // 2], dtype=torch.int64), # FLOAT4_E2M1X2 - input.new_empty([m * k // sf_vec_size], dtype=torch.int32), # Scale factors + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + m, k = tuple(input.shape) + return paddle.empty(shape=[m, k // 2], dtype="int64"), paddle.empty( + shape=[m * k // sf_vec_size], dtype="int32" ) - @register_custom_op( - "flashinfer::mxfp4_dequantize_host", - mutates_args=(""), - ) + @register_custom_op("flashinfer::mxfp4_dequantize_host", mutates_args="") def mxfp4_dequantize_host( - weight: torch.Tensor, - scale: torch.Tensor, - group_size: int = 32, - ) -> torch.Tensor: - return module.mxfp4_dequantize_host( - weight, - scale, - group_size, - ) + weight: paddle.Tensor, scale: paddle.Tensor, group_size: int = 32 + ) -> paddle.Tensor: + return module.mxfp4_dequantize_host(weight, scale, group_size) @register_fake_op("flashinfer::mxfp4_dequantize_host") def _fake_mxfp4_dequantize_host( - weight: torch.Tensor, - scale: torch.Tensor, - group_size: int = 32, - ) -> torch.Tensor: - return weight.new_empty( - [weight.shape[0], weight.shape[1] * 2], dtype=torch.float32 + weight: paddle.Tensor, scale: paddle.Tensor, group_size: int = 32 + ) -> paddle.Tensor: + return paddle.empty( + shape=[tuple(weight.shape)[0], tuple(weight.shape)[1] * 2], dtype="float32" ) - @register_custom_op( - "flashinfer::block_scale_interleave_sm100", - mutates_args=("",), - ) - def block_scale_interleave_sm100( - unswizzled_sf: torch.Tensor, - ) -> torch.Tensor: + @register_custom_op("flashinfer::block_scale_interleave_sm100", mutates_args=("",)) + def block_scale_interleave_sm100(unswizzled_sf: paddle.Tensor) -> paddle.Tensor: """Swizzle block scale tensor for FP4 format. Args: @@ -207,30 +188,28 @@ def block_scale_interleave_sm100( Returns: torch.Tensor: output tensor for swizzled block scale with dtype uint8. """ - return module.block_scale_interleave_sm100( - unswizzled_sf, - ) + return module.block_scale_interleave_sm100(unswizzled_sf) @register_fake_op("flashinfer::block_scale_interleave_sm100") def _fake_block_scale_interleave_sm100( - unswizzled_sf: torch.Tensor, - ) -> torch.Tensor: - return unswizzled_sf.new_empty( - [unswizzled_sf.shape[0] * unswizzled_sf.shape[1] // 16], dtype=torch.uint8 + unswizzled_sf: paddle.Tensor, + ) -> paddle.Tensor: + return paddle.empty( + shape=[tuple(unswizzled_sf.shape)[0] * tuple(unswizzled_sf.shape)[1] // 16], + dtype="uint8", ) @register_custom_op( - "flashinfer::e2m1_and_ufp8sf_scale_to_float_sm100", - mutates_args=(""), + "flashinfer::e2m1_and_ufp8sf_scale_to_float_sm100", mutates_args="" ) def e2m1_and_ufp8sf_scale_to_float_sm100( - e2m1_tensor: torch.Tensor, - ufp8_scale_tensor: torch.Tensor, - global_scale_tensor: Optional[torch.Tensor] = None, + e2m1_tensor: paddle.Tensor, + ufp8_scale_tensor: paddle.Tensor, + global_scale_tensor: Optional[paddle.Tensor] = None, sf_vec_size: int = 16, ufp8_type: int = 1, is_sf_swizzled_layout: bool = True, - ) -> torch.Tensor: + ) -> paddle.Tensor: """Convert E2M1 format tensor and UFP8 scale factors to float tensor. This function performs dequantization by converting a packed FP4 tensor in E2M1 format @@ -258,18 +237,18 @@ def e2m1_and_ufp8sf_scale_to_float_sm100( @register_fake_op("flashinfer::e2m1_and_ufp8sf_scale_to_float_sm100") def _fake_e2m1_and_ufp8sf_scale_to_float_sm100( - e2m1_tensor: torch.Tensor, - ufp8_scale_tensor: torch.Tensor, - global_scale_tensor: Optional[torch.Tensor] = None, + e2m1_tensor: paddle.Tensor, + ufp8_scale_tensor: paddle.Tensor, + global_scale_tensor: Optional[paddle.Tensor] = None, sf_vec_size: int = 16, ufp8_type: int = 1, is_sf_swizzled_layout: bool = True, - ) -> torch.Tensor: - return e2m1_tensor.new_empty( - [e2m1_tensor.shape[0], e2m1_tensor.shape[1] * 2], dtype=torch.float32 + ) -> paddle.Tensor: + return paddle.empty( + shape=[tuple(e2m1_tensor.shape)[0], tuple(e2m1_tensor.shape)[1] * 2], + dtype="float32", ) - # Register the module return SimpleNamespace( fp4_quantize_sm100=fp4_quantize_sm100, block_scale_interleave_sm100=block_scale_interleave_sm100, @@ -279,14 +258,14 @@ def _fake_e2m1_and_ufp8sf_scale_to_float_sm100( def fp4_quantize( - input: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, + input: paddle.Tensor, + global_scale: Optional[paddle.Tensor] = None, sf_vec_size: int = 16, sf_use_ue8m0: bool = False, is_sf_swizzled_layout: bool = True, is_sf_8x4_layout: bool = False, enable_pdl: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[paddle.Tensor, paddle.Tensor]: """Quantize input tensor to FP4 format. This function implements FP4 quantization that converts input tensors to a compressed FP4 format @@ -315,15 +294,12 @@ def fp4_quantize( """ if sf_vec_size != 16 and sf_vec_size != 32: raise NotImplementedError("sf_vec_size can only be 16 or 32") - - # for column major input, we need to transpose the input - is_column_major = input.stride(-2) == 1 + is_column_major = input.get_strides()[-2] == 1 if is_column_major: - input = input.transpose(-2, -1) - - assert input.shape[-1] % sf_vec_size == 0 + input = input.transpose(perm=dim2perm(input.ndim, -2, -1)) + assert tuple(input.shape)[-1] % sf_vec_size == 0 if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) x_q, sf = get_fp4_quantization_module("100").fp4_quantize_sm100( input, global_scale, @@ -333,15 +309,14 @@ def fp4_quantize( is_sf_8x4_layout, enable_pdl, ) - sf = sf.reshape((-1, input.shape[-1] // sf_vec_size)) + sf = sf.reshape((-1, tuple(input.shape)[-1] // sf_vec_size)) if is_column_major: - x_q = x_q.transpose(-2, -1) - sf = sf.transpose(-2, -1) - + x_q = x_q.transpose(perm=dim2perm(x_q.ndim, -2, -1)) + sf = sf.transpose(perm=dim2perm(sf.ndim, -2, -1)) return x_q, sf -def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: +def block_scale_interleave(unswizzled_sf: paddle.Tensor) -> paddle.Tensor: """Swizzle block scale tensor for FP4 format. This function swizzles the block scale tensor to optimize memory access patterns @@ -356,31 +331,27 @@ def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: Raises: AssertionError: If input dtype is not uint8. """ - # TODO(shuw): check input dtype is uint8 - assert unswizzled_sf.dtype == torch.uint8, ( - f"Input dtype must be uint8, got {unswizzled_sf.dtype}" - ) - - major, minor = torch.cuda.get_device_capability() + assert ( + unswizzled_sf.dtype == "uint8" + ), f"Input dtype must be uint8, got {unswizzled_sf.dtype}" + major, minor = paddle.device.cuda.get_device_capability() device_arch = f"{major * 10 + minor}" - return get_fp4_quantization_module(device_arch).block_scale_interleave_sm100( - unswizzled_sf, + unswizzled_sf ) -# Maintain compatibility with libraries using the old name nvfp4_block_scale_interleave = block_scale_interleave def e2m1_and_ufp8sf_scale_to_float( - e2m1_tensor: torch.Tensor, - ufp8_scale_tensor: torch.Tensor, - global_scale_tensor: Optional[torch.Tensor] = None, + e2m1_tensor: paddle.Tensor, + ufp8_scale_tensor: paddle.Tensor, + global_scale_tensor: Optional[paddle.Tensor] = None, sf_vec_size: int = 16, ufp8_type: int = 1, is_sf_swizzled_layout: bool = True, -) -> torch.Tensor: +) -> paddle.Tensor: """Convert E2M1 format tensor and UFP8 scale factors to float tensor. This function performs dequantization by converting a packed FP4 tensor in E2M1 format @@ -398,7 +369,7 @@ def e2m1_and_ufp8sf_scale_to_float( torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. """ - major, minor = torch.cuda.get_device_capability() + major, minor = paddle.device.cuda.get_device_capability() device_arch = f"{major * 10 + minor}" return get_fp4_quantization_module( device_arch @@ -412,19 +383,18 @@ def e2m1_and_ufp8sf_scale_to_float( ) -def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: +def shuffle_matrix_a( + input_tensor: paddle.Tensor, epilogue_tile_m: int +) -> paddle.Tensor: """ PyTorch equivalent of trtllm-gen `shuffleMatrixA` """ row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) - - return input_tensor[row_indices.to(input_tensor.device)] + return input_tensor[row_indices.to(input_tensor.place)] def shuffle_matrix_sf_a( - input_tensor: torch.Tensor, - epilogue_tile_m: int, - num_elts_per_sf: int = 16, + input_tensor: paddle.Tensor, epilogue_tile_m: int, num_elts_per_sf: int = 16 ): """ Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat. @@ -436,12 +406,8 @@ def shuffle_matrix_sf_a( and are in linear layout. This function doesn't add padding. """ - row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m) - - w_shuffled = input_tensor[row_indices.to(input_tensor.device)] - - # 128x4 + w_shuffled = input_tensor[row_indices.to(input_tensor.place)] return block_scale_interleave(w_shuffled) @@ -480,9 +446,7 @@ def nvfp4_quantize( - Quantized tensor of shape [M, K/2] with dtype FLOAT4_E2M1X2 - Scale factors tensor with shape determined by layout and sf_vec_size """ - if do_shuffle: - # Weights 128x4 + shuffle. It is done during the model load and we do not care much about the perf assert sfLayout == SfLayout.layout_128x4 a_fp4, a_sf = fp4_quantize( a.cuda(), @@ -493,15 +457,12 @@ def nvfp4_quantize( is_sf_8x4_layout=False, enable_pdl=enable_pdl, ) - epilogue_tile_m = 128 - a_fp4 = shuffle_matrix_a(a_fp4.view(torch.uint8), epilogue_tile_m) - a_sf = shuffle_matrix_sf_a(a_sf.view(torch.uint8), epilogue_tile_m).reshape( - a_sf.shape + a_fp4 = shuffle_matrix_a(a_fp4.view("uint8"), epilogue_tile_m) + a_sf = shuffle_matrix_sf_a(a_sf.view("uint8"), epilogue_tile_m).reshape( + tuple(a_sf.shape) ) else: - # Activations with 8x4 layout for SFs (GEMM with small tileN) - # Activations with 128x4 layout for SFs (GEMM with large tileN) a_fp4, a_sf = fp4_quantize( a.cuda(), a_global_sf.cuda(), @@ -511,7 +472,6 @@ def nvfp4_quantize( is_sf_8x4_layout=sfLayout == SfLayout.layout_8x4, enable_pdl=enable_pdl, ) - return a_fp4, a_sf @@ -527,7 +487,7 @@ def mxfp4_quantize(a): - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - Scale factors tensor with shape determined by layout and sf_vec_size (uint8) """ - a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max() + a_global_sf = 448 * 6 / a.astype(dtype="float32").abs().nan_to_num()._max() a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True) return a_fp4, a_sf @@ -544,9 +504,9 @@ def mxfp4_dequantize(a_fp4, a_sf): torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. """ return e2m1_and_ufp8sf_scale_to_float( - a_fp4.cpu().view(torch.uint8), - a_sf.cpu().view(torch.uint8).reshape(-1), - torch.tensor([1.0], device=a_fp4.device), + a_fp4.cpu().view("uint8"), + a_sf.cpu().view("uint8").reshape(-1), + paddle.to_tensor(data=[1.0], place=a_fp4.place), 32, 0, True, @@ -554,10 +514,8 @@ def mxfp4_dequantize(a_fp4, a_sf): def mxfp4_dequantize_host( - weight: torch.Tensor, - scale: torch.Tensor, - group_size: int = 32, -) -> torch.Tensor: + weight: paddle.Tensor, scale: paddle.Tensor, group_size: int = 32 +) -> paddle.Tensor: """ Dequantize input tensor from MXFP4 format on host. @@ -569,10 +527,8 @@ def mxfp4_dequantize_host( Returns: torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. """ - major, minor = torch.cuda.get_device_capability() + major, minor = paddle.device.cuda.get_device_capability() device_arch = f"{major * 10 + minor}" return get_fp4_quantization_module(device_arch).mxfp4_dequantize_host( - weight, - scale, - group_size, + weight, scale, group_size ) diff --git a/flashinfer/fp8_quantization.py b/flashinfer/fp8_quantization.py index 86fd3062d1..60cf32d6b4 100644 --- a/flashinfer/fp8_quantization.py +++ b/flashinfer/fp8_quantization.py @@ -2,16 +2,12 @@ from types import SimpleNamespace from typing import Optional, Tuple -import torch +import paddle from .jit import JitSpec from .jit import env as jit_env from .jit import gen_jit_spec, sm100a_nvcc_flags -from .utils import ( - device_support_pdl, - register_custom_op, - register_fake_op, -) +from .utils import device_support_pdl, register_custom_op, register_fake_op def gen_mxfp8_quantization_sm100_module() -> JitSpec: @@ -27,16 +23,8 @@ def gen_mxfp8_quantization_sm100_module() -> JitSpec: jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp", ], extra_cuda_cflags=sm100a_nvcc_flags - + [ - "-DENABLE_BF16", - "-DENABLE_FP8", - "-DENABLE_FP4", - ], - extra_cflags=[ - "-DENABLE_BF16", - "-DENABLE_FP8", - "-DENABLE_FP4", - ], + + ["-DENABLE_BF16", "-DENABLE_FP8", "-DENABLE_FP4"], + extra_cflags=["-DENABLE_BF16", "-DENABLE_FP8", "-DENABLE_FP4"], extra_include_paths=[ jit_env.FLASHINFER_CSRC_DIR / "nv_internal", jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", @@ -48,16 +36,13 @@ def gen_mxfp8_quantization_sm100_module() -> JitSpec: def get_mxfp8_quantization_sm100_module(): module = gen_mxfp8_quantization_sm100_module().build_and_load() - @register_custom_op( - "flashinfer::mxfp8_quantize_sm100", - mutates_args=(""), - ) + @register_custom_op("flashinfer::mxfp8_quantize_sm100", mutates_args="") def mxfp8_quantize_sm100( - input: torch.Tensor, + input: paddle.Tensor, is_sf_swizzled_layout: bool = True, alignment: int = 32, enable_pdl: Optional[bool] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Quantize input tensor to MxFP8 format. Args: @@ -72,41 +57,29 @@ def mxfp8_quantize_sm100( - Scale factors tensor with shape determined by layout and sf_vec_size """ if input.device.type == "cpu": - return module.mxfp8_quantize_host( - input, - is_sf_swizzled_layout, - ) + return module.mxfp8_quantize_host(input, is_sf_swizzled_layout) else: if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) return module.mxfp8_quantize( - input, - is_sf_swizzled_layout, - alignment, - enable_pdl, + input, is_sf_swizzled_layout, alignment, enable_pdl ) @register_fake_op("flashinfer::mxfp8_quantize_sm100") def _fake_mxfp8_quantize_sm100( - input: torch.Tensor, - is_sf_swizzled_layout: bool = True, - alignment: int = 32, - ) -> Tuple[torch.Tensor, torch.Tensor]: - m, k = input.shape - return ( - input.new_empty([m, k], dtype=torch.int64), # FLOAT8_E4M3 - input.new_empty([m * k // 32], dtype=torch.int32), # Scale factors + input: paddle.Tensor, is_sf_swizzled_layout: bool = True, alignment: int = 32 + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + m, k = tuple(input.shape) + return paddle.empty(shape=[m, k], dtype="int64"), paddle.empty( + shape=[m * k // 32], dtype="int32" ) - @register_custom_op( - "flashinfer::mxfp8_dequantize_host_sm100", - mutates_args=("",), - ) + @register_custom_op("flashinfer::mxfp8_dequantize_host_sm100", mutates_args=("",)) def mxfp8_dequantize_host_sm100( - input: torch.Tensor, - scale_tensor: torch.Tensor, + input: paddle.Tensor, + scale_tensor: paddle.Tensor, is_sf_swizzled_layout: bool = True, - ) -> torch.Tensor: + ) -> paddle.Tensor: """Dequantize input tensor from MxFP8 format. Args: @@ -117,21 +90,18 @@ def mxfp8_dequantize_host_sm100( Returns: torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. """ - return module.mxfp8_dequantize_host( - input, - scale_tensor, - is_sf_swizzled_layout, - ) + return module.mxfp8_dequantize_host(input, scale_tensor, is_sf_swizzled_layout) @register_fake_op("flashinfer::mxfp8_dequantize_host_sm100") def _fake_mxfp8_dequantize_host_sm100( - input: torch.Tensor, - scale_tensor: torch.Tensor, + input: paddle.Tensor, + scale_tensor: paddle.Tensor, is_sf_swizzled_layout: bool = True, - ) -> torch.Tensor: - return input.new_empty([input.shape[0], input.shape[1]], dtype=torch.float32) + ) -> paddle.Tensor: + return paddle.empty( + shape=[tuple(input.shape)[0], tuple(input.shape)[1]], dtype="float32" + ) - # Register the module return SimpleNamespace( mxfp8_quantize_sm100=mxfp8_quantize_sm100, mxfp8_dequantize_host_sm100=mxfp8_dequantize_host_sm100, @@ -139,11 +109,11 @@ def _fake_mxfp8_dequantize_host_sm100( def mxfp8_quantize( - input: torch.Tensor, + input: paddle.Tensor, is_sf_swizzled_layout: bool = True, alignment: int = 32, enable_pdl: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[paddle.Tensor, paddle.Tensor]: """Quantize input tensor to MxFP8 format. This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format @@ -161,24 +131,20 @@ def mxfp8_quantize( - Scale factors tensor with shape determined by layout and sf_vec_size """ sf_vec_size = 32 - - assert input.shape[-1] % sf_vec_size == 0 + assert tuple(input.shape)[-1] % sf_vec_size == 0 if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) x_q, sf = get_mxfp8_quantization_sm100_module().mxfp8_quantize_sm100( - input, - is_sf_swizzled_layout, - alignment, - enable_pdl, + input, is_sf_swizzled_layout, alignment, enable_pdl ) return x_q, sf def mxfp8_dequantize_host( - input: torch.Tensor, - scale_tensor: torch.Tensor, + input: paddle.Tensor, + scale_tensor: paddle.Tensor, is_sf_swizzled_layout: bool = True, -) -> torch.Tensor: +) -> paddle.Tensor: """Dequantize input tensor from MxFP8 format. This function performs dequantization by converting a packed FP8 tensor in MxFP8 format @@ -193,9 +159,6 @@ def mxfp8_dequantize_host( torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. """ - return get_mxfp8_quantization_sm100_module().mxfp8_dequantize_host_sm100( - input, - scale_tensor, - is_sf_swizzled_layout, + input, scale_tensor, is_sf_swizzled_layout ) diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index 15a4760739..3a0cb19f91 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -13,21 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. """ - -from .core import ( - RoutingMethodType, - GatedActType, - WeightLayout, - convert_to_block_layout, - cutlass_fused_moe, - gen_cutlass_fused_moe_sm100_module, - gen_cutlass_fused_moe_sm90_module, - reorder_rows_for_gated_act_gemm, - trtllm_fp4_block_scale_moe, - trtllm_fp4_block_scale_routed_moe, - trtllm_fp8_block_scale_moe, - trtllm_fp8_per_tensor_scale_moe, -) +from .core import (GatedActType, RoutingMethodType, WeightLayout, + convert_to_block_layout, cutlass_fused_moe, + gen_cutlass_fused_moe_sm90_module, + gen_cutlass_fused_moe_sm100_module, + reorder_rows_for_gated_act_gemm, trtllm_fp4_block_scale_moe, + trtllm_fp4_block_scale_routed_moe, + trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe) __all__ = [ "RoutingMethodType", diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 9a7ba595cf..7310803ba1 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,97 +19,73 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools from enum import IntEnum from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple, Union -import torch - from ..artifacts import ArtifactPath, MetaInfoHash -from ..autotuner import ( - AutoTuner, - DynamicTensorSpec, - OptimizationProfile, - TunableRunner, - TuningConfig, -) +from ..autotuner import (AutoTuner, DynamicTensorSpec, OptimizationProfile, + TunableRunner, TuningConfig) from ..jit import JitSpec from ..jit import env as jit_env -from ..jit import gen_jit_spec, setup_cubin_loader, sm100a_nvcc_flags, sm90a_nvcc_flags +from ..jit import (gen_jit_spec, setup_cubin_loader, sm90a_nvcc_flags, + sm100a_nvcc_flags) from ..jit.cpp_ext import is_cuda_version_at_least from ..jit.cubin_loader import get_cubin from ..jit.cutlass_gemm.generate_kernels import generate_gemm_operations -from ..utils import ( - check_shape_dtype_device, - device_support_pdl, - get_shuffle_matrix_a_row_indices, - get_shuffle_matrix_sf_a_row_indices, - register_custom_op, - register_fake_op, -) -from .utils import ( - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - next_positive_power_of_2, -) - - -# The type of method in top-K routing, for use in torch custom op -# Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h +from ..utils import (check_shape_dtype_device, device_support_pdl, + get_shuffle_matrix_a_row_indices, + get_shuffle_matrix_sf_a_row_indices, register_custom_op, + register_fake_op) +from .utils import (get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, next_positive_power_of_2) + + class RoutingMethodType(IntEnum): - # Default: Softmax -> TopK Default = (0,) - # Renormalize: TopK -> Softmax Renormalize = (1,) - # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts from the Top4 groups DeepSeekV3 = (2,) - # Llama4: Top1 -> Sigmoid Llama4 = (3,) - # Qwen3: Softmax -> TopK -> Renormalize RenormalizeNaive = (4,) - # TopK only (no softmax) TopK = (5,) - # Unspecified Unspecified = 6 class DtypeTrtllmGen(IntEnum): def __new__(cls, block_format_bit, signed_bit, integer_bit, num_bits, uid): value = ( - (block_format_bit << 24) - | (signed_bit << 20) - | (integer_bit << 16) - | (num_bits << 8) + block_format_bit << 24 + | signed_bit << 20 + | integer_bit << 16 + | num_bits << 8 | uid ) obj = int.__new__(cls, value) obj._value_ = value return obj - # keep the values in sync with include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h - Bfloat16 = (0, 1, 0, 16, 0) - Bool = (0, 0, 1, 1, 1) - E2m1 = (1, 1, 0, 4, 2) - E2m3 = (1, 1, 0, 6, 3) - E3m2 = (1, 1, 0, 6, 4) - E4m3 = (0, 1, 0, 8, 5) - E5m2 = (0, 1, 0, 8, 6) - Fp16 = (0, 1, 0, 16, 7) - Fp32 = (0, 1, 0, 32, 8) - Int8 = (0, 1, 1, 8, 9) - Int32 = (0, 1, 1, 32, 10) - Int64 = (0, 1, 1, 64, 11) - MxE2m1 = (1, 1, 0, 4, 12) - MxE4m3 = (1, 1, 0, 8, 13) - UE8m0 = (0, 0, 0, 8, 14) - UInt8 = (0, 0, 1, 8, 15) - UInt16 = (0, 0, 1, 16, 16) - UInt32 = (0, 0, 1, 32, 17) - UInt64 = (0, 0, 1, 64, 18) - UInt128 = (0, 0, 1, 128, 19) - Void = (0, 1, 0, 0, 20) + Bfloat16 = 0, 1, 0, 16, 0 + Bool = 0, 0, 1, 1, 1 + E2m1 = 1, 1, 0, 4, 2 + E2m3 = 1, 1, 0, 6, 3 + E3m2 = 1, 1, 0, 6, 4 + E4m3 = 0, 1, 0, 8, 5 + E5m2 = 0, 1, 0, 8, 6 + Fp16 = 0, 1, 0, 16, 7 + Fp32 = 0, 1, 0, 32, 8 + Int8 = 0, 1, 1, 8, 9 + Int32 = 0, 1, 1, 32, 10 + Int64 = 0, 1, 1, 64, 11 + MxE2m1 = 1, 1, 0, 4, 12 + MxE4m3 = 1, 1, 0, 8, 13 + UE8m0 = 0, 0, 0, 8, 14 + UInt8 = 0, 0, 1, 8, 15 + UInt16 = 0, 0, 1, 16, 16 + UInt32 = 0, 0, 1, 32, 17 + UInt64 = 0, 0, 1, 64, 18 + UInt128 = 0, 0, 1, 128, 19 + Void = 0, 1, 0, 0, 20 def trtllm_gen_dtype_has_scale(dtype: DtypeTrtllmGen) -> bool: @@ -119,20 +101,18 @@ def trtllm_gen_dtype_has_scale(dtype: DtypeTrtllmGen) -> bool: def deduce_trtllm_gen_tensor_dtype( - x: torch.Tensor, scale: Optional[torch.Tensor] + x: paddle.Tensor, scale: Optional[paddle.Tensor] ) -> DtypeTrtllmGen: - hidden_size = x.shape[-1] - if x.dtype == torch.uint8: # FIXME(siyuan): use torch.float4_e2m1x2 after torch 2.8 + hidden_size = tuple(x.shape)[-1] + if x.dtype == "uint8": hidden_size *= 2 - if x.dtype == torch.bfloat16: + if x.dtype == "bfloat16": dtype = DtypeTrtllmGen.Bfloat16 - elif x.dtype == torch.float8_e4m3fn: + elif x.dtype == paddle.float8_e4m3fn: dtype = DtypeTrtllmGen.E4m3 if scale is None else DtypeTrtllmGen.MxE4m3 - elif ( - x.dtype == torch.uint8 - ): # FIXME(siyuan): use torch.float4_e2m1x2 after torch 2.8 + elif x.dtype == "uint8": assert scale is not None, "Scale tensor must be provided for float4x2 input" - if scale.shape[-1] == hidden_size // 16: + if tuple(scale.shape)[-1] == hidden_size // 16: dtype = DtypeTrtllmGen.E2m1 else: dtype = DtypeTrtllmGen.MxE2m1 @@ -141,34 +121,24 @@ def deduce_trtllm_gen_tensor_dtype( return dtype -# See MatrixLayout from include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h class WeightLayout(IntEnum): - # K-major layout (default). [Mn, K] MajorK = 0 - # M-major for A and N-major for B. [K, Mn] MajorMn = 1 - # Layout is blocked along the K dimension. [K / blockK, Mn, blockK] - # where blockK is fixed at 128B BlockMajorK = 2 -# The type of gated activation function -# Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h class GatedActType(IntEnum): - # SwiGlu SwiGlu = 0 - # GeGlu GeGlu = 1 def _maybe_get_cached_w3_w1_permute_indices( _cache_permute_indices, - dst_w3_w1_weight: torch.Tensor, + dst_w3_w1_weight: paddle.Tensor, epilogue_tile_m: int, num_elts_per_sf: Union[None, int] = None, -) -> torch.Tensor: - if dst_w3_w1_weight.shape not in _cache_permute_indices: - # Get permute indices and chain them together +) -> paddle.Tensor: + if tuple(dst_w3_w1_weight.shape) not in _cache_permute_indices: permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(dst_w3_w1_weight) if num_elts_per_sf is None: permute1 = get_shuffle_matrix_a_row_indices( @@ -180,38 +150,36 @@ def _maybe_get_cached_w3_w1_permute_indices( epilogue_tile_m=epilogue_tile_m, num_elts_per_sf=num_elts_per_sf, ) - # Memoize permute indices as recompute is **very** costly - _cache_permute_indices[dst_w3_w1_weight.shape] = permute0[permute1].to( - dst_w3_w1_weight.device + _cache_permute_indices[tuple(dst_w3_w1_weight.shape)] = permute0[permute1].to( + dst_w3_w1_weight.place ) - permute_indices = _cache_permute_indices[dst_w3_w1_weight.shape] + permute_indices = _cache_permute_indices[tuple(dst_w3_w1_weight.shape)] return permute_indices def _maybe_get_cached_w2_permute_indices( _cache_permute_indices, - dst_w2_weight: torch.Tensor, + dst_w2_weight: paddle.Tensor, epilogue_tile_m: int, num_elts_per_sf: Union[None, int] = None, -) -> torch.Tensor: - if dst_w2_weight.shape not in _cache_permute_indices: +) -> paddle.Tensor: + if tuple(dst_w2_weight.shape) not in _cache_permute_indices: if num_elts_per_sf is None: permute_indices = get_shuffle_matrix_a_row_indices( dst_w2_weight, epilogue_tile_m - ).to(dst_w2_weight.device) + ).to(dst_w2_weight.place) else: permute_indices = get_shuffle_matrix_sf_a_row_indices( dst_w2_weight, epilogue_tile_m=epilogue_tile_m, num_elts_per_sf=num_elts_per_sf, - ).to(dst_w2_weight.device) - # Memoize permute indices as recompute is **very** costly - _cache_permute_indices[dst_w2_weight.shape] = permute_indices - permute_indices = _cache_permute_indices[dst_w2_weight.shape] + ).to(dst_w2_weight.place) + _cache_permute_indices[tuple(dst_w2_weight.shape)] = permute_indices + permute_indices = _cache_permute_indices[tuple(dst_w2_weight.shape)] return permute_indices -def get_reorder_rows_for_gated_act_gemm_row_indices(x) -> torch.Tensor: +def get_reorder_rows_for_gated_act_gemm_row_indices(x) -> paddle.Tensor: """ Reorders rows in the gemm/MOE_gemm weight matrix for min-latency [r0, r1, r2, r3, ..., rN/2, r(N/2+1), .. r(N-1)] @@ -219,23 +187,14 @@ def get_reorder_rows_for_gated_act_gemm_row_indices(x) -> torch.Tensor: [r0, rN/2, r1, rN/2+1, ..., r(N/2-1), r(N-1)] """ assert x.dim() == 2, f"x should be a 2D tensor, not {x.dim()}" - M, K = x.shape + M, K = tuple(x.shape) assert M % 2 == 0, f"x.shape[0] must be even, not {M}" - - row_indices = torch.arange(M, dtype=torch.long) - - # We split into top half and bottom half, but if M is odd, - # the bottom half is one row larger. - top = row_indices[: (M + 1) // 2] # round up - bot = row_indices[(M + 1) // 2 :] # remainder - - # Create the output - permuted_row_indices = torch.empty_like(row_indices) - - # We'll place rows of `top` and `bot` in alternation + row_indices = paddle.arange(dtype="int64", end=M) + top = row_indices[: (M + 1) // 2] + bot = row_indices[(M + 1) // 2 :] + permuted_row_indices = paddle.empty_like(x=row_indices) permuted_row_indices[0::2] = top permuted_row_indices[1::2] = bot - return permuted_row_indices @@ -244,16 +203,16 @@ def reorder_rows_for_gated_act_gemm(x): PyTorch implementation of trt-llm gen `reorderRowsForGatedActGemm` """ row_indices = get_reorder_rows_for_gated_act_gemm_row_indices(x) - permute = lambda x: x[row_indices] - return permute(x) -def convert_to_block_layout(input_tensor: torch.Tensor, blockK: int) -> torch.Tensor: - M, K = input_tensor.shape +def convert_to_block_layout(input_tensor: paddle.Tensor, blockK: int) -> paddle.Tensor: + M, K = tuple(input_tensor.shape) assert K % blockK == 0, "K must be divisible by blockK" - return input_tensor.view(M, K // blockK, blockK).permute(1, 0, 2).contiguous() + return ( + input_tensor.view(M, K // blockK, blockK).transpose(perm=[1, 0, 2]).contiguous() + ) def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec: @@ -290,19 +249,11 @@ def gen_cutlass_fused_moe_module( jit_env.FLASHINFER_CSRC_DIR / f"nv_internal/tensorrt_llm/cutlass_instantiations/{device_arch}" ) - try: - # Create output directory if it doesn't exist output_dir.mkdir(parents=True, exist_ok=True) - - generate_gemm_operations( - output_dir, - f"{device_arch};{device_arch}-real", - ) - + generate_gemm_operations(output_dir, f"{device_arch};{device_arch}-real") except Exception as e: raise RuntimeError(f"Failed to generate Cutlass kernels: {e}") from e - return gen_jit_spec( f"fused_moe_{device_arch}", [ @@ -342,7 +293,6 @@ def gen_cutlass_fused_moe_module( / "fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_ops.cu", jit_env.FLASHINFER_CSRC_DIR / "fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu", - # Add all generated kernels *(output_dir / kernel for kernel in output_dir.rglob("*.generated.cu")), jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp", jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp", @@ -396,9 +346,8 @@ def get_cutlass_fused_moe_module(backend: str = "100", use_fast_build: bool = Fa raise ValueError(f"Invalid backend: {backend}") class MoERunner(TunableRunner): - # avoid overhead of creating a new runner in forward pass runner_dict: Dict[ - Tuple[torch.dtype, torch.dtype, torch.dtype, bool, bool, bool], Any + Tuple[paddle.dtype, paddle.dtype, paddle.dtype, bool, bool, bool], Any ] = dict() tuning_config = TuningConfig( dynamic_tensor_specs=( @@ -413,9 +362,9 @@ class MoERunner(TunableRunner): def __init__( self, - x_dtype: torch.dtype, - weight_dtype: torch.dtype, - output_dtype: torch.dtype, + x_dtype: paddle.dtype, + weight_dtype: paddle.dtype, + output_dtype: paddle.dtype, top_k: int, tp_size: int, tp_rank: int, @@ -454,7 +403,6 @@ def __init__( use_w4_group_scaling, use_mxfp8_act_scaling, ) - if instance_key not in MoERunner.runner_dict: MoERunner.runner_dict[instance_key] = FusedMoeRunner( x_dtype, @@ -464,19 +412,16 @@ def __init__( use_w4_group_scaling, use_mxfp8_act_scaling, ) - self.fused_moe_runner = MoERunner.runner_dict[instance_key] def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, + self, inputs: List[paddle.Tensor], profile: OptimizationProfile ) -> List[int]: return list(range(self.fused_moe_runner.get_tactic_num())) def forward( self, - inputs: List[torch.Tensor], + inputs: List[paddle.Tensor], tactic: int = -1, do_preparation: bool = False, **kwargs, @@ -523,25 +468,22 @@ def refine_tuning_config(cls, tune_max_num_tokens: int): ) ) - @register_custom_op( - "flashinfer::cutlass_fused_moe", - mutates_args=(""), - ) + @register_custom_op("flashinfer::cutlass_fused_moe", mutates_args="") def cutlass_fused_moe( - output: torch.Tensor, - input: torch.Tensor, - token_selected_experts: torch.Tensor, - token_final_scales: torch.Tensor, - fc1_expert_weights: torch.Tensor, - fc1_expert_biases: Optional[torch.Tensor], - fc2_expert_weights: torch.Tensor, - fc2_expert_biases: Optional[torch.Tensor], - output_dtype: torch.dtype, - quant_scales: List[torch.Tensor], - input_sf: Optional[torch.Tensor] = None, - swiglu_alpha: Optional[torch.Tensor] = None, - swiglu_beta: Optional[torch.Tensor] = None, - swiglu_limit: Optional[torch.Tensor] = None, + output: paddle.Tensor, + input: paddle.Tensor, + token_selected_experts: paddle.Tensor, + token_final_scales: paddle.Tensor, + fc1_expert_weights: paddle.Tensor, + fc1_expert_biases: Optional[paddle.Tensor], + fc2_expert_weights: paddle.Tensor, + fc2_expert_biases: Optional[paddle.Tensor], + output_dtype: paddle.dtype, + quant_scales: List[paddle.Tensor], + input_sf: Optional[paddle.Tensor] = None, + swiglu_alpha: Optional[paddle.Tensor] = None, + swiglu_beta: Optional[paddle.Tensor] = None, + swiglu_limit: Optional[paddle.Tensor] = None, tp_size: int = 1, tp_rank: int = 0, ep_size: int = 1, @@ -555,18 +497,16 @@ def cutlass_fused_moe( min_latency_mode: bool = False, tune_max_num_tokens: int = 8192, enable_pdl: Optional[bool] = None, - ) -> List[torch.Tensor]: + ) -> List[paddle.Tensor]: if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) - - # allocate workspace for profiling moe_runner = MoERunner( x_dtype=input.dtype, weight_dtype=fc1_expert_weights.dtype, output_dtype=output_dtype, - top_k=token_selected_experts.size(1), + top_k=token_selected_experts.shape[1], tp_size=tp_size, tp_rank=tp_rank, ep_size=ep_size, @@ -580,7 +520,6 @@ def cutlass_fused_moe( min_latency_mode=min_latency_mode, enable_pdl=enable_pdl, ) - _, gemm_tactic_1 = tuner.choose_one( "trtllm::fused_moe::gemm1", [moe_runner], @@ -594,7 +533,6 @@ def cutlass_fused_moe( ], gemm_idx=1, ) - _, gemm_tactic_2 = tuner.choose_one( "trtllm::fused_moe::gemm2", [moe_runner], @@ -608,7 +546,6 @@ def cutlass_fused_moe( ], gemm_idx=2, ) - run_moe = ( moe_runner.fused_moe_runner.run_moe_min_latency if min_latency_mode @@ -639,25 +576,24 @@ def cutlass_fused_moe( [gemm_tactic_1, gemm_tactic_2], enable_pdl, ) - return result if min_latency_mode else [result] @register_fake_op("flashinfer::cutlass_fused_moe") def _fake_cutlass_fused_moe( - output: torch.Tensor, - input: torch.Tensor, - token_selected_experts: torch.Tensor, - token_final_scales: torch.Tensor, - fc1_expert_weights: torch.Tensor, - fc1_expert_biases: Optional[torch.Tensor], - fc2_expert_weights: torch.Tensor, - fc2_expert_biases: Optional[torch.Tensor], - output_dtype: torch.dtype, - quant_scales: List[torch.Tensor], - input_sf: Optional[torch.Tensor] = None, - swiglu_alpha: Optional[torch.Tensor] = None, - swiglu_beta: Optional[torch.Tensor] = None, - swiglu_limit: Optional[torch.Tensor] = None, + output: paddle.Tensor, + input: paddle.Tensor, + token_selected_experts: paddle.Tensor, + token_final_scales: paddle.Tensor, + fc1_expert_weights: paddle.Tensor, + fc1_expert_biases: Optional[paddle.Tensor], + fc2_expert_weights: paddle.Tensor, + fc2_expert_biases: Optional[paddle.Tensor], + output_dtype: paddle.dtype, + quant_scales: List[paddle.Tensor], + input_sf: Optional[paddle.Tensor] = None, + swiglu_alpha: Optional[paddle.Tensor] = None, + swiglu_beta: Optional[paddle.Tensor] = None, + swiglu_limit: Optional[paddle.Tensor] = None, tp_size: int = 1, tp_rank: int = 0, ep_size: int = 1, @@ -672,51 +608,46 @@ def _fake_cutlass_fused_moe( tune_max_num_tokens: int = 8192, enable_pdl: Optional[bool] = None, ): - seq_len = input.shape[0] - hidden_size = fc2_expert_weights.shape[1] - + seq_len = tuple(input.shape)[0] + hidden_size = tuple(fc2_expert_weights.shape)[1] if min_latency_mode: - num_experts_on_rank = fc2_expert_weights.shape[0] + num_experts_on_rank = tuple(fc2_expert_weights.shape)[0] output_shape = [seq_len * num_experts_on_rank, hidden_size] experts_to_token_score_shape = [num_experts_on_rank, seq_len] active_expert_global_ids_shape = [num_experts_on_rank] return [ - input.new_empty(output_shape, dtype=output_dtype), - input.new_empty([1], dtype=torch.int32), - input.new_empty(experts_to_token_score_shape, dtype=torch.float32), - input.new_empty(active_expert_global_ids_shape, dtype=torch.int32), + paddle.empty(shape=output_shape, dtype=output_dtype), + paddle.empty(shape=[1], dtype="int32"), + paddle.empty(shape=experts_to_token_score_shape, dtype="float32"), + paddle.empty(shape=active_expert_global_ids_shape, dtype="int32"), ] else: - return [input.new_empty([seq_len, hidden_size], dtype=output_dtype)] + return [paddle.empty(shape=[seq_len, hidden_size], dtype=output_dtype)] - # Register the module - return SimpleNamespace( - cutlass_fused_moe=cutlass_fused_moe, - ) + return SimpleNamespace(cutlass_fused_moe=cutlass_fused_moe) -# ref: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py#L121 def cutlass_fused_moe( - input: torch.Tensor, - token_selected_experts: torch.Tensor, - token_final_scales: torch.Tensor, - fc1_expert_weights: torch.Tensor, - fc2_expert_weights: torch.Tensor, - output_dtype: torch.dtype, - quant_scales: List[torch.Tensor], - fc1_expert_biases: Optional[torch.Tensor] = None, - fc2_expert_biases: Optional[torch.Tensor] = None, - input_sf: Optional[torch.Tensor] = None, - swiglu_alpha: Optional[torch.Tensor] = None, - swiglu_beta: Optional[torch.Tensor] = None, - swiglu_limit: Optional[torch.Tensor] = None, + input: paddle.Tensor, + token_selected_experts: paddle.Tensor, + token_final_scales: paddle.Tensor, + fc1_expert_weights: paddle.Tensor, + fc2_expert_weights: paddle.Tensor, + output_dtype: paddle.dtype, + quant_scales: List[paddle.Tensor], + fc1_expert_biases: Optional[paddle.Tensor] = None, + fc2_expert_biases: Optional[paddle.Tensor] = None, + input_sf: Optional[paddle.Tensor] = None, + swiglu_alpha: Optional[paddle.Tensor] = None, + swiglu_beta: Optional[paddle.Tensor] = None, + swiglu_limit: Optional[paddle.Tensor] = None, tp_size: int = 1, tp_rank: int = 0, ep_size: int = 1, ep_rank: int = 0, cluster_size: int = 1, cluster_rank: int = 0, - output: Optional[torch.Tensor] = None, + output: Optional[paddle.Tensor] = None, enable_alltoall: bool = False, use_deepseek_fp8_block_scale: bool = False, use_w4_group_scaling: bool = False, @@ -724,7 +655,7 @@ def cutlass_fused_moe( min_latency_mode: bool = False, tune_max_num_tokens: int = 8192, enable_pdl: Optional[bool] = None, -) -> torch.Tensor: +) -> paddle.Tensor: """Compute a Mixture of Experts (MoE) layer using CUTLASS backend. This function implements a fused MoE layer that combines expert selection, expert computation, @@ -855,26 +786,21 @@ def cutlass_fused_moe( ) if min_latency_mode: raise NotImplementedError("min latency mode not yet implemented for Blackwell.") - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) - - num_rows = input.shape[0] + enable_pdl = device_support_pdl(input.place) + num_rows = tuple(input.shape)[0] if min_latency_mode: - num_rows *= fc2_expert_weights.shape[0] - hidden_size = fc2_expert_weights.shape[1] - output_shape = (num_rows, hidden_size) - + num_rows *= tuple(fc2_expert_weights.shape)[0] + hidden_size = tuple(fc2_expert_weights.shape)[1] + output_shape = num_rows, hidden_size if output is None: - output = torch.empty(output_shape, dtype=output_dtype, device=input.device) + output = paddle.empty(shape=output_shape, dtype=output_dtype) else: check_shape_dtype_device( - output, output_shape, output_dtype, input.device, "output" + output, output_shape, output_dtype, input.place, "output" ) - - major, minor = torch.cuda.get_device_capability() + major, minor = paddle.device.cuda.get_device_capability() device_arch = f"{major * 10 + minor}" - return get_cutlass_fused_moe_module(device_arch).cutlass_fused_moe( output, input, @@ -906,23 +832,13 @@ def cutlass_fused_moe( ) -# trtllmgen-moe-fp8 - - def trtllm_gen_fused_moe_sm100_module() -> JitSpec: - # Fetch "flashinferMetaInfo.h" from the online kernel cache. This file - # contains the `tllmGenBatchedGemmList` as the list of available kernels - # online. It is included when compiling `trtllm_fused_moe_runner.cu`, etc. include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include" header_name = "flashinferMetaInfo" - - # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( f"{include_path}/{header_name}", MetaInfoHash.TRTLLM_GEN_BMM, ".h" ) - # make sure "flashinferMetaInfo.h" is downloaded or cached assert metainfo, f"{header_name}.h not found" - return gen_jit_spec( "fused_moe_trtllm_sm100", [ @@ -950,7 +866,6 @@ def trtllm_gen_fused_moe_sm100_module() -> JitSpec: + sm100a_nvcc_flags, extra_ldflags=["-lcuda"], extra_include_paths=[ - # link "include" sub-directory in cache jit_env.FLASHINFER_CUBIN_DIR / include_path, jit_env.FLASHINFER_CSRC_DIR / "nv_internal", jit_env.FLASHINFER_CSRC_DIR / "nv_internal/include", @@ -966,26 +881,13 @@ def get_trtllm_moe_sm100_module(): class MoERunner(TunableRunner): dynamic_tensor_initializers = [ - lambda shapes, dtype, device: torch.empty( - shapes, device=device, dtype=dtype - ), # output buffer, [num_tokens, hidden_size] - lambda shapes, dtype, device: torch.rand( - shapes, device=device, dtype=dtype - ), # routing_logits, [num_tokens, num_experts] - lambda shapes, dtype, device: torch.empty( - shapes, device=device, dtype=dtype - ), # topk_ids buffer. empty since routing_logits is used. [num_tokens, topk] - lambda shapes, dtype, device: torch.empty( - shapes, device=device, dtype=dtype - ), # expert_weights buffer. empty since routing_logits is used. [num_tokens, topk] - lambda shapes, dtype, device: torch.randn(shapes, device=device).to( - dtype - ), # hidden_states, [num_tokens, hidden_size] - lambda shapes, dtype, device: torch.ones(shapes, device=device).to( - dtype - ), # hidden_states_scale, [num_tokens, hidden_size // sf_vec_size] + lambda shapes, dtype, device: paddle.empty(shape=shapes, dtype=dtype), + lambda shapes, dtype, device: paddle.rand(shape=shapes, dtype=dtype), + lambda shapes, dtype, device: paddle.empty(shape=shapes, dtype=dtype), + lambda shapes, dtype, device: paddle.empty(shape=shapes, dtype=dtype), + lambda shapes, dtype, device: paddle.randn(shape=shapes).to(dtype), + lambda shapes, dtype, device: paddle.ones(shape=shapes).to(dtype), ] - # their first dimension is num_tokens which will be tuned tuning_config_with_hidden_states_scales = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec( @@ -1006,10 +908,8 @@ class MoERunner(TunableRunner): lambda x: min(last_positive_power_of_2(x), 1024), dynamic_tensor_initializers[:5], ), - ), + ) ) - # cache the valid tactics to reduce the overhead of instantiating the runner - # TODO(siyuan): directly cache the runners valid_tactics_dict = dict() def __init__( @@ -1036,31 +936,15 @@ def __init__( self.tile_tokens_dim = tile_tokens_dim def get_tile_tokens_dim(self, num_tokens: int, top_k: int): - # Factor to account for the imbalance of the experts. - # factor equals to the - # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert - # - 1.0 means perfect expert distribution. - # - > 1.0 means some experts have more - # tokens than the perfect distribution. - # - < 1.0 does not make sense. imbalance_factor = 1.3 - # Calculate the number of tokens per expert - # assuming perfect distribution. - num_tokens_per_expert = (num_tokens * top_k) // self.num_experts - # Apply the imbalance factor. + num_tokens_per_expert = num_tokens * top_k // self.num_experts num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) - # And pad the number to the next power of 2. tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile - # as it's the range supported by the kernel. tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - return tile_tokens_dim def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, + self, inputs: List[paddle.Tensor], profile: OptimizationProfile ) -> List[int]: ( output, @@ -1070,7 +954,7 @@ def get_valid_tactics( hidden_states, *extra_inputs, ) = inputs - num_tokens = routing_logits.shape[0] + num_tokens = tuple(routing_logits.shape)[0] tile_tokens_dim = ( self.get_tile_tokens_dim(num_tokens, self.top_k) if self.tile_tokens_dim is None @@ -1089,14 +973,14 @@ def get_valid_tactics( num_tokens, ) if instance_key not in MoERunner.valid_tactics_dict: - MoERunner.valid_tactics_dict[instance_key] = ( - moe_op.trtllm_get_valid_moe_configs(*instance_key) - ) + MoERunner.valid_tactics_dict[ + instance_key + ] = moe_op.trtllm_get_valid_moe_configs(*instance_key) return MoERunner.valid_tactics_dict[instance_key] def forward( self, - inputs: List[torch.Tensor], + inputs: List[paddle.Tensor], tactic: int = -1, do_preparation: bool = False, **kwargs, @@ -1109,45 +993,42 @@ def forward( hidden_states, *extra_inputs, ) = inputs - num_tokens = routing_logits.shape[0] + num_tokens = tuple(routing_logits.shape)[0] tile_tokens_dim = ( self.get_tile_tokens_dim(num_tokens, self.top_k) if self.tile_tokens_dim is None else self.tile_tokens_dim ) - extra_input_idx = 0 if trtllm_gen_dtype_has_scale(self.dtype_act): hidden_states_scale = extra_inputs[extra_input_idx] extra_input_idx += 1 else: hidden_states_scale = None - # sanity checks to ensure that dynamic tensors have the correct shapes - assert output.shape[0] == num_tokens, ( - "output's first dimension must be batch size." - ) - assert topk_ids.shape[0] == num_tokens, ( - "topk_ids's first dimension must be batch size." - ) - assert expert_weights.shape[0] == num_tokens, ( - "expert_weights's first dimension must be batch size." - ) - assert hidden_states.shape[0] == num_tokens, ( - "hidden_states's first dimension must be batch size." - ) - assert hidden_states_scale is None or ( - hidden_states_scale.dim() == 2 - and hidden_states_scale.shape[0] == num_tokens + assert ( + tuple(output.shape)[0] == num_tokens + ), "output's first dimension must be batch size." + assert ( + tuple(topk_ids.shape)[0] == num_tokens + ), "topk_ids's first dimension must be batch size." + assert ( + tuple(expert_weights.shape)[0] == num_tokens + ), "expert_weights's first dimension must be batch size." + assert ( + tuple(hidden_states.shape)[0] == num_tokens + ), "hidden_states's first dimension must be batch size." + assert ( + hidden_states_scale is None + or hidden_states_scale.dim() == 2 + and tuple(hidden_states_scale.shape)[0] == num_tokens ), "hidden_states_scale's first dimension must be batch size" - - # TODO(siyuan): support fp8 moe_op.trtllm_fp4_block_scale_moe( - routing_logits.to(torch.bfloat16), + routing_logits.to("bfloat16"), topk_ids, expert_weights, kwargs["routing_bias"], hidden_states, - hidden_states_scale, # hidden_states_scale + hidden_states_scale, kwargs["gemm1_weights"], kwargs["gemm1_weights_scale"], kwargs["gemm1_bias"], @@ -1200,22 +1081,19 @@ def refine_tuning_config(cls, tune_max_num_tokens: int): lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens), cls.dynamic_tensor_initializers[:5], ), - ), + ) ) - @register_custom_op( - "flashinfer::trtllm_fp8_per_tensor_scale_moe", - mutates_args=(""), - ) + @register_custom_op("flashinfer::trtllm_fp8_per_tensor_scale_moe", mutates_args="") def trtllm_fp8_per_tensor_scale_moe_op( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - gemm1_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - gemm2_weights: torch.Tensor, - output2_scales_scalar: torch.Tensor, + routing_logits: paddle.Tensor, + routing_bias: Optional[paddle.Tensor], + hidden_states: paddle.Tensor, + gemm1_weights: paddle.Tensor, + output1_scales_scalar: paddle.Tensor, + output1_scales_gate_scalar: paddle.Tensor, + gemm2_weights: paddle.Tensor, + output2_scales_scalar: paddle.Tensor, num_experts: int, top_k: int, n_group: int, @@ -1228,10 +1106,9 @@ def trtllm_fp8_per_tensor_scale_moe_op( tile_tokens_dim: int = 8, routing_method_type: int = 0, enable_pdl: Optional[bool] = None, - ) -> torch.Tensor: + ) -> paddle.Tensor: if enable_pdl is None: - enable_pdl = device_support_pdl(hidden_states.device) - # Call the C++ function + enable_pdl = device_support_pdl(hidden_states.place) output = moe_op.trtllm_fp8_per_tensor_scale_moe( routing_logits, routing_bias, @@ -1258,14 +1135,14 @@ def trtllm_fp8_per_tensor_scale_moe_op( @register_fake_op("flashinfer::trtllm_fp8_per_tensor_scale_moe") def _fake_trtllm_fp8_per_tensor_scale_moe( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - gemm1_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - gemm2_weights: torch.Tensor, - output2_scales_scalar: torch.Tensor, + routing_logits: paddle.Tensor, + routing_bias: Optional[paddle.Tensor], + hidden_states: paddle.Tensor, + gemm1_weights: paddle.Tensor, + output1_scales_scalar: paddle.Tensor, + output1_scales_gate_scalar: paddle.Tensor, + gemm2_weights: paddle.Tensor, + output2_scales_scalar: paddle.Tensor, num_experts: int, top_k: int, n_group: int, @@ -1279,24 +1156,20 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( routing_method_type: int = 0, enable_pdl: Optional[bool] = None, ): - seq_len = hidden_states.shape[0] - hidden_size = hidden_states.shape[1] - - return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + seq_len = tuple(hidden_states.shape)[0] + hidden_size = tuple(hidden_states.shape)[1] + return [paddle.empty(shape=[seq_len, hidden_size], dtype="bfloat16")] - @register_custom_op( - "flashinfer::trtllm_fp8_block_scale_moe", - mutates_args=(""), - ) + @register_custom_op("flashinfer::trtllm_fp8_block_scale_moe", mutates_args="") def trtllm_fp8_block_scale_moe_op( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - hidden_states_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - gemm2_weights: torch.Tensor, - gemm2_weights_scale: torch.Tensor, + routing_logits: paddle.Tensor, + routing_bias: Optional[paddle.Tensor], + hidden_states: paddle.Tensor, + hidden_states_scale: paddle.Tensor, + gemm1_weights: paddle.Tensor, + gemm1_weights_scale: paddle.Tensor, + gemm2_weights: paddle.Tensor, + gemm2_weights_scale: paddle.Tensor, num_experts: int, top_k: int, n_group: int, @@ -1310,10 +1183,9 @@ def trtllm_fp8_block_scale_moe_op( use_shuffled_weight: bool = False, weight_layout: int = 0, enable_pdl: Optional[bool] = None, - ) -> torch.Tensor: + ) -> paddle.Tensor: if enable_pdl is None: - enable_pdl = device_support_pdl(hidden_states.device) - # Call the C++ function for block scale MoE + enable_pdl = device_support_pdl(hidden_states.place) output = moe_op.trtllm_fp8_block_scale_moe( routing_logits, routing_bias, @@ -1337,19 +1209,18 @@ def trtllm_fp8_block_scale_moe_op( weight_layout, enable_pdl, ) - return output @register_fake_op("flashinfer::trtllm_fp8_block_scale_moe") def _fake_trtllm_fp8_block_scale_moe( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - hidden_states_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - gemm2_weights: torch.Tensor, - gemm2_weights_scale: torch.Tensor, + routing_logits: paddle.Tensor, + routing_bias: Optional[paddle.Tensor], + hidden_states: paddle.Tensor, + hidden_states_scale: paddle.Tensor, + gemm1_weights: paddle.Tensor, + gemm1_weights_scale: paddle.Tensor, + gemm2_weights: paddle.Tensor, + gemm2_weights_scale: paddle.Tensor, num_experts: int, top_k: int, n_group: int, @@ -1364,34 +1235,30 @@ def _fake_trtllm_fp8_block_scale_moe( weight_layout: int = 0, enable_pdl: Optional[bool] = None, ): - seq_len = hidden_states.shape[0] - hidden_size = hidden_states.shape[1] + seq_len = tuple(hidden_states.shape)[0] + hidden_size = tuple(hidden_states.shape)[1] + return [paddle.empty(shape=[seq_len, hidden_size], dtype="bfloat16")] - return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] - - @register_custom_op( - "flashinfer::trtllm_fp4_block_scale_moe", - mutates_args=(""), - ) + @register_custom_op("flashinfer::trtllm_fp4_block_scale_moe", mutates_args="") def trtllm_fp4_block_scale_moe_op( - routing_logits: Optional[torch.Tensor], - topk_ids: Optional[torch.Tensor], - expert_weights: Optional[torch.Tensor], - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - hidden_states_scale: Optional[torch.Tensor], - gemm1_weights: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - gemm1_bias: Optional[torch.Tensor], - gemm1_alpha: Optional[torch.Tensor], - gemm1_beta: Optional[torch.Tensor], - gemm1_clamp_limit: Optional[torch.Tensor], - gemm2_weights: torch.Tensor, - gemm2_weights_scale: torch.Tensor, - gemm2_bias: Optional[torch.Tensor], - output1_scale_scalar: Optional[torch.Tensor], - output1_scale_gate_scalar: Optional[torch.Tensor], - output2_scale_scalar: Optional[torch.Tensor], + routing_logits: Optional[paddle.Tensor], + topk_ids: Optional[paddle.Tensor], + expert_weights: Optional[paddle.Tensor], + routing_bias: Optional[paddle.Tensor], + hidden_states: paddle.Tensor, + hidden_states_scale: Optional[paddle.Tensor], + gemm1_weights: paddle.Tensor, + gemm1_weights_scale: paddle.Tensor, + gemm1_bias: Optional[paddle.Tensor], + gemm1_alpha: Optional[paddle.Tensor], + gemm1_beta: Optional[paddle.Tensor], + gemm1_clamp_limit: Optional[paddle.Tensor], + gemm2_weights: paddle.Tensor, + gemm2_weights_scale: paddle.Tensor, + gemm2_bias: Optional[paddle.Tensor], + output1_scale_scalar: Optional[paddle.Tensor], + output1_scale_gate_scalar: Optional[paddle.Tensor], + output2_scale_scalar: Optional[paddle.Tensor], num_experts: int, top_k: int, n_group: Optional[int], @@ -1405,41 +1272,31 @@ def trtllm_fp4_block_scale_moe_op( do_finalize: bool, enable_pdl: Optional[bool] = None, gated_act_type: int = 0, - output: Optional[torch.Tensor] = None, + output: Optional[paddle.Tensor] = None, tune_max_num_tokens: int = 1024, - ) -> List[torch.Tensor]: + ) -> List[paddle.Tensor]: if routing_logits is None: - assert topk_ids is not None, ( - "either topk_ids or routing_logits must be provided." - ) - assert topk_ids.dtype == torch.int32, "topk_ids must be an int32 tensor." - routing_dtype = torch.bfloat16 + assert ( + topk_ids is not None + ), "either topk_ids or routing_logits must be provided." + assert topk_ids.dtype == "int32", "topk_ids must be an int32 tensor." + routing_dtype = "bfloat16" else: routing_dtype = routing_logits.dtype - hidden_size = hidden_states.shape[-1] - if hidden_states.dtype == torch.uint8: + hidden_size = tuple(hidden_states.shape)[-1] + if hidden_states.dtype == "uint8": hidden_size = hidden_size * 2 - num_tokens = hidden_states.shape[0] - - # workspace buffers required by trtllm-gen + num_tokens = tuple(hidden_states.shape)[0] if topk_ids is None: - topk_ids = torch.empty( - num_tokens, top_k, dtype=torch.int32, device=hidden_states.device - ) + topk_ids = paddle.empty(shape=[num_tokens, top_k], dtype="int32") if expert_weights is None: - expert_weights = torch.empty( - num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device + expert_weights = paddle.empty( + shape=[num_tokens, top_k], dtype=routing_dtype ) if enable_pdl is None: - enable_pdl = device_support_pdl(hidden_states.device) + enable_pdl = device_support_pdl(hidden_states.place) if output is None: - output = torch.empty( - num_tokens, - hidden_size, - dtype=torch.bfloat16, - device=hidden_states.device, - ) - + output = paddle.empty(shape=[num_tokens, hidden_size], dtype="bfloat16") tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) dtype_act = deduce_trtllm_gen_tensor_dtype(hidden_states, hidden_states_scale) @@ -1455,25 +1312,15 @@ def trtllm_fp4_block_scale_moe_op( hidden_size=hidden_size, intermediate_size=intermediate_size, gated_act_type=gated_act_type, - # NOTE(siyuan): do not fix the tile_tokens_dim to let tunnable runner decide the tile_tokens_dim itself. - # however, when the user chooses a different heuristic for tile_tokens_dim, the autotuner will fail to find the correct cached tactics. - # tile_tokens_dim=tile_tokens_dim, ) tunning_config = ( MoERunner.tuning_config_no_hidden_states_scales if hidden_states_scale is None else MoERunner.tuning_config_with_hidden_states_scales ) - inputs = [ - output, - routing_logits, - topk_ids, - expert_weights, - hidden_states, - ] + inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states] if hidden_states_scale is not None: inputs.append(hidden_states_scale) - _, tactic = tuner.choose_one( "flashinfer::trtllm_fp4_block_scale_moe", [moe_runner], @@ -1502,8 +1349,6 @@ def trtllm_fp4_block_scale_moe_op( do_finalize=do_finalize, gated_act_type=gated_act_type, ) - - # Call the C++ function for block scale MoE output = moe_op.trtllm_fp4_block_scale_moe( routing_logits, topk_ids, @@ -1539,29 +1384,28 @@ def trtllm_fp4_block_scale_moe_op( output, tactic, ) - return output @register_fake_op("flashinfer::trtllm_fp4_block_scale_moe") def _fake_trtllm_fp4_block_scale_moe( - routing_logits: torch.Tensor, - topk_ids: Optional[torch.Tensor], - expert_weights: Optional[torch.Tensor], - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - hidden_states_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - gemm1_bias: Optional[torch.Tensor], - gemm1_alpha: Optional[torch.Tensor], - gemm1_beta: Optional[torch.Tensor], - gemm1_clamp_limit: Optional[torch.Tensor], - gemm2_weights: torch.Tensor, - gemm2_weights_scale: torch.Tensor, - gemm2_bias: Optional[torch.Tensor], - output1_scale_scalar: Optional[torch.Tensor], - output1_scale_gate_scalar: Optional[torch.Tensor], - output2_scale_scalar: Optional[torch.Tensor], + routing_logits: paddle.Tensor, + topk_ids: Optional[paddle.Tensor], + expert_weights: Optional[paddle.Tensor], + routing_bias: Optional[paddle.Tensor], + hidden_states: paddle.Tensor, + hidden_states_scale: paddle.Tensor, + gemm1_weights: paddle.Tensor, + gemm1_weights_scale: paddle.Tensor, + gemm1_bias: Optional[paddle.Tensor], + gemm1_alpha: Optional[paddle.Tensor], + gemm1_beta: Optional[paddle.Tensor], + gemm1_clamp_limit: Optional[paddle.Tensor], + gemm2_weights: paddle.Tensor, + gemm2_weights_scale: paddle.Tensor, + gemm2_bias: Optional[paddle.Tensor], + output1_scale_scalar: Optional[paddle.Tensor], + output1_scale_gate_scalar: Optional[paddle.Tensor], + output2_scale_scalar: Optional[paddle.Tensor], num_experts: int, top_k: int, n_group: Optional[int], @@ -1575,13 +1419,12 @@ def _fake_trtllm_fp4_block_scale_moe( do_finalize: bool, enable_pdl: bool, gated_act_type: int, - output: Optional[torch.Tensor], + output: Optional[paddle.Tensor], tune_max_num_tokens: int, ): - seq_len = hidden_states.shape[0] - hidden_size = hidden_states.shape[1] - - return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + seq_len = tuple(hidden_states.shape)[0] + hidden_size = tuple(hidden_states.shape)[1] + return [paddle.empty(shape=[seq_len, hidden_size], dtype="bfloat16")] return SimpleNamespace( trtllm_fp8_per_tensor_scale_moe=trtllm_fp8_per_tensor_scale_moe_op, @@ -1591,14 +1434,14 @@ def _fake_trtllm_fp4_block_scale_moe( def trtllm_fp8_per_tensor_scale_moe( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - gemm1_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - gemm2_weights: torch.Tensor, - output2_scales_scalar: torch.Tensor, + routing_logits: paddle.Tensor, + routing_bias: Optional[paddle.Tensor], + hidden_states: paddle.Tensor, + gemm1_weights: paddle.Tensor, + output1_scales_scalar: paddle.Tensor, + output1_scales_gate_scalar: paddle.Tensor, + gemm2_weights: paddle.Tensor, + output2_scales_scalar: paddle.Tensor, num_experts: int, top_k: int, n_group: int, @@ -1611,7 +1454,7 @@ def trtllm_fp8_per_tensor_scale_moe( tile_tokens_dim: int = 8, routing_method_type: int = 0, enable_pdl: Optional[bool] = None, -) -> torch.Tensor: +) -> paddle.Tensor: """FP8 per tensor scale MoE operation. Args: @@ -1664,14 +1507,14 @@ def trtllm_fp8_per_tensor_scale_moe( def trtllm_fp8_block_scale_moe( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - hidden_states_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - gemm2_weights: torch.Tensor, - gemm2_weights_scale: torch.Tensor, + routing_logits: paddle.Tensor, + routing_bias: Optional[paddle.Tensor], + hidden_states: paddle.Tensor, + hidden_states_scale: paddle.Tensor, + gemm1_weights: paddle.Tensor, + gemm1_weights_scale: paddle.Tensor, + gemm2_weights: paddle.Tensor, + gemm2_weights_scale: paddle.Tensor, num_experts: int, top_k: int, n_group: int, @@ -1685,7 +1528,7 @@ def trtllm_fp8_block_scale_moe( use_shuffled_weight: bool = False, weight_layout: int = 0, enable_pdl: Optional[bool] = None, -) -> torch.Tensor: +) -> paddle.Tensor: """FP8 block scale MoE operation. Args: @@ -1737,22 +1580,22 @@ def trtllm_fp8_block_scale_moe( def trtllm_fp4_block_scale_moe( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - hidden_states_scale: Optional[torch.Tensor], - gemm1_weights: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - gemm1_bias: Optional[torch.Tensor], - gemm1_alpha: Optional[torch.Tensor], - gemm1_beta: Optional[torch.Tensor], - gemm1_clamp_limit: Optional[torch.Tensor], - gemm2_weights: torch.Tensor, - gemm2_weights_scale: torch.Tensor, - gemm2_bias: Optional[torch.Tensor], - output1_scale_scalar: Optional[torch.Tensor], - output1_scale_gate_scalar: Optional[torch.Tensor], - output2_scale_scalar: Optional[torch.Tensor], + routing_logits: paddle.Tensor, + routing_bias: Optional[paddle.Tensor], + hidden_states: paddle.Tensor, + hidden_states_scale: Optional[paddle.Tensor], + gemm1_weights: paddle.Tensor, + gemm1_weights_scale: paddle.Tensor, + gemm1_bias: Optional[paddle.Tensor], + gemm1_alpha: Optional[paddle.Tensor], + gemm1_beta: Optional[paddle.Tensor], + gemm1_clamp_limit: Optional[paddle.Tensor], + gemm2_weights: paddle.Tensor, + gemm2_weights_scale: paddle.Tensor, + gemm2_bias: Optional[paddle.Tensor], + output1_scale_scalar: Optional[paddle.Tensor], + output1_scale_gate_scalar: Optional[paddle.Tensor], + output2_scale_scalar: Optional[paddle.Tensor], num_experts: int, top_k: int, n_group: Optional[int], @@ -1766,9 +1609,9 @@ def trtllm_fp4_block_scale_moe( do_finalize: bool = True, enable_pdl: Optional[bool] = None, gated_act_type: int = 0, - output: Optional[torch.Tensor] = None, + output: Optional[paddle.Tensor] = None, tune_max_num_tokens: int = 1024, -) -> List[torch.Tensor]: +) -> List[paddle.Tensor]: """FP4 block scale MoE operation. Args: @@ -1869,22 +1712,22 @@ def trtllm_fp4_block_scale_moe( def trtllm_fp4_block_scale_routed_moe( - topk_ids: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - hidden_states_scale: Optional[torch.Tensor], - gemm1_weights: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - gemm1_bias: Optional[torch.Tensor], - gemm1_alpha: Optional[torch.Tensor], - gemm1_beta: Optional[torch.Tensor], - gemm1_clamp_limit: Optional[torch.Tensor], - gemm2_weights: torch.Tensor, - gemm2_weights_scale: torch.Tensor, - gemm2_bias: Optional[torch.Tensor], - output1_scale_scalar: Optional[torch.Tensor], - output1_scale_gate_scalar: Optional[torch.Tensor], - output2_scale_scalar: Optional[torch.Tensor], + topk_ids: paddle.Tensor, + routing_bias: Optional[paddle.Tensor], + hidden_states: paddle.Tensor, + hidden_states_scale: Optional[paddle.Tensor], + gemm1_weights: paddle.Tensor, + gemm1_weights_scale: paddle.Tensor, + gemm1_bias: Optional[paddle.Tensor], + gemm1_alpha: Optional[paddle.Tensor], + gemm1_beta: Optional[paddle.Tensor], + gemm1_clamp_limit: Optional[paddle.Tensor], + gemm2_weights: paddle.Tensor, + gemm2_weights_scale: paddle.Tensor, + gemm2_bias: Optional[paddle.Tensor], + output1_scale_scalar: Optional[paddle.Tensor], + output1_scale_gate_scalar: Optional[paddle.Tensor], + output2_scale_scalar: Optional[paddle.Tensor], num_experts: int, top_k: int, n_group: Optional[int], @@ -1898,9 +1741,9 @@ def trtllm_fp4_block_scale_routed_moe( do_finalize: bool = True, enable_pdl: Optional[bool] = None, gated_act_type: int = 0, - output: Optional[torch.Tensor] = None, + output: Optional[paddle.Tensor] = None, tune_max_num_tokens: int = 1024, -) -> List[torch.Tensor]: +) -> List[paddle.Tensor]: """FP4 block scale MoE operation. Args: diff --git a/flashinfer/fused_moe/utils.py b/flashinfer/fused_moe/utils.py index 2bb196858a..5b78eb203f 100644 --- a/flashinfer/fused_moe/utils.py +++ b/flashinfer/fused_moe/utils.py @@ -1,23 +1,21 @@ +import sys + + import contextlib import threading from dataclasses import dataclass from enum import Enum from typing import Dict, List, Tuple -import torch +import paddle +from flashinfer.paddle_utils import * from ..utils import ceil_div, round_up is_torch_compiling_flag = False - -AuxStreamType = Enum( - "AuxStreamType", - ["Attention", "MoeShared", "MoeChunkingOverlap"], -) +AuxStreamType = Enum("AuxStreamType", ["Attention", "MoeShared", "MoeChunkingOverlap"]) EventType = Enum( - "EventType", - ["Main", "Attention", "MoeShared", "MoeChunkingOverlap"], - start=0, + "EventType", ["Main", "Attention", "MoeShared", "MoeChunkingOverlap"], start=0 ) @@ -68,13 +66,13 @@ def wrapper(self, *args, **kwargs): @dataclass class Fp4QuantizedTensor: - fp4_tensor: torch.Tensor - scaling_factor: torch.Tensor + fp4_tensor: paddle.Tensor + scaling_factor: paddle.Tensor is_sf_swizzled: bool = True @property def shape(self): - return self.fp4_tensor.shape + return tuple(self.fp4_tensor.shape) def compute_swizzled_sf_shape(row: int, col: int): @@ -83,7 +81,7 @@ def compute_swizzled_sf_shape(row: int, col: int): return padded_row, padded_col -def swizzle_sf(sf: torch.Tensor, rows: int, cols: int, scaling_vector_size: int = 16): +def swizzle_sf(sf: paddle.Tensor, rows: int, cols: int, scaling_vector_size: int = 16): """Swizzle FP4 scaling factors using C++ torch op implementation Args: sf: [b, rows, cols_sf] or [rows, cols_sf]. The original unswizzled scaling factors. @@ -98,7 +96,9 @@ def swizzle_sf(sf: torch.Tensor, rows: int, cols: int, scaling_vector_size: int return torch.ops.trtllm.block_scale_interleave(sf) -def unswizzle_sf(sf: torch.Tensor, rows: int, cols: int, scaling_vector_size: int = 16): +def unswizzle_sf( + sf: paddle.Tensor, rows: int, cols: int, scaling_vector_size: int = 16 +): """Swizzle FP4 scaling factors using C++ torch op implementation Args: sf: The (padded and) swizzled scaling factors. @@ -113,10 +113,10 @@ def unswizzle_sf(sf: torch.Tensor, rows: int, cols: int, scaling_vector_size: in return torch.ops.trtllm.block_scale_interleave_reverse(sf).view(-1, sf_cols) -@torch.library.custom_op("trtllm::reswizzle_sf", mutates_args=()) +# @torch.library.custom_op("trtllm::reswizzle_sf", mutates_args=()) def reswizzle_sf( - sf: torch.Tensor, rows: int, cols: int, scaling_vector_size: int = 16 -) -> torch.Tensor: + sf: paddle.Tensor, rows: int, cols: int, scaling_vector_size: int = 16 +) -> paddle.Tensor: """Reswizzle FP4 scaling factors using C++ torch op implementation. It unswizzles the scaling factors in each partition first, then concatenates them together, and finally swizzles them back. Args: @@ -130,46 +130,32 @@ def reswizzle_sf( sf_cols = ceil_div(cols, scaling_vector_size) padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols) padded_cols = padded_sf_cols * scaling_vector_size - - assert sf.numel() % (padded_rows * padded_sf_cols) == 0 - num_partitions = sf.numel() // (padded_rows * padded_sf_cols) - + assert sf.size % (padded_rows * padded_sf_cols) == 0 + num_partitions = sf.size // (padded_rows * padded_sf_cols) sf_reshaped = sf.view(num_partitions, padded_rows, padded_sf_cols) - - # Unswizzle each partition sf_unswizzled = unswizzle_sf( sf_reshaped, padded_rows, padded_cols, scaling_vector_size ) - - # Brings the unswizzled scaling factors in each partition together total_rows = num_partitions * rows sf_unswizzled = sf_unswizzled.view(num_partitions, padded_rows, padded_sf_cols) - sf_concatenated = sf_unswizzled[ - :, :rows, :sf_cols - ].contiguous() # TODO: This will incur a elementwise kernel + sf_concatenated = sf_unswizzled[:, :rows, :sf_cols].contiguous() sf_concatenated = sf_concatenated.view(total_rows, sf_cols) - - # Finally swizzle the concatenated scaling factors return swizzle_sf(sf_concatenated, total_rows, cols, scaling_vector_size) -@torch.library.register_fake("trtllm::reswizzle_sf") +# @torch.library.register_fake("trtllm::reswizzle_sf") def _(sf, rows, cols, scaling_vector_size=16): sf_cols = ceil_div(cols, scaling_vector_size) padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols) - num_partitions = sf.numel() // (padded_rows * padded_sf_cols) + num_partitions = sf.size // (padded_rows * padded_sf_cols) total_rows = num_partitions * rows sz = round_up(total_rows, 128) * round_up(cols, 4) - return sf.new_empty(sz) + return paddle.empty(shape=sz, dtype=sf.dtype) def next_positive_power_of_2(x: int) -> int: if x < 1: return 1 - - # Following code is equivalent to 1 << (x - 1).bit_length() - # But this impl does not contain bit_length() so can be used by torch compile. - # It can correctly handle 64bit number which should be enough for now. n = x - 1 n |= n >> 1 n |= n >> 2 @@ -184,7 +170,6 @@ def last_positive_power_of_2(x: int) -> int: next = next_positive_power_of_2(x) if next == x: return next - return next // 2 @@ -199,7 +184,6 @@ def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]: while m >= 1: num_token_buckets.append(m) m //= 2 - return tuple(num_token_buckets) @@ -219,10 +203,8 @@ def get_fp4_shape(input_shape, sf_vec_size, is_swizzled_layout=True): m = 1 for i in range(len(input_shape) - 1): m *= input_shape[i] - output_shape = [i for i in input_shape] output_shape[-1] //= 2 - scale_shape = ( round_up(m, 128) * round_up(input_shape[-1] // sf_vec_size, 4) if is_swizzled_layout diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index f173f02ea2..bad5f65bf0 100755 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -1,3 +1,11 @@ +import sys + + +import os + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2024 by FlashInfer team. @@ -13,30 +21,19 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools -import os from enum import Enum from itertools import product from types import SimpleNamespace from typing import List, Literal, Optional, Tuple import jinja2 -import torch from .artifacts import ArtifactPath, MetaInfoHash -from .autotuner import ( - AutoTuner, - ConstraintSpec, - DynamicTensorSpec, - OptimizationProfile, - TunableRunner, - TuningConfig, -) -from .fused_moe.utils import ( - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, -) +from .autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, + OptimizationProfile, TunableRunner, TuningConfig) +from .fused_moe.utils import (get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2) from .jit.cubin_loader import get_cubin CUDNN_AVAILABLE = False @@ -51,27 +48,20 @@ is_lib_missing = any(ext in error_msg for ext in [".so", ".dll"]) if not is_lib_missing: raise - - from .jit import JitSpec from .jit import env as jit_env from .jit import gen_jit_spec, sm90a_nvcc_flags, sm100a_nvcc_flags from .jit.cubin_loader import setup_cubin_loader -from .jit.utils import dtype_cutlass_map, filename_safe_dtype_map, write_if_different -from .utils import ( - _get_cache_buf, - determine_gemm_backend, - get_indptr, - is_float8, - register_custom_op, - register_fake_op, - get_compute_capability, -) +from .jit.utils import (dtype_cutlass_map, filename_safe_dtype_map, + write_if_different) +from .utils import (_get_cache_buf, determine_gemm_backend, + get_compute_capability, get_indptr, is_float8, + register_custom_op, register_fake_op) DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024 -def _match_sm_version(device: torch.device, sm_version: str): +def _match_sm_version(device: str, sm_version: str): major, minor = get_compute_capability(device) device_arch = f"{major * 10 + minor}" return device_arch == sm_version @@ -93,24 +83,21 @@ def gen_gemm_module() -> JitSpec: def get_gemm_module(): module = gen_gemm_module().build_and_load() - # auto-tuned cublas fp8 gemm runner def cublas_fp8_gemm_runner(): class CublasFp8GemmRunner(TunableRunner): def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, + self, inputs: List[paddle.Tensor], profile: OptimizationProfile ) -> List[int]: - # cublas has heuristic for fp8 gemm, so we only need to use the default tactic return [0] def forward( self, - inputs: List[torch.Tensor], + inputs: List[paddle.Tensor], tactic: int = -1, do_preparation: bool = False, **kwargs, - ) -> torch.Tensor: + ) -> paddle.Tensor: + # TODO: Add python api like torch cublas_handle = torch.cuda.current_blas_handle() a, b, scale_a, scale_b, out, workspace_buffer = inputs module.bmm_fp8.default( @@ -120,20 +107,18 @@ def forward( return CublasFp8GemmRunner() - # torch library for cutlass_segment_gemm - - @register_custom_op("flashinfer::cutlass_segment_gemm", mutates_args=("y")) + @register_custom_op("flashinfer::cutlass_segment_gemm", mutates_args="y") def cutlass_segment_gemm( - workspace_buffer: torch.Tensor, - all_problems: torch.Tensor, - x_data: torch.Tensor, - w_data: torch.Tensor, - y_data: torch.Tensor, - x_ld: torch.Tensor, - w_ld: torch.Tensor, - y_ld: torch.Tensor, - y: torch.Tensor, - empty_x_data: torch.Tensor, + workspace_buffer: paddle.Tensor, + all_problems: paddle.Tensor, + x_data: paddle.Tensor, + w_data: paddle.Tensor, + y_data: paddle.Tensor, + x_ld: paddle.Tensor, + w_ld: paddle.Tensor, + y_ld: paddle.Tensor, + y: paddle.Tensor, + empty_x_data: paddle.Tensor, weight_column_major: bool, ) -> None: module.cutlass_segment_gemm.default( @@ -151,36 +136,31 @@ def cutlass_segment_gemm( @register_fake_op("flashinfer::cutlass_segment_gemm") def _fake_cutlass_segment_gemm( - workspace_buffer: torch.Tensor, - all_problems: torch.Tensor, - x_data: torch.Tensor, - w_data: torch.Tensor, - y_data: torch.Tensor, - x_ld: torch.Tensor, - w_ld: torch.Tensor, - y_ld: torch.Tensor, - y: torch.Tensor, - empty_x_data: torch.Tensor, + workspace_buffer: paddle.Tensor, + all_problems: paddle.Tensor, + x_data: paddle.Tensor, + w_data: paddle.Tensor, + y_data: paddle.Tensor, + x_ld: paddle.Tensor, + w_ld: paddle.Tensor, + y_ld: paddle.Tensor, + y: paddle.Tensor, + empty_x_data: paddle.Tensor, weight_column_major: bool, ) -> None: pass - # Register the module _gemm_module = SimpleNamespace( cublas_fp8_gemm_runner=cublas_fp8_gemm_runner, cutlass_segment_gemm=cutlass_segment_gemm, ) - return _gemm_module def gen_gemm_sm100_module_cutlass_fp4() -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_fp4" os.makedirs(gen_directory, exist_ok=True) - source_paths = [ - jit_env.FLASHINFER_CSRC_DIR / "fp4_gemm_cutlass.cu", - ] - + source_paths = [jit_env.FLASHINFER_CSRC_DIR / "fp4_gemm_cutlass.cu"] with open(jit_env.FLASHINFER_CSRC_DIR / "fp4_gemm_cutlass.jinja") as f: kernel_inst_templ = jinja2.Template(f.read()) dtype_list = ["__nv_bfloat16", "half"] @@ -198,24 +178,14 @@ def gen_gemm_sm100_module_cutlass_fp4() -> JitSpec: ) source_paths.append(dest_path) source = kernel_inst_templ.render( - type=dtype, - cta_m=cta_m, - cta_n=cta_n, - cta_k=cta_k, + type=dtype, cta_m=cta_m, cta_n=cta_n, cta_k=cta_k ) write_if_different(dest_path, source) - return gen_jit_spec( "fp4_gemm_cutlass", source_paths, - extra_cuda_cflags=sm100a_nvcc_flags - + [ - "-DENABLE_BF16", - "-DENABLE_FP4", - ], - extra_cflags=[ - "-DFAST_BUILD", - ], + extra_cuda_cflags=sm100a_nvcc_flags + ["-DENABLE_BF16", "-DENABLE_FP4"], + extra_cflags=["-DFAST_BUILD"], extra_ldflags=["-lcuda"], ) @@ -223,10 +193,7 @@ def gen_gemm_sm100_module_cutlass_fp4() -> JitSpec: def gen_gemm_sm100_module_cutlass_fp8() -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_fp8" os.makedirs(gen_directory, exist_ok=True) - source_paths = [ - jit_env.FLASHINFER_CSRC_DIR / "fp8_gemm_cutlass.cu", - ] - + source_paths = [jit_env.FLASHINFER_CSRC_DIR / "fp8_gemm_cutlass.cu"] with open(jit_env.FLASHINFER_CSRC_DIR / "fp8_gemm_cutlass.jinja") as f: kernel_inst_templ = jinja2.Template(f.read()) dtype_list = ["__nv_bfloat16", "half"] @@ -246,23 +213,14 @@ def gen_gemm_sm100_module_cutlass_fp8() -> JitSpec: ) source_paths.append(dest_path) source = kernel_inst_templ.render( - type=dtype, - cta_m=cta_m, - cta_n=cta_n, - cta_k=cta_k, + type=dtype, cta_m=cta_m, cta_n=cta_n, cta_k=cta_k ) write_if_different(dest_path, source) - return gen_jit_spec( "fp8_gemm_cutlass", source_paths, - extra_cuda_cflags=sm100a_nvcc_flags - + [ - "-DENABLE_BF16", - ], - extra_cflags=[ - "-DFAST_BUILD", - ], + extra_cuda_cflags=sm100a_nvcc_flags + ["-DENABLE_BF16"], + extra_cflags=["-DFAST_BUILD"], extra_ldflags=["-lcuda"], ) @@ -276,8 +234,8 @@ def gen_gemm_sm100_module() -> JitSpec: jit_env.FLASHINFER_CSRC_DIR / f"{prefix}_sm100_kernel_inst.jinja" ) as f: kernel_inst_templ = jinja2.Template(f.read()) - dtype_in_list = [torch.float8_e4m3fn, torch.float8_e5m2] - dtype_out_list = [torch.float16, torch.bfloat16] + dtype_in_list = [paddle.float8_e4m3fn, paddle.float8_e5m2] + dtype_out_list = ["float16", "bfloat16"] scale_major_k_list = ["true", "false"] mma_sm_list = [1, 2] for dtype_in, dtype_out, scale_major_k, mma_sm in product( @@ -300,8 +258,8 @@ def gen_gemm_sm100_module() -> JitSpec: prefix = "group_gemm_mxfp4_groupwise" with open(jit_env.FLASHINFER_CSRC_DIR / f"{prefix}_sm100_kernel_inst.jinja") as f: kernel_inst_templ = jinja2.Template(f.read()) - dtype_a_list = [torch.float8_e4m3fn, torch.float8_e5m2] - dtype_d_list = [torch.float16, torch.bfloat16] + dtype_a_list = [paddle.float8_e4m3fn, paddle.float8_e5m2] + dtype_d_list = ["float16", "bfloat16"] mma_sm_list = [1, 2] swap_ab_list = ["true", "false"] for dtype_a, dtype_d, mma_sm, swap_ab in product( @@ -335,47 +293,31 @@ def gen_gemm_sm100_module() -> JitSpec: with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - return gen_jit_spec( - "gemm_sm100", - source_paths, - extra_cuda_cflags=sm100a_nvcc_flags, - ) + return gen_jit_spec("gemm_sm100", source_paths, extra_cuda_cflags=sm100a_nvcc_flags) @functools.cache def get_gemm_sm100_module(): module = gen_gemm_sm100_module().build_and_load() - return module def trtllm_gemm_gen_module() -> JitSpec: - # Fetch "flashinferMetaInfo.h" from the online kernel cache. This file - # contains the `tllmGenGemmList` as the list of available kernels online. - # It is included when compiling `trtllm_gemm_runner.cu`. include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include" header_name = "flashinferMetaInfo" - - # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( - f"{include_path}/{header_name}", - MetaInfoHash.TRTLLM_GEN_GEMM, - ".h", + f"{include_path}/{header_name}", MetaInfoHash.TRTLLM_GEN_GEMM, ".h" ) - # make sure "flashinferMetaInfo.h" is downloaded or cached assert metainfo, f"{header_name}.h not found" return gen_jit_spec( "trtllm_gemm", - [ - jit_env.FLASHINFER_CSRC_DIR / "trtllm_gemm_runner.cu", - ], + [jit_env.FLASHINFER_CSRC_DIR / "trtllm_gemm_runner.cu"], extra_cuda_cflags=[ "-DTLLM_GEN_EXPORT_INTERFACE", "-DTLLM_ENABLE_CUDA", f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"', ] + sm100a_nvcc_flags, - # link "include" sub-directory in cache extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path], extra_ldflags=["-lcuda"], ) @@ -396,23 +338,21 @@ def get_gemm_sm100_module_cutlass_fp8(): def cutlass_fp8_gemm_runner(): class CutlassFp8GemmRunner(TunableRunner): def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, + self, inputs: List[paddle.Tensor], profile: OptimizationProfile ) -> List[int]: return list(range(module.fp8_gemm_tactic_num())) def forward( self, - inputs: List[torch.Tensor], + inputs: List[paddle.Tensor], tactic: int = -1, do_preparation: bool = False, **kwargs, - ) -> torch.Tensor: + ) -> paddle.Tensor: a, b, scale_a, scale_b, out, workspace_buffer = inputs module.fp8_gemm.default( a, - b.transpose(-2, -1), + b.transpose(perm=dim2perm(b.ndim, -2, -1)), scale_a, scale_b, out, @@ -423,36 +363,30 @@ def forward( return CutlassFp8GemmRunner() - # Register the module - return SimpleNamespace( - cutlass_fp8_gemm_runner=cutlass_fp8_gemm_runner, - ) + return SimpleNamespace(cutlass_fp8_gemm_runner=cutlass_fp8_gemm_runner) def fp8_gemm_sm100( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out: torch.Tensor, - workspace_buffer: torch.Tensor, + a: paddle.Tensor, + b: paddle.Tensor, + scale_a: paddle.Tensor, + scale_b: paddle.Tensor, + out: paddle.Tensor, + workspace_buffer: paddle.Tensor, runner_names: List[str], ) -> None: runners = [] - # No e5m2 for cutlass - is_e5m2 = a.dtype == torch.float8_e5m2 or b.dtype == torch.float8_e5m2 - is_sm100 = _match_sm_version(a.device, "100") + is_e5m2 = a.dtype == paddle.float8_e5m2 or b.dtype == paddle.float8_e5m2 + is_sm100 = _match_sm_version(a.place, "100") if "cutlass" in runner_names and is_sm100 and not is_e5m2: runners.append(get_gemm_sm100_module_cutlass_fp8().cutlass_fp8_gemm_runner()) if "cublas" in runner_names: runners.append(get_gemm_module().cublas_fp8_gemm_runner()) if CUDNN_AVAILABLE and "cudnn" in runner_names: runners.append(_cudnn_gemm_fp8_runner()) - if len(runners) == 0: - major, minor = get_compute_capability(torch.device("cuda")) + major, minor = get_compute_capability(device2str("cuda")) raise ValueError(f"No valid runner found for current device sm{major}{minor}") - tuner = AutoTuner.get() a_tensor_index = 0 out_tensor_index = 4 @@ -471,15 +405,8 @@ def fp8_gemm_sm100( ), ), ) - inputs = [a, b, scale_a, scale_b, out, workspace_buffer] - runner, tactic = tuner.choose_one( - "fp8_gemm", - runners, - tuning_config, - inputs, - ) - + runner, tactic = tuner.choose_one("fp8_gemm", runners, tuning_config, inputs) runner(inputs=inputs, tactic=tactic) @@ -492,15 +419,13 @@ def __init__(self): self._fp4_gemm_runner = module.fp4_gemm def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, + self, inputs: List[paddle.Tensor], profile: OptimizationProfile ) -> List[int]: return list(range(module.fp4_gemm_tactic_num())) def forward( self, - inputs: List[torch.Tensor], + inputs: List[paddle.Tensor], tactic: int = -1, do_preparation: bool = False, **kwargs, @@ -511,27 +436,23 @@ def forward( ) return out - @register_custom_op( - "flashinfer::cutlass_fp4_gemm", - mutates_args=(""), - ) + @register_custom_op("flashinfer::cutlass_fp4_gemm", mutates_args="") def cutlass_fp4_gemm( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: torch.Tensor, - out: torch.Tensor, - workspace_buffer: torch.Tensor, + a: paddle.Tensor, + b: paddle.Tensor, + a_descale: paddle.Tensor, + b_descale: paddle.Tensor, + alpha: paddle.Tensor, + out: paddle.Tensor, + workspace_buffer: paddle.Tensor, ): tuner = AutoTuner.get() - a_tensor_index = 0 a_scale_tensor_index = 2 out_tensor_index = 5 def pad_up(x, y): - return ((x + y - 1) // y) * y + return (x + y - 1) // y * y tuning_config = TuningConfig( dynamic_tensor_specs=( @@ -553,23 +474,14 @@ def pad_up(x, y): ), ), ) - fp4_runner = CutlassFp4GemmRunner() - inputs = [a, b, a_descale, b_descale, alpha, out, workspace_buffer] _, tactic = tuner.choose_one( - "cutlass_fp4_gemm", - [fp4_runner], - tuning_config, - inputs, + "cutlass_fp4_gemm", [fp4_runner], tuning_config, inputs ) - fp4_runner(inputs=inputs, tactic=tactic) - # Register the module - return SimpleNamespace( - cutlass_fp4_gemm=cutlass_fp4_gemm, - ) + return SimpleNamespace(cutlass_fp4_gemm=cutlass_fp4_gemm) def gen_gemm_sm90_module() -> JitSpec: @@ -579,12 +491,12 @@ def gen_gemm_sm90_module() -> JitSpec: with open(jit_env.FLASHINFER_CSRC_DIR / "group_gemm_sm90_kernel_inst.jinja") as f: kernel_inst_templ = jinja2.Template(f.read()) for dtype_in, dtype_out in [ - (torch.float16, torch.float16), - (torch.bfloat16, torch.bfloat16), - (torch.float8_e4m3fn, torch.float16), - (torch.float8_e5m2, torch.float16), - (torch.float8_e4m3fn, torch.bfloat16), - (torch.float8_e5m2, torch.bfloat16), + ("float16", "float16"), + ("bfloat16", "bfloat16"), + (paddle.float8_e4m3fn, "float16"), + (paddle.float8_e5m2, "float16"), + (paddle.float8_e4m3fn, "bfloat16"), + (paddle.float8_e5m2, "bfloat16"), ]: name_dtype_in = filename_safe_dtype_map[dtype_in] name_dtype_out = filename_safe_dtype_map[dtype_out] @@ -593,50 +505,39 @@ def gen_gemm_sm90_module() -> JitSpec: ) source_paths.append(dest_path) source = kernel_inst_templ.render( - dtype_in=dtype_cutlass_map[dtype_in], - dtype_out=dtype_cutlass_map[dtype_out], + dtype_in=dtype_cutlass_map[dtype_in], dtype_out=dtype_cutlass_map[dtype_out] ) write_if_different(dest_path, source) - for filename in [ - "group_gemm_sm90.cu", - "flashinfer_gemm_sm90_ops.cu", - ]: + for filename in ["group_gemm_sm90.cu", "flashinfer_gemm_sm90_ops.cu"]: src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - return gen_jit_spec( - "gemm_sm90", - source_paths, - extra_cuda_cflags=sm90a_nvcc_flags, - ) + return gen_jit_spec("gemm_sm90", source_paths, extra_cuda_cflags=sm90a_nvcc_flags) @functools.cache def get_gemm_sm90_module(): module = gen_gemm_sm90_module().build_and_load() - # torch library for cutlass_segment_gemm_sm90 - @register_custom_op( - "flashinfer::cutlass_segment_gemm_sm90", - mutates_args=("workspace_buffer", "y"), + "flashinfer::cutlass_segment_gemm_sm90", mutates_args=("workspace_buffer", "y") ) def cutlass_segment_gemm_sm90( - workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, - all_problems: torch.Tensor, - x_data: torch.Tensor, - w_data: torch.Tensor, - y_data: torch.Tensor, - x_stride: torch.Tensor, - w_stride: torch.Tensor, - y_stride: torch.Tensor, - y: torch.Tensor, - empty_x_data: torch.Tensor, - empty_y_data: torch.Tensor, + workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, + all_problems: paddle.Tensor, + x_data: paddle.Tensor, + w_data: paddle.Tensor, + y_data: paddle.Tensor, + x_stride: paddle.Tensor, + w_stride: paddle.Tensor, + y_stride: paddle.Tensor, + y: paddle.Tensor, + empty_x_data: paddle.Tensor, + empty_y_data: paddle.Tensor, weight_column_major: bool, ) -> None: module.cutlass_segment_gemm_sm90.default( @@ -656,62 +557,53 @@ def cutlass_segment_gemm_sm90( @register_fake_op("flashinfer::cutlass_segment_gemm_sm90") def _fake_cutlass_segment_gemm_sm90( - workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, - all_problems: torch.Tensor, - x_data: torch.Tensor, - w_data: torch.Tensor, - y_data: torch.Tensor, - x_stride: torch.Tensor, - w_stride: torch.Tensor, - y_stride: torch.Tensor, - y: torch.Tensor, - empty_x_data: torch.Tensor, - empty_y_data: torch.Tensor, + workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, + all_problems: paddle.Tensor, + x_data: paddle.Tensor, + w_data: paddle.Tensor, + y_data: paddle.Tensor, + x_stride: paddle.Tensor, + w_stride: paddle.Tensor, + y_stride: paddle.Tensor, + y: paddle.Tensor, + empty_x_data: paddle.Tensor, + empty_y_data: paddle.Tensor, weight_column_major: bool, ) -> None: pass - # Register the module - return SimpleNamespace( - cutlass_segment_gemm_sm90=cutlass_segment_gemm_sm90, - ) + return SimpleNamespace(cutlass_segment_gemm_sm90=cutlass_segment_gemm_sm90) def launch_compute_sm80_group_gemm_args( - x: torch.Tensor, - weights: torch.Tensor, - y: torch.Tensor, + x: paddle.Tensor, + weights: paddle.Tensor, + y: paddle.Tensor, w_column_major: bool, batch_size: int, - seg_indptr: torch.Tensor, - weight_indices: Optional[torch.Tensor] = None, + seg_indptr: paddle.Tensor, + weight_indices: Optional[paddle.Tensor] = None, ): - device = x.device - prob_type = torch.int32 # problem sizes -> int - ptr_type = torch.int64 # pointers -> int64_t - ld_type = torch.int64 # strides -> int64_t - + device = x.place + prob_type = "int32" + ptr_type = "int64" + ld_type = "int64" seg_indptr = seg_indptr.to(ptr_type) if weight_indices is not None: weight_indices = weight_indices.to(ptr_type) - - d_out = weights.size(1) if w_column_major else weights.size(2) - d_in = weights.size(2) if w_column_major else weights.size(1) - - all_problems = torch.empty((batch_size, 3), dtype=prob_type, device=device) - - x_data = torch.empty(batch_size, dtype=ptr_type, device=device) - w_data = torch.empty(batch_size, dtype=ptr_type, device=device) - y_data = torch.empty(batch_size, dtype=ptr_type, device=device) - - x_stride_data = torch.empty(batch_size, dtype=ld_type, device=device) - w_stride_data = torch.empty(batch_size, dtype=ld_type, device=device) - y_stride_data = torch.empty(batch_size, dtype=ld_type, device=device) - + d_out = weights.shape[1] if w_column_major else weights.shape[2] + d_in = weights.shape[2] if w_column_major else weights.shape[1] + all_problems = paddle.empty(shape=(batch_size, 3), dtype=prob_type) + x_data = paddle.empty(shape=batch_size, dtype=ptr_type) + w_data = paddle.empty(shape=batch_size, dtype=ptr_type) + y_data = paddle.empty(shape=batch_size, dtype=ptr_type) + x_stride_data = paddle.empty(shape=batch_size, dtype=ld_type) + w_stride_data = paddle.empty(shape=batch_size, dtype=ld_type) + y_stride_data = paddle.empty(shape=batch_size, dtype=ld_type) from .triton.gemm import compute_sm80_group_gemm_args - compute_sm80_group_gemm_args[(batch_size,)]( + compute_sm80_group_gemm_args[batch_size,]( all_problems, x_data, w_data, @@ -728,7 +620,6 @@ def launch_compute_sm80_group_gemm_args( d_out, w_column_major, ) - return ( all_problems, x_data, @@ -741,39 +632,33 @@ def launch_compute_sm80_group_gemm_args( def launch_compute_sm90_group_gemm_args( - x: torch.Tensor, - weights: torch.Tensor, - y: torch.Tensor, + x: paddle.Tensor, + weights: paddle.Tensor, + y: paddle.Tensor, w_column_major: bool, batch_size: int, - seg_indptr: torch.Tensor, - weight_indices: Optional[torch.Tensor] = None, + seg_indptr: paddle.Tensor, + weight_indices: Optional[paddle.Tensor] = None, ): - device = x.device - prob_type = torch.int32 # problem sizes -> int - ptr_type = torch.int64 # pointers -> int64_t - stride_type = torch.int64 # strides -> int64_t - + device = x.place + prob_type = "int32" + ptr_type = "int64" + stride_type = "int64" seg_indptr = seg_indptr.to(ptr_type) if weight_indices is not None: weight_indices = weight_indices.to(ptr_type) - - d_out = weights.size(1) if w_column_major else weights.size(2) - d_in = weights.size(2) if w_column_major else weights.size(1) - - all_problems = torch.empty((batch_size, 3), dtype=prob_type, device=device) - - x_data = torch.empty(batch_size, dtype=ptr_type, device=device) - w_data = torch.empty(batch_size, dtype=ptr_type, device=device) - y_data = torch.empty(batch_size, dtype=ptr_type, device=device) - - x_stride_data = torch.empty(batch_size, dtype=stride_type, device=device) - w_stride_data = torch.empty(batch_size, dtype=stride_type, device=device) - y_stride_data = torch.empty(batch_size, dtype=stride_type, device=device) - + d_out = weights.shape[1] if w_column_major else weights.shape[2] + d_in = weights.shape[2] if w_column_major else weights.shape[1] + all_problems = paddle.empty(shape=(batch_size, 3), dtype=prob_type) + x_data = paddle.empty(shape=batch_size, dtype=ptr_type) + w_data = paddle.empty(shape=batch_size, dtype=ptr_type) + y_data = paddle.empty(shape=batch_size, dtype=ptr_type) + x_stride_data = paddle.empty(shape=batch_size, dtype=stride_type) + w_stride_data = paddle.empty(shape=batch_size, dtype=stride_type) + y_stride_data = paddle.empty(shape=batch_size, dtype=stride_type) from .triton.gemm import compute_sm90_group_gemm_args - compute_sm90_group_gemm_args[(batch_size,)]( + compute_sm90_group_gemm_args[batch_size,]( all_problems, x_data, w_data, @@ -790,7 +675,6 @@ def launch_compute_sm90_group_gemm_args( d_out, w_column_major, ) - return ( all_problems, x_data, @@ -803,7 +687,7 @@ def launch_compute_sm90_group_gemm_args( class SegmentGEMMWrapper: - r"""Wrapper for segment GEMM kernels. + """Wrapper for segment GEMM kernels. Example ------- @@ -854,9 +738,9 @@ class SegmentGEMMWrapper: """ def __init__( - self, float_workspace_buffer: torch.Tensor, backend: str = "auto" + self, float_workspace_buffer: paddle.Tensor, backend: str = "auto" ) -> None: - r"""Initialize the wrapper. + """Initialize the wrapper. Parameters ---------- @@ -864,16 +748,14 @@ def __init__( The workspace buffer for the kernels, we use it for storing intermediate results in cutlass segment GEMM kernels. Encouraged size is 128MB. """ - self._int_workspace_buffer = torch.empty( - (1024 * 1024,), dtype=torch.int8, device=float_workspace_buffer.device - ) + self._int_workspace_buffer = paddle.empty(shape=(1024 * 1024,), dtype="int8") self._float_workspace_buffer = float_workspace_buffer self.backend = backend def reset_workspace_buffer( - self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor + self, float_workspace_buffer: paddle.Tensor, int_workspace_buffer: paddle.Tensor ) -> None: - r"""Reset the workspace buffer. + """Reset the workspace buffer. Parameters ---------- @@ -887,30 +769,30 @@ def reset_workspace_buffer( def run( self, - x: torch.Tensor, - weights: torch.Tensor, + x: paddle.Tensor, + weights: paddle.Tensor, batch_size: int, weight_column_major: bool, - out: Optional[torch.Tensor] = None, - seg_lens: Optional[torch.Tensor] = None, - seg_indptr: Optional[torch.Tensor] = None, - weight_indices: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r"""Run the segment GEMM kernel. + out: Optional[paddle.Tensor] = None, + seg_lens: Optional[paddle.Tensor] = None, + seg_indptr: Optional[paddle.Tensor] = None, + weight_indices: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + """Run the segment GEMM kernel. Compute the matrix multiplication between a batch of input tensor (with variable number of rows, but fixed number of columns) and a batch of weight tensor with fixed number of rows and columns: .. math:: - y[i] = x[i] \times W[i] + y[i] = x[i] \\times W[i] if :attr:`weight_indices` is provided, we will select the weight tensor based on the indices in the :attr:`weight_indices` tensor: .. math:: - y[i] = x[i] \times W[\text{weight_indices}[i]] + y[i] = x[i] \\times W[\\text{weight_indices}[i]] We use Ragged Tensor to represent the input tensor :attr:`x` and the output tensor :attr:`y`, and each x[i] is a segment of the concatenated tensor. Please see :ref:`Ragged Tensor tutorial ` for more details. @@ -919,7 +801,7 @@ def run( .. math:: - \text{seg_indptr}[i] = \sum_{j=0}^{i-1} \text{seg_lens}[j], \quad \text{seg_indptr}[0] = 0 + \\text{seg_indptr}[i] = \\sum_{j=0}^{i-1} \\text{seg_lens}[j], \\quad \\text{seg_indptr}[0] = 0 - If ``seg_lens`` is provided, then :attr:`x` has shape ``(sum(seg_lens), d_in)`` and :attr:`y` has shape ``(sum(seg_lens), d_out)``, where ``d_in`` is the number of columns of the input tensor and ``d_out`` is the @@ -962,31 +844,25 @@ def run( if seg_indptr is None: seg_indptr = get_indptr(seg_lens.to(x)) if weight_indices is None: - # create an empty CPU tensor as placeholder - weight_indices = torch.empty(0, dtype=torch.int64) - cumulative_batch_size = x.size(0) - d_out = weights.size(1) if weight_column_major else weights.size(2) + weight_indices = paddle.empty(shape=[0], dtype="int64") + cumulative_batch_size = x.shape[0] + d_out = weights.shape[1] if weight_column_major else weights.shape[2] if out is None: if is_float8(x): - out_dtype = torch.bfloat16 + out_dtype = "bfloat16" else: out_dtype = x.dtype - out = torch.zeros( - (cumulative_batch_size, d_out), dtype=out_dtype, device=x.device + out = paddle.zeros(shape=(cumulative_batch_size, d_out), dtype=out_dtype) + elif tuple(out.shape) != (cumulative_batch_size, d_out): + raise ValueError( + f"Output tensor shape mismatch, expected {cumulative_batch_size, d_out}, got {tuple(out.shape)}" ) - else: - if out.shape != (cumulative_batch_size, d_out): - raise ValueError( - f"Output tensor shape mismatch, expected {cumulative_batch_size, d_out}, got {out.shape}" - ) - empty_x_data = torch.empty(0, dtype=x.dtype, device=x.device) - empty_y_data = torch.empty(0, dtype=out.dtype, device=out.device) - + empty_x_data = paddle.empty(shape=[0], dtype=x.dtype) + empty_y_data = paddle.empty(shape=[0], dtype=out.dtype) if self.backend == "auto": - backend = determine_gemm_backend(x.device) + backend = determine_gemm_backend(x.place) else: backend = self.backend - if backend == "sm90": ( all_problems, @@ -1015,8 +891,8 @@ def run( x_stride_data, w_stride_data, y_stride_data, - out, # for torch compile mutates_args - empty_x_data, # for kernel type dispatch + out, + empty_x_data, empty_y_data, weight_column_major, ) @@ -1075,37 +951,29 @@ def _check_cudnn_availability(): """Check if cuDNN is available and raise exception if not.""" if not CUDNN_AVAILABLE: raise RuntimeError( - "cuDNN is not available. Please install cuDNN to use FP8 GEMM functions. " - "You can install it with: pip install nvidia-cudnn-cu12 nvidia-cudnn-frontend" + "cuDNN is not available. Please install cuDNN to use FP8 GEMM functions. You can install it with: pip install nvidia-cudnn-cu12 nvidia-cudnn-frontend" ) def _check_cudnn_fp4_availability(): """Check if cuDNN FP4 support is available and raise exception if not.""" _check_cudnn_availability() - - # Check cuDNN version for FP4 support (requires 1.13.* or later) try: version_str = cudnn.__version__ major, minor = map(int, version_str.split(".")[:2]) - if (major, minor) < (1, 13): raise RuntimeError( - f"cuDNN FP4 requires version 1.13+, found {version_str}. " - f"Upgrade: pip install --upgrade nvidia-cudnn-cu12 nvidia-cudnn-frontend" + f"cuDNN FP4 requires version 1.13+, found {version_str}. Upgrade: pip install --upgrade nvidia-cudnn-cu12 nvidia-cudnn-frontend" ) except (ImportError, AttributeError, ValueError, IndexError) as e: raise RuntimeError( "Unable to determine cuDNN version. FP4 requires cuDNN 1.13+." ) from e - - # Check cuDNN backend version for FP4 support (requires >= 91002) try: backend_version = cudnn.backend_version() if backend_version < 91002: raise RuntimeError( - f"cuDNN FP4 requires backend version >= 91002, found {backend_version}. " - f"Please upgrade cuDNN backend." + f"cuDNN FP4 requires backend version >= 91002, found {backend_version}. Please upgrade cuDNN backend." ) except (AttributeError, TypeError) as e: raise RuntimeError( @@ -1116,8 +984,6 @@ def _check_cudnn_fp4_availability(): def _is_cublas_fp4_available_in_cudnn(): """Check if cuBLAS backend for FP4 GEMM is available in cuDNN.""" _check_cudnn_availability() - - # Check cuDNN backend version for FP4 support (requires cudnn_version == 9.11.1 or cudnn_version >= 9.13) backend_version = cudnn.backend_version() CUDNN_VERSION_9_11_1 = 91101 CUDNN_VERSION_9_13_0 = 91300 @@ -1129,17 +995,16 @@ def _is_cublas_fp4_available_in_cudnn(): def _get_native_fp4_dtype(): """get native fp4 datatype if supported in the torch, otherwise return uint8.""" - if hasattr(torch, "float4_e2m1fn_x2"): - return torch.float4_e2m1fn_x2 + if hasattr(paddle, "float4_e2m1fn_x2"): + return paddle.float4_e2m1fn_x2 else: - return torch.uint8 + return "uint8" -# Global cudnn handle. need to make it per device in future _cudnn_handle = None -def _get_cudnn_handle(stream: torch.cuda.Stream): +def _get_cudnn_handle(stream: paddle.device.Stream): """Create and return a cached cuDNN handle.""" global _cudnn_handle if _cudnn_handle is None: @@ -1149,12 +1014,11 @@ def _get_cudnn_handle(stream: torch.cuda.Stream): return _cudnn_handle -def _validate_fp8_output_dtype(dtype: torch.dtype): +def _validate_fp8_output_dtype(dtype: paddle.dtype): """Validate that the output dtype is either bf16 or fp16.""" - if dtype not in (torch.bfloat16, torch.float16): + if dtype not in ("bfloat16", "float16"): raise ValueError( - f"Unsupported output dtype: {dtype}. " - f"Only torch.bfloat16 and torch.float16 are supported for FP8 GEMM operations." + f"Unsupported output dtype: {dtype}. Only torch.bfloat16 and torch.float16 are supported for FP8 GEMM operations." ) @@ -1175,7 +1039,7 @@ def build_cudnn_gemm_block_scale_dequantize_graph( device, ): _check_cudnn_availability() - stream = torch.cuda.current_stream(device) + stream = paddle.device.current_stream(device=device2str(device)) with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _): a_cudnn_tensor = graph.tensor( name="a", dim=a_shape, stride=a_stride, data_type=ab_type @@ -1224,7 +1088,6 @@ def build_cudnn_gemm_block_scale_dequantize_graph( name="gemm", ) c_tensor.set_data_type(cudnn.data_type.FLOAT) - c_final_cudnn_tensor = graph.mul( name="scale_mul", a=c_tensor, @@ -1232,25 +1095,19 @@ def build_cudnn_gemm_block_scale_dequantize_graph( compute_data_type=cudnn.data_type.FLOAT, ) c_final_cudnn_tensor.set_name("c_final").set_output(True).set_data_type(o_type) - a_cudnn_tensor.set_uid(UIDs.A_UID.value) b_cudnn_tensor.set_uid(UIDs.B_UID.value) block_descale_a_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_A_UID.value) block_descale_b_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_B_UID.value) global_scale_cudnn_tensor.set_uid(UIDs.ALPHA_UID.value) c_final_cudnn_tensor.set_uid(UIDs.O_UID.value) - graph.validate() graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.B]) - - # WAR: The alpha (contains the global scale) is not supported by the cuBLAS backend (eng0) - # in older cuDNN versions, so we deselect it. if not _is_cublas_fp4_available_in_cudnn(): graph.deselect_engines(["eng0"]) graph.check_support() graph.build_plans() - return graph @@ -1260,19 +1117,14 @@ def execute_cudnn_gemm_fp4_graph( variant_pack = { UIDs.A_UID.value: a.view(_get_native_fp4_dtype()), UIDs.B_UID.value: b.view(_get_native_fp4_dtype()), - UIDs.BLOCK_DESCALE_A_UID.value: a_descale.view(torch.float8_e4m3fn), - UIDs.BLOCK_DESCALE_B_UID.value: b_descale.view(torch.float8_e4m3fn), - UIDs.ALPHA_UID.value: alpha.view(torch.float), + UIDs.BLOCK_DESCALE_A_UID.value: a_descale.view(paddle.float8_e4m3fn), + UIDs.BLOCK_DESCALE_B_UID.value: b_descale.view(paddle.float8_e4m3fn), + UIDs.ALPHA_UID.value: alpha.view("float32"), UIDs.O_UID.value: c_final, } - - if workspace_buffer.numel() < graph.get_workspace_size(): - workspace_buffer = torch.empty( - graph.get_workspace_size(), device=a.device, dtype=torch.uint8 - ) - - stream = torch.cuda.current_stream(a.device) - + if workspace_buffer.size < graph.get_workspace_size(): + workspace_buffer = paddle.empty(shape=graph.get_workspace_size(), dtype="uint8") + stream = paddle.device.current_stream(device=device2str(a.place)) graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream)) @@ -1297,8 +1149,7 @@ def build_cudnn_gemm_with_per_tensor_q_graph( cuDNN graph object """ _check_cudnn_availability() - - stream = torch.cuda.current_stream(device) + stream = paddle.device.current_stream(device=device2str(device)) with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _): a_cudnn_tensor = graph.tensor( name="a", dim=a_shape, stride=a_stride, data_type=a_type @@ -1337,23 +1188,19 @@ def build_cudnn_gemm_with_per_tensor_q_graph( b=b_scale_cudnn_tensor, compute_data_type=cudnn.data_type.FLOAT, ) - c_after_scale_b_cudnn_tensor.set_name("c_final").set_output(True).set_data_type( o_type ) - a_cudnn_tensor.set_uid(UIDs.A_UID.value) b_cudnn_tensor.set_uid(UIDs.B_UID.value) a_scale_cudnn_tensor.set_uid(UIDs.A_SCALE_UID.value) b_scale_cudnn_tensor.set_uid(UIDs.B_SCALE_UID.value) c_after_scale_b_cudnn_tensor.set_uid(UIDs.O_UID.value) - graph.validate() graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) graph.check_support() graph.build_plans() - return graph @@ -1367,53 +1214,46 @@ def execute_cudnn_gemm_with_per_tensor_q_graph( UIDs.B_SCALE_UID.value: b_scale, UIDs.O_UID.value: c_final, } - - stream = torch.cuda.current_stream(a.device) + stream = paddle.device.current_stream(device=device2str(a.place)) cudnn_handle = _get_cudnn_handle(stream) - - if workspace.numel() < graph.get_workspace_size(): - workspace = torch.empty( - graph.get_workspace_size(), device=a.device, dtype=torch.uint8 - ) - + if workspace.size < graph.get_workspace_size(): + workspace = paddle.empty(shape=graph.get_workspace_size(), dtype="uint8") graph.execute(variant_pack, workspace, handle=cudnn_handle) -def _torch_data_type_to_cudnn_data_type(dtype: torch.dtype): - if dtype == torch.bfloat16: +def _torch_data_type_to_cudnn_data_type(dtype: paddle.dtype): + if dtype == "bfloat16": return cudnn.data_type.BFLOAT16 - elif dtype == torch.float16: + elif dtype == "float16": return cudnn.data_type.HALF - elif dtype == torch.float8_e4m3fn: + elif dtype == paddle.float8_e4m3fn: return cudnn.data_type.FP8_E4M3 - elif dtype == torch.float8_e5m2: + elif dtype == paddle.float8_e5m2: return cudnn.data_type.FP8_E5M2 else: raise ValueError(f"Unsupported dtype: {dtype}") def _cudnn_gemm_fp8( - workspace: torch.Tensor, - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, - out: Optional[torch.Tensor], - torch_out_dtype: torch.dtype, + workspace: paddle.Tensor, + a: paddle.Tensor, + b: paddle.Tensor, + a_scale: paddle.Tensor, + b_scale: paddle.Tensor, + out: Optional[paddle.Tensor], + torch_out_dtype: paddle.dtype, ): _check_cudnn_availability() - graph = build_cudnn_gemm_with_per_tensor_q_graph( - a.shape, - a.stride(), - b.shape, - b.stride(), + tuple(a.shape), + a.get_strides(), + tuple(b.shape), + b.get_strides(), _torch_data_type_to_cudnn_data_type(a.dtype), _torch_data_type_to_cudnn_data_type(b.dtype), _torch_data_type_to_cudnn_data_type(torch_out_dtype), - a.device, + a.place, ) - execute_cudnn_gemm_with_per_tensor_q_graph( graph, a, b, a_scale, b_scale, out, workspace ) @@ -1423,20 +1263,17 @@ def _cudnn_gemm_fp8( def _cudnn_gemm_fp8_runner(): class CudnnFp8GemmRunner(TunableRunner): def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, + self, inputs: List[paddle.Tensor], profile: OptimizationProfile ) -> List[int]: - # cudnn has heuristic for fp8 gemm, so we only need to use the default tactic return [0] def forward( self, - inputs: List[torch.Tensor], + inputs: List[paddle.Tensor], tactic: int = -1, do_preparation: bool = False, **kwargs, - ) -> torch.Tensor: + ) -> paddle.Tensor: a, b, scale_a, scale_b, out, workspace_buffer = inputs _cudnn_gemm_fp8(workspace_buffer, a, b, scale_a, scale_b, out, out.dtype) return out @@ -1445,17 +1282,12 @@ def forward( def _get_real_fp4_shape_from_packed_uint8(packed_fp4_tensor): - # the FP4 data are packed into uint8, we need to expand the shape and stride information to get the real shape and stride to be used in the cuDNN graph. - is_column_major = packed_fp4_tensor.stride(-2) == 1 - real_shape = list(packed_fp4_tensor.shape) - real_stride = list(packed_fp4_tensor.stride()) - - # this function will be used for both mm and bmm, so we need to insert batch dimension if the tensor is 2d + is_column_major = packed_fp4_tensor.get_strides()[-2] == 1 + real_shape = list(tuple(packed_fp4_tensor.shape)) + real_stride = list(packed_fp4_tensor.get_strides()) if len(real_shape) == 2: real_shape.insert(0, 1) - real_stride.insert(0, packed_fp4_tensor.numel()) - - # each packed uint8 contains 2 fp4 elements + real_stride.insert(0, packed_fp4_tensor.size) real_shape[-2 if is_column_major else -1] *= 2 if is_column_major: real_stride[-1] *= 2 @@ -1464,24 +1296,17 @@ def _get_real_fp4_shape_from_packed_uint8(packed_fp4_tensor): else: for i in range(len(real_stride) - 1): real_stride[i] *= 2 - - return (tuple(real_shape), tuple(real_stride)) + return tuple(real_shape), tuple(real_stride) def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size): - # This function will be shared for both mm and bmm, when 2d block scale tensor is provided, we need unfold the batch dimension. the unfoled dim and stride is returned. - block_scale_shape = list(block_scale_tensor.shape) - block_scale_stride = list(block_scale_tensor.stride()) - + block_scale_shape = list(tuple(block_scale_tensor.shape)) + block_scale_stride = list(block_scale_tensor.get_strides()) if len(block_scale_shape) == 2: - # expand to 3d block_scale_shape.insert(0, batch_size) block_scale_stride.insert(0, 1) - - # update the stride and shape for the expanded dimension - is_column_major = block_scale_tensor.stride(-2) == 1 + is_column_major = block_scale_tensor.get_strides()[-2] == 1 expand_dim = 2 if is_column_major else 1 - assert block_scale_shape[expand_dim] % batch_size == 0 block_scale_shape[expand_dim] = block_scale_shape[expand_dim] // batch_size block_scale_stride[0] = ( @@ -1493,23 +1318,22 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size): raise ValueError( f"Unsupported block scale tensor shape: {block_scale_shape}, expected 2d or 3d." ) - - return (tuple(block_scale_shape), tuple(block_scale_stride)) + return tuple(block_scale_shape), tuple(block_scale_stride) def mm_fp4( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: torch.Tensor, - out_dtype: torch.dtype, - out: Optional[torch.Tensor] = None, + a: paddle.Tensor, + b: paddle.Tensor, + a_descale: paddle.Tensor, + b_descale: paddle.Tensor, + alpha: paddle.Tensor, + out_dtype: paddle.dtype, + out: Optional[paddle.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", -) -> torch.Tensor: - r"""MM FP4 +) -> paddle.Tensor: + """MM FP4 Parameters ---------- @@ -1567,73 +1391,60 @@ def mm_fp4( >>> out.shape torch.Size([48, 256]) """ - # pre-check the input tensor, block scale tensor and alpha tensor if a.ndim != 2 or b.ndim != 2: - raise ValueError(f"mm_fp4 accepts 2d tensors, got {a.shape} and {b.shape}") - if a.shape[1] != b.shape[0]: raise ValueError( - f"K dimension mismatch in mm_fp4. got a.shape[1] = {a.shape[1]}, b.shape[0] = {b.shape[0]}" + f"mm_fp4 accepts 2d tensors, got {tuple(a.shape)} and {tuple(b.shape)}" ) - if a.dtype not in {torch.uint8, _get_native_fp4_dtype()} or b.dtype not in { - torch.uint8, + if tuple(a.shape)[1] != tuple(b.shape)[0]: + raise ValueError( + f"K dimension mismatch in mm_fp4. got a.shape[1] = {tuple(a.shape)[1]}, b.shape[0] = {tuple(b.shape)[0]}" + ) + if a.dtype not in {"uint8", _get_native_fp4_dtype()} or b.dtype not in { + "uint8", _get_native_fp4_dtype(), }: raise ValueError( - f"a and b must have float4_e2m1fn_x2 packed into uint8. " - f"Got {a.dtype} and {b.dtype}." + f"a and b must have float4_e2m1fn_x2 packed into uint8. Got {a.dtype} and {b.dtype}." ) - if a_descale.dtype not in { - torch.float8_e4m3fn, - torch.uint8, - } or b_descale.dtype not in {torch.float8_e4m3fn, torch.uint8}: + if a_descale.dtype not in {paddle.float8_e4m3fn, "uint8"} or b_descale.dtype not in { + paddle.float8_e4m3fn, + "uint8", + }: raise ValueError( - f"a_descale and b_descale must have float8_e4m3fnx2 packed into uint8. " - f"Got {a_descale.dtype} and {b_descale.dtype}." + f"a_descale and b_descale must have float8_e4m3fnx2 packed into uint8. Got {a_descale.dtype} and {b_descale.dtype}." ) - if alpha.dtype != torch.float: + if alpha.dtype != "float32": raise ValueError(f"alpha must be a float tensor, got {alpha.dtype}") - if alpha.numel() != 1: - raise ValueError(f"alpha must be a scalar, got {alpha.numel()}") - - if out_dtype not in (torch.bfloat16, torch.float16): + if alpha.size != 1: + raise ValueError(f"alpha must be a scalar, got {alpha.size}") + if out_dtype not in ("bfloat16", "float16"): raise ValueError( - f"Unsupported output dtype: {out_dtype}. " - f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations." + f"Unsupported output dtype: {out_dtype}. Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations." ) if block_size != 16: raise ValueError("Only block_size = 16 is supported for FP4 GEMM operations.") if backend != "trtllm" and use_8x4_sf_layout: raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") - - # allocate the output tensor if not provided if out is None: - out = torch.empty( - (a.shape[0], b.shape[1]), - device=a.device, - dtype=out_dtype, + out = paddle.empty( + shape=(tuple(a.shape)[0], tuple(b.shape)[1]), dtype=out_dtype ) - workspace_buffer = _get_cache_buf( - "mm_fp4_workspace", DEFAULT_WORKSPACE_SIZE, a.device + "mm_fp4_workspace", DEFAULT_WORKSPACE_SIZE, a.place ) - if backend == "cudnn": _check_cudnn_fp4_availability() - - # the fp4 cudnn graph will be shared for both mm and bmm, so - # here we need to get the 3d shape and stride including the - # batch dimension for both input and block scale tensors. real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) batch = real_a_shape[0] - expanded_a_descale_shape, expanded_a_descale_stride = ( - _expand_block_scale_tensor_shape(a_descale, batch) - ) - expanded_b_descale_shape, expanded_b_descale_stride = ( - _expand_block_scale_tensor_shape(b_descale, batch) - ) - - # build the fp4 cudnn graph + ( + expanded_a_descale_shape, + expanded_a_descale_stride, + ) = _expand_block_scale_tensor_shape(a_descale, batch) + ( + expanded_b_descale_shape, + expanded_b_descale_stride, + ) = _expand_block_scale_tensor_shape(b_descale, batch) graph = build_cudnn_gemm_block_scale_dequantize_graph( real_a_shape, real_a_stride, @@ -1644,23 +1455,19 @@ def mm_fp4( expanded_b_descale_shape, expanded_b_descale_stride, cudnn.data_type.FP4_E2M1, - torch.float8_e4m3fn, + paddle.float8_e4m3fn, _torch_data_type_to_cudnn_data_type(out_dtype), block_size, - a.device, + a.place, ) - - # execute the fp4 cudnn graph execute_cudnn_gemm_fp4_graph( graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer ) elif backend == "trtllm": - if out_dtype != torch.bfloat16: + if out_dtype != "bfloat16": raise ValueError( - f"Unsupported output dtype: {out_dtype}. " - f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations." + f"Unsupported output dtype: {out_dtype}. Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations." ) - get_trtllm_fp4_gemm_module().trtllm_fp4_gemm( a, b.T, @@ -1672,11 +1479,10 @@ def mm_fp4( workspace_buffer=workspace_buffer, ) elif backend == "cutlass": - # cutlass require uint8 scale when a/b is fp4 packed uint8. - if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn: - a_descale = a_descale.view(torch.uint8) - if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn: - b_descale = b_descale.view(torch.uint8) + if a.dtype == "uint8" and a_descale.dtype == paddle.float8_e4m3fn: + a_descale = a_descale.view("uint8") + if b.dtype == "uint8" and b_descale.dtype == paddle.float8_e4m3fn: + b_descale = b_descale.view("uint8") get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm( a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer ) @@ -1684,15 +1490,15 @@ def mm_fp4( def bmm_fp8( - A: torch.Tensor, - B: torch.Tensor, - A_scale: torch.Tensor, - B_scale: torch.Tensor, - dtype: torch.dtype, - out: Optional[torch.Tensor] = None, + A: paddle.Tensor, + B: paddle.Tensor, + A_scale: paddle.Tensor, + B_scale: paddle.Tensor, + dtype: paddle.dtype, + out: Optional[paddle.Tensor] = None, backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", -) -> torch.Tensor: - r"""BMM FP8 +) -> paddle.Tensor: + """BMM FP8 Parameters ---------- @@ -1728,7 +1534,7 @@ def bmm_fp8( >>> import torch >>> import torch.nn.functional as F >>> import flashinfer - >>> def to_float8(x, dtype=torch.float8_e4m3fn): + >>> def to_float8(x, dtype=paddle.float8_e4m3fn): ... finfo = torch.finfo(dtype) ... min_val, max_val = x.aminmax() ... amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) @@ -1737,10 +1543,10 @@ def bmm_fp8( ... return x_scl_sat.to(dtype), scale.float().reciprocal() >>> >>> input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) - >>> input_fp8, input_inv_s = to_float8(input, dtype=torch.float8_e4m3fn) + >>> input_fp8, input_inv_s = to_float8(input, dtype=paddle.float8_e4m3fn) >>> # column major weight >>> weight = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) - >>> weight_fp8, weight_inv_s = to_float8(weight, dtype=torch.float8_e4m3fn) + >>> weight_fp8, weight_inv_s = to_float8(weight, dtype=paddle.float8_e4m3fn) >>> out = flashinfer.bmm_fp8(input_fp8, weight_fp8, input_inv_s, weight_inv_s, torch.bfloat16) >>> out.shape torch.Size([16, 48, 80]) @@ -1748,48 +1554,42 @@ def bmm_fp8( torch.bfloat16 """ _validate_fp8_output_dtype(dtype) - if out is None: - out = torch.empty( - (A.shape[0], A.shape[1], B.shape[2]), - device=A.device, - dtype=dtype, + out = paddle.empty( + shape=(tuple(A.shape)[0], tuple(A.shape)[1], tuple(B.shape)[2]), dtype=dtype ) - workspace_buffer = _get_cache_buf( - "bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device + "bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.place ) - if backend == "cudnn": backends = ["cudnn"] elif backend == "cublas": backends = ["cublas"] elif backend == "cutlass": - if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2: + if A.dtype == paddle.float8_e5m2 or B.dtype == paddle.float8_e5m2: raise ValueError("e5m2 is not supported for cutlass backend") backends = ["cutlass"] elif backend == "auto": backends = ["cutlass", "cublas", "cudnn"] else: raise ValueError(f"Unsupported backend: {backend}") - fp8_gemm_sm100(A, B, A_scale, B_scale, out, workspace_buffer, backends) return out def gemm_fp8_nt_groupwise( - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, + a: paddle.Tensor, + b: paddle.Tensor, + a_scale: paddle.Tensor, + b_scale: paddle.Tensor, scale_major_mode: Optional[Literal["MN", "K"]] = None, mma_sm: int = 1, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), - out: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, + out: Optional[paddle.Tensor] = None, + out_dtype: Optional[paddle.dtype] = None, backend: Literal["cutlass", "trtllm"] = "cutlass", -) -> torch.Tensor: - r"""Performs matrix multiplication with FP8 data types using groupwise scaling. +) -> paddle.Tensor: + """Performs matrix multiplication with FP8 data types using groupwise scaling. This function implements a GEMM operation that allows for fine-grained control over scale granularity across different dimensions. Currently only supported on NVIDIA @@ -1851,40 +1651,27 @@ def gemm_fp8_nt_groupwise( The ``m`` should be padded to a multiple of 4 before calling this function, to accommodate the kernel's requirement. """ workspace_buffer = _get_cache_buf( - "gemm_fp8_nt_groupwise_workspace", DEFAULT_WORKSPACE_SIZE, a.device + "gemm_fp8_nt_groupwise_workspace", DEFAULT_WORKSPACE_SIZE, a.place ) if a.ndim != 2 or b.ndim != 2: - raise ValueError(f"Shape mismatch. a.shape = {a.shape}, b.shape = {b.shape}") - - if a.shape[1] != b.shape[1]: raise ValueError( - f"Shape mismatch. a.shape[1] = {a.shape[1]}, b.shape[1] = {b.shape[1]}" + f"Shape mismatch. a.shape = {tuple(a.shape)}, b.shape = {tuple(b.shape)}" + ) + if tuple(a.shape)[1] != tuple(b.shape)[1]: + raise ValueError( + f"Shape mismatch. a.shape[1] = {tuple(a.shape)[1]}, b.shape[1] = {tuple(b.shape)[1]}" ) - if out is None: - out_dtype = out_dtype or torch.bfloat16 + out_dtype = out_dtype or "bfloat16" else: out_dtype = out.dtype - _validate_fp8_output_dtype(out_dtype) - - # NOTE(Zihao): (out_specified, need_padding) - # (False, False) -> create out_padded tensor explicitly - # (False, True) -> create out_padded tensor explicitly - # (True, False) -> use out tensor as out_padded - # (True, True) -> create out_padded tensor explicitly - if out is None: - out = torch.empty( - a.shape[0], - b.shape[0], - device=a.device, - dtype=out_dtype, + out = paddle.empty( + shape=[tuple(a.shape)[0], tuple(b.shape)[0]], dtype=out_dtype ) - - if not _match_sm_version(a.device, "100"): + if not _match_sm_version(a.place, "100"): raise ValueError("gemm_fp8_nt_groupwise is only supported on SM100.") - if backend == "cutlass": assert scale_major_mode is not None get_gemm_sm100_module().gemm_fp8_nt_groupwise.default( @@ -1900,20 +1687,10 @@ def gemm_fp8_nt_groupwise( ) elif backend == "trtllm": assert scale_granularity_mnk == (1, 128, 128) - assert a.shape[1] >= 256 - # mma_sm is ignored + assert tuple(a.shape)[1] >= 256 get_trtllm_gemm_module().trtllm_gemm( - workspace_buffer, - a, - b, - a_scale, - b_scale, - None, - out, - False, - -1, + workspace_buffer, a, b, a_scale, b_scale, None, out, False, -1 ) - return out @@ -1929,27 +1706,16 @@ def __init__(self, use_8x4_sf_layout: bool = True): self._use_8x4_sf_layout = use_8x4_sf_layout def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, + self, inputs: List[paddle.Tensor], profile: OptimizationProfile ) -> List[int]: a_tensor_index = 1 b_tensor_index = 2 - a = profile.get_opt_shapes()[a_tensor_index] b = profile.get_opt_shapes()[b_tensor_index] m = a[0] n = b[0] k = a[1] * 2 - ( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ) = inputs + workspace_buffer, a, b, a_descale, b_descale, alpha, out = inputs type_e2m1 = 0 type_bf16 = 2 return list( @@ -1960,20 +1726,12 @@ def get_valid_tactics( def forward( self, - inputs: List[torch.Tensor], + inputs: List[paddle.Tensor], tactic: int = -1, do_preparation: bool = False, **kwargs, ): - ( - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ) = inputs + workspace_buffer, a, b, a_descale, b_descale, alpha, out = inputs op.trtllm_gemm.default( workspace_buffer, a, @@ -1987,28 +1745,24 @@ def forward( ) return out - @register_custom_op( - "flashinfer::trtllm_fp4_gemm", - mutates_args=(""), - ) + @register_custom_op("flashinfer::trtllm_fp4_gemm", mutates_args="") def trtllm_fp4_gemm( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: torch.Tensor, - out: torch.Tensor, + a: paddle.Tensor, + b: paddle.Tensor, + a_descale: paddle.Tensor, + b_descale: paddle.Tensor, + alpha: paddle.Tensor, + out: paddle.Tensor, use_8x4_sf_layout: bool, - workspace_buffer: torch.Tensor, + workspace_buffer: paddle.Tensor, ): tuner = AutoTuner.get() - a_tensor_index = 1 a_scale_tensor_index = 3 out_tensor_index = 6 def pad_up(x, y): - return ((x + y - 1) // y) * y + return (x + y - 1) // y * y tuning_config = TuningConfig( dynamic_tensor_specs=( @@ -2032,44 +1786,30 @@ def pad_up(x, y): ), ), ) - fp4_runner = TrtllmFp4GemmRunner(use_8x4_sf_layout) - - inputs = [ - workspace_buffer, - a, - b, - a_descale, - b_descale, - alpha, - out, - ] + inputs = [workspace_buffer, a, b, a_descale, b_descale, alpha, out] _, tactic = tuner.choose_one( "trtllm_fp4_gemm_8x4" if use_8x4_sf_layout else "trtllm_fp4_gemm_128x4", [fp4_runner], tuning_config, inputs, ) - fp4_runner(inputs=inputs, tactic=tactic) - # Register the module - return SimpleNamespace( - trtllm_fp4_gemm=trtllm_fp4_gemm, - ) + return SimpleNamespace(trtllm_fp4_gemm=trtllm_fp4_gemm) def gemm_fp8_nt_blockscaled( - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, + a: paddle.Tensor, + b: paddle.Tensor, + a_scale: paddle.Tensor, + b_scale: paddle.Tensor, scale_major_mode: Optional[Literal["MN", "K"]] = "MN", mma_sm: int = 1, - out: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, -) -> torch.Tensor: - r"""Performs matrix multiplication with FP8 data types using block-scaled scaling. + out: Optional[paddle.Tensor] = None, + out_dtype: Optional[paddle.dtype] = None, +) -> paddle.Tensor: + """Performs matrix multiplication with FP8 data types using block-scaled scaling. Block-scaled scaling is a special case of groupwise scaling where the scale granularity is (128, 128, 128). @@ -2088,28 +1828,28 @@ def gemm_fp8_nt_blockscaled( def group_gemm_fp8_nt_groupwise( - a: torch.Tensor, # (cum_m, k) - b: torch.Tensor, # (batch_size, n, k) - a_scale: torch.Tensor, # (k // block_size, cum_m) - b_scale: torch.Tensor, # (batch_size, k // block_size, n // block_size) - m_indptr: torch.Tensor, # (batch_size + 1, ) + a: paddle.Tensor, + b: paddle.Tensor, + a_scale: paddle.Tensor, + b_scale: paddle.Tensor, + m_indptr: paddle.Tensor, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), scale_major_mode: Literal["MN", "K"] = "MN", mma_sm: int = 1, - out: Optional[torch.Tensor] = None, # (cum_m, n) - out_dtype: Optional[torch.dtype] = None, -) -> torch.Tensor: - r"""Perform group GEMM with FP8 data types using groupwise scaling. Currently only supported on NVIDIA + out: Optional[paddle.Tensor] = None, + out_dtype: Optional[paddle.dtype] = None, +) -> paddle.Tensor: + """Perform group GEMM with FP8 data types using groupwise scaling. Currently only supported on NVIDIA Blackwell architecture. Parameters ---------- a: torch.Tensor - Row-major input tensor shape ``(cum_m, k)``, data type is ``torch.float8_e4m3fn`` or ``torch.float8_e5m2``. + Row-major input tensor shape ``(cum_m, k)``, data type is ``paddle.float8_e4m3fn`` or ``paddle.float8_e5m2``. ``cum_m`` is the cumulative sum of the segment lengths. b: torch.Tensor - Column-major input tensor shape ``(batch_size, n, k)``, data type is ``torch.float8_e4m3fn`` or ``torch.float8_e5m2``. + Column-major input tensor shape ``(batch_size, n, k)``, data type is ``paddle.float8_e4m3fn`` or ``paddle.float8_e5m2``. a_scale: torch.Tensor Column-major scale tensor for a, shape ``(cum_m, k // block_size)`` if scale_major_mode is ``K`` @@ -2151,50 +1891,42 @@ def group_gemm_fp8_nt_groupwise( Each value in ``m_indptr`` should be padded to a multiple of 4 before calling this function, to accommodate the kernel's requirement. """ - if not _match_sm_version(a.device, "100"): + if not _match_sm_version(a.place, "100"): raise ValueError("gemm_fp8_nt_groupwise is only supported on SM100.") - int_workspace_buffer = _get_cache_buf( - "group_gemm_fp8_nt_groupwise_int_workspace", DEFAULT_WORKSPACE_SIZE, a.device + "group_gemm_fp8_nt_groupwise_int_workspace", DEFAULT_WORKSPACE_SIZE, a.place ) float_workspace_buffer = _get_cache_buf( - "group_gemm_fp8_nt_groupwise_float_workspace", DEFAULT_WORKSPACE_SIZE, a.device + "group_gemm_fp8_nt_groupwise_float_workspace", DEFAULT_WORKSPACE_SIZE, a.place ) - - assert a.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - assert b.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - assert a_scale.dtype == torch.float32 - assert b_scale.dtype == torch.float32 - assert m_indptr.dtype == torch.int32 + assert a.dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2] + assert b.dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2] + assert a_scale.dtype == "float32" + assert b_scale.dtype == "float32" + assert m_indptr.dtype == "int32" assert scale_major_mode in ["MN", "K"] assert mma_sm in [1, 2] if out is None: if out_dtype is None: - out_dtype = torch.bfloat16 - else: - if out_dtype is None: - out_dtype = out.dtype + out_dtype = "bfloat16" + elif out_dtype is None: + out_dtype = out.dtype _validate_fp8_output_dtype(out_dtype) - - num_groups = m_indptr.shape[0] - 1 - assert b.shape[0] == num_groups - n = b.shape[1] - k = b.shape[2] - - # assert a.shape[0] == m_indptr[-1].item() # Not enabled in consideration of performance - assert a.shape[1] == k + num_groups = tuple(m_indptr.shape)[0] - 1 + assert tuple(b.shape)[0] == num_groups + n = tuple(b.shape)[1] + k = tuple(b.shape)[2] + assert tuple(a.shape)[1] == k align_n = 8 align_k = 16 assert n % align_n == 0 assert k % align_k == 0 - - out_shape = (a.shape[0], n) + out_shape = tuple(a.shape)[0], n if out is None: - out = torch.empty(out_shape, dtype=out_dtype, device=a.device) + out = paddle.empty(shape=out_shape, dtype=out_dtype) else: - assert out.shape == out_shape + assert tuple(out.shape) == out_shape assert out.dtype == out_dtype - get_gemm_sm100_module().group_gemm_fp8_nt_groupwise.default( int_workspace_buffer, float_workspace_buffer, @@ -2214,26 +1946,26 @@ def group_gemm_fp8_nt_groupwise( def group_gemm_mxfp8_mxfp4_nt_groupwise( - a: torch.Tensor, # (cum_m, k) - b: torch.Tensor, # (batch_size, n, k // 2) - a_scale: torch.Tensor, # (cum_m_padded, k // 32) - b_scale: torch.Tensor, # (batch_size, n_padded, k // 32) - m_indptr: torch.Tensor, # (batch_size + 1, ) + a: paddle.Tensor, + b: paddle.Tensor, + a_scale: paddle.Tensor, + b_scale: paddle.Tensor, + m_indptr: paddle.Tensor, mma_sm: int = 1, tile_m: int = 128, tile_n: int = 128, tile_k: int = 128, swap_ab: bool = True, - out: Optional[torch.Tensor] = None, # (cum_m, n) - out_dtype: Optional[torch.dtype] = None, -) -> torch.Tensor: - r"""Perform group GEMM with MXFP4 data types using groupwise scaling. Currently only supported on NVIDIA + out: Optional[paddle.Tensor] = None, + out_dtype: Optional[paddle.dtype] = None, +) -> paddle.Tensor: + """Perform group GEMM with MXFP4 data types using groupwise scaling. Currently only supported on NVIDIA Blackwell architecture. Parameters ---------- a: torch.Tensor - Row-major input tensor, shape ``(cum_m, k)``, data type is ``torch.float8_e4m3fn`` or ``torch.float8_e5m2``. + Row-major input tensor, shape ``(cum_m, k)``, data type is ``paddle.float8_e4m3fn`` or ``paddle.float8_e5m2``. ``cum_m`` is the cumulative sum of the segment lengths. b: torch.Tensor @@ -2282,19 +2014,16 @@ def group_gemm_mxfp8_mxfp4_nt_groupwise( to accommodate the kernel's requirement. """ int_workspace_buffer = _get_cache_buf( - "group_gemm_mxfp4_nt_groupwise_int_workspace", DEFAULT_WORKSPACE_SIZE, a.device + "group_gemm_mxfp4_nt_groupwise_int_workspace", DEFAULT_WORKSPACE_SIZE, a.place ) float_workspace_buffer = _get_cache_buf( - "group_gemm_mxfp4_nt_groupwise_float_workspace", - DEFAULT_WORKSPACE_SIZE, - a.device, + "group_gemm_mxfp4_nt_groupwise_float_workspace", DEFAULT_WORKSPACE_SIZE, a.place ) - - assert a.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - assert b.dtype == torch.uint8 - assert a_scale.dtype == torch.uint8 - assert b_scale.dtype == torch.uint8 - assert m_indptr.dtype == torch.int32 + assert a.dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2] + assert b.dtype == "uint8" + assert a_scale.dtype == "uint8" + assert b_scale.dtype == "uint8" + assert m_indptr.dtype == "int32" assert mma_sm in [1, 2] assert tile_m in [128] assert tile_n in [64, 128, 192, 256] @@ -2302,31 +2031,25 @@ def group_gemm_mxfp8_mxfp4_nt_groupwise( assert swap_ab in [True, False] if out is None: if out_dtype is None: - out_dtype = torch.bfloat16 - else: - if out_dtype is None: - out_dtype = out.dtype - assert out_dtype in [torch.bfloat16, torch.float16] - - num_groups = m_indptr.shape[0] - 1 - assert b.shape[0] == num_groups - n = b.shape[1] - k = b.shape[2] * 2 # Multiply by 2 because b is e2m1 packed as uint8 - - # assert a.shape[0] == m_indptr[-1].item() # Not enabled in consideration of performance - assert a.shape[1] == k + out_dtype = "bfloat16" + elif out_dtype is None: + out_dtype = out.dtype + assert out_dtype in ["bfloat16", "float16"] + num_groups = tuple(m_indptr.shape)[0] - 1 + assert tuple(b.shape)[0] == num_groups + n = tuple(b.shape)[1] + k = tuple(b.shape)[2] * 2 + assert tuple(a.shape)[1] == k align_n = 8 align_k = 128 assert n % align_n == 0 assert k % align_k == 0 - - out_shape = (a.shape[0], n) + out_shape = tuple(a.shape)[0], n if out is None: - out = torch.empty(out_shape, dtype=out_dtype, device=a.device) + out = paddle.empty(shape=out_shape, dtype=out_dtype) else: - assert out.shape == out_shape + assert tuple(out.shape) == out_shape assert out.dtype == out_dtype - get_gemm_sm100_module().group_gemm_mxfp4_nt_groupwise.default( int_workspace_buffer, float_workspace_buffer, @@ -2347,30 +2070,22 @@ def group_gemm_mxfp8_mxfp4_nt_groupwise( return out -# NOTE(Zihao): keep the old name for backward compatibility group_gemm_mxfp4_nt_groupwise = group_gemm_mxfp8_mxfp4_nt_groupwise -def pad_indptr_to_multiple_of_4( - m_indptr: torch.Tensor, -): +def pad_indptr_to_multiple_of_4(m_indptr: paddle.Tensor): from .triton.gemm import compute_padding_mapping - batch_size = m_indptr.shape[0] - 1 + batch_size = tuple(m_indptr.shape)[0] - 1 m = m_indptr[1:] - m_indptr[:-1] m = m + 3 - (m + 3) % 4 - padded_m_indptr = torch.cat((torch.zeros((1,), device=m.device, dtype=m.dtype), m)) - padded_m_indptr = padded_m_indptr.cumsum(dim=0, dtype=padded_m_indptr.dtype) - - m_rank = torch.zeros((m_indptr[-1],), dtype=m_indptr.dtype, device=m_indptr.device) - padded_m_rank = torch.zeros( - (m_indptr[-1],), dtype=m_indptr.dtype, device=m_indptr.device - ) - - compute_padding_mapping[(batch_size,)]( - m_indptr, padded_m_indptr, m_rank, padded_m_rank - ) - + padded_m_indptr = paddle.concat(x=(paddle.zeros(shape=(1,), dtype=m.dtype), m)) + padded_m_indptr = padded_m_indptr.cumsum(axis=0, dtype=padded_m_indptr.dtype) + m_rank = paddle.zeros(shape=(m_indptr[-1],), dtype=m_indptr.dtype) + padded_m_rank = paddle.zeros(shape=(m_indptr[-1],), dtype=m_indptr.dtype) + compute_padding_mapping[ + batch_size, + ](m_indptr, padded_m_indptr, m_rank, padded_m_rank) return padded_m_indptr, padded_m_rank @@ -2391,16 +2106,16 @@ def get_deepgemm_sm100_module(): def group_deepgemm_fp8_nt_groupwise( - a: torch.Tensor, # (m, k) - b: torch.Tensor, # (batch_size, n, k) - a_scale: torch.Tensor, # (m, k // block_size) - b_scale: torch.Tensor, # (batch_size, n // block_size, k // block_size) - m_indices: torch.Tensor, # (m, ) + a: paddle.Tensor, + b: paddle.Tensor, + a_scale: paddle.Tensor, + b_scale: paddle.Tensor, + m_indices: paddle.Tensor, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), - out: Optional[torch.Tensor] = None, # (m, n) - out_dtype: Optional[torch.dtype] = None, + out: Optional[paddle.Tensor] = None, + out_dtype: Optional[paddle.dtype] = None, ): - r"""Perform grouped matrix multiplication with FP8 data types using DeepGEMM backend. + """Perform grouped matrix multiplication with FP8 data types using DeepGEMM backend. This function performs a grouped GEMM operation where each group in tensor `b` is multiplied with the corresponding rows in tensor `a`. The grouping is determined by the `m_indices` tensor, @@ -2418,11 +2133,11 @@ def group_deepgemm_fp8_nt_groupwise( Parameters ---------- a : torch.Tensor - Input tensor A of shape ``(m, k)`` with FP8 data type (``torch.float8_e4m3fn``). + Input tensor A of shape ``(m, k)`` with FP8 data type (``paddle.float8_e4m3fn``). This tensor contains all rows that will be multiplied with different groups in `b`. b : torch.Tensor - Input tensor B of shape ``(batch_size, n, k)`` with FP8 data type (``torch.float8_e4m3fn``). + Input tensor B of shape ``(batch_size, n, k)`` with FP8 data type (``paddle.float8_e4m3fn``). Each slice ``b[i]`` represents a different group/expert that will be multiplied with the corresponding rows in `a`. @@ -2476,7 +2191,7 @@ def group_deepgemm_fp8_nt_groupwise( >>> >>> # Quantize to FP8 with appropriate scaling >>> a_fp8, a_scale = per_token_cast_to_fp8(a_f32) - >>> b_fp8 = torch.empty_like(b_f32, dtype=torch.float8_e4m3fn) + >>> b_fp8 = torch.empty_like(b_f32, dtype=paddle.float8_e4m3fn) >>> b_scale = torch.empty((group_size, n // 128, k // 128), device="cuda", dtype=torch.float32) >>> for i in range(group_size): ... b_fp8[i], b_scale[i] = per_block_cast_to_fp8(b_f32[i]) @@ -2505,28 +2220,28 @@ def group_deepgemm_fp8_nt_groupwise( from flashinfer.deep_gemm import m_grouped_fp8_gemm_nt_contiguous if out is None: - out_dtype = out_dtype or torch.bfloat16 - out = torch.empty(a.shape[0], b.shape[1], dtype=out_dtype, device=a.device) - + out_dtype = out_dtype or "bfloat16" + out = paddle.empty( + shape=[tuple(a.shape)[0], tuple(b.shape)[1]], dtype=out_dtype + ) m_grouped_fp8_gemm_nt_contiguous( (a, a_scale), (b, b_scale), out, m_indices, scale_granularity_mnk ) - return out def batch_deepgemm_fp8_nt_groupwise( - a: torch.Tensor, # (batch_size, m, k) - b: torch.Tensor, # (batch_size, n, k) - a_scale: torch.Tensor, # (batch_size, m, k // block_size) - b_scale: torch.Tensor, # (batch_size, n // block_size, k // block_size) - masked_m: torch.Tensor, # (batch_size, ) + a: paddle.Tensor, + b: paddle.Tensor, + a_scale: paddle.Tensor, + b_scale: paddle.Tensor, + masked_m: paddle.Tensor, expected_m: int, scale_granularity_mnk: Tuple[int, int, int] = (1, 128, 128), - out: Optional[torch.Tensor] = None, # (batch_size, m, n) - out_dtype: Optional[torch.dtype] = None, + out: Optional[paddle.Tensor] = None, + out_dtype: Optional[paddle.dtype] = None, ): - r"""Perform batch matrix multiplication with FP8 data types using DeepGEMM backend. + """Perform batch matrix multiplication with FP8 data types using DeepGEMM backend. This function performs a batch GEMM operation where each group in tensor `b` is multiplied with the corresponding group of rows in tensor `a`. The results of each group is masked by @@ -2543,12 +2258,12 @@ def batch_deepgemm_fp8_nt_groupwise( Parameters ---------- a : torch.Tensor - Input tensor A of shape ``(batch_size, m, k)`` with FP8 data type (``torch.float8_e4m3fn``). + Input tensor A of shape ``(batch_size, m, k)`` with FP8 data type (``paddle.float8_e4m3fn``). Each slice ``a[i]`` represents a group of rows that will be multiplied with the corresponding group/expert in `b`. b : torch.Tensor - Input tensor B of shape ``(batch_size, n, k)`` with FP8 data type (``torch.float8_e4m3fn``). + Input tensor B of shape ``(batch_size, n, k)`` with FP8 data type (``paddle.float8_e4m3fn``). Each slice ``b[i]`` represents a different group/expert that will be multiplied with the corresponding rows in `a`. @@ -2603,9 +2318,9 @@ def batch_deepgemm_fp8_nt_groupwise( >>> a = torch.rand((group_size, m, k), device="cuda", dtype=torch.float32) >>> b = torch.rand((group_size, n, k), device="cuda", dtype=torch.float32) >>> masked_m = torch.randint(0, m, (group_size,), device="cuda", dtype=torch.int32) - >>> a_fp8 = torch.empty_like(a, device="cuda", dtype=torch.float8_e4m3fn) + >>> a_fp8 = torch.empty_like(a, device="cuda", dtype=paddle.float8_e4m3fn) >>> a_scale = torch.empty((group_size, m, k // 128), device="cuda", dtype=torch.float32) - >>> b_fp8 = torch.empty_like(b, device="cuda", dtype=torch.float8_e4m3fn) + >>> b_fp8 = torch.empty_like(b, device="cuda", dtype=paddle.float8_e4m3fn) >>> b_scale = torch.empty( ... (group_size, n // 128, k // 128), device="cuda", dtype=torch.float32 >>> ) @@ -2633,13 +2348,12 @@ def batch_deepgemm_fp8_nt_groupwise( from flashinfer.deep_gemm import m_grouped_fp8_gemm_nt_masked if out is None: - out_dtype = out_dtype or torch.bfloat16 - out = torch.empty( - a.shape[0], a.shape[1], b.shape[1], dtype=out_dtype, device=a.device + out_dtype = out_dtype or "bfloat16" + out = paddle.empty( + shape=[tuple(a.shape)[0], tuple(a.shape)[1], tuple(b.shape)[1]], + dtype=out_dtype, ) - m_grouped_fp8_gemm_nt_masked( (a, a_scale), (b, b_scale), out, masked_m, expected_m, scale_granularity_mnk ) - return out diff --git a/flashinfer/green_ctx.py b/flashinfer/green_ctx.py index 1aa9538bf5..c4f2158e0d 100644 --- a/flashinfer/green_ctx.py +++ b/flashinfer/green_ctx.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,12 +15,10 @@ See the License for the specific language governing permissions and limitations under the License. """ - from typing import List, Tuple import cuda.bindings.driver as driver import cuda.bindings.runtime as runtime -import torch from cuda.bindings.driver import CUdevice, CUdevResource from .cuda_utils import checkCudaErrors @@ -27,18 +27,18 @@ def get_sm_count_constraint(major: int, minor: int) -> Tuple[int, int]: if major == 6: - return (1, 1) + return 1, 1 elif major == 7: - return (2, 2) + return 2, 2 elif major == 8: - return (4, 2) + return 4, 2 elif major >= 9: - return (8, 8) + return 8, 8 else: raise ValueError(f"Unsupported CUDA capability: {major}.{minor}") -def get_cudevice(dev: torch.device) -> CUdevice: +def get_cudevice(dev: str) -> CUdevice: try: cu_dev = checkCudaErrors(driver.cuDeviceGet(dev.index)) except RuntimeError: @@ -56,17 +56,10 @@ def get_device_resource(cu_dev: CUdevice) -> CUdevResource: def split_resource( - resource: CUdevResource, - num_groups: int, - min_count: int, + resource: CUdevResource, num_groups: int, min_count: int ) -> Tuple[CUdevResource, CUdevResource]: results, _, remaining = checkCudaErrors( - driver.cuDevSmResourceSplitByCount( - num_groups, - resource, - 0, # useFlags - min_count, - ) + driver.cuDevSmResourceSplitByCount(num_groups, resource, 0, min_count) ) return results, remaining @@ -78,7 +71,6 @@ def split_resource_by_sm_count( for sm_count in sm_counts: result, remaining = split_resource(resource, 1, sm_count) results.extend(result) - # Refresh the remaining resource for the next iteration desc = checkCudaErrors(driver.cuDevResourceGenerateDesc([remaining], 1)) green_ctx = checkCudaErrors( driver.cuGreenCtxCreate( @@ -90,13 +82,12 @@ def split_resource_by_sm_count( green_ctx, driver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM ) ) - return results, resource def create_green_ctx_streams( cu_dev: CUdevResource, resources: List[CUdevResource] -) -> List[torch.Stream]: +>>>>>>) -> List[torch.Stream]: streams = [] for split in resources: desc = checkCudaErrors(driver.cuDevResourceGenerateDesc([split], 1)) @@ -107,20 +98,17 @@ def create_green_ctx_streams( ) stream = checkCudaErrors( driver.cuGreenCtxStreamCreate( - green_ctx, - driver.CUstream_flags.CU_STREAM_NON_BLOCKING, - 0, # priority + green_ctx, driver.CUstream_flags.CU_STREAM_NON_BLOCKING, 0 ) ) - streams.append(torch.cuda.get_stream_from_external(stream)) - +>>>>>> streams.append(torch.cuda.get_stream_from_external(stream)) return streams def split_device_green_ctx( - dev: torch.device, num_groups: int, min_count: int -) -> Tuple[List[torch.Stream], List[CUdevResource]]: - r""" + dev: str, num_groups: int, min_count: int +>>>>>>) -> Tuple[List[torch.Stream], List[CUdevResource]]: + """ Split the device into multiple `green contexts `_, return the corresponding streams and `CUdevResource` for each group and the remaining SMs. Green contexts allow concurrent execution of multiple kernels on different SM partitions. @@ -173,9 +161,9 @@ def split_device_green_ctx( def split_device_green_ctx_by_sm_count( - dev: torch.device, sm_counts: List[int] -) -> Tuple[List[torch.Stream], List[CUdevResource]]: - r""" + dev: str, sm_counts: List[int] +>>>>>>) -> Tuple[List[torch.Stream], List[CUdevResource]]: + """ Split the device into multiple green contexts, each with a fixed number of SMs, return the corresponding streams and `CUdevResource` for each group and the remaining SMs. Green contexts allow concurrent execution of multiple kernels on different SM partitions. @@ -237,8 +225,6 @@ def split_device_green_ctx_by_sm_count( """ cu_dev = get_cudevice(dev) resource = get_device_resource(cu_dev) - - # Round sm counts to meet the alignment and granularity requirements rounded_sm_counts = [] for sm_count in sm_counts: min_sm_count, sm_alignment = get_sm_count_constraint( @@ -247,8 +233,6 @@ def split_device_green_ctx_by_sm_count( if sm_count <= 0: raise ValueError(f"SM count must be positive, got {sm_count}") rounded_sm_counts.append(round_up(max(sm_count, min_sm_count), sm_alignment)) - - # Split the device into multiple green contexts results, remaining = split_resource_by_sm_count(cu_dev, resource, rounded_sm_counts) resources = results + [remaining] streams = create_green_ctx_streams(cu_dev, resources) diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index faa6660a9f..1367c2db4b 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -1,3 +1,5 @@ +import os + """ Copyright (c) 2024 by FlashInfer team. @@ -13,42 +15,37 @@ See the License for the specific language governing permissions and limitations under the License. """ - import ctypes import functools -import os -# Re-export from . import cubin_loader from . import env as env from .activation import gen_act_and_mul_module as gen_act_and_mul_module from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str from .attention import cudnn_fmha_gen_module as cudnn_fmha_gen_module from .attention import gen_batch_attention_module as gen_batch_attention_module -from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module +from .attention import \ + gen_batch_decode_mla_module as gen_batch_decode_mla_module from .attention import gen_batch_decode_module as gen_batch_decode_module from .attention import gen_batch_mla_module as gen_batch_mla_module from .attention import gen_batch_mla_tvm_binding as gen_batch_mla_tvm_binding from .attention import gen_batch_prefill_module as gen_batch_prefill_module -from .attention import ( - gen_customize_batch_decode_module as gen_customize_batch_decode_module, -) -from .attention import ( - gen_customize_batch_decode_tvm_binding as gen_customize_batch_decode_tvm_binding, -) -from .attention import ( - gen_customize_batch_prefill_module as gen_customize_batch_prefill_module, -) -from .attention import ( - gen_customize_batch_prefill_tvm_binding as gen_customize_batch_prefill_tvm_binding, -) -from .attention import ( - gen_customize_single_decode_module as gen_customize_single_decode_module, -) -from .attention import ( - gen_customize_single_prefill_module as gen_customize_single_prefill_module, -) -from .attention import gen_fmha_cutlass_sm100a_module as gen_fmha_cutlass_sm100a_module +from .attention import \ + gen_customize_batch_decode_module as gen_customize_batch_decode_module +from .attention import \ + gen_customize_batch_decode_tvm_binding as \ + gen_customize_batch_decode_tvm_binding +from .attention import \ + gen_customize_batch_prefill_module as gen_customize_batch_prefill_module +from .attention import \ + gen_customize_batch_prefill_tvm_binding as \ + gen_customize_batch_prefill_tvm_binding +from .attention import \ + gen_customize_single_decode_module as gen_customize_single_decode_module +from .attention import \ + gen_customize_single_prefill_module as gen_customize_single_prefill_module +from .attention import \ + gen_fmha_cutlass_sm100a_module as gen_fmha_cutlass_sm100a_module from .attention import gen_pod_module as gen_pod_module from .attention import gen_sampling_tvm_binding as gen_sampling_tvm_binding from .attention import gen_single_decode_module as gen_single_decode_module diff --git a/flashinfer/jit/activation.py b/flashinfer/jit/activation.py index 4d78616e5c..bb33205ceb 100644 --- a/flashinfer/jit/activation.py +++ b/flashinfer/jit/activation.py @@ -1,3 +1,5 @@ +import os + """ Copyright (c) 2024 by FlashInfer team. @@ -13,16 +15,13 @@ See the License for the specific language governing permissions and limitations under the License. """ - -import os - import jinja2 from . import env as jit_env from .core import JitSpec, gen_jit_spec from .utils import write_if_different -activation_templ = r""" +activation_templ = """ #include #include "pytorch_extension_utils.h" #include @@ -80,11 +79,5 @@ def gen_act_and_mul_module(act_func_name: str, act_func_def: str) -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR os.makedirs(gen_directory, exist_ok=True) sources = [gen_directory / f"{act_func_name}_and_mul.cu"] - write_if_different( - sources[0], - get_act_and_mul_cu_str(act_func_name, act_func_def), - ) - return gen_jit_spec( - f"{act_func_name}_and_mul", - sources, - ) + write_if_different(sources[0], get_act_and_mul_cu_str(act_func_name, act_func_def)) + return gen_jit_spec(f"{act_func_name}_and_mul", sources) diff --git a/flashinfer/jit/attention/__init__.py b/flashinfer/jit/attention/__init__.py index a26b7cee41..511c53bd97 100644 --- a/flashinfer/jit/attention/__init__.py +++ b/flashinfer/jit/attention/__init__.py @@ -13,27 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. """ - from . import pytorch, tvm from .pytorch import cudnn_fmha_gen_module as cudnn_fmha_gen_module from .pytorch import gen_batch_attention_module as gen_batch_attention_module from .pytorch import gen_batch_decode_mla_module as gen_batch_decode_mla_module from .pytorch import gen_batch_decode_module as gen_batch_decode_module from .pytorch import gen_batch_mla_module as gen_batch_mla_module +from .pytorch import \ + gen_batch_prefill_attention_sink_module as \ + gen_batch_prefill_attention_sink_module from .pytorch import gen_batch_prefill_module as gen_batch_prefill_module -from .pytorch import ( - gen_customize_batch_decode_module as gen_customize_batch_decode_module, -) -from .pytorch import ( - gen_customize_batch_prefill_module as gen_customize_batch_prefill_module, -) -from .pytorch import ( - gen_customize_single_decode_module as gen_customize_single_decode_module, -) -from .pytorch import ( - gen_customize_single_prefill_module as gen_customize_single_prefill_module, -) -from .pytorch import gen_fmha_cutlass_sm100a_module as gen_fmha_cutlass_sm100a_module +from .pytorch import \ + gen_customize_batch_decode_module as gen_customize_batch_decode_module +from .pytorch import \ + gen_customize_batch_prefill_module as gen_customize_batch_prefill_module +from .pytorch import \ + gen_customize_single_decode_module as gen_customize_single_decode_module +from .pytorch import \ + gen_customize_single_prefill_module as gen_customize_single_prefill_module +from .pytorch import \ + gen_fmha_cutlass_sm100a_module as gen_fmha_cutlass_sm100a_module from .pytorch import gen_pod_module as gen_pod_module from .pytorch import gen_single_decode_module as gen_single_decode_module from .pytorch import gen_single_prefill_module as gen_single_prefill_module @@ -41,20 +40,19 @@ from .pytorch import get_batch_decode_mla_uri as get_batch_decode_mla_uri from .pytorch import get_batch_decode_uri as get_batch_decode_uri from .pytorch import get_batch_mla_uri as get_batch_mla_uri +from .pytorch import \ + get_batch_prefill_attention_sink_uri as \ + get_batch_prefill_attention_sink_uri from .pytorch import get_batch_prefill_uri as get_batch_prefill_uri from .pytorch import get_pod_uri as get_pod_uri from .pytorch import get_single_decode_uri as get_single_decode_uri from .pytorch import get_single_prefill_uri as get_single_prefill_uri from .pytorch import trtllm_gen_fmha_module as trtllm_gen_fmha_module -from .pytorch import ( - gen_batch_prefill_attention_sink_module as gen_batch_prefill_attention_sink_module, - get_batch_prefill_attention_sink_uri as get_batch_prefill_attention_sink_uri, -) from .tvm import gen_batch_mla_tvm_binding as gen_batch_mla_tvm_binding -from .tvm import ( - gen_customize_batch_decode_tvm_binding as gen_customize_batch_decode_tvm_binding, -) -from .tvm import ( - gen_customize_batch_prefill_tvm_binding as gen_customize_batch_prefill_tvm_binding, -) +from .tvm import \ + gen_customize_batch_decode_tvm_binding as \ + gen_customize_batch_decode_tvm_binding +from .tvm import \ + gen_customize_batch_prefill_tvm_binding as \ + gen_customize_batch_prefill_tvm_binding from .tvm import gen_sampling_tvm_binding as gen_sampling_tvm_binding diff --git a/flashinfer/jit/attention/pytorch.py b/flashinfer/jit/attention/pytorch.py index cf2815a704..d880c25ad6 100644 --- a/flashinfer/jit/attention/pytorch.py +++ b/flashinfer/jit/attention/pytorch.py @@ -1,3 +1,7 @@ +import os + +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,100 +17,69 @@ See the License for the specific language governing permissions and limitations under the License. """ - -import os from typing import List import jinja2 -import torch from ...artifacts import ArtifactPath, MetaInfoHash -from .. import env as jit_env -from ..core import JitSpec, gen_jit_spec, logger, sm90a_nvcc_flags, sm100a_nvcc_flags from ...jit.cubin_loader import get_cubin -from ..utils import ( - dtype_map, - filename_safe_dtype_map, - mask_mode_literal, - pos_encoding_mode_literal, - write_if_different, -) +from .. import env as jit_env +from ..core import (JitSpec, gen_jit_spec, logger, sm90a_nvcc_flags, + sm100a_nvcc_flags) +from ..utils import (dtype_map, filename_safe_dtype_map, mask_mode_literal, + pos_encoding_mode_literal, write_if_different) from .utils import generate_additional_params def get_single_decode_uri( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, ) -> str: - return ( - f"single_decode_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" - f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" - f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" - f"head_dim_qk_{head_dim_qk}_" - f"head_dim_vo_{head_dim_vo}_" - f"posenc_{pos_encoding_mode}_" - f"use_swa_{use_sliding_window}_" - f"use_logits_cap_{use_logits_soft_cap}" - ) + return f"single_decode_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_dtype_kv_{filename_safe_dtype_map[dtype_kv]}_dtype_o_{filename_safe_dtype_map[dtype_o]}_head_dim_qk_{head_dim_qk}_head_dim_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_use_swa_{use_sliding_window}_use_logits_cap_{use_logits_soft_cap}" def get_batch_decode_uri( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, ) -> str: - return ( - f"batch_decode_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" - f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" - f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" - f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" - f"head_dim_qk_{head_dim_qk}_" - f"head_dim_vo_{head_dim_vo}_" - f"posenc_{pos_encoding_mode}_" - f"use_swa_{use_sliding_window}_" - f"use_logits_cap_{use_logits_soft_cap}" - ) + return f"batch_decode_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_dtype_kv_{filename_safe_dtype_map[dtype_kv]}_dtype_o_{filename_safe_dtype_map[dtype_o]}_dtype_idx_{filename_safe_dtype_map[dtype_idx]}_head_dim_qk_{head_dim_qk}_head_dim_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_use_swa_{use_sliding_window}_use_logits_cap_{use_logits_soft_cap}" def get_batch_mla_uri( backend: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_ckv: int, head_dim_kpe: int, use_profiler: bool, ) -> str: return ( - f"batch_mla_attention_dtype_q_{filename_safe_dtype_map[dtype_q]}_" - f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" - f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" - f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" - f"head_dim_ckv_{head_dim_ckv}_" - f"head_dim_kpe_{head_dim_kpe}_" - f"profiler_{use_profiler}" - ) + ("_sm90" if backend == "fa3" else "") + f"batch_mla_attention_dtype_q_{filename_safe_dtype_map[dtype_q]}_dtype_kv_{filename_safe_dtype_map[dtype_kv]}_dtype_o_{filename_safe_dtype_map[dtype_o]}_dtype_idx_{filename_safe_dtype_map[dtype_idx]}_head_dim_ckv_{head_dim_ckv}_head_dim_kpe_{head_dim_kpe}_profiler_{use_profiler}" + + ("_sm90" if backend == "fa3" else "") + ) def gen_batch_mla_module( backend: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_ckv: int, head_dim_kpe: int, use_profiler: bool, @@ -125,7 +98,6 @@ def gen_batch_mla_module( ) gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri os.makedirs(gen_directory, exist_ok=True) - if backend == "fa2": with open(jit_env.FLASHINFER_CSRC_DIR / "batch_mla_config.jinja") as f: config_templ = jinja2.Template(f.read()) @@ -141,7 +113,6 @@ def gen_batch_mla_module( head_dim_kpe=head_dim_kpe, ), ) - source_paths = [] for filename in [ "batch_mla_plan.cu", @@ -183,74 +154,56 @@ def gen_batch_mla_module( write_if_different(dest_path, source) else: raise ValueError(f"Unsupported backend: {backend}") - extra_cuda_cflags = [] if backend == "fa3": extra_cuda_cflags += sm90a_nvcc_flags if use_profiler: extra_cuda_cflags += ["-DFLASHINFER_ENABLE_PROFILER"] - - return gen_jit_spec( - uri, - source_paths, - extra_cuda_cflags=extra_cuda_cflags, - ) + return gen_jit_spec(uri, source_paths, extra_cuda_cflags=extra_cuda_cflags) def get_batch_decode_mla_uri( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_ckv: int, use_sliding_window: bool, use_logits_soft_cap: bool, arc: str, ) -> str: - return ( - f"batch_decode_mla_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" - f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" - f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" - f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" - f"head_dim_ckv{head_dim_ckv}_" - f"use_swa_{use_sliding_window}_" - f"use_logits_cap_{use_logits_soft_cap}_" - f"arc_{arc}" - ) + return f"batch_decode_mla_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_dtype_kv_{filename_safe_dtype_map[dtype_kv]}_dtype_o_{filename_safe_dtype_map[dtype_o]}_dtype_idx_{filename_safe_dtype_map[dtype_idx]}_head_dim_ckv{head_dim_ckv}_use_swa_{use_sliding_window}_use_logits_cap_{use_logits_soft_cap}_arc_{arc}" def gen_batch_decode_mla_module( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim: int, num_qo_heads: int, use_sliding_window: bool, use_logits_soft_cap: bool, use_tensor_cores: bool, ) -> JitSpec: - cuda_arch_major = torch.cuda.get_device_properties(0).major - - if cuda_arch_major >= 9: # smem size of SM90 can accommodate all 128 qo-heads data + cuda_arch_major = paddle.device.cuda.get_device_properties(device="gpu:0").major + if cuda_arch_major >= 9: qo_tile_len = 128 else: qo_tile_len = 64 - if ( use_tensor_cores and cuda_arch_major >= 8 and num_qo_heads % qo_tile_len == 0 - and dtype_q == torch.float16 - and dtype_kv == torch.float16 - and dtype_o == torch.float16 + and dtype_q == "float16" + and dtype_kv == "float16" + and dtype_o == "float16" ): logger.info("Use tensor-core SM80 version of MLA decode kernel.") arc = "sm80" else: logger.info("Fall back to cuda-core version of MLA decode kernel.") arc = "cuda_core" - uri = get_batch_decode_mla_uri( dtype_q, dtype_kv, @@ -263,7 +216,6 @@ def gen_batch_decode_mla_module( ) gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri os.makedirs(gen_directory, exist_ok=True) - with open(jit_env.FLASHINFER_CSRC_DIR / "batch_decode_mla_config.jinja") as f: config_templ = jinja2.Template(f.read()) generated_config_path = gen_directory / "mla_config.inc" @@ -281,20 +233,15 @@ def gen_batch_decode_mla_module( use_logits_soft_cap=str(use_logits_soft_cap).lower(), ), ) - filenames = [] if arc == "sm80": - filenames = [ - "batch_decode_mla_cute_sm80.cu", - "batch_decode_mla_pybind.cu", - ] + filenames = ["batch_decode_mla_cute_sm80.cu", "batch_decode_mla_pybind.cu"] else: filenames = [ "batch_decode_mla_plan.cu", "batch_decode_mla_run.cu", "batch_decode_mla_pybind.cu", ] - source_paths = [] for filename in filenames: src_path = jit_env.FLASHINFER_CSRC_DIR / filename @@ -303,15 +250,14 @@ def gen_batch_decode_mla_module( with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - return gen_jit_spec(uri, source_paths) def get_single_prefill_uri( backend: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, @@ -320,54 +266,34 @@ def get_single_prefill_uri( use_fp16_qk_reduction: bool, ) -> str: return ( - f"single_prefill_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" - f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" - f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" - f"head_dim_qk_{head_dim_qk}_" - f"head_dim_vo_{head_dim_vo}_" - f"posenc_{pos_encoding_mode}_" - f"use_swa_{use_sliding_window}_" - f"use_logits_cap_{use_logits_soft_cap}_" - f"f16qk_{use_fp16_qk_reduction}" + ("_sm90" if backend == "fa3" else "") + f"single_prefill_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_dtype_kv_{filename_safe_dtype_map[dtype_kv]}_dtype_o_{filename_safe_dtype_map[dtype_o]}_head_dim_qk_{head_dim_qk}_head_dim_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_use_swa_{use_sliding_window}_use_logits_cap_{use_logits_soft_cap}_f16qk_{use_fp16_qk_reduction}" + + ("_sm90" if backend == "fa3" else "") ) def get_pod_uri( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, head_dim: int, pos_encoding_mode_p: int, use_sliding_window_p: bool, use_logits_soft_cap_p: bool, use_fp16_qk_reduction: bool, - dtype_idx: torch.dtype, + dtype_idx: paddle.dtype, pos_encoding_mode_d: int, use_sliding_window_d: bool, use_logits_soft_cap_d: bool, ) -> str: - return ( - f"pod_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" - f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" - f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" - f"head_dim_{head_dim}_" - f"posenc_p_{pos_encoding_mode_p}_" - f"use_swa_p_{use_sliding_window_p}_" - f"use_logits_cap_p_{use_logits_soft_cap_p}_" - f"posenc_d_{pos_encoding_mode_d}_" - f"use_swa_d_{use_sliding_window_d}_" - f"use_logits_cap_d_{use_logits_soft_cap_d}_" - f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" - f"f16qk_{use_fp16_qk_reduction}" - ) + return f"pod_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_dtype_kv_{filename_safe_dtype_map[dtype_kv]}_dtype_o_{filename_safe_dtype_map[dtype_o]}_head_dim_{head_dim}_posenc_p_{pos_encoding_mode_p}_use_swa_p_{use_sliding_window_p}_use_logits_cap_p_{use_logits_soft_cap_p}_posenc_d_{pos_encoding_mode_d}_use_swa_d_{use_sliding_window_d}_use_logits_cap_d_{use_logits_soft_cap_d}_dtype_idx_{filename_safe_dtype_map[dtype_idx]}_f16qk_{use_fp16_qk_reduction}" def get_batch_prefill_uri( backend: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, @@ -376,69 +302,46 @@ def get_batch_prefill_uri( use_fp16_qk_reduction: bool, ) -> str: return ( - f"batch_prefill_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" - f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" - f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" - f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" - f"head_dim_qk_{head_dim_qk}_" - f"head_dim_vo_{head_dim_vo}_" - f"posenc_{pos_encoding_mode}_" - f"use_swa_{use_sliding_window}_" - f"use_logits_cap_{use_logits_soft_cap}_" - f"f16qk_{use_fp16_qk_reduction}" + ("_sm90" if backend == "fa3" else "") + f"batch_prefill_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_dtype_kv_{filename_safe_dtype_map[dtype_kv]}_dtype_o_{filename_safe_dtype_map[dtype_o]}_dtype_idx_{filename_safe_dtype_map[dtype_idx]}_head_dim_qk_{head_dim_qk}_head_dim_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_use_swa_{use_sliding_window}_use_logits_cap_{use_logits_soft_cap}_f16qk_{use_fp16_qk_reduction}" + + ("_sm90" if backend == "fa3" else "") ) def get_batch_prefill_attention_sink_uri( backend: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, use_sliding_window: bool, ) -> str: return ( - f"batch_prefill_with_attention_sink_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" - f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" - f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" - f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" - f"head_dim_qk_{head_dim_qk}_" - f"head_dim_vo_{head_dim_vo}_" - f"use_swa_{use_sliding_window}_" + ("_sm90" if backend == "fa3" else "") + f"batch_prefill_with_attention_sink_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_dtype_kv_{filename_safe_dtype_map[dtype_kv]}_dtype_o_{filename_safe_dtype_map[dtype_o]}_dtype_idx_{filename_safe_dtype_map[dtype_idx]}_head_dim_qk_{head_dim_qk}_head_dim_vo_{head_dim_vo}_use_swa_{use_sliding_window}_" + + ("_sm90" if backend == "fa3" else "") ) def get_batch_attention_uri( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, use_logits_soft_cap: bool, use_profiler: bool, ) -> str: - return ( - f"batch_attention_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" - f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" - f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" - f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" - f"head_dim_qk_{head_dim_qk}_" - f"head_dim_vo_{head_dim_vo}_" - f"posenc_{pos_encoding_mode}_" - f"use_logits_soft_cap_{str(use_logits_soft_cap).lower()}_" - f"use_profiler_{str(use_profiler).lower()}" - ) + return f"batch_attention_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_dtype_kv_{filename_safe_dtype_map[dtype_kv]}_dtype_o_{filename_safe_dtype_map[dtype_o]}_dtype_idx_{filename_safe_dtype_map[dtype_idx]}_head_dim_qk_{head_dim_qk}_head_dim_vo_{head_dim_vo}_posenc_{pos_encoding_mode}_use_logits_soft_cap_{str(use_logits_soft_cap).lower()}_use_profiler_{str(use_profiler).lower()}" def gen_single_decode_module( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, @@ -462,17 +365,12 @@ def gen_single_decode_module( dtype_o, head_dim_qk, head_dim_vo, - ["maybe_alibi_slopes"], # additional_tensor_names - ["float"], # additional_tensor_dtypes - [ - "logits_soft_cap", - "sm_scale", - "rope_rcp_scale", - "rope_rcp_theta", - ], # additional_scalar_names - ["double", "double", "double", "double"], # additional_scalar_dtypes - f"DefaultAttention", # variant_name - "#include", # variant_decl + ["maybe_alibi_slopes"], + ["float"], + ["logits_soft_cap", "sm_scale", "rope_rcp_scale", "rope_rcp_theta"], + ["double", "double", "double", "double"], + f"DefaultAttention", + "#include", pos_encoding_mode=pos_encoding_mode, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, @@ -481,9 +379,9 @@ def gen_single_decode_module( def gen_single_prefill_module( backend: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, @@ -503,12 +401,7 @@ def gen_single_prefill_module( use_logits_soft_cap, use_fp16_qk_reduction, ) - - # use `fp8_enabled` flag to use separate kernel template - # this is used for fp8 tensor core computation - # KV-only quant is not influenced by this flag - fp8_enabled = dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2] - + fp8_enabled = dtype_q in [paddle.float8_e4m3fn, paddle.float8_e5m2] if backend == "fa2": assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend" additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes"] @@ -522,22 +415,20 @@ def gen_single_prefill_module( additional_scalar_dtypes = ["double", "double", "double", "double"] variant_name = f"DefaultAttention" variant_decl = "#include" + elif not fp8_enabled: + additional_tensor_names = [] + additional_tensor_dtypes = [] + additional_scalar_names = ["logits_soft_cap", "sm_scale"] + additional_scalar_dtypes = ["double", "double"] + variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>" + variant_decl = "#include" else: - if not fp8_enabled: - additional_tensor_names = [] - additional_tensor_dtypes = [] - additional_scalar_names = ["logits_soft_cap", "sm_scale"] - additional_scalar_dtypes = ["double", "double"] - variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>" - variant_decl = "#include" - else: - additional_tensor_names = ["scale_q", "scale_k", "scale_v"] - additional_tensor_dtypes = ["float", "float", "float"] - additional_scalar_names = ["sm_scale"] - additional_scalar_dtypes = ["double"] - variant_name = "DefaultFP8Attention" - variant_decl = "#include" - + additional_tensor_names = ["scale_q", "scale_k", "scale_v"] + additional_tensor_dtypes = ["float", "float", "float"] + additional_scalar_names = ["sm_scale"] + additional_scalar_dtypes = ["double"] + variant_name = "DefaultFP8Attention" + variant_decl = "#include" return gen_customize_single_prefill_module( backend, uri, @@ -561,15 +452,15 @@ def gen_single_prefill_module( def gen_pod_module( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, head_dim: int, pos_encoding_mode_p: int, use_sliding_window_p: bool, use_logits_soft_cap_p: bool, use_fp16_qk_reduction: bool, - dtype_idx: torch.dtype, + dtype_idx: paddle.dtype, pos_encoding_mode_d: int, use_sliding_window_d: bool, use_logits_soft_cap_d: bool, @@ -600,7 +491,6 @@ def gen_pod_module( variant_name_p = f"DefaultAttention" variant_name_d = f"DefaultAttention" variant_decl = "#include" - return gen_customize_pod_module( uri, dtype_q, @@ -627,10 +517,10 @@ def gen_pod_module( def gen_customize_pod_module( uri: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim: int, additional_tensor_names: List[str], additional_tensor_dtypes: List[str], @@ -648,7 +538,6 @@ def gen_customize_pod_module( use_fp16_qk_reduction: bool = False, ) -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri - ( additional_params_decl, additional_func_params, @@ -659,13 +548,10 @@ def gen_customize_pod_module( additional_scalar_names, additional_scalar_dtypes, ) - with open(jit_env.FLASHINFER_CSRC_DIR / "pod_customize_config.jinja") as f: config_templ = jinja2.Template(f.read()) - with open(jit_env.FLASHINFER_CSRC_DIR / "pod_kernel_inst.jinja") as f: kernel_inst_templ = jinja2.Template(f.read()) - kwargs = { "additional_func_params": additional_func_params, "additional_params_decl": additional_params_decl, @@ -687,49 +573,35 @@ def gen_customize_pod_module( "use_logits_soft_cap_d": str(use_logits_soft_cap_d).lower(), "use_fp16_qk_reduction": str(use_fp16_qk_reduction).lower(), } - - generated_inc_str = config_templ.render( - **kwargs, - ) - + generated_inc_str = config_templ.render(**kwargs) os.makedirs(gen_directory, exist_ok=True) - source_paths = [] - for mask_mode_p in [0, 1, 2, 3]: for mask_mode_d in [0, 1, 2, 3]: kwargs["mask_mode_p"] = mask_mode_literal[mask_mode_p] kwargs["mask_mode_d"] = mask_mode_literal[mask_mode_d] - filename = f"pod_kernel_mask_{mask_mode_p}p_{mask_mode_d}d.cu" dest_path = gen_directory / filename source_paths.append(dest_path) - source = kernel_inst_templ.render( - **kwargs, - ) + source = kernel_inst_templ.render(**kwargs) write_if_different(dest_path, source) - - for filename in [ - "pod.cu", - "pod_jit_pybind.cu", - ]: + for filename in ["pod.cu", "pod_jit_pybind.cu"]: src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - generated_config_path = gen_directory / "pod_config.inc" write_if_different(generated_config_path, generated_inc_str) return gen_jit_spec(uri, source_paths) def gen_batch_decode_module( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, @@ -755,17 +627,12 @@ def gen_batch_decode_module( dtype_idx, head_dim_qk, head_dim_vo, - ["maybe_alibi_slopes"], # additional_tensor_names - ["float"], # additional_tensor_dtypes - [ - "logits_soft_cap", - "sm_scale", - "rope_rcp_scale", - "rope_rcp_theta", - ], # additional_scalar_names - ["double", "double", "double", "double"], # additional_scalar_dtypes - f"DefaultAttention", # variant_name - "#include", # variant_decl + ["maybe_alibi_slopes"], + ["float"], + ["logits_soft_cap", "sm_scale", "rope_rcp_scale", "rope_rcp_theta"], + ["double", "double", "double", "double"], + f"DefaultAttention", + "#include", pos_encoding_mode=pos_encoding_mode, use_sliding_window=use_sliding_window, use_logits_soft_cap=use_logits_soft_cap, @@ -774,10 +641,10 @@ def gen_batch_decode_module( def gen_batch_prefill_module( backend: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, @@ -798,12 +665,7 @@ def gen_batch_prefill_module( use_logits_soft_cap, use_fp16_qk_reduction, ) - - # use `fp8_enabled` flag to use separate kernel template - # this is used for fp8 tensor core computation - # KV-only quant is not influenced by this flag - fp8_enabled = dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2] - + fp8_enabled = dtype_q in [paddle.float8_e4m3fn, paddle.float8_e5m2] if backend == "fa2": assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend" additional_tensor_names = [ @@ -821,7 +683,7 @@ def gen_batch_prefill_module( "uint32_t", "uint16_t", "uint16_t", - ] # NOTE(Zihao): int32_t should follow dtype_idx + ] additional_scalar_names = [ "logits_soft_cap", "sm_scale", @@ -832,30 +694,28 @@ def gen_batch_prefill_module( additional_scalar_dtypes = ["double", "double", "double", "double", "int64_t"] variant_name = f"DefaultAttention" variant_decl = "#include" + elif not fp8_enabled: + additional_tensor_names = [ + "maybe_prefix_len_ptr", + "maybe_token_pos_in_items_ptr", + "maybe_max_item_len_ptr", + ] + additional_tensor_dtypes = ["uint32_t", "uint16_t", "uint16_t"] + additional_scalar_names = [ + "logits_soft_cap", + "sm_scale", + "token_pos_in_items_len", + ] + additional_scalar_dtypes = ["double", "double", "int64_t"] + variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>" + variant_decl = "#include" else: - if not fp8_enabled: - additional_tensor_names = [ - "maybe_prefix_len_ptr", - "maybe_token_pos_in_items_ptr", - "maybe_max_item_len_ptr", - ] - additional_tensor_dtypes = ["uint32_t", "uint16_t", "uint16_t"] - additional_scalar_names = [ - "logits_soft_cap", - "sm_scale", - "token_pos_in_items_len", - ] - additional_scalar_dtypes = ["double", "double", "int64_t"] - variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>" - variant_decl = "#include" - else: - additional_tensor_names = ["scale_q", "scale_k", "scale_v"] - additional_tensor_dtypes = ["float", "float", "float"] - additional_scalar_names = ["sm_scale"] - additional_scalar_dtypes = ["double"] - variant_name = "DefaultFP8Attention" - variant_decl = "#include" - + additional_tensor_names = ["scale_q", "scale_k", "scale_v"] + additional_tensor_dtypes = ["float", "float", "float"] + additional_scalar_names = ["sm_scale"] + additional_scalar_dtypes = ["double"] + variant_name = "DefaultFP8Attention" + variant_decl = "#include" return gen_customize_batch_prefill_module( backend, uri, @@ -881,10 +741,10 @@ def gen_batch_prefill_module( def gen_batch_prefill_attention_sink_module( backend: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, @@ -903,7 +763,6 @@ def gen_batch_prefill_attention_sink_module( pos_encoding_mode, use_sliding_window, ) - return gen_customize_batch_prefill_module( backend, uri, @@ -928,10 +787,10 @@ def gen_batch_prefill_attention_sink_module( def gen_batch_attention_module( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, @@ -949,14 +808,12 @@ def gen_batch_attention_module( use_logits_soft_cap, use_profiler, ) - additional_tensor_names: List[str] = [] additional_tensor_dtypes: List[str] = [] additional_scalar_names: List[str] = [] additional_scalar_dtypes: List[str] = [] variant_name = f"StandardAttention<{str(use_logits_soft_cap).lower()}>" variant_decl = "#include" - return gen_customize_batch_attention_module( uri, dtype_q, @@ -979,9 +836,9 @@ def gen_batch_attention_module( def gen_customize_single_decode_module( uri: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, head_dim_qk: int, head_dim_vo: int, additional_tensor_names: List[str], @@ -995,7 +852,6 @@ def gen_customize_single_decode_module( use_logits_soft_cap: bool = False, ) -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri - ( additional_params_decl, additional_func_params, @@ -1006,15 +862,12 @@ def gen_customize_single_decode_module( additional_scalar_names, additional_scalar_dtypes, ) - with open( jit_env.FLASHINFER_CSRC_DIR / "single_decode_customize_config.jinja" ) as f: config_templ = jinja2.Template(f.read()) - with open(jit_env.FLASHINFER_CSRC_DIR / "single_decode_kernel_inst.jinja") as f: kernel_inst_templ = jinja2.Template(f.read()) - kwargs = { "additional_func_params": additional_func_params, "additional_params_decl": additional_params_decl, @@ -1030,45 +883,31 @@ def gen_customize_single_decode_module( "use_sliding_window": str(use_sliding_window).lower(), "use_logits_soft_cap": str(use_logits_soft_cap).lower(), } - - generated_inc_str = config_templ.render( - **kwargs, - ) - + generated_inc_str = config_templ.render(**kwargs) os.makedirs(gen_directory, exist_ok=True) - source_paths = [] - dest_path = gen_directory / "single_decode_kernel.cu" source_paths.append(dest_path) - source = kernel_inst_templ.render( - **kwargs, - ) + source = kernel_inst_templ.render(**kwargs) write_if_different(dest_path, source) - - for filename in [ - "single_decode.cu", - "single_decode_jit_pybind.cu", - ]: + for filename in ["single_decode.cu", "single_decode_jit_pybind.cu"]: src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - generated_config_path = gen_directory / "single_decode_config.inc" write_if_different(generated_config_path, generated_inc_str) - return gen_jit_spec(uri, source_paths) def gen_customize_single_prefill_module( backend: str, uri: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, head_dim_qk: int, head_dim_vo: int, additional_tensor_names: List[str], @@ -1100,75 +939,63 @@ def gen_customize_single_prefill_module( raise ValueError("backend should not be auto when jit_args is provided") elif backend == "fa2": gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri - additional_params_decl, additional_func_params, additional_params_setter = ( - generate_additional_params( - additional_tensor_names, - additional_tensor_dtypes, - additional_scalar_names, - additional_scalar_dtypes, - ) + ( + additional_params_decl, + additional_func_params, + additional_params_setter, + ) = generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, ) - with open( jit_env.FLASHINFER_CSRC_DIR / "single_prefill_customize_config.jinja" ) as f: config_templ = jinja2.Template(f.read()) - with open( jit_env.FLASHINFER_CSRC_DIR / "single_prefill_kernel_inst.jinja" ) as f: kernel_inst_templ = jinja2.Template(f.read()) - kwargs |= { "additional_func_params": additional_func_params, "additional_params_decl": additional_params_decl, "additional_params_setter": additional_params_setter, } - - generated_inc_str = config_templ.render( - **kwargs, - ) + generated_inc_str = config_templ.render(**kwargs) os.makedirs(gen_directory, exist_ok=True) - source_paths = [] for mask_mode in [0, 1, 2, 3]: filename = f"single_prefill_kernel_mask_{mask_mode}.cu" dest_path = gen_directory / filename source_paths.append(dest_path) source = kernel_inst_templ.render( - mask_mode=mask_mode_literal[mask_mode], - **kwargs, + mask_mode=mask_mode_literal[mask_mode], **kwargs ) write_if_different(dest_path, source) - - for filename in [ - "single_prefill.cu", - "single_prefill_jit_pybind.cu", - ]: + for filename in ["single_prefill.cu", "single_prefill_jit_pybind.cu"]: src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - generated_config_path = gen_directory / "single_prefill_config.inc" write_if_different(generated_config_path, generated_inc_str) - return gen_jit_spec(uri, source_paths) elif backend == "fa3": gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri - - (additional_params_decl, additional_func_params, additional_params_setter) = ( - generate_additional_params( - additional_tensor_names, - additional_tensor_dtypes, - additional_scalar_names, - additional_scalar_dtypes, - is_sm90_template=True, - ) + ( + additional_params_decl, + additional_func_params, + additional_params_setter, + ) = generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, + is_sm90_template=True, ) - _file_config = "single_prefill_sm90_customize_config.jinja" if fp8_enabled: _file_kernel_inst = "single_prefill_fp8_sm90_kernel_inst.jinja" @@ -1176,63 +1003,46 @@ def gen_customize_single_prefill_module( else: _file_kernel_inst = "single_prefill_sm90_kernel_inst.jinja" _file_csrc = "single_prefill_sm90.cu" - with open(jit_env.FLASHINFER_CSRC_DIR / _file_config) as f: config_templ = jinja2.Template(f.read()) - with open(jit_env.FLASHINFER_CSRC_DIR / _file_kernel_inst) as f: kernel_inst_templ = jinja2.Template(f.read()) - kwargs |= { "additional_func_params": additional_func_params, "additional_params_decl": additional_params_decl, "additional_params_setter": additional_params_setter, } - - generated_inc_str = config_templ.render( - **kwargs, - ) + generated_inc_str = config_templ.render(**kwargs) os.makedirs(gen_directory, exist_ok=True) - source_paths = [] for mask_mode in [0, 1, 2, 3]: filename = f"single_prefill_sm90_kernel_mask_{mask_mode}.cu" dest_path = gen_directory / filename source_paths.append(dest_path) source = kernel_inst_templ.render( - mask_mode=mask_mode_literal[mask_mode], - **kwargs, + mask_mode=mask_mode_literal[mask_mode], **kwargs ) write_if_different(dest_path, source) - - for filename in [ - _file_csrc, - "single_prefill_sm90_jit_pybind.cu", - ]: + for filename in [_file_csrc, "single_prefill_sm90_jit_pybind.cu"]: src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - generated_config_path = gen_directory / "single_prefill_sm90_config.inc" write_if_different(generated_config_path, generated_inc_str) - return gen_jit_spec( - uri, - source_paths, - extra_cuda_cflags=sm90a_nvcc_flags, - ) + return gen_jit_spec(uri, source_paths, extra_cuda_cflags=sm90a_nvcc_flags) else: raise ValueError(f"Invalid backend: {backend}") def gen_customize_batch_decode_module( uri: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - idtype: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + idtype: paddle.dtype, head_dim_qk: int, head_dim_vo: int, additional_tensor_names: List[str], @@ -1246,15 +1056,16 @@ def gen_customize_batch_decode_module( use_logits_soft_cap: bool = False, ) -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri - (additional_params_decl, additional_func_params, additional_params_setter) = ( - generate_additional_params( - additional_tensor_names, - additional_tensor_dtypes, - additional_scalar_names, - additional_scalar_dtypes, - ) + ( + additional_params_decl, + additional_func_params, + additional_params_setter, + ) = generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, ) - kwargs = { "additional_params_decl": additional_params_decl, "additional_func_params": additional_func_params, @@ -1271,37 +1082,23 @@ def gen_customize_batch_decode_module( "use_sliding_window": str(use_sliding_window).lower(), "use_logits_soft_cap": str(use_logits_soft_cap).lower(), } - with open(jit_env.FLASHINFER_CSRC_DIR / "batch_decode_customize_config.jinja") as f: config_templ = jinja2.Template(f.read()) - with open(jit_env.FLASHINFER_CSRC_DIR / "batch_decode_kernel_inst.jinja") as f: kernel_inst_templ = jinja2.Template(f.read()) - - generated_inc_str = config_templ.render( - **kwargs, - ) - + generated_inc_str = config_templ.render(**kwargs) source_paths = [] - dest_path = gen_directory / "batch_decode_kernel.cu" source_paths.append(dest_path) - source = kernel_inst_templ.render( - **kwargs, - ) + source = kernel_inst_templ.render(**kwargs) write_if_different(dest_path, source) - - for filename in [ - "batch_decode.cu", - "batch_decode_jit_pybind.cu", - ]: + for filename in ["batch_decode.cu", "batch_decode_jit_pybind.cu"]: src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - generated_config_path = gen_directory / "batch_decode_config.inc" write_if_different(generated_config_path, generated_inc_str) return gen_jit_spec(uri, source_paths) @@ -1310,10 +1107,10 @@ def gen_customize_batch_decode_module( def gen_customize_batch_prefill_module( backend: str, uri: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - idtype: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + idtype: paddle.dtype, head_dim_qk: int, head_dim_vo: int, additional_tensor_names: List[str], @@ -1346,41 +1143,35 @@ def gen_customize_batch_prefill_module( raise ValueError("backend should not be auto when jit_args is provided") elif backend == "fa2": gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri - (additional_params_decl, additional_func_params, additional_params_setter) = ( - generate_additional_params( - additional_tensor_names, - additional_tensor_dtypes, - additional_scalar_names, - additional_scalar_dtypes, - ) + ( + additional_params_decl, + additional_func_params, + additional_params_setter, + ) = generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, ) - with open( jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_customize_config.jinja" ) as f: config_templ = jinja2.Template(f.read()) - with open( jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_paged_kernel_inst.jinja" ) as f: paged_kernel_inst_templ = jinja2.Template(f.read()) - with open( jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_ragged_kernel_inst.jinja" ) as f: ragged_kernel_inst_templ = jinja2.Template(f.read()) - kwargs |= { "additional_params_decl": additional_params_decl, "additional_func_params": additional_func_params, "additional_params_setter": additional_params_setter, } - - generated_inc_str = config_templ.render( - **kwargs, - ) + generated_inc_str = config_templ.render(**kwargs) os.makedirs(gen_directory, exist_ok=True) - source_paths = [] for mask_mode in [0, 1, 2, 3]: dest_path = ( @@ -1388,47 +1179,40 @@ def gen_customize_batch_prefill_module( ) source_paths.append(dest_path) source = paged_kernel_inst_templ.render( - mask_mode=mask_mode_literal[mask_mode], - **kwargs, + mask_mode=mask_mode_literal[mask_mode], **kwargs ) write_if_different(dest_path, source) - dest_path = ( gen_directory / f"batch_prefill_ragged_kernel_mask_{mask_mode}.cu" ) source_paths.append(dest_path) source = ragged_kernel_inst_templ.render( - mask_mode=mask_mode_literal[mask_mode], - **kwargs, + mask_mode=mask_mode_literal[mask_mode], **kwargs ) write_if_different(dest_path, source) - - for filename in [ - "batch_prefill.cu", - "batch_prefill_jit_pybind.cu", - ]: + for filename in ["batch_prefill.cu", "batch_prefill_jit_pybind.cu"]: src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - generated_config_path = gen_directory / "batch_prefill_config.inc" write_if_different(generated_config_path, generated_inc_str) return gen_jit_spec(uri, source_paths) elif backend == "fa3": gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri - (additional_params_decl, additional_func_params, additional_params_setter) = ( - generate_additional_params( - additional_tensor_names, - additional_tensor_dtypes, - additional_scalar_names, - additional_scalar_dtypes, - is_sm90_template=True, - ) + ( + additional_params_decl, + additional_func_params, + additional_params_setter, + ) = generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, + is_sm90_template=True, ) - _file_config = "batch_prefill_sm90_customize_config.jinja" if fp8_enabled: _file_paged_kernel_inst = "batch_prefill_fp8_paged_sm90_kernel_inst.jinja" @@ -1438,96 +1222,67 @@ def gen_customize_batch_prefill_module( _file_paged_kernel_inst = "batch_prefill_paged_sm90_kernel_inst.jinja" _file_ragged_kernel_inst = "batch_prefill_ragged_sm90_kernel_inst.jinja" _file_csrc = "batch_prefill_sm90.cu" - with open(jit_env.FLASHINFER_CSRC_DIR / _file_config) as f: config_templ = jinja2.Template(f.read()) - with open(jit_env.FLASHINFER_CSRC_DIR / _file_paged_kernel_inst) as f: paged_kernel_inst_templ = jinja2.Template(f.read()) - with open(jit_env.FLASHINFER_CSRC_DIR / _file_ragged_kernel_inst) as f: ragged_kernel_inst_templ = jinja2.Template(f.read()) - kwargs |= { "additional_params_decl": additional_params_decl, "additional_func_params": additional_func_params, "additional_params_setter": additional_params_setter, } generated_inc_str = config_templ.render(**kwargs) - source_paths = [] for mask_mode in [0, 1, 2, 3]: filename = f"batch_prefill_paged_sm90_kernel_mask_{mask_mode}.cu" dest_path = gen_directory / filename source_paths.append(dest_path) source = paged_kernel_inst_templ.render( - mask_mode=mask_mode_literal[mask_mode], - **kwargs, + mask_mode=mask_mode_literal[mask_mode], **kwargs ) write_if_different(dest_path, source) - filename = f"batch_prefill_ragged_sm90_kernel_mask_{mask_mode}.cu" dest_path = gen_directory / filename source_paths.append(dest_path) source = ragged_kernel_inst_templ.render( - mask_mode=mask_mode_literal[mask_mode], - **kwargs, + mask_mode=mask_mode_literal[mask_mode], **kwargs ) write_if_different(dest_path, source) - - for filename in [ - _file_csrc, - "batch_prefill_sm90_jit_pybind.cu", - ]: + for filename in [_file_csrc, "batch_prefill_sm90_jit_pybind.cu"]: src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - generated_config_path = gen_directory / "batch_prefill_sm90_config.inc" write_if_different(generated_config_path, generated_inc_str) - return gen_jit_spec( - uri, - source_paths, - extra_cuda_cflags=sm90a_nvcc_flags, - ) + return gen_jit_spec(uri, source_paths, extra_cuda_cflags=sm90a_nvcc_flags) else: raise ValueError(f"Invalid backend: {backend}") def get_fmha_cutlass_sm100a_uri( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, use_sliding_window: bool, use_logits_soft_cap: bool, ) -> str: - # NOTE(Zihao): use different uri after when support customize attention return "fmha_cutlass_sm100a" - # return ( - # f"fmha_cutlass_sm100a_dtype_q_{filename_safe_dtype_map[dtype_q]}_" - # f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" - # f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" - # f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" - # f"head_dim_qk_{head_dim_qk}_" - # f"head_dim_vo_{head_dim_vo}_" - # f"posenc_{pos_encoding_mode}_" - # f"use_swa_{use_sliding_window}_" - # f"use_logits_cap_{use_logits_soft_cap}" - # ) def gen_fmha_cutlass_sm100a_module( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, @@ -1545,38 +1300,25 @@ def gen_fmha_cutlass_sm100a_module( use_sliding_window, use_logits_soft_cap, ) - source_paths = [ jit_env.FLASHINFER_CSRC_DIR / "fmha_cutlass_sm100.cu", jit_env.FLASHINFER_CSRC_DIR / "fmha_cutlass_sm100_pybind.cu", jit_env.FLASHINFER_CSRC_DIR / "blackwell_fmha_plan.cu", ] - return gen_jit_spec( - uri, - source_paths, - extra_cuda_cflags=sm100a_nvcc_flags, - ) + return gen_jit_spec(uri, source_paths, extra_cuda_cflags=sm100a_nvcc_flags) def trtllm_gen_fmha_module(): include_path = f"{ArtifactPath.TRTLLM_GEN_FMHA}/include" header_name = "flashInferMetaInfo" - - # use `get_cubin` to get "flashinferMetaInfo.h" metainfo = get_cubin( f"{include_path}/{header_name}", MetaInfoHash.TRTLLM_GEN_FMHA, ".h" ) - - # make sure "flashinferMetaInfo.h" is downloaded or cached assert metainfo, f"{header_name}.h not found" - return gen_jit_spec( "fmha_gen", - [ - jit_env.FLASHINFER_CSRC_DIR / "trtllm_fmha_kernel_launcher.cu", - ], + [jit_env.FLASHINFER_CSRC_DIR / "trtllm_fmha_kernel_launcher.cu"], extra_ldflags=["-lcuda"], - # link "include" sub-directory in cache extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path], extra_cuda_cflags=[ f'-DTLLM_GEN_FMHA_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_FMHA}\\"', @@ -1587,10 +1329,10 @@ def trtllm_gen_fmha_module(): def gen_customize_batch_attention_module( uri: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - idtype: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + idtype: paddle.dtype, head_dim_qk: int, head_dim_vo: int, additional_tensor_names: List[str], @@ -1616,59 +1358,48 @@ def gen_customize_batch_attention_module( "use_logits_soft_cap": str(use_logits_soft_cap).lower(), } gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri - (additional_params_decl, additional_func_params, additional_params_setter) = ( - generate_additional_params( - additional_tensor_names, - additional_tensor_dtypes, - additional_scalar_names, - additional_scalar_dtypes, - ) + ( + additional_params_decl, + additional_func_params, + additional_params_setter, + ) = generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, ) with open( jit_env.FLASHINFER_CSRC_DIR / "batch_attention_customize_config.jinja" ) as f: config_templ = jinja2.Template(f.read()) - with open( jit_env.FLASHINFER_CSRC_DIR / "batch_attention_paged_kernel_inst.jinja" ) as f: paged_kernel_inst_templ = jinja2.Template(f.read()) - kwargs |= { "additional_params_decl": additional_params_decl, "additional_func_params": additional_func_params, "additional_params_setter": additional_params_setter, } - - generated_inc_str = config_templ.render( - **kwargs, - ) + generated_inc_str = config_templ.render(**kwargs) os.makedirs(gen_directory, exist_ok=True) - source_paths = [] for mask_mode in [0, 1, 2, 3]: dest_path = gen_directory / f"batch_attention_paged_kernel_mask_{mask_mode}.cu" source_paths.append(dest_path) source = paged_kernel_inst_templ.render( - mask_mode=mask_mode_literal[mask_mode], - **kwargs, + mask_mode=mask_mode_literal[mask_mode], **kwargs ) write_if_different(dest_path, source) - - for filename in [ - "batch_attention.cu", - "batch_attention_jit_pybind.cu", - ]: + for filename in ["batch_attention.cu", "batch_attention_jit_pybind.cu"]: src_path = jit_env.FLASHINFER_CSRC_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - generated_config_path = gen_directory / "batch_attention_config.inc" write_if_different(generated_config_path, generated_inc_str) - return gen_jit_spec( uri, source_paths, @@ -1681,7 +1412,5 @@ def cudnn_fmha_gen_module(): "fmha_cudnn_gen", [jit_env.FLASHINFER_CSRC_DIR / "cudnn_sdpa_kernel_launcher.cu"], extra_ldflags=["-lcuda"], - extra_cuda_cflags=[ - f'-DCUDNN_SDPA_CUBIN_PATH=\\"{ArtifactPath.CUDNN_SDPA}\\"', - ], + extra_cuda_cflags=[f'-DCUDNN_SDPA_CUBIN_PATH=\\"{ArtifactPath.CUDNN_SDPA}\\"'], ) diff --git a/flashinfer/jit/attention/tvm.py b/flashinfer/jit/attention/tvm.py index b52a15a989..a971de72f5 100644 --- a/flashinfer/jit/attention/tvm.py +++ b/flashinfer/jit/attention/tvm.py @@ -1,3 +1,7 @@ +import os + +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,28 +17,20 @@ See the License for the specific language governing permissions and limitations under the License. """ - import itertools -import os from typing import List import jinja2 -import torch from .. import env as jit_env -from ..utils import ( - dtype_map, - mask_mode_literal, - pos_encoding_mode_literal, - write_if_different, -) +from ..utils import (dtype_map, mask_mode_literal, pos_encoding_mode_literal, + write_if_different) from .utils import generate_additional_params def gen_sampling_tvm_binding(uri: str): gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri os.makedirs(gen_directory, exist_ok=True) - source_paths = [] for filename in ["sampling.cu", "sampling_jit_tvm_binding.cu"]: src_path = jit_env.FLASHINFER_TVM_BINDING_DIR / filename @@ -43,17 +39,16 @@ def gen_sampling_tvm_binding(uri: str): with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - return uri, source_paths def gen_customize_batch_prefill_tvm_binding( backend: str, uri: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - idtype: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + idtype: paddle.dtype, head_dim_qk: int, head_dim_vo: int, additional_tensor_names: List[str], @@ -81,46 +76,40 @@ def gen_customize_batch_prefill_tvm_binding( "use_fp16_qk_reduction": str(use_fp16_qk_reduction).lower(), } if backend == "fa3": - # NOTE: fa3 backend is not supported for now, which will be resolved in the near future. raise ValueError("TVM binding does not support fa3 backend for now.") - if backend == "auto": raise ValueError("backend should not be auto when jit_args is provided") elif backend == "fa2": gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri - (additional_params_decl, additional_func_params, additional_params_setter) = ( - generate_additional_params( - additional_tensor_names, - additional_tensor_dtypes, - additional_scalar_names, - additional_scalar_dtypes, - ) + ( + additional_params_decl, + additional_func_params, + additional_params_setter, + ) = generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, ) - with open( jit_env.FLASHINFER_TVM_BINDING_DIR / "batch_prefill_customize_config.jinja" ) as f: config_templ = jinja2.Template(f.read()) - with open( jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_paged_kernel_inst.jinja" ) as f: paged_kernel_inst_templ = jinja2.Template(f.read()) - with open( jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_ragged_kernel_inst.jinja" ) as f: ragged_kernel_inst_templ = jinja2.Template(f.read()) - kwargs |= { "additional_params_decl": additional_params_decl, "additional_func_params": additional_func_params, "additional_params_setter": additional_params_setter, } - generated_inc_str = config_templ.render(**kwargs) os.makedirs(gen_directory, exist_ok=True) - source_paths = [] pos_encoding_modes = [0] if enable_inline_rope: @@ -129,8 +118,8 @@ def gen_customize_batch_prefill_tvm_binding( [0, 1], pos_encoding_modes ): dest_path = ( - gen_directory / f"batch_prefill_paged_kernel_mask_{mask_mode}_" - f"pos_encoding_{pos_encoding_mode}.cu" + gen_directory + / f"batch_prefill_paged_kernel_mask_{mask_mode}_pos_encoding_{pos_encoding_mode}.cu" ) source_paths.append(dest_path) source = paged_kernel_inst_templ.render( @@ -139,10 +128,9 @@ def gen_customize_batch_prefill_tvm_binding( **kwargs, ) write_if_different(dest_path, source) - dest_path = ( - gen_directory / f"batch_prefill_ragged_kernel_mask_{mask_mode}_" - f"pos_encoding_{pos_encoding_mode}.cu" + gen_directory + / f"batch_prefill_ragged_kernel_mask_{mask_mode}_pos_encoding_{pos_encoding_mode}.cu" ) source_paths.append(dest_path) source = ragged_kernel_inst_templ.render( @@ -151,62 +139,51 @@ def gen_customize_batch_prefill_tvm_binding( **kwargs, ) write_if_different(dest_path, source) - - for filename in [ - "batch_prefill.cu", - "batch_prefill_jit_tvm_binding.cu", - ]: + for filename in ["batch_prefill.cu", "batch_prefill_jit_tvm_binding.cu"]: src_path = jit_env.FLASHINFER_TVM_BINDING_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - generated_config_path = gen_directory / "batch_prefill_config.inc" write_if_different(generated_config_path, generated_inc_str) return uri, source_paths elif backend == "fa3": gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri - (additional_params_decl, additional_func_params, additional_params_setter) = ( - generate_additional_params( - additional_tensor_names, - additional_tensor_dtypes, - additional_scalar_names, - additional_scalar_dtypes, - is_sm90_template=True, - ) + ( + additional_params_decl, + additional_func_params, + additional_params_setter, + ) = generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, + is_sm90_template=True, ) - with open( jit_env.FLASHINFER_TVM_BINDING_DIR / "batch_prefill_sm90_customize_config.jinja" ) as f: config_templ = jinja2.Template(f.read()) - with open( jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_paged_sm90_kernel_inst.jinja" ) as f: paged_kernel_inst_templ = jinja2.Template(f.read()) - with open( jit_env.FLASHINFER_CSRC_DIR / "batch_prefill_ragged_sm90_kernel_inst.jinja" ) as f: ragged_kernel_inst_templ = jinja2.Template(f.read()) - kwargs |= { "additional_params_decl": additional_params_decl, "additional_func_params": additional_func_params, "additional_params_setter": additional_params_setter, } generated_inc_str = config_templ.render(**kwargs) - source_paths = [] for mask_mode, pos_encoding_mode in itertools.product([0, 1], [0, 1]): - filename = ( - f"batch_prefill_paged_sm90_kernel_mask_{mask_mode}_" - f"pos_encoding_{pos_encoding_mode}.cu" - ) + filename = f"batch_prefill_paged_sm90_kernel_mask_{mask_mode}_pos_encoding_{pos_encoding_mode}.cu" dest_path = gen_directory / filename source_paths.append(dest_path) source = paged_kernel_inst_templ.render( @@ -215,11 +192,7 @@ def gen_customize_batch_prefill_tvm_binding( **kwargs, ) write_if_different(dest_path, source) - - filename = ( - f"batch_prefill_ragged_sm90_kernel_mask_{mask_mode}_" - f"pos_encoding_{pos_encoding_mode}.cu" - ) + filename = f"batch_prefill_ragged_sm90_kernel_mask_{mask_mode}_pos_encoding_{pos_encoding_mode}.cu" dest_path = gen_directory / filename source_paths.append(dest_path) source = ragged_kernel_inst_templ.render( @@ -228,7 +201,6 @@ def gen_customize_batch_prefill_tvm_binding( **kwargs, ) write_if_different(dest_path, source) - for filename in [ "batch_prefill_sm90.cu", "batch_prefill_sm90_jit_tvm_binding.cu", @@ -239,7 +211,6 @@ def gen_customize_batch_prefill_tvm_binding( with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - generated_config_path = gen_directory / "batch_prefill_sm90_config.inc" write_if_different(generated_config_path, generated_inc_str) return uri, source_paths @@ -249,10 +220,10 @@ def gen_customize_batch_prefill_tvm_binding( def gen_customize_batch_decode_tvm_binding( uri: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - idtype: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + idtype: paddle.dtype, head_dim_qk: int, head_dim_vo: int, additional_tensor_names: List[str], @@ -277,23 +248,22 @@ def gen_customize_batch_decode_tvm_binding( "use_logits_soft_cap": str(use_logits_soft_cap).lower(), } gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri - (additional_params_decl, additional_func_params, additional_params_setter) = ( - generate_additional_params( - additional_tensor_names, - additional_tensor_dtypes, - additional_scalar_names, - additional_scalar_dtypes, - ) + ( + additional_params_decl, + additional_func_params, + additional_params_setter, + ) = generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, ) - with open( jit_env.FLASHINFER_TVM_BINDING_DIR / "batch_decode_customize_config.jinja" ) as f: config_templ = jinja2.Template(f.read()) - with open(jit_env.FLASHINFER_CSRC_DIR / "batch_decode_kernel_inst.jinja") as f: kernel_inst_templ = jinja2.Template(f.read()) - kwargs |= { "additional_params_decl": additional_params_decl, "additional_func_params": additional_func_params, @@ -307,22 +277,16 @@ def gen_customize_batch_decode_tvm_binding( ) source_paths.append(dest_path) source = kernel_inst_templ.render( - pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], - **kwargs, + pos_encoding_mode=pos_encoding_mode_literal[pos_encoding_mode], **kwargs ) write_if_different(dest_path, source) - - for filename in [ - "batch_decode.cu", - "batch_decode_jit_tvm_binding.cu", - ]: + for filename in ["batch_decode.cu", "batch_decode_jit_tvm_binding.cu"]: src_path = jit_env.FLASHINFER_TVM_BINDING_DIR / filename dest_path = gen_directory / filename source_paths.append(dest_path) with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - generated_config_path = gen_directory / "batch_decode_config.inc" write_if_different(generated_config_path, generated_inc_str) return uri, source_paths @@ -330,16 +294,15 @@ def gen_customize_batch_decode_tvm_binding( def gen_batch_mla_tvm_binding( uri: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_ckv: int, head_dim_kpe: int, ): gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri os.makedirs(gen_directory, exist_ok=True) - with open(jit_env.FLASHINFER_TVM_BINDING_DIR / "batch_mla_config.jinja") as f: config_templ = jinja2.Template(f.read()) generated_config_path = gen_directory / "batch_mla_config.inc" @@ -354,7 +317,6 @@ def gen_batch_mla_tvm_binding( head_dim_kpe=head_dim_kpe, ), ) - source_paths = [] for filename in [ "batch_mla_plan.cu", @@ -367,5 +329,4 @@ def gen_batch_mla_tvm_binding( with open(src_path, "r") as f: source = f.read() write_if_different(dest_path, source) - return uri, source_paths diff --git a/flashinfer/jit/attention/utils.py b/flashinfer/jit/attention/utils.py index 86352d82e9..6c6df32e81 100644 --- a/flashinfer/jit/attention/utils.py +++ b/flashinfer/jit/attention/utils.py @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ - from typing import List @@ -27,10 +26,7 @@ def generate_additional_params( additional_params_decl = "".join( [ f"{dtype}* {var};\n" - for dtype, var in zip( - additional_tensor_dtypes, - additional_tensor_names, - ) + for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names) ] + [ f"{dtype} {var};\n" diff --git a/flashinfer/jit/attention/variants.py b/flashinfer/jit/attention/variants.py index 16ee1c4f3f..f1e28ebb59 100644 --- a/flashinfer/jit/attention/variants.py +++ b/flashinfer/jit/attention/variants.py @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - -attention_sink_fa2_decl = r""" +attention_sink_fa2_decl = """ struct AttentionSink : AttentionVariantBase { static constexpr bool use_softmax = true; @@ -51,8 +50,7 @@ }); }; """ - -attention_sink_fa3_decl = r""" +attention_sink_fa3_decl = """ template struct OnlineSoftmaxWithSink { @@ -162,8 +160,4 @@ } }; """ - -attention_sink_decl = { - "fa2": attention_sink_fa2_decl, - "fa3": attention_sink_fa3_decl, -} +attention_sink_decl = {"fa2": attention_sink_fa2_decl, "fa3": attention_sink_fa3_decl} diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 13b1880afe..7b5b92592b 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -7,8 +7,7 @@ from pathlib import Path from typing import List, Optional, Sequence, Union -import torch -import torch.utils.cpp_extension as torch_cpp_ext +import paddle from filelock import FileLock from . import env as jit_env @@ -27,11 +26,9 @@ def __init__(self, name): self.addHandler(logging.StreamHandler()) log_path = jit_env.FLASHINFER_WORKSPACE_DIR / "flashinfer_jit.log" if not os.path.exists(log_path): - # create an empty file - with open(log_path, "w") as f: # noqa: F841 + with open(log_path, "w") as f: pass self.addHandler(logging.FileHandler(log_path)) - # set the format of the log self.handlers[0].setFormatter( logging.Formatter( "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - flashinfer.jit: %(message)s" @@ -48,11 +45,14 @@ def __init__(self, name): def check_cuda_arch(): - # cuda arch check for fp8 at the moment. - for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): - arch = int(re.search(r"compute_(\d+)", cuda_arch_flags).group(1)) - if arch < 75: - raise RuntimeError("FlashInfer requires sm75+") + dev = paddle.device.get_device() + dev_id = int(dev.split(":")[1]) + props = paddle.device.cuda.get_device_properties(dev_id) + arch = props.major * 10 + props.minor + if arch < 75: + raise RuntimeError( + f"FlashInfer requires sm75+, but current GPU is compute_{arch} (sm_{arch})" + ) def clear_cache_dir(): @@ -62,10 +62,7 @@ def clear_cache_dir(): shutil.rmtree(jit_env.FLASHINFER_JIT_DIR) -common_nvcc_flags = [ - "-DFLASHINFER_ENABLE_FP8_E8M0", - "-DFLASHINFER_ENABLE_FP4_E2M1", -] +common_nvcc_flags = ["-DFLASHINFER_ENABLE_FP8_E8M0", "-DFLASHINFER_ENABLE_FP4_E2M1"] sm90a_nvcc_flags = ["-gencode=arch=compute_90a,code=sm_90a"] + common_nvcc_flags sm100a_nvcc_flags = ["-gencode=arch=compute_100a,code=sm_100a"] + common_nvcc_flags @@ -129,25 +126,21 @@ def build(self, verbose: bool, need_lock: bool = True) -> None: def load(self, so_path: Path, class_name: str = None): load_class = class_name is not None - loader = torch.classes if load_class else torch.ops + loader = paddle.classes if load_class else paddle.ops loader.load_library(so_path) if load_class: - cls = torch._C._get_custom_class_python_wrapper(self.name, class_name) + cls = paddle.base.core.torch_compat._get_custom_class_python_wrapper(self.name, class_name) return cls return getattr(loader, self.name) def build_and_load(self, class_name: str = None): if self.is_aot: return self.load(self.aot_path, class_name) - - # Guard both build and load with the same lock to avoid race condition - # where another process is building the library and removes the .so file. with FileLock(self.lock_path, thread_local=False): so_path = self.jit_library_path verbose = os.environ.get("FLASHINFER_JIT_VERBOSE", "0") == "1" self.build(verbose, need_lock=False) result = self.load(so_path, class_name) - return result @@ -162,7 +155,6 @@ def gen_jit_spec( ) -> JitSpec: check_cuda_arch() verbose = os.environ.get("FLASHINFER_JIT_VERBOSE", "0") == "1" - cflags = ["-O3", "-std=c++17", "-Wno-switch-bool"] cuda_cflags = [ "-O3", @@ -183,25 +175,20 @@ def gen_jit_spec( "-DCUTLASS_DEBUG_TRACE_LEVEL=2", ] else: - # non debug mode cuda_cflags += ["-DNDEBUG"] - if extra_cflags is not None: cflags += extra_cflags if extra_cuda_cflags is not None: cuda_cflags += extra_cuda_cflags - spec = JitSpec( name=name, sources=[Path(x) for x in sources], extra_cflags=cflags, extra_cuda_cflags=cuda_cflags, extra_ldflags=extra_ldflags, - extra_include_dirs=( - [Path(x) for x in extra_include_paths] - if extra_include_paths is not None - else None - ), + extra_include_dirs=[Path(x) for x in extra_include_paths] + if extra_include_paths is not None + else None, needs_device_linking=needs_device_linking, ) spec.write_ninja() @@ -209,7 +196,6 @@ def gen_jit_spec( def get_tmpdir() -> Path: - # TODO(lequn): Try /dev/shm first. This should help Lock on NFS. tmpdir = jit_env.FLASHINFER_JIT_DIR / "tmp" if not tmpdir.exists(): tmpdir.mkdir(parents=True, exist_ok=True) @@ -217,9 +203,7 @@ def get_tmpdir() -> Path: def build_jit_specs( - specs: List[JitSpec], - verbose: bool = False, - skip_prebuilt: bool = True, + specs: List[JitSpec], verbose: bool = False, skip_prebuilt: bool = True ) -> None: lines: List[str] = [] for spec in specs: @@ -228,9 +212,7 @@ def build_jit_specs( lines.append(f"subninja {spec.ninja_path}") if not lines: return - lines = ["ninja_required_version = 1.3"] + lines + [""] - tmpdir = get_tmpdir() with FileLock(tmpdir / "flashinfer_jit.lock", thread_local=False): ninja_path = tmpdir / "flashinfer_jit.ninja" @@ -246,7 +228,6 @@ def load_cuda_ops( extra_ldflags=None, extra_include_paths=None, ): - # TODO(lequn): Remove this function and use JitSpec directly. warnings.warn( "load_cuda_ops is deprecated. Use JitSpec directly.", DeprecationWarning, diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 26f7a2a073..360b4befbd 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -1,26 +1,27 @@ -# Adapted from https://github.com/pytorch/pytorch/blob/v2.7.0/torch/utils/cpp_extension.py - import functools import os import re import subprocess import sys import sysconfig -from packaging.version import Version from pathlib import Path from typing import List, Optional -import torch -from torch.utils.cpp_extension import ( - _TORCH_PATH, - CUDA_HOME, - _get_cuda_arch_flags, - _get_num_workers, - _get_pybind11_abi_build_flags, -) +import paddle +from packaging.version import Version +from paddle.utils.cpp_extension.cpp_extension import CUDA_HOME from . import env as jit_env +# torch compat polyfills +_TORCH_PATH = paddle.__path__[0] +def _get_num_workers(verbose: bool) -> Optional[int]: + max_jobs = os.environ.get('MAX_JOBS') + if max_jobs is not None and max_jobs.isdigit(): + return int(max_jobs) + return None + + @functools.cache def get_cuda_version() -> Version: @@ -29,7 +30,7 @@ def get_cuda_version() -> Version: else: nvcc = os.path.join(CUDA_HOME, "bin/nvcc") txt = subprocess.check_output([nvcc, "--version"], text=True) - matches = re.findall(r"release (\d+\.\d+),", txt) + matches = re.findall("release (\\d+\\.\\d+),", txt) if not matches: raise RuntimeError( f"Could not parse CUDA version from nvcc --version output: {txt}" @@ -43,7 +44,9 @@ def is_cuda_version_at_least(version_str: str) -> bool: def _get_glibcxx_abi_build_flags() -> List[str]: glibcxx_abi_cflags = [ - "-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI)) + # TODO: Provide a python interface like PyTorch + # "-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI)) + "-D_GLIBCXX_USE_CXX11_ABI=1" ] return glibcxx_abi_cflags @@ -64,34 +67,33 @@ def generate_ninja_build_for_op( system_includes = [ sysconfig.get_path("include"), "$torch_home/include", - "$torch_home/include/torch/csrc/api/include", + "$torch_home/include/paddle/phi/api/include/compat", + "$torch_home/include/paddle/phi/api/include/compat/torch/csrc/api/include", "$cuda_home/include", jit_env.FLASHINFER_INCLUDE_DIR.resolve(), jit_env.FLASHINFER_CSRC_DIR.resolve(), ] system_includes += [p.resolve() for p in jit_env.CUTLASS_INCLUDE_DIRS] system_includes.append(jit_env.SPDLOG_INCLUDE_DIR.resolve()) - common_cflags = [ "-DTORCH_EXTENSION_NAME=$name", "-DTORCH_API_INCLUDE_EXTENSION_H", "-DPy_LIMITED_API=0x03090000", + "-DPADDLE_WITH_CUDA", ] - common_cflags += _get_pybind11_abi_build_flags() + + # TODO: Provide a python interface like torch + # common_cflags += _get_pybind11_abi_build_flags() + common_cflags += ['-DPYBIND11_COMPILER_TYPE=\\"_gcc\\"', '-DPYBIND11_STDLIB=\\"_libstdcpp\\"', '-DPYBIND11_BUILD_ABI=\\"_cxxabi1018\\"'] common_cflags += _get_glibcxx_abi_build_flags() if extra_include_dirs is not None: for extra_dir in extra_include_dirs: common_cflags.append(f"-I{extra_dir.resolve()}") for sys_dir in system_includes: common_cflags.append(f"-isystem {sys_dir}") - - cflags = [ - "$common_cflags", - "-fPIC", - ] + cflags = ["$common_cflags", "-fPIC"] if extra_cflags is not None: cflags += extra_cflags - cuda_cflags: List[str] = [] cc_env = os.environ.get("CC") if cc_env is not None: @@ -102,27 +104,26 @@ def generate_ninja_build_for_op( "--expt-relaxed-constexpr", ] cuda_version = get_cuda_version() - # enable -static-global-template-stub when cuda version >= 12.8 if cuda_version >= Version("12.8"): cuda_cflags += [ "-static-global-template-stub=false", ] - cuda_cflags += _get_cuda_arch_flags(extra_cuda_cflags) + # TODO: Provide a python interface, currently the `_get_cuda_arch_flags(extra_cuda_cflags)` returns [] + # cuda_cflags += _get_cuda_arch_flags(extra_cuda_cflags) if extra_cuda_cflags is not None: cuda_cflags += extra_cuda_cflags - ldflags = [ "-shared", - "-L$torch_home/lib", + "-L$torch_home/libs", + "-L$torch_home/base", "-L$cuda_home/lib64", - "-lc10", - "-lc10_cuda", - "-ltorch_cpu", - "-ltorch_cuda", - "-ltorch", + "-lpaddle", + "-lphi", + "-lphi_core", + "-lphi_gpu", + "-lcommon", "-lcudart", ] - env_extra_ldflags = os.environ.get("FLASHINFER_EXTRA_LDFLAGS") if env_extra_ldflags: try: @@ -135,14 +136,11 @@ def generate_ninja_build_for_op( file=sys.stderr, ) ldflags += env_extra_ldflags.split() - if extra_ldflags is not None: ldflags += extra_ldflags - cxx = os.environ.get("CXX", "c++") cuda_home = CUDA_HOME or "/usr/local/cuda" nvcc = os.environ.get("PYTORCH_NVCC", "$cuda_home/bin/nvcc") - lines = [ "ninja_required_version = 1.3", f"name = {name}", @@ -169,25 +167,12 @@ def generate_ninja_build_for_op( " deps = gcc", "", ] - - # Add nvcc linking rule for device code if needs_device_linking: lines.extend( - [ - "rule nvcc_link", - " command = $nvcc -shared $in $ldflags -o $out", - "", - ] + ["rule nvcc_link", " command = $nvcc -shared $in $ldflags -o $out", ""] ) else: - lines.extend( - [ - "rule link", - " command = $cxx $in $ldflags -o $out", - "", - ] - ) - + lines.extend(["rule link", " command = $cxx $in $ldflags -o $out", ""]) objects = [] for source in sources: is_cuda = source.suffix == ".cu" @@ -197,13 +182,11 @@ def generate_ninja_build_for_op( obj = f"$name/{obj_name}" objects.append(obj) lines.append(f"build {obj}: {cmd} {source.resolve()}") - lines.append("") link_rule = "nvcc_link" if needs_device_linking else "link" lines.append(f"build $name/$name.so: {link_rule} " + " ".join(objects)) lines.append("default $name/$name.so") lines.append("") - return "\n".join(lines) @@ -220,7 +203,6 @@ def run_ninja(workdir: Path, ninja_file: Path, verbose: bool) -> None: num_workers = _get_num_workers(verbose) if num_workers is not None: command += ["-j", str(num_workers)] - sys.stdout.flush() sys.stderr.flush() try: diff --git a/flashinfer/jit/cubin_loader.py b/flashinfer/jit/cubin_loader.py index bc2bd84e4e..375e53c0f4 100644 --- a/flashinfer/jit/cubin_loader.py +++ b/flashinfer/jit/cubin_loader.py @@ -1,3 +1,5 @@ +import os + """ Copyright (c) 2025 by FlashInfer team. @@ -13,10 +15,8 @@ See the License for the specific language governing permissions and limitations under the License. """ - import ctypes import hashlib -import os import shutil import time @@ -25,8 +25,6 @@ from .core import logger from .env import FLASHINFER_CUBIN_DIR -# This is the storage path for the cubins, it can be replaced -# with a local path for testing. FLASHINFER_CUBINS_REPOSITORY = os.environ.get( "FLASHINFER_CUBINS_REPOSITORY", "https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/", @@ -48,17 +46,13 @@ def download_file(source, local_path, retries=3, delay=5, timeout=10, lock_timeo Returns: - bool: True if download or copy is successful, False otherwise. """ + import requests - import requests # type: ignore[import-untyped] - - lock_path = f"{local_path}.lock" # Lock file path + lock_path = f"{local_path}.lock" lock = filelock.FileLock(lock_path, timeout=lock_timeout) - try: with lock: logger.info(f"Acquired lock for {local_path}") - - # Handle local file copy if os.path.exists(source): try: shutil.copy(source, local_path) @@ -67,41 +61,32 @@ def download_file(source, local_path, retries=3, delay=5, timeout=10, lock_timeo except Exception as e: logger.error(f"Failed to copy local file: {e}") return False - - # Handle URL downloads for attempt in range(1, retries + 1): try: response = requests.get(source, timeout=timeout) response.raise_for_status() - with open(local_path, "wb") as file: file.write(response.content) - logger.info( f"File downloaded successfully: {source} -> {local_path}" ) return True - except requests.exceptions.RequestException as e: logger.warning( f"Downloading {source}: attempt {attempt} failed: {e}" ) - if attempt < retries: logger.info(f"Retrying in {delay} seconds...") time.sleep(delay) else: logger.error("Max retries reached. Download failed.") return False - except filelock.Timeout: logger.error( f"Failed to acquire lock for {local_path} within {lock_timeout} seconds." ) return False - finally: - # Clean up the lock file if os.path.exists(lock_path): os.remove(lock_path) logger.info(f"Lock file {lock_path} removed.") @@ -148,7 +133,6 @@ def get_cubin(name, sha256, file_extension=".cubin"): cubin = load_cubin(cubin_path, sha256) if cubin: return cubin - # either the file does not exist or it is corrupted, we'll download a new one. uri = FLASHINFER_CUBINS_REPOSITORY + "/" + cubin_fname logger.info(f"Fetching cubin {name} from {uri}") download_file(uri, cubin_path) @@ -159,28 +143,21 @@ def convert_to_ctypes_char_p(data: bytes): return ctypes.c_char_p(data) -# Keep a reference to the callback for each loaded library to prevent GC dll_cubin_handlers = {} def setup_cubin_loader(dll_path: str): if dll_path in dll_cubin_handlers: return - _LIB = ctypes.CDLL(dll_path) - - # Define the correct callback type CALLBACK_TYPE = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p) def get_cubin_callback(name, sha256): - # Both name and sha256 are bytes (c_char_p) cubin = get_cubin(name.decode("utf-8"), sha256.decode("utf-8")) _LIB.FlashInferSetCurrentCubin( convert_to_ctypes_char_p(cubin), ctypes.c_int(len(cubin)) ) - # Create the callback and keep a reference to prevent GC cb = CALLBACK_TYPE(get_cubin_callback) dll_cubin_handlers[dll_path] = cb - _LIB.FlashInferSetCubinCallback(cb) diff --git a/flashinfer/jit/cutlass_gemm/cutlass_library.py b/flashinfer/jit/cutlass_gemm/cutlass_library.py index bd76dd208a..f368e30c3c 100644 --- a/flashinfer/jit/cutlass_gemm/cutlass_library.py +++ b/flashinfer/jit/cutlass_gemm/cutlass_library.py @@ -1,77 +1,30 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - """ Data types and tags used for emitting CUTLASS C++ kernels """ - import enum import re -# The following block implements enum.auto() for Python 3.5 variants that don't include it such -# as the default 3.5.2 on Ubuntu 16.04. -# -# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility - try: from enum import auto as enum_auto except ImportError: __cutlass_library_auto_enum = 0 - def enum_auto() -> int: # type: ignore[no-redef] + def enum_auto() -> int: global __cutlass_library_auto_enum i = __cutlass_library_auto_enum __cutlass_library_auto_enum += 1 return i -################################################################################################### - - -# class GeneratorTarget(enum.Enum): Library = enum_auto() -# GeneratorTargetNames = {GeneratorTarget.Library: "library"} -# - -################################################################################################### -# class DataType(enum.Enum): - void = enum_auto() # primarily used to disable C tensor for epilogues + void = enum_auto() b1 = enum_auto() u2 = enum_auto() u4 = enum_auto() @@ -120,7 +73,6 @@ class DataType(enum.Enum): invalid = enum_auto() -# ShortDataTypeNames = { DataType.s32: "i", DataType.e4m3: "e4m3", @@ -134,8 +86,6 @@ class DataType(enum.Enum): DataType.f6: "f6", DataType.f4: "f4", } - -# DataTypeNames = { DataType.void: "void", DataType.b1: "b1", @@ -184,7 +134,6 @@ class DataType(enum.Enum): DataType.cs32: "cs32", DataType.cs64: "cs64", } - DataTypeTag = { DataType.void: "void", DataType.b1: "cutlass::uint1b_t", @@ -233,7 +182,6 @@ class DataType(enum.Enum): DataType.cs32: "cutlass::complex", DataType.cs64: "cutlass::complex", } - DataTypeSize = { DataType.void: 0, DataType.b1: 1, @@ -284,39 +232,30 @@ class DataType(enum.Enum): } -################################################################################################### -# class BlasMode(enum.Enum): symmetric = enum_auto() hermitian = enum_auto() -# BlasModeTag = { BlasMode.symmetric: "cutlass::BlasMode::kSymmetric", BlasMode.hermitian: "cutlass::BlasMode::kHermitian", } -# class ComplexTransform(enum.Enum): none = enum_auto() conj = enum_auto() -# ComplexTransformTag = { ComplexTransform.none: "cutlass::ComplexTransform::kNone", ComplexTransform.conj: "cutlass::ComplexTransform::kConjugate", } - -# Used for cutlass3x complex kernel collective mainloop builder instantiation ComplexTransformTag3x = { ComplexTransform.none: "cute::identity", ComplexTransform.conj: "cute::conjugate", } - -# RealComplexBijection = [ (DataType.f16, DataType.cf16), (DataType.f32, DataType.cf32), @@ -324,7 +263,6 @@ class ComplexTransform(enum.Enum): ] -# def is_complex(data_type): return any(data_type == c for _r, c in RealComplexBijection) @@ -351,7 +289,6 @@ def is_grouped(gemm_kind): ) -# def get_complex_from_real(real_type): for r, c in RealComplexBijection: if real_type == r: @@ -359,7 +296,6 @@ def get_complex_from_real(real_type): return DataType.invalid -# def get_real_from_complex(complex_type): for r, c in RealComplexBijection: if complex_type == c: @@ -367,26 +303,20 @@ def get_real_from_complex(complex_type): return DataType.invalid -# TMA requires an alignment of 128 bits for all data types def get_tma_alignment(data_type): if data_type == DataType.void: return 0 elif DataTypeSize[data_type] == 6: - return 128 # 96B alignment for 16U6 format + return 128 else: return 128 // DataTypeSize[data_type] -# class ComplexMultiplyOp(enum.Enum): multiply_add = enum_auto() gaussian = enum_auto() -################################################################################################### - - -# class MathOperation(enum.Enum): multiply_add = enum_auto() multiply_add_saturate = enum_auto() @@ -402,7 +332,6 @@ class MathOperation(enum.Enum): multiply_add_fast_accum = enum_auto() -# MathOperationTag = { MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd", MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate", @@ -418,10 +347,7 @@ class MathOperation(enum.Enum): MathOperation.multiply_add_fast_accum: "cutlass::arch::OpMultiplyAddFastAccum", } -################################################################################################### - -# class LayoutType(enum.Enum): ColumnMajor = enum_auto() RowMajor = enum_auto() @@ -445,7 +371,6 @@ class LayoutType(enum.Enum): TensorKCSRT = enum_auto() -# LayoutTag = { LayoutType.ColumnMajor: "cutlass::layout::ColumnMajor", LayoutType.RowMajor: "cutlass::layout::RowMajor", @@ -468,8 +393,6 @@ class LayoutType(enum.Enum): LayoutType.TensorKCSR: "cutlass::layout::TensorKCSR", LayoutType.TensorKCSRT: "cutlass::layout::TensorKCSRT", } - -# TransposedLayout = { LayoutType.ColumnMajor: LayoutType.RowMajor, LayoutType.RowMajor: LayoutType.ColumnMajor, @@ -481,8 +404,6 @@ class LayoutType(enum.Enum): LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, LayoutType.TensorNHWC: LayoutType.TensorNHWC, } - -# ShortLayoutTypeNames = { LayoutType.ColumnMajor: "n", LayoutType.ColumnMajorInterleaved2: "n2", @@ -505,8 +426,6 @@ class LayoutType(enum.Enum): LayoutType.TensorKCSR: "kcsr", LayoutType.TensorKCSRT: "kcsrt", } - -# ShortComplexLayoutNames = { (LayoutType.ColumnMajor, ComplexTransform.none): "n", (LayoutType.ColumnMajor, ComplexTransform.conj): "c", @@ -515,7 +434,6 @@ class LayoutType(enum.Enum): } -################################################################################################### class KernelScheduleType(enum.Enum): ScheduleAuto = enum_auto() Multistage = enum_auto() @@ -534,18 +452,14 @@ class KernelScheduleType(enum.Enum): PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() PtrArrayTmaWarpSpecializedPingpong = enum_auto() PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto() - BlockwiseTmaWarpSpecializedCooperative = enum_auto() PtrArrayBlockwiseTmaWarpSpecializedCooperative = enum_auto() - TmaWarpSpecialized1SmSm100 = enum_auto() TmaWarpSpecialized2SmSm100 = enum_auto() ImplicitTmaWarpSpecialized1SmSm100 = enum_auto() ImplicitTmaWarpSpecialized2SmSm100 = enum_auto() - PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto() PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto() - PtrArrayTmaWarpSpecialized1SmBlockScaledSm100 = enum_auto() PtrArrayTmaWarpSpecialized2SmBlockScaledSm100 = enum_auto() PtrArrayNvf4TmaWarpSpecialized1SmSm100 = enum_auto() @@ -554,35 +468,27 @@ class KernelScheduleType(enum.Enum): PtrArrayMxf4TmaWarpSpecialized2SmSm100 = enum_auto() PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() - SparseTmaWarpSpecialized1SmSm100 = enum_auto() SparseTmaWarpSpecialized2SmSm100 = enum_auto() - BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto() BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto() Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() - BlockwiseTmaWarpSpecialized1SmSm100 = enum_auto() BlockwiseTmaWarpSpecialized2SmSm100 = enum_auto() - PtrArrayBlockwiseTmaWarpSpecialized1SmSm100 = enum_auto() PtrArrayBlockwiseTmaWarpSpecialized2SmSm100 = enum_auto() - Mxf4TmaWarpSpecialized1SmSm100 = enum_auto() Mxf4TmaWarpSpecialized2SmSm100 = enum_auto() Nvf4TmaWarpSpecialized1SmSm100 = enum_auto() Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() - Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto() Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto() Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto() Nvf4TmaWarpSpecializedPingpongSm120 = enum_auto() Mxf4TmaWarpSpecializedCooperativeSm120 = enum_auto() Mxf4TmaWarpSpecializedPingpongSm120 = enum_auto() - F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto() - BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto() BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto() @@ -645,8 +551,6 @@ class KernelScheduleType(enum.Enum): KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: "cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120", KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: "cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120", } - -# KernelScheduleSuffixes = { KernelScheduleType.ScheduleAuto: "", KernelScheduleType.Multistage: "_cpasync", @@ -726,7 +630,6 @@ class EpilogueScheduleType(enum.Enum): PtrArrayTmaWarpSpecializedCooperative = enum_auto() -# EpilogueScheduleTag = { EpilogueScheduleType.ScheduleAuto: "cutlass::epilogue::collective::EpilogueScheduleAuto", EpilogueScheduleType.EpilogueTransposed: "cutlass::gemm::EpilogueTransposed", @@ -745,8 +648,6 @@ class EpilogueScheduleType(enum.Enum): EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: "cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative", EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: "cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong", } - -# EpilogueScheduleSuffixes = { EpilogueScheduleType.ScheduleAuto: "", EpilogueScheduleType.EpilogueTransposed: "", @@ -772,14 +673,12 @@ class EpilogueFunctor3x(enum.Enum): LinearCombinationBlockScaleFactor = enum_auto() -# EpilogueFunctor3xTag = { EpilogueFunctor3x.LinearCombination: "cutlass::epilogue::fusion::LinearCombination", EpilogueFunctor3x.LinearCombinationBlockScaleFactor: "cutlass::epilogue::fusion::LinCombBlockScaleFactor", } -# TMA epilogues have certain alignment requirements as calculated in get_tma_alignment(data_type) def is_tma_epilogue(epilogue_schedule_type): return epilogue_schedule_type in [ EpilogueScheduleType.ScheduleAuto, @@ -797,9 +696,7 @@ def is_tma_epilogue(epilogue_schedule_type): def to_grouped_schedule(schedule, grouped): if not grouped: return schedule - group_schedule_map = { - # SM90 KernelScheduleType.TmaWarpSpecializedCooperative: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative, KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative, KernelScheduleType.TmaWarpSpecializedPingpong: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong, @@ -808,7 +705,6 @@ def to_grouped_schedule(schedule, grouped): EpilogueScheduleType.TmaWarpSpecialized: EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecializedCooperative: EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized, - # SM100 KernelScheduleType.TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100, KernelScheduleType.TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100, KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100, @@ -822,7 +718,6 @@ def to_grouped_schedule(schedule, grouped): EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, } - return group_schedule_map[schedule] @@ -832,78 +727,54 @@ class TileSchedulerType(enum.Enum): StreamK = enum_auto() -# TileSchedulerTag = { TileSchedulerType.Default: "void", TileSchedulerType.Persistent: "cutlass::gemm::PersistentScheduler", TileSchedulerType.StreamK: "cutlass::gemm::StreamKScheduler", } - -# TileSchedulerSuffixes = { TileSchedulerType.Default: "", TileSchedulerType.Persistent: "", TileSchedulerType.StreamK: "_stream_k", } -################################################################################################### - -# class SideMode(enum.Enum): Left = enum_auto() Right = enum_auto() -# SideModeTag = { SideMode.Left: "cutlass::SideMode::kLeft", SideMode.Right: "cutlass::SideMode::kRight", } - -# ShortSideModeNames = {SideMode.Left: "ls", SideMode.Right: "rs"} -################################################################################################### - -# class FillMode(enum.Enum): Lower = enum_auto() Upper = enum_auto() -# FillModeTag = { FillMode.Lower: "cutlass::FillMode::kLower", FillMode.Upper: "cutlass::FillMode::kUpper", } - -# ShortFillModeNames = {FillMode.Lower: "l", FillMode.Upper: "u"} -################################################################################################### - -# class DiagType(enum.Enum): NonUnit = enum_auto() Unit = enum_auto() -# DiagTypeTag = { DiagType.NonUnit: "cutlass::DiagType::kNonUnit", DiagType.Unit: "cutlass::DiagType::kUnit", } - -# ShortDiagTypeNames = {DiagType.NonUnit: "nu", DiagType.Unit: "un"} -################################################################################################### - -# class OpcodeClass(enum.Enum): Simt = enum_auto() TensorOp = enum_auto() @@ -919,7 +790,6 @@ class OpcodeClass(enum.Enum): OpcodeClass.SparseTensorOp: "sptensorop", OpcodeClass.BlockScaledTensorOp: "bstensorop", } - OpcodeClassTag = { OpcodeClass.Simt: "cutlass::arch::OpClassSimt", OpcodeClass.TensorOp: "cutlass::arch::OpClassTensorOp", @@ -928,10 +798,7 @@ class OpcodeClass(enum.Enum): OpcodeClass.BlockScaledTensorOp: "cutlass::arch::OpClassBlockScaledTensorOp", } -################################################################################################### - -# class OperationKind(enum.Enum): Gemm = enum_auto() RankK = enum_auto() @@ -942,7 +809,6 @@ class OperationKind(enum.Enum): Conv3d = enum_auto() -# OperationKindNames = { OperationKind.Gemm: "gemm", OperationKind.RankK: "rank_k", @@ -954,39 +820,32 @@ class OperationKind(enum.Enum): } -# class Target(enum.Enum): library = enum_auto() -# ArchitectureNames = { - 50: "maxwell", - 60: "pascal", - 61: "pascal", - 70: "volta", - 75: "turing", - 80: "ampere", - 89: "ada", - 90: "hopper", + (50): "maxwell", + (60): "pascal", + (61): "pascal", + (70): "volta", + (75): "turing", + (80): "ampere", + (89): "ada", + (90): "hopper", } - -# SharedMemPerCC = { - 70: 96, # 96KB of SMEM - 72: 96, # 96KB of SMEM - 75: 64, # 64KB of SMEM - 80: 163, # 163KB of SMEM - 1KB reserved for the driver - 86: 99, # 99KB of SMEM - 1KB reserved for the driver - 87: 163, # 163KB of SMEM - 1KB reserved for the driver - 89: 99, # 99KB of SMEM - 1KB reserved for the driver - 90: 227, # 227KB of SMEM - 1KB reserved for the driver + (70): 96, + (72): 96, + (75): 64, + (80): 163, + (86): 99, + (87): 163, + (89): 99, + (90): 227, } -################################################################################################### - -# def SubstituteTemplate(template, values): text = template changed = True @@ -1001,10 +860,6 @@ def SubstituteTemplate(template, values): return text -################################################################################################### - - -# class GemmKind(enum.Enum): Gemm = enum_auto() Sparse = enum_auto() @@ -1021,7 +876,6 @@ class GemmKind(enum.Enum): GroupedBlockwiseUniversal3x = enum_auto() -# GemmKindNames = { GemmKind.Gemm: "gemm", GemmKind.Sparse: "spgemm", @@ -1039,54 +893,44 @@ class GemmKind(enum.Enum): } -# class RankKKind(enum.Enum): Universal = enum_auto() -# RankKKindNames = {RankKKind.Universal: "rank_k"} -# class TrmmKind(enum.Enum): Universal = enum_auto() -# TrmmKindNames = {TrmmKind.Universal: "trmm"} -# class SymmKind(enum.Enum): Universal = enum_auto() -# SymmKindNames = {SymmKind.Universal: "symm"} -# class EpilogueFunctor(enum.Enum): LinearCombination = enum_auto() LinearCombinationClamp = enum_auto() -# EpilogueFunctorTag = { EpilogueFunctor.LinearCombination: "cutlass::epilogue::thread::LinearCombination", EpilogueFunctor.LinearCombinationClamp: "cutlass::epilogue::thread::LinearCombinationClamp", } -# class MixedInputMode(enum.Enum): ConvertOnly = enum_auto() ScaleOnly = enum_auto() ScaleWithZeroPoint = enum_auto() -# class SwizzlingFunctor(enum.Enum): Identity1 = enum_auto() Identity2 = enum_auto() @@ -1099,7 +943,6 @@ class SwizzlingFunctor(enum.Enum): StreamK = enum_auto() -# SwizzlingFunctorTag = { SwizzlingFunctor.Identity1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>", SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>", @@ -1113,41 +956,32 @@ class SwizzlingFunctor(enum.Enum): } -# class GroupScheduleMode(enum.Enum): Device = (enum_auto(),) Host = enum_auto() -# GroupScheduleModeTag = { GroupScheduleMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly", GroupScheduleMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute", } - -# ShortGroupScheduleModeNames = { GroupScheduleMode.Device: "Device", GroupScheduleMode.Host: "Host", } -################################################################################################### - -# class ConvKind(enum.IntEnum): Fprop = 0 Dgrad = 1 Wgrad = 2 -# ConvKindTag = { ConvKind.Fprop: "cutlass::conv::Operator::kFprop", ConvKind.Dgrad: "cutlass::conv::Operator::kDgrad", ConvKind.Wgrad: "cutlass::conv::Operator::kWgrad", } - ConvKindNames = { ConvKind.Fprop: "fprop", ConvKind.Dgrad: "dgrad", @@ -1160,7 +994,6 @@ class ConvMode(enum.IntEnum): Convolution = 1 -# class IteratorAlgorithm(enum.Enum): Analytic = 0 Optimized = 1 @@ -1169,7 +1002,6 @@ class IteratorAlgorithm(enum.Enum): FixedStrideDilation = 4 -# IteratorAlgorithmTag = { IteratorAlgorithm.Analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic", IteratorAlgorithm.Optimized: "cutlass::conv::IteratorAlgorithm::kOptimized", @@ -1177,7 +1009,6 @@ class IteratorAlgorithm(enum.Enum): IteratorAlgorithm.FewChannels: "cutlass::conv::IteratorAlgorithm::kFewChannels", IteratorAlgorithm.FixedStrideDilation: "cutlass::conv::IteratorAlgorithm::kFixedStrideDilation", } - IteratorAlgorithmNames = { IteratorAlgorithm.Analytic: "analytic", IteratorAlgorithm.Optimized: "optimized", @@ -1187,20 +1018,17 @@ class IteratorAlgorithm(enum.Enum): } -# class StrideSupport(enum.Enum): Strided = 0 Unity = 1 Fixed = 2 -# StrideSupportTag = { StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided", StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity", StrideSupport.Fixed: "cutlass::conv::StrideSupport::kFixed", } - StrideSupportNames = { StrideSupport.Strided: "", StrideSupport.Unity: "unity_stride", @@ -1208,35 +1036,28 @@ class StrideSupport(enum.Enum): } -# class GroupMode(enum.Enum): - NoneGroup = enum_auto() # dense conv (G=1) - SingleGroup = enum_auto() # grouped convolution (single group per CTA) - MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA) - Depthwise = enum_auto() # Depthwise convolution ( C=K=G ) + NoneGroup = enum_auto() + SingleGroup = enum_auto() + MultipleGroup = enum_auto() + Depthwise = enum_auto() -# GroupModeTag = { GroupMode.NoneGroup: "cutlass::conv::GroupMode::kNone", GroupMode.SingleGroup: "cutlass::conv::GroupMode::kSingleGroup", GroupMode.MultipleGroup: "cutlass::conv::GroupMode::kMultipleGroup", GroupMode.Depthwise: "cutlass::conv::GroupMode::kDepthwise", } - GroupModeNames = { GroupMode.NoneGroup: "", GroupMode.SingleGroup: "single_group", GroupMode.MultipleGroup: "multiple_group", GroupMode.Depthwise: "depthwise", } - DynamicClusterShape = [0, 0, 1] -################################################################################################### - -# class MathInstruction: def __init__( self, @@ -1257,7 +1078,6 @@ def __init__( self.element_scale_factor = element_scale_factor -# class TileDescription: def __init__( self, @@ -1343,7 +1163,6 @@ def procedural_name(self): self.filter_shape[0], self.filter_shape[1], ) - # Fixed Strided and dilation if self.stride != [-1, -1] and self.dilation != [-1, -1]: str_name += "_stride%dx%d_dilation%dx%d" % ( self.stride[0], @@ -1354,7 +1173,6 @@ def procedural_name(self): return str_name -# class TensorDescription: def __init__( self, element, layout, alignment=1, complex_transform=ComplexTransform.none @@ -1365,7 +1183,6 @@ def __init__( self.complex_transform = complex_transform -# class SymmetricTensorDescription: def __init__( self, @@ -1384,7 +1201,6 @@ def __init__( self.side_mode = side_mode -# class TriangularTensorDescription: def __init__( self, @@ -1405,40 +1221,33 @@ def __init__( self.complex_transform = complex_transform -# def CalculateSmemUsage(operation): cta_shape = operation.tile_description.threadblock_shape stages = operation.tile_description.stages - if ( operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse ): - # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) if DataTypeSize[operation.A.element] == 32: elements_per_8b_md = 2 elif DataTypeSize[operation.A.element] == 4: elements_per_8b_md = 8 else: elements_per_8b_md = 4 - smem_per_stage = ( DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md ) else: - # Few BLAS3 operations only have A tensor data_type_size_a = DataTypeSize[operation.A.element] data_type_size_b = DataTypeSize[operation.A.element] if operation.is_mixed_input(): data_type_size_b = DataTypeSize[operation.B.element] - smem_per_stage = ( data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + data_type_size_b * cta_shape[1] * cta_shape[2] // 8 ) - smem_usage = smem_per_stage * stages return smem_usage >> 10 diff --git a/flashinfer/jit/cutlass_gemm/generate_kernels.py b/flashinfer/jit/cutlass_gemm/generate_kernels.py index 9545d1d047..42405e94ca 100644 --- a/flashinfer/jit/cutlass_gemm/generate_kernels.py +++ b/flashinfer/jit/cutlass_gemm/generate_kernels.py @@ -2,26 +2,14 @@ import os from itertools import chain, product -from .cutlass_library import ( - enum_auto, - DataTypeNames, - DataTypeSize, - DataType, - DataTypeTag, - GemmKind, - GemmKindNames, - KernelScheduleType, - KernelScheduleTag, - KernelScheduleSuffixes, - EpilogueScheduleType, - EpilogueScheduleTag, - EpilogueScheduleSuffixes, -) from ..cpp_ext import is_cuda_version_at_least +from .cutlass_library import (DataType, DataTypeNames, DataTypeSize, + DataTypeTag, EpilogueScheduleSuffixes, + EpilogueScheduleTag, EpilogueScheduleType, + GemmKind, GemmKindNames, KernelScheduleSuffixes, + KernelScheduleTag, KernelScheduleType, enum_auto) -################################################################################ -# Epilogue Tag enum and string utils class TrtLlm_EpilogueTag(enum.Enum): epilogue_op_default = enum_auto() epilogue_op_bias = enum_auto() @@ -35,24 +23,21 @@ class TrtLlm_EpilogueFusion(enum.Enum): EpiTagNames = { - TrtLlm_EpilogueTag.epilogue_op_default: "lc", # linear combination - TrtLlm_EpilogueTag.epilogue_op_bias: "lc_bias", # linear combination with bias addition - TrtLlm_EpilogueTag.epilogue_op_silu: "silu", # silu or swiglu - TrtLlm_EpilogueTag.epilogue_op_gelu: "gelu", # gelu or geglu + TrtLlm_EpilogueTag.epilogue_op_default: "lc", + TrtLlm_EpilogueTag.epilogue_op_bias: "lc_bias", + TrtLlm_EpilogueTag.epilogue_op_silu: "silu", + TrtLlm_EpilogueTag.epilogue_op_gelu: "gelu", } - EpiTag = { TrtLlm_EpilogueTag.epilogue_op_default: "tensorrt_llm::cutlass_extensions::EpilogueOpDefault", TrtLlm_EpilogueTag.epilogue_op_bias: "tensorrt_llm::cutlass_extensions::EpilogueOpBias", TrtLlm_EpilogueTag.epilogue_op_silu: "tensorrt_llm::cutlass_extensions::EpilogueOpDefaultSilu", TrtLlm_EpilogueTag.epilogue_op_gelu: "tensorrt_llm::cutlass_extensions::EpilogueOpDefaultFtGelu", } - EpiFusion = { TrtLlm_EpilogueFusion.epilogue_fusion_none: "tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE", TrtLlm_EpilogueFusion.epilogue_fusion_finalize: "tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE", } - EpiFusionSuffixes = { None: "", TrtLlm_EpilogueFusion.epilogue_fusion_none: "EpilogueFusion_NONE", @@ -60,8 +45,6 @@ class TrtLlm_EpilogueFusion(enum.Enum): } -################################################################################ -# Quantization Operation and string utils class TrtLlm_QuantOp(enum.Enum): per_column_scale_only = enum_auto() finegrained_scale_only = enum_auto() @@ -75,7 +58,6 @@ class TrtLlm_QuantOp(enum.Enum): TrtLlm_QuantOp.finegrained_scale_and_zeros: "fgsz", TrtLlm_QuantOp.none: "noquant", } - QuantOpTag = { TrtLlm_QuantOp.per_column_scale_only: "cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY", TrtLlm_QuantOp.finegrained_scale_only: "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY", @@ -83,12 +65,8 @@ class TrtLlm_QuantOp(enum.Enum): TrtLlm_QuantOp.none: "void", } -################################################################################ -# The activations, biases, scales and zeros are instantiated using CUDA types, -# not CUTLASS types. This map materializes the name of the CUDA type. - -class e2m1_type: # WAR until we have upgraded everything to a supported version +class e2m1_type: pass @@ -122,8 +100,6 @@ def GetDataTypeNames(type, is_mx_fpx=None): } -################################################################################ -# A data structure holding all info to instantiate gemm launchers in TRT LLM. class TrtLlm_GemmLauncher: def __init__( self, @@ -182,7 +158,6 @@ def __repr__(self): self.warp_shape[2], self.stages, ) - hopper_suffix = "_{}x{}x{}{}{}{}".format( self.cga_shape[0], self.cga_shape[1], @@ -191,7 +166,6 @@ def __repr__(self): EpilogueScheduleSuffixes[self.epi_schedule], EpiFusionSuffixes[self.epi_fusion], ) - if self.arch >= 90: return kernel_prefix + hopper_suffix elif self.arch > 100: @@ -199,7 +173,6 @@ def __repr__(self): return kernel_prefix -################################################################################ def tuple_to_cute_shape(shape): return f"cute::Shape, cute::Int<{shape[1]}>, cute::Int<{shape[2]}>>" @@ -209,16 +182,12 @@ def instantiate_operation_tma_warp_specialized(operation): scale_zero_tag = CudaTypeName[operation.scalezero_type] bias_tag = CudaTypeName[operation.bias_type] out_tag = CudaTypeName[operation.output_type] - quant_op = QuantOpTag[operation.quant_op] epi_tag = EpiTag[operation.epi_tag] - cute_cta_shape = tuple_to_cute_shape(operation.cta_shape) cute_cga_shape = tuple_to_cute_shape(operation.cga_shape) - kernel_sched = KernelScheduleTag[operation.mainloop_schedule] epi_sched = EpilogueScheduleTag[operation.epi_schedule] - if operation.gemm_kind == GemmKind.Gemm: weight_tag = DataTypeTag[operation.weight_type] instantiation = f""" @@ -234,7 +203,6 @@ def instantiate_operation_tma_warp_specialized(operation): if operation.act_type != operation.weight_type and ( operation.act_type != DataType.e4m3 or operation.weight_type != e2m1 ): - # Mixed MoE GEMM weight_tag = CudaTypeName[operation.weight_type] instantiation = f""" template void sm90_generic_mixed_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {out_tag}, @@ -242,7 +210,6 @@ def instantiate_operation_tma_warp_specialized(operation): GroupedGemmInput<{act_tag}, {weight_tag}, {out_tag}, {out_tag}>inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size); """ else: - # Similar to MixedInput above, we must modify the tags for grouped gemm as CUTLASS library does not have the updated schedules assert operation.mainloop_schedule in [ KernelScheduleType.TmaWarpSpecializedCooperative, KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, @@ -250,16 +217,12 @@ def instantiate_operation_tma_warp_specialized(operation): assert operation.epi_schedule == EpilogueScheduleType.NoSmemWarpSpecialized kernel_sched.replace("::Kernel", "::KernelGrouped") epi_sched += "Grouped" - - # arch_tag = f"cutlass::arch::Sm{operation.arch}" arch_tag = f"Sm{operation.arch}" weight_tag = CudaTypeName[operation.weight_type] assert operation.epi_fusion is not None epi_fusion = EpiFusion[operation.epi_fusion] - epi_fusion = epi_fusion.split(":")[-1] epi_tag = epi_tag.split(":")[-1] - guard_map = { e2m1: "defined(ENABLE_FP4)", DataType.e4m3: "defined(ENABLE_FP8)", @@ -267,16 +230,12 @@ def instantiate_operation_tma_warp_specialized(operation): } guard_act = guard_map.get(operation.act_type, "1") guard_weight = guard_map.get(operation.weight_type, "1") - # TODO Revert this once compiler bug is fixed so we can use template instead of macro again - # instantiation = f""" - # template void tma_warp_specialized_generic_moe_gemm_kernelLauncher<{arch_tag}, {act_tag}, {weight_tag}, {out_tag}, - # {epi_tag}, {epi_fusion}, {cute_cta_shape}, {cute_cga_shape}, false> - # (TmaWarpSpecializedGroupedGemmInput, int, int, cudaStream_t, int*, size_t*); - # """ instantiation = f""" -#if {guard_act} && {guard_weight}\n +#if {guard_act} && {guard_weight} + INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag}, - {epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, false);\n + {epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {'true' if operation.is_mx_fpx else 'false'}, false); + #endif """ return instantiation @@ -286,7 +245,6 @@ def instantiate_operation_sm80(operation): act_tag = DataTypeTag[operation.dtype] weight_tag = DataTypeTag[operation.dtype] epi_tag = EpiTag[operation.epi_tag] - instantiation = f""" template void sm80_generic_fused_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.stage}, {epi_tag}> ({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy); @@ -307,12 +265,10 @@ def get_file_content(launcher_inl_files, operations): for file in launcher_inl_files: include_list.append(f'#include "{file}"') includes = "\n".join(include_list) - insts_list = list() for op in operations: insts_list.append(instantiate_operation(op)) instantiations = "\n".join(insts_list) - file_content = f"""{includes} namespace tensorrt_llm {{ @@ -341,7 +297,6 @@ def clean_leftover_files(output_dir, generated_files): def write_file(launcher_inl_files, operations, output_file): os.makedirs(os.path.dirname(output_file), exist_ok=True) - # Avoid changing modified time if file content is up to date content = get_file_content(launcher_inl_files, operations) try: with open(output_file, mode="r") as f: @@ -357,105 +312,73 @@ def write_file(launcher_inl_files, operations, output_file): def elementwise(x, y, f): - return tuple(f(a, b) for (a, b) in zip(x, y)) + return tuple(f(a, b) for a, b in zip(x, y)) def is_gemm_op_valid_sm100(op): - # TODO These are much more restricted than theory dictates, investigate if more can be enabled in future tile_m, tile_n, _ = elementwise(op.cta_shape, op.cga_shape, truediv) cga_m, cga_n, _ = op.cga_shape - - # Default shapes - # This is epilogue tile size. For two CTA this is actually size 128/256 for the MMA if tile_m not in [64, 128]: return False - - # FP4 Has some much more limited sizes if op.act_type == e2m1 or op.weight_type == e2m1: - # TODO 128x256x256 FP4 compiles but crashes - # if tile_n % 64 != 0 or tile_n < 128: - # return False if tile_n not in [64, 128, 256] or tile_m != 128: return False - - # Shapes for fp8 small N shapes if ( op.act_type == DataType.e4m3 and (tile_n == 16 or tile_n == 8) and (cga_m == 1 and cga_n == 1) ): - # todo: double check why this is disable in CUTLASS backend. @yuhan if tile_m == 128 and tile_n == 8: return False else: return True - - # Default alignment requirements if tile_n % 32 != 0 or tile_n < 32 or tile_n > 256: return False - - # Two CTA mode needs bigger tile n alignment if cga_m % 2 == 0 and tile_n % 64 != 0: return False - return True def is_gemm_op_valid(op): tile_m, tile_n, _ = op.cta_shape cga_m, cga_n, _ = op.cga_shape - if cga_m == 1 and cga_n == 1: return True - if cga_m == 2 and cga_n == 1 and tile_m >= 128: return True - if cga_m == 1 and cga_n == 2 and tile_n >= 128: return True - if cga_m == 2 and cga_n == 2 and tile_m >= 128 and tile_n >= 128: return True - return False def is_grouped_gemm_op_valid(op): if not is_gemm_op_valid(op): return False - if op.epi_tag != TrtLlm_EpilogueTag.epilogue_op_default: return False - if op.epi_schedule != EpilogueScheduleType.NoSmemWarpSpecialized: return False - if op.mainloop_schedule not in [ KernelScheduleType.TmaWarpSpecializedCooperative, KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, ]: return False - return True def is_op_valid(op): if op.arch >= 100: return is_gemm_op_valid_sm100(op) - if op.gemm_kind == GemmKind.Gemm: return is_gemm_op_valid(op) if op.gemm_kind == GemmKind.Grouped: return is_grouped_gemm_op_valid(op) -################################################################################ def generate_sm90_mixed_gemm_operations(): arch = 90 - - # For legacy reasons, we use unsigned types for the weights. The instanitated template - # will remap those back to the signed type. - # Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type) supported_dtypes = [ (DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16), (DataType.e4m3, DataType.u4, DataType.f16, DataType.bf16, DataType.bf16), @@ -464,34 +387,26 @@ def generate_sm90_mixed_gemm_operations(): (DataType.f16, DataType.u8, DataType.f16, DataType.f16, DataType.f16), (DataType.bf16, DataType.u8, DataType.bf16, DataType.bf16, DataType.bf16), ] - quant_ops = [ TrtLlm_QuantOp.per_column_scale_only, TrtLlm_QuantOp.finegrained_scale_only, TrtLlm_QuantOp.finegrained_scale_and_zeros, ] - epi_tags = [TrtLlm_EpilogueTag.epilogue_op_bias] - M_TILES = [64, 128] N_TILES = [16, 32, 64, 128, 256] cta_shapes_mn = product(M_TILES, N_TILES) - warp_shape = [4, 1, 1] - stages = 0 # auto - + stages = 0 cga_shapes = product([1, 2], [1, 2], [1]) - partial_args = product( supported_dtypes, quant_ops, epi_tags, cta_shapes_mn, cga_shapes ) - operations = list() for dtype_combo, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args: max_k_bits = 128 * 8 cta_shape_k = max_k_bits // GetDataTypeBits(dtype_combo[0]) cta_shape_mnk = cta_shape_mn + (cta_shape_k,) - use_coop = cta_shape_mn[0] == 128 mainloop_schedule = ( KernelScheduleType.TmaWarpSpecializedCooperative @@ -503,7 +418,6 @@ def generate_sm90_mixed_gemm_operations(): if use_coop else EpilogueScheduleType.TmaWarpSpecialized ) - fpA_intB_operation = TrtLlm_GemmLauncher( GemmKind.Gemm, arch, @@ -517,10 +431,8 @@ def generate_sm90_mixed_gemm_operations(): mainloop_schedule, epi_schedule, ) - if is_op_valid(fpA_intB_operation): operations.append(fpA_intB_operation) - return operations @@ -531,41 +443,33 @@ def generate_sm90_grouped_gemm_operations(is_arch_enabled): supported_dtypes = [DataType.f16, DataType.bf16, DataType.f32, DataType.e4m3] quant_ops = [TrtLlm_QuantOp.none] epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default] - M_TILES = [128] # Currently M tile must be 128 for Grouped GEMM + M_TILES = [128] N_TILES = [16, 32, 64, 128, 256] cta_shapes_mn = list(product(M_TILES, N_TILES)) + [(256, 128)] - - warp_shape = [0, 0, 0] # ignored except for naming - stages = 0 # auto - + warp_shape = [0, 0, 0] + stages = 0 epi_fusions = [ TrtLlm_EpilogueFusion.epilogue_fusion_none, TrtLlm_EpilogueFusion.epilogue_fusion_finalize, ] - cga_shapes = product([1, 2], [1, 2], [1]) - partial_args = product( supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mn, cga_shapes ) - operations = list() for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mn, cga_shape in partial_args: max_k_bits = 128 * 8 cta_shape_k = max_k_bits // GetDataTypeBits(dtype) cta_shape_mnk = cta_shape_mn + (cta_shape_k,) - mainloop_schedule = ( KernelScheduleType.TmaWarpSpecializedCooperative if dtype != DataType.e4m3 else KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum ) epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized - otypes = [dtype] if dtype == DataType.e4m3: otypes = [DataType.f16, DataType.bf16] - for otype in otypes: moe_gemm_operation = TrtLlm_GemmLauncher( GemmKind.Grouped, @@ -585,7 +489,6 @@ def generate_sm90_grouped_gemm_operations(is_arch_enabled): epi_schedule, epi_fusion, ) - if is_op_valid(moe_gemm_operation): operations.append(moe_gemm_operation) return operations @@ -595,13 +498,10 @@ def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled): if not is_arch_enabled: return [] arch = 90 - - # act_type, weight_type, scalezero_type, bias_type, output_type supported_dtypes_int4 = [ (DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16), (DataType.e4m3, DataType.u4, DataType.bf16, DataType.bf16, DataType.bf16), ] - if is_cuda_version_at_least("12.8"): supported_dtypes_fp4 = [ (DataType.f16, DataType.e2m1, DataType.ue8m0, DataType.f16, DataType.f16), @@ -615,27 +515,20 @@ def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled): ] else: supported_dtypes_fp4 = [] - quant_ops = [TrtLlm_QuantOp.finegrained_scale_only] - epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default] - - M_TILES = [64, 128] # Currently M tile must be 128 for Grouped GEMM + M_TILES = [64, 128] N_TILES = [16, 32, 64, 128] K_TILES = [128, 256, 512] cta_shapes_mnk_int4 = list(product(M_TILES, N_TILES, K_TILES)) - - M_TILES = [64, 128] # Currently M tile must be 128 for Grouped GEMM + M_TILES = [64, 128] N_TILES = [16, 32, 64] K_TILES = [128, 256] cta_shapes_mnk_fp4 = list(product(M_TILES, N_TILES, K_TILES)) cta_shapes_mnk_fp4.append((128, 128, 128)) - - warp_shape = [0, 0, 0] # ignored except for naming - stages = 0 # auto - + warp_shape = [0, 0, 0] + stages = 0 cga_shapes = list(product([1, 2], [1, 2], [1])) - partial_args_int4 = product( supported_dtypes_int4, quant_ops, epi_tags, cta_shapes_mnk_int4, cga_shapes ) @@ -643,7 +536,6 @@ def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled): supported_dtypes_fp4, quant_ops, epi_tags, cta_shapes_mnk_fp4, cga_shapes ) partial_args = chain(partial_args_int4, partial_args_fp4) - operations = list() for dtype_combo, quant_op, epi_tag, cta_shape_mnk, cga_shape in partial_args: use_coop = cta_shape_mnk[0] >= 128 @@ -691,9 +583,9 @@ def generate_sm90_operations(is_arch_enabled): def calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype): max_k_bits = 128 * 8 cta_shape_k = max_k_bits // GetDataTypeBits(dtype) - if dtype == DataType.e4m3 and (cta_shape_mn[1] == 8): + if dtype == DataType.e4m3 and cta_shape_mn[1] == 8: cta_shape_k = 256 - if dtype == DataType.e4m3 and (cta_shape_mn[1] == 16): + if dtype == DataType.e4m3 and cta_shape_mn[1] == 16: cta_shape_k = 128 return cta_shape_mn + (cta_shape_k,) @@ -711,33 +603,21 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled): [256, 128, 128], [128, 256, 128], ] - - warp_shape = [0, 0, 0] # ignored except for naming - stages = 0 # auto - - epi_fusions = [ - TrtLlm_EpilogueFusion.epilogue_fusion_none, - # TrtLlm_EpilogueFusion.epilogue_fusion_finalize - ] - + warp_shape = [0, 0, 0] + stages = 0 + epi_fusions = [TrtLlm_EpilogueFusion.epilogue_fusion_none] cga_shapes = [[1, 1, 1]] - partial_args = product( supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mnk, cga_shapes ) - operations = list() for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args: cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul) - - # Ignored mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized - otypes = [dtype] if dtype in [DataType.e4m3, e2m1]: otypes = [DataType.f16, DataType.bf16] - for otype in otypes: moe_gemm_operation = TrtLlm_GemmLauncher( GemmKind.Grouped, @@ -757,7 +637,6 @@ def generate_sm120_grouped_gemm_operations(is_arch_enabled): epi_schedule, epi_fusion, ) - operations.append(moe_gemm_operation) return operations @@ -784,39 +663,26 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): cta_shapes_m = [64, 128] cta_shapes_n = [8, 16, 32, 64, 128, 256] cta_shapes_mn = product(cta_shapes_m, cta_shapes_n) - - warp_shape = [0, 0, 0] # ignored except for naming - stages = 0 # auto - - epi_fusions = [ - TrtLlm_EpilogueFusion.epilogue_fusion_none, - # TrtLlm_EpilogueFusion.epilogue_fusion_finalize - ] - + warp_shape = [0, 0, 0] + stages = 0 + epi_fusions = [TrtLlm_EpilogueFusion.epilogue_fusion_none] cga_shapes = list(product([1, 2], [1, 2], [1])) - partial_args = product( supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mn, cga_shapes ) - operations = list() for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mn, cga_shape in partial_args: if isinstance(dtype, tuple): dtype, weight_type = dtype else: weight_type = dtype - cta_shape_mnk = calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype) cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul) - - # Ignored mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized - otypes = [dtype] if dtype in [DataType.e4m3, e2m1]: otypes = [DataType.f16, DataType.bf16] - for otype in otypes: moe_gemm_operation = TrtLlm_GemmLauncher( GemmKind.Grouped, @@ -835,9 +701,8 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled): mainloop_schedule, epi_schedule, epi_fusion, - is_mx_fpx=(dtype == DataType.e4m3 and weight_type == e2m1), + is_mx_fpx=dtype == DataType.e4m3 and weight_type == e2m1, ) - if is_op_valid(moe_gemm_operation): operations.append(moe_gemm_operation) return operations @@ -872,11 +737,8 @@ def generate_sm80_fused_grouped_gemm_operations(): (64, 128, 64), (128, 128, 64), ] - stages = [2, 3, 4] - partial_args = product(supported_dtypes, epi_tags, cta_shapes_mnk, stages) - operations = list() for dtype, epi_tag, cta_shape_mnk, stage in partial_args: item = GemmSm80LauncherConfig( @@ -893,17 +755,11 @@ def generate_sm80_operations(is_arch_enabled): def generate_gemm_operations(output_dir, architectures): arches = architectures.split(";") - # Get the absolute path of the provided directory output_dir = os.path.abspath(output_dir) - fpA_intB_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl" moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl" - # moe_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl" moe_mixed_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl" - # moe_mixed_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl" sm80_moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl" - # sm80_moe_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl" - inl_map = { (GemmKind.Gemm, 90): [fpA_intB_inl], (GemmKind.Grouped, 90): [moe_gemm_inl], @@ -915,8 +771,6 @@ def generate_gemm_operations(output_dir, architectures): def has_arch(sm): return f"{sm}" in arches or f"{sm}-real" in arches - # The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads. - # Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve. operations = [] operations += generate_sm120_operations(has_arch(120) or has_arch(121)) operations += generate_sm100_operations(has_arch(100)) @@ -924,31 +778,22 @@ def has_arch(sm): operations += generate_sm80_operations(has_arch(80) or has_arch(89)) def should_skip(op): - return False # All kernels have a public implementation + return False - # The mixed dtype grouped gemm for w4afp8 has a different launcher def is_mixed_dtype_grouped(op): if isinstance(op, GemmSm80LauncherConfig): return False - # Only w4a8fp8 and not wfp4afp8 return ( - (op.act_type != op.weight_type) - and (op.gemm_kind == GemmKind.Grouped) + op.act_type != op.weight_type + and op.gemm_kind == GemmKind.Grouped and (op.act_type != DataType.e4m3 or op.weight_type != e2m1) ) - # Fix OOM error in CI. If len(operations) is more than GROUP_SIZE, it will be split into multiple sub groups. GROUP_SIZE = 8 op_groups = dict() for op in operations: if should_skip(op): continue - # This dict key is used to group kernels with common instantiations together - # Similar implementations should live in the same file so the compiler can share the cutlass state - # Without this we see significant memory consumption, and separating them also does not reduce the compilation time - # because most time is spent parsing the same cutlass files - # We separate by: Architecture, Leading dimension of the CTA shape, FP4 (i.e. block scaled MMA), mixed input - # TODO Do a more scientific analysis of this dict_key = ( op.gemm_kind, op.arch, @@ -962,7 +807,6 @@ def is_mixed_dtype_grouped(op): else: op_group[-1].append(op) op_groups[dict_key] = op_group - file_list = [] for key, value in op_groups.items(): gemm_kind, arch, m, block_scale, is_mixed = key @@ -976,6 +820,4 @@ def is_mixed_dtype_grouped(op): inl_file = [moe_mixed_gemm_inl] if is_mixed else inl_map[key[:2]] write_file(inl_file, op_sub_group, out_file) file_list.append(out_file) - - # Clean up any leftover files from previous runs clean_leftover_files(output_dir, set(file_list)) diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index 24104ee138..6d637dec90 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -1,3 +1,11 @@ +import sys + + +import os + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2024 by FlashInfer team. @@ -13,22 +21,13 @@ See the License for the specific language governing permissions and limitations under the License. """ - -# NOTE(lequn): Do not "from .jit.env import xxx". -# Do "from .jit import env as jit_env" and use "jit_env.xxx" instead. -# This helps AOT script to override envs. - -import os import pathlib import re import warnings -from torch.utils.cpp_extension import _get_cuda_arch_flags - FLASHINFER_BASE_DIR = pathlib.Path( os.getenv("FLASHINFER_WORKSPACE_BASE", pathlib.Path.home().as_posix()) ) - FLASHINFER_CACHE_DIR = FLASHINFER_BASE_DIR / ".cache" / "flashinfer" FLASHINFER_CUBIN_DIR = pathlib.Path( os.getenv("FLASHINFER_CUBIN_DIR", (FLASHINFER_CACHE_DIR / "cubins").as_posix()) @@ -38,19 +37,18 @@ def _get_workspace_dir_name() -> pathlib.Path: try: with warnings.catch_warnings(): - # Ignore the warning for TORCH_CUDA_ARCH_LIST not set warnings.filterwarnings( - "ignore", r".*TORCH_CUDA_ARCH_LIST.*", module="torch" + "ignore", ".*TORCH_CUDA_ARCH_LIST.*", module="torch" ) - flags = _get_cuda_arch_flags() - arch = "_".join(sorted(set(re.findall(r"compute_(\d+)", "".join(flags))))) + # TODO: provide a python interface like torch + # flags = torch.utils.cpp_extension._get_cuda_arch_flags() + flags = ['-gencode=arch=compute_90,code=compute_90', '-gencode=arch=compute_90,code=sm_90'] + arch = "_".join(sorted(set(re.findall("compute_(\\d+)", "".join(flags))))) except Exception: arch = "noarch" - # e.g.: $HOME/.cache/flashinfer/75_80_89_90/ return FLASHINFER_CACHE_DIR / arch -# use pathlib FLASHINFER_WORKSPACE_DIR = _get_workspace_dir_name() FLASHINFER_JIT_DIR = FLASHINFER_WORKSPACE_DIR / "cached_ops" FLASHINFER_GEN_SRC_DIR = FLASHINFER_WORKSPACE_DIR / "generated" @@ -58,7 +56,6 @@ def _get_workspace_dir_name() -> pathlib.Path: FLASHINFER_DATA = _package_root / "data" FLASHINFER_INCLUDE_DIR = _package_root / "data" / "include" FLASHINFER_CSRC_DIR = _package_root / "data" / "csrc" -# FLASHINFER_SRC_DIR = _package_root / "data" / "src" FLASHINFER_TVM_BINDING_DIR = _package_root / "data" / "tvm_binding" FLASHINFER_AOT_DIR = _package_root / "data" / "aot" CUTLASS_INCLUDE_DIRS = [ @@ -72,7 +69,6 @@ def get_nvshmem_include_dirs(): paths = os.environ.get("NVSHMEM_INCLUDE_PATH") if paths is not None: return [pathlib.Path(p) for p in paths.split(os.pathsep) if p] - import nvidia.nvshmem path = pathlib.Path(nvidia.nvshmem.__path__[0]) / "include" @@ -83,7 +79,6 @@ def get_nvshmem_lib_dirs(): paths = os.environ.get("NVSHMEM_LIBRARY_PATH") if paths is not None: return [pathlib.Path(p) for p in paths.split(os.pathsep) if p] - import nvidia.nvshmem path = pathlib.Path(nvidia.nvshmem.__path__[0]) / "lib" diff --git a/flashinfer/jit/utils.py b/flashinfer/jit/utils.py index 4e19212e14..3a4dd5186f 100644 --- a/flashinfer/jit/utils.py +++ b/flashinfer/jit/utils.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,11 +15,8 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pathlib -import torch - def write_if_different(path: pathlib.Path, content: str) -> None: if path.exists(): @@ -31,53 +30,49 @@ def write_if_different(path: pathlib.Path, content: str) -> None: dtype_map = { - torch.float16: "half", - torch.bfloat16: "nv_bfloat16", - torch.float8_e4m3fn: "__nv_fp8_e4m3", - torch.float8_e5m2: "__nv_fp8_e5m2", - torch.int8: "int8_t", - torch.uint8: "uint8_t", - torch.int32: "int32_t", - torch.uint32: "uint32_t", - torch.int64: "int64_t", - torch.uint64: "uint64_t", + "float16": "half", + "bfloat16": "nv_bfloat16", + paddle.float8_e4m3fn: "__nv_fp8_e4m3", + paddle.float8_e5m2: "__nv_fp8_e5m2", + "int8": "int8_t", + "uint8": "uint8_t", + "int32": "int32_t", + "uint32": "uint32_t", + "int64": "int64_t", + "uint64": "uint64_t", } - dtype_cutlass_map = { - torch.float16: "cutlass::half_t", - torch.bfloat16: "cutlass::bfloat16_t", - torch.float8_e4m3fn: "cutlass::float_e4m3_t", - torch.float8_e5m2: "cutlass::float_e5m2_t", - torch.int8: "cutlass::int8_t", - torch.uint8: "cutlass::uint8_t", - torch.int32: "cutlass::int32_t", - torch.uint32: "cutlass::uint32_t", - torch.int64: "cutlass::int64_t", - torch.uint64: "cutlass::uint64_t", + "float16": "cutlass::half_t", + "bfloat16": "cutlass::bfloat16_t", + paddle.float8_e4m3fn: "cutlass::float_e4m3_t", + paddle.float8_e5m2: "cutlass::float_e5m2_t", + "int8": "cutlass::int8_t", + "uint8": "cutlass::uint8_t", + "int32": "cutlass::int32_t", + "uint32": "cutlass::uint32_t", + "int64": "cutlass::int64_t", + "uint64": "cutlass::uint64_t", } - filename_safe_dtype_map = { - torch.float16: "f16", - torch.bfloat16: "bf16", - torch.float8_e4m3fn: "e4m3", - torch.float8_e5m2: "e5m2", - torch.int8: "i8", - torch.uint8: "u8", - torch.int32: "i32", - torch.uint32: "u32", - torch.int64: "i64", - torch.uint64: "u64", + "float16": "f16", + "bfloat16": "bf16", + paddle.float8_e4m3fn: "e4m3", + paddle.float8_e5m2: "e5m2", + "int8": "i8", + "uint8": "u8", + "int32": "i32", + "uint32": "u32", + "int64": "i64", + "uint64": "u64", } - pos_encoding_mode_literal = { - 0: "PosEncodingMode::kNone", - 1: "PosEncodingMode::kRoPELlama", - 2: "PosEncodingMode::kALiBi", + (0): "PosEncodingMode::kNone", + (1): "PosEncodingMode::kRoPELlama", + (2): "PosEncodingMode::kALiBi", } - mask_mode_literal = { - 0: "MaskMode::kNone", - 1: "MaskMode::kCausal", - 2: "MaskMode::kCustom", - 3: "MaskMode::kMultiItemScoring", + (0): "MaskMode::kNone", + (1): "MaskMode::kCausal", + (2): "MaskMode::kCustom", + (3): "MaskMode::kMultiItemScoring", } diff --git a/flashinfer/logits_processor/__init__.py b/flashinfer/logits_processor/__init__.py index 80e611189a..f2be9cdd85 100644 --- a/flashinfer/logits_processor/__init__.py +++ b/flashinfer/logits_processor/__init__.py @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ - from .compiler import CompileError as CompileError from .compiler import Compiler as Compiler from .compiler import compile_pipeline as compile_pipeline diff --git a/flashinfer/logits_processor/compiler.py b/flashinfer/logits_processor/compiler.py index cd74cd9ba5..07044ca9a2 100644 --- a/flashinfer/logits_processor/compiler.py +++ b/flashinfer/logits_processor/compiler.py @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. """ - from typing import List, Optional from .fusion_rules import FusionRule, get_default_fusion_rules from .op import Op from .types import TensorType -from .validators import CompileError, ValidityCheck, get_default_validity_checks +from .validators import (CompileError, ValidityCheck, + get_default_validity_checks) class Compiler: @@ -38,36 +38,24 @@ def register_validity_check(self, check: ValidityCheck) -> None: def compile(self, ops: List[Op]) -> List[Op]: if not ops: raise CompileError("Cannot compile empty operator list") - compiled_ops = list(ops) - self._type_check(compiled_ops) - self._run_validity_checks(compiled_ops) - compiled_ops = self._fuse_all(compiled_ops) - return compiled_ops def _type_check(self, ops: List[Op]) -> None: first_op = ops[0] - current_type = first_op.IN - if current_type not in [TensorType.LOGITS, TensorType.PROBS]: raise CompileError( - f"First operator ({first_op.__class__.__name__}) cannot accept standard pipeline inputs. " - f"Expected LOGITS or PROBS, but operator accepts: {first_op.IN}" + f"First operator ({first_op.__class__.__name__}) cannot accept standard pipeline inputs. Expected LOGITS or PROBS, but operator accepts: {first_op.IN}" ) - for i, op in enumerate(ops): if current_type != op.IN: raise CompileError( - f"Type mismatch at operator {i} ({op.__class__.__name__}). " - f"Expected input type: {current_type}, but operator accepts: {op.IN}. " - f"Previous operator output: {current_type}" + f"Type mismatch at operator {i} ({op.__class__.__name__}). Expected input type: {current_type}, but operator accepts: {op.IN}. Previous operator output: {current_type}" ) - current_type = op.OUT def _run_validity_checks(self, ops: List[Op]) -> None: @@ -78,38 +66,29 @@ def _fuse_all(self, ops: List[Op]) -> List[Op]: i = 0 while i < len(ops): fusion_applied = False - for rule in self.fusion_rules: span = len(rule.pattern) - if i + span > len(ops): continue - window = ops[i : i + span] - if self._pattern_matches(window, rule.pattern) and rule.guard(window): fused_op = rule.build(window) ops[i : i + span] = [fused_op] - i = max(i - 1, 0) fusion_applied = True break - if not fusion_applied: i += 1 - return ops def _pattern_matches(self, window: List[Op], pattern: tuple) -> bool: if len(window) != len(pattern): return False - return all(isinstance(window[i], pattern[i]) for i in range(len(pattern))) def _install_defaults(self) -> None: for check in get_default_validity_checks(): self.validity_checks.append(check) - for rule in get_default_fusion_rules(): self.register_fusion_rule(rule) @@ -137,13 +116,10 @@ def compile_pipeline( List of compiled operators """ compiler = Compiler() - if custom_fusion_rules: for rule in custom_fusion_rules: compiler.register_fusion_rule(rule) - if custom_validity_checks: for check in custom_validity_checks: compiler.register_validity_check(check) - return compiler.compile(ops) diff --git a/flashinfer/logits_processor/fusion_rules.py b/flashinfer/logits_processor/fusion_rules.py index 1e5152e61b..7968236483 100644 --- a/flashinfer/logits_processor/fusion_rules.py +++ b/flashinfer/logits_processor/fusion_rules.py @@ -13,23 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. """ - from typing import Callable, List, NamedTuple, Tuple from .op import Op -from .operators import ( - FusedProbsMinPSampleOp, - FusedProbsTopKSampleOp, - FusedProbsTopKTopPSampleOp, - FusedProbsTopPSampleOp, - FusedTemperatureSoftmaxOp, - MinPOp, - ProbsSampleOp, - ProbsTopKOp, - SoftmaxOp, - TemperatureOp, - TopPOp, -) +from .operators import (FusedProbsMinPSampleOp, FusedProbsTopKSampleOp, + FusedProbsTopKTopPSampleOp, FusedProbsTopPSampleOp, + FusedTemperatureSoftmaxOp, MinPOp, ProbsSampleOp, + ProbsTopKOp, SoftmaxOp, TemperatureOp, TopPOp) class FusionRule(NamedTuple): diff --git a/flashinfer/logits_processor/legalization.py b/flashinfer/logits_processor/legalization.py index 05922aef73..487e472563 100644 --- a/flashinfer/logits_processor/legalization.py +++ b/flashinfer/logits_processor/legalization.py @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ - from typing import List from .op import Op @@ -45,56 +44,41 @@ def legalize_processors( """ if not processors: raise LegalizationError("Cannot legalize empty processor list") - ops = [] current_type = initial_type - for i, processor in enumerate(processors): try: legalized_ops = processor.legalize(current_type) - if not legalized_ops: raise LegalizationError( f"Processor {processor.__class__.__name__} produced no ops" ) - ops.extend(legalized_ops) - current_type = legalized_ops[-1].OUT - except Exception as e: raise LegalizationError( f"Failed to legalize processor {i} ({processor.__class__.__name__}): {e}" ) from e - return ops def infer_initial_type(processors: List[LogitsProcessor]) -> TensorType: if not processors: return TensorType.LOGITS - first_processor = processors[0] valid_types = _get_supported_types(first_processor) - if len(valid_types) > 1: raise LegalizationError( - f"Cannot infer input type: {first_processor.__class__.__name__} can accept both LOGITS and PROBS. " - f"Please specify input_type explicitly when creating the LogitsPipe." + f"Cannot infer input type: {first_processor.__class__.__name__} can accept both LOGITS and PROBS. Please specify input_type explicitly when creating the LogitsPipe." ) - if len(valid_types) == 1: return valid_types[0] - raise LegalizationError( - f"Processor {first_processor.__class__.__name__} cannot accept standard pipeline inputs " - f"(LOGITS or PROBS)" + f"Processor {first_processor.__class__.__name__} cannot accept standard pipeline inputs (LOGITS or PROBS)" ) -def _get_supported_types( - processor: LogitsProcessor, -) -> List[TensorType]: +def _get_supported_types(processor: LogitsProcessor) -> List[TensorType]: valid_types = [] for tensor_type in [TensorType.LOGITS, TensorType.PROBS]: try: @@ -102,16 +86,13 @@ def _get_supported_types( valid_types.append(tensor_type) except (ValueError, LegalizationError): continue - return valid_types def validate_processor_chain(processors: List[LogitsProcessor]) -> None: if not processors: raise LegalizationError("Processor chain cannot be empty") - initial_type = infer_initial_type(processors) - try: legalize_processors(processors, initial_type) except LegalizationError: diff --git a/flashinfer/logits_processor/op.py b/flashinfer/logits_processor/op.py index 30772b861d..329237a460 100644 --- a/flashinfer/logits_processor/op.py +++ b/flashinfer/logits_processor/op.py @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ - from abc import ABC, abstractmethod from typing import Any, Dict @@ -37,8 +36,7 @@ def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: def _validate_input_type(self, tensor: TaggedTensor) -> TensorType: if tensor.type != self.IN: raise ValueError( - f"Operator {self.__class__.__name__} cannot accept input type {tensor.type}. " - f"Expected: {self.IN}" + f"Operator {self.__class__.__name__} cannot accept input type {tensor.type}. Expected: {self.IN}" ) return self.OUT diff --git a/flashinfer/logits_processor/operators.py b/flashinfer/logits_processor/operators.py index c9c13ad76b..1b17b0f3f5 100644 --- a/flashinfer/logits_processor/operators.py +++ b/flashinfer/logits_processor/operators.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,11 +15,8 @@ See the License for the specific language governing permissions and limitations under the License. """ - from typing import Any, Optional, Tuple, Union -import torch - from flashinfer.sampling import get_sampling_module from flashinfer.utils import _get_cache_buf, device_support_pdl @@ -26,12 +25,12 @@ def _to_tensor_scalar_tuple( - x: Union[torch.Tensor, float, int], -) -> Tuple[Optional[torch.Tensor], Union[float, int]]: - if isinstance(x, torch.Tensor): - return (x, 0 if x.dtype == torch.int32 else 0.0) + x: Union[paddle.Tensor, float, int] +) -> Tuple[Optional[paddle.Tensor], Union[float, int]]: + if isinstance(x, paddle.Tensor): + return x, 0 if x.dtype == "int32" else 0.0 else: - return (None, x) + return None, x class TemperatureOp(ParameterizedOp): @@ -51,21 +50,17 @@ class TemperatureOp(ParameterizedOp): def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) - temperature = self._get_param("temperature", kwargs, required=True) maybe_temperature_arr, temperature_val = _to_tensor_scalar_tuple(temperature) if maybe_temperature_arr is None and ( not isinstance(temperature_val, float) or temperature_val <= 0 ): raise ValueError("Temperature must be positive float or a tensor array") - if maybe_temperature_arr is not None: temperature = maybe_temperature_arr else: temperature = temperature_val - scaled_logits = tensor.data / temperature - return TaggedTensor(scaled_logits, output_type) @@ -88,15 +83,12 @@ class SoftmaxOp(ParameterizedOp): def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) - enable_pdl = self.default_params.get("enable_pdl", None) if enable_pdl is None: - enable_pdl = device_support_pdl(tensor.data.device) - + enable_pdl = device_support_pdl(tensor.data.place) workspace_buffer = _get_cache_buf( - "softmax_workspace", 1024 * 1024, tensor.data.device + "softmax_workspace", 1024 * 1024, tensor.data.place ) - probs = get_sampling_module().softmax( workspace_buffer, tensor.data, None, 1.0, enable_pdl ) @@ -126,19 +118,15 @@ class ProbsTopKOp(ParameterizedOp): def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) - top_k = self._get_param("top_k", kwargs, required=True) maybe_top_k_arr, top_k_val = _to_tensor_scalar_tuple(top_k) - if maybe_top_k_arr is None and ( not isinstance(top_k_val, int) or top_k_val <= 0 ): raise ValueError("top_k must be a positive integer or a tensor array") - renorm_probs = get_sampling_module().top_k_renorm_probs( tensor.data, maybe_top_k_arr, top_k_val ) - return TaggedTensor(renorm_probs, output_type) @@ -165,15 +153,12 @@ class LogitsTopKOp(ParameterizedOp): def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) - top_k = self._get_param("top_k", kwargs, required=True) maybe_top_k_arr, top_k_val = _to_tensor_scalar_tuple(top_k) - if maybe_top_k_arr is None and ( not isinstance(top_k_val, int) or top_k_val <= 0 ): raise ValueError("top_k must be a positive integer or a tensor array") - masked_logits = get_sampling_module().top_k_mask_logits( tensor.data, maybe_top_k_arr, top_k_val ) @@ -203,17 +188,13 @@ class TopPOp(ParameterizedOp): def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) - top_p = self._get_param("top_p", kwargs, required=True) maybe_top_p_arr, top_p_val = _to_tensor_scalar_tuple(top_p) - - if maybe_top_p_arr is None and not (0 < top_p_val <= 1): + if maybe_top_p_arr is None and not 0 < top_p_val <= 1: raise ValueError("top_p must be float in (0, 1] or a tensor array") - renorm_probs = get_sampling_module().top_p_renorm_probs( tensor.data, maybe_top_p_arr, top_p_val ) - return TaggedTensor(renorm_probs, output_type) @@ -240,26 +221,23 @@ class MinPOp(ParameterizedOp): def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) - min_p = self._get_param("min_p", kwargs, required=True) maybe_min_p_arr, min_p_val = _to_tensor_scalar_tuple(min_p) - - if maybe_min_p_arr is None and not (0 < min_p_val <= 1): + if maybe_min_p_arr is None and not 0 < min_p_val <= 1: raise ValueError("min_p must be float in (0, 1] or a tensor array") - if maybe_min_p_arr is not None: - min_p_mask = tensor.data >= ( - maybe_min_p_arr.unsqueeze(-1) * tensor.data.max(dim=-1, keepdim=True)[0] + min_p_mask = ( + tensor.data + >= maybe_min_p_arr.unsqueeze(axis=-1) + * tensor.data.max(dim=-1, keepdim=True)[0] ) else: - min_p_mask = tensor.data >= ( - min_p_val * tensor.data.max(dim=-1, keepdim=True)[0] + min_p_mask = ( + tensor.data >= min_p_val * tensor.data.max(dim=-1, keepdim=True)[0] ) - masked_probs = tensor.data.clone() masked_probs[~min_p_mask] = 0 - probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True) - + probs = masked_probs / masked_probs.sum(axis=-1, keepdim=True) return TaggedTensor(probs, output_type) @@ -290,16 +268,12 @@ class ProbsSampleOp(ParameterizedOp): def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) - deterministic = self.default_params.get("deterministic", True) - indices = self._get_param("indices", kwargs, required=False) generator = self._get_param("generator", kwargs, required=False) - samples = get_sampling_module().sampling_from_probs( tensor.data, indices, deterministic, generator ) - return TaggedTensor(samples, output_type) @@ -330,20 +304,15 @@ class LogitsSampleOp(ParameterizedOp): def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) - deterministic = self.default_params.get("deterministic", True) - indices = self._get_param("indices", kwargs, required=False) generator = self._get_param("generator", kwargs, required=False) - samples = get_sampling_module().sampling_from_logits( tensor.data, indices, deterministic, generator ) - return TaggedTensor(samples, output_type) -# Fused operators class FusedTemperatureSoftmaxOp(ParameterizedOp): """ Fused temperature scaling and softmax operator. @@ -370,22 +339,18 @@ def __init__(self, enable_pdl: Optional[bool] = None, **default_params: Any): def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) - temperature = self._get_param("temperature", kwargs, required=True) maybe_temperature_arr, temperature_val = _to_tensor_scalar_tuple(temperature) if maybe_temperature_arr is None and ( not isinstance(temperature_val, float) or temperature_val <= 0 ): raise ValueError("Temperature must be positive float or a tensor array") - workspace_buffer = _get_cache_buf( - "softmax_workspace", 1024 * 1024, tensor.data.device + "softmax_workspace", 1024 * 1024, tensor.data.place ) - enable_pdl = self.default_params.get("enable_pdl", None) if enable_pdl is None: - enable_pdl = device_support_pdl(tensor.data.device) - + enable_pdl = device_support_pdl(tensor.data.place) probs = get_sampling_module().softmax( workspace_buffer, tensor.data, @@ -393,7 +358,6 @@ def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: temperature_val, enable_pdl, ) - return TaggedTensor(probs, output_type) @@ -429,24 +393,18 @@ def __init__(self, deterministic: bool = True, **default_params: Any): def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) - deterministic = self.default_params.get("deterministic", True) - top_k = self._get_param("top_k", kwargs, required=True) maybe_top_k_arr, top_k_val = _to_tensor_scalar_tuple(top_k) - if maybe_top_k_arr is None and ( not isinstance(top_k_val, int) or top_k_val <= 0 ): raise ValueError("top_k must be a positive integer or a tensor array") - indices = self._get_param("indices", kwargs, required=False) generator = self._get_param("generator", kwargs, required=False) - samples = get_sampling_module().top_k_sampling_from_probs( tensor.data, indices, maybe_top_k_arr, top_k_val, deterministic, generator ) - return TaggedTensor(samples, output_type) @@ -482,22 +440,16 @@ def __init__(self, deterministic: bool = True, **default_params: Any): def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) - deterministic = self.default_params.get("deterministic", True) - top_p = self._get_param("top_p", kwargs, required=True) maybe_top_p_arr, top_p_val = _to_tensor_scalar_tuple(top_p) - - if maybe_top_p_arr is None and not (0 < top_p_val <= 1): + if maybe_top_p_arr is None and not 0 < top_p_val <= 1: raise ValueError("top_p must be float in (0, 1] or a tensor array") - indices = self._get_param("indices", kwargs, required=False) generator = self._get_param("generator", kwargs, required=False) - samples = get_sampling_module().top_p_sampling_from_probs( tensor.data, indices, maybe_top_p_arr, top_p_val, deterministic, generator ) - return TaggedTensor(samples, output_type) @@ -533,22 +485,16 @@ def __init__(self, deterministic: bool = True, **default_params: Any): def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) - deterministic = self.default_params.get("deterministic", True) - min_p = self._get_param("min_p", kwargs, required=True) maybe_min_p_arr, min_p_val = _to_tensor_scalar_tuple(min_p) - - if maybe_min_p_arr is None and not (0 < min_p_val <= 1): + if maybe_min_p_arr is None and not 0 < min_p_val <= 1: raise ValueError("min_p must be float in (0, 1] or a tensor array") - indices = self._get_param("indices", kwargs, required=False) generator = self._get_param("generator", kwargs, required=False) - samples = get_sampling_module().min_p_sampling_from_probs( tensor.data, indices, maybe_min_p_arr, min_p_val, deterministic, generator ) - return TaggedTensor(samples, output_type) @@ -586,26 +532,19 @@ def __init__(self, deterministic: bool = True, **default_params: Any): def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: output_type = self._validate_input_type(tensor) - deterministic = self.default_params.get("deterministic", True) - top_k = self._get_param("top_k", kwargs, required=True) maybe_top_k_arr, top_k_val = _to_tensor_scalar_tuple(top_k) - top_p = self._get_param("top_p", kwargs, required=True) maybe_top_p_arr, top_p_val = _to_tensor_scalar_tuple(top_p) - if maybe_top_k_arr is None and ( not isinstance(top_k_val, int) or top_k_val <= 0 ): raise ValueError("top_k must be a positive integer or a tensor array") - - if maybe_top_p_arr is None and not (0 < top_p_val <= 1): + if maybe_top_p_arr is None and not 0 < top_p_val <= 1: raise ValueError("top_p must be float in (0, 1] or a tensor array") - indices = self._get_param("indices", kwargs, required=False) generator = self._get_param("generator", kwargs, required=False) - samples = get_sampling_module().top_k_top_p_sampling_from_probs( tensor.data, indices, @@ -616,5 +555,4 @@ def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor: deterministic, generator, ) - return TaggedTensor(samples, output_type) diff --git a/flashinfer/logits_processor/pipeline.py b/flashinfer/logits_processor/pipeline.py index 5c31fb0809..e76697803d 100644 --- a/flashinfer/logits_processor/pipeline.py +++ b/flashinfer/logits_processor/pipeline.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,15 +15,13 @@ See the License for the specific language governing permissions and limitations under the License. """ - import logging from typing import Any, List, Optional, Union -import torch - from .compiler import compile_pipeline from .fusion_rules import FusionRule -from .legalization import LegalizationError, infer_initial_type, legalize_processors +from .legalization import (LegalizationError, infer_initial_type, + legalize_processors) from .op import Op from .processors import LogitsProcessor from .types import TaggedTensor, TensorType @@ -101,21 +101,13 @@ def __init__( """ if not processors: raise ValueError("Pipeline cannot be empty") - self.processors = list(processors) - try: - # Step 1: Infer initial input tensor type self._initial_type = input_type or infer_initial_type(self.processors) - - # Step 2: Legalization - convert high-level processors to low-level ops self.ops = legalize_processors(self.processors, self._initial_type) - - # Step 3: Compilation - type check, validate, and fuse ops self.compiled_ops: Optional[List[Op]] = None if compile: self.compile(custom_fusion_rules, custom_validity_checks) - except (LegalizationError, CompileError) as e: raise ValueError(f"Pipeline creation failed: {e}") from e @@ -130,24 +122,20 @@ def __repr__(self) -> str: return f"LogitsPipe([{' -> '.join(processor_names)}], ops=[{' -> '.join(op_names)}], compiled_ops=[{' -> '.join(compiled_op_names)}])" def __call__( - self, x: Union[torch.Tensor, TaggedTensor], **kwargs: Any - ) -> torch.Tensor: + self, x: Union[paddle.Tensor, TaggedTensor], **kwargs: Any + ) -> paddle.Tensor: if self.compiled_ops is None: logger.warning("Pipeline is not compiled, running discrete ops.") ops = self.ops else: ops = self.compiled_ops - if isinstance(x, TaggedTensor): tagged_tensor = x + elif self._initial_type == TensorType.PROBS: + tagged_tensor = TaggedTensor.probs(x) else: - if self._initial_type == TensorType.PROBS: - tagged_tensor = TaggedTensor.probs(x) - else: - tagged_tensor = TaggedTensor.logits(x) - + tagged_tensor = TaggedTensor.logits(x) runtime_kwargs = dict(kwargs) - for i, op in enumerate(ops): try: tagged_tensor = op(tagged_tensor, **runtime_kwargs) @@ -155,7 +143,6 @@ def __call__( raise ValueError( f"Error executing operator {i} ({op.__class__.__name__}): {e}" ) from e - return tagged_tensor.data @property diff --git a/flashinfer/logits_processor/processors.py b/flashinfer/logits_processor/processors.py index b15ae5275d..5e25e010f2 100644 --- a/flashinfer/logits_processor/processors.py +++ b/flashinfer/logits_processor/processors.py @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ - from abc import ABC, abstractmethod from typing import Any, List, Optional @@ -130,7 +129,6 @@ def legalize(self, input_type: TensorType) -> List[Op]: raise ValueError( f"Temperature can only be applied to LOGITS, got {input_type}" ) - return [TemperatureOp(**self.params)] @@ -186,7 +184,6 @@ def legalize(self, input_type: TensorType) -> List[Op]: if input_type != TensorType.LOGITS: raise ValueError(f"Softmax can only be applied to LOGITS, got {input_type}") - return [SoftmaxOp(**self.params)] @@ -315,7 +312,6 @@ def legalize(self, input_type: TensorType) -> List[Op]: if input_type != TensorType.PROBS: raise ValueError(f"TopP can only be applied to PROBS, got {input_type}") - return [TopPOp(**self.params)] @@ -363,7 +359,6 @@ def legalize(self, input_type: TensorType) -> List[Op]: if input_type != TensorType.PROBS: raise ValueError(f"MinP can only be applied to PROBS, got {input_type}") - return [MinPOp(**self.params)] diff --git a/flashinfer/logits_processor/types.py b/flashinfer/logits_processor/types.py index 4a353800c2..08912d4998 100644 --- a/flashinfer/logits_processor/types.py +++ b/flashinfer/logits_processor/types.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,15 +15,12 @@ See the License for the specific language governing permissions and limitations under the License. """ - -from __future__ import annotations # python 3.7+ +from __future__ import annotations from dataclasses import dataclass from enum import Enum, auto from typing import Any, Dict, Optional, Tuple, Type -import torch - class TensorType(Enum): """ @@ -61,7 +60,7 @@ class TaggedTensor: - Users typically work with plain tensors; tagging happens automatically """ - data: torch.Tensor + data: paddle.Tensor """ The underlying tensor. """ @@ -71,21 +70,21 @@ class TaggedTensor: """ @staticmethod - def logits(t: torch.Tensor) -> TaggedTensor: + def logits(t: paddle.Tensor) -> TaggedTensor: """ Create a TaggedTensor with type :attr:`TensorType.LOGITS`. """ return TaggedTensor(t, TensorType.LOGITS) @staticmethod - def probs(t: torch.Tensor) -> TaggedTensor: + def probs(t: paddle.Tensor) -> TaggedTensor: """ Create a TaggedTensor with type :attr:`TensorType.PROBS`. """ return TaggedTensor(t, TensorType.PROBS) @staticmethod - def indices(t: torch.Tensor) -> TaggedTensor: + def indices(t: paddle.Tensor) -> TaggedTensor: """ Create a TaggedTensor with type :attr:`TensorType.INDICES`. """ @@ -97,43 +96,40 @@ def __torch_function__( types: Tuple[Type, ...], args: Tuple[Any, ...] = (), kwargs: Optional[Dict[str, Any]] = None, - ) -> torch.Tensor: + ) -> paddle.Tensor: kwargs = kwargs or {} - unwrapped_args = tuple( arg.data if isinstance(arg, TaggedTensor) else arg for arg in args ) - result = func(*unwrapped_args, **kwargs) - return result @property - def shape(self) -> torch.Size: + def shape(self) -> list: """ Get the shape of the underlying tensor. """ - return self.data.shape + return tuple(self.data.shape) @property - def device(self) -> torch.device: + def device(self) -> str: """ Get the device of the underlying tensor. """ - return self.data.device + return self.data.place @property - def dtype(self) -> torch.dtype: + def dtype(self) -> paddle.dtype: """ Get the data type of the underlying tensor. """ return self.data.dtype - def size(self, dim: Optional[int] = None) -> torch.Size | int: + def size(self, dim: Optional[int] = None) -> (list | int): """ Get the size of the underlying tensor. """ - return self.data.size(dim) + return self.data.shape[dim] def __repr__(self) -> str: return f"TaggedTensor(type={self.type}, shape={self.shape}, dtype={self.dtype}, device={self.device})" diff --git a/flashinfer/logits_processor/validators.py b/flashinfer/logits_processor/validators.py index 293ec2edcb..3cd4c3183a 100644 --- a/flashinfer/logits_processor/validators.py +++ b/flashinfer/logits_processor/validators.py @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ - from typing import Callable, List from .op import Op @@ -41,56 +40,29 @@ def single_softmax_rule(ops: List[Op]) -> None: ) -# Disabled since we allow PROBS inputs to TopP, the input type is already guarded by the compiler -# def topp_after_softmax_rule(ops: List[Op]) -> None: -# """ -# R2: TopP-after-Softmax rule. - -# Every TopP must be preceded (anywhere earlier) by a Softmax. -# """ -# seen_softmax = False - -# for op in ops: -# if isinstance(op, SoftmaxOp): -# seen_softmax = True -# elif isinstance(op, TopPOp) and not seen_softmax: -# raise CompileError( -# "TopP operator requires a preceding Softmax operator. " -# "TopP can only operate on probabilities, not logits." -# ) - - def indices_terminal_rule(ops: List[Op]) -> None: """ R3': Indices-terminal rule. If an operator outputs Indices, no operator may follow it. """ - for i, op in enumerate(ops[:-1]): # Check all but the last operator + for i, op in enumerate(ops[:-1]): if TensorType.INDICES == op.OUT: next_op = ops[i + 1] raise CompileError( - f"No operator may follow one that outputs Indices. " - f"Found {next_op.__class__.__name__} after {op.__class__.__name__} " - f"which outputs Indices." + f"No operator may follow one that outputs Indices. Found {next_op.__class__.__name__} after {op.__class__.__name__} which outputs Indices." ) def get_default_validity_checks() -> List[ValidityCheck]: - return [ - single_softmax_rule, - # topp_after_softmax_rule, - indices_terminal_rule, - ] + return [single_softmax_rule, indices_terminal_rule] def validate_pipeline(ops: List[Op], custom_checks: List[ValidityCheck] = None) -> None: if not ops: raise CompileError("Pipeline cannot be empty") - for check in get_default_validity_checks(): check(ops) - if custom_checks: for check in custom_checks: check(ops) diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 8de451ca68..24fd0cb5f2 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2023 by FlashInfer team. @@ -13,12 +15,9 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools from typing import Literal, Optional, Tuple, Union, overload -import torch - from .jit import JitSpec from .jit import env as jit_env from .jit import gen_batch_mla_module, gen_jit_spec, sm100a_nvcc_flags @@ -34,23 +33,23 @@ def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table): raise ValueError(f"Expected kv_len.ndim == 1, got {kv_len.ndim}") if page_table.ndim != 2: raise ValueError(f"Expected page_table.ndim == 2, got {page_table.ndim}") - B_q, H, D_q = q_nope_pe.shape - D_ckv = ckv_kpe_cache.shape[2] + B_q, H, D_q = tuple(q_nope_pe.shape) + D_ckv = tuple(ckv_kpe_cache.shape)[2] if H != 128: raise ValueError(f"Expected 128 heads for q_nope_pe, got {H}") if D_q != D_ckv or D_q != 576: raise ValueError( f"Expected head dim 576 for q_nope_pe and ckv_kpe_cache, got {D_q} and {D_ckv}" ) - B_block_table, block_num = page_table.shape - block_size = ckv_kpe_cache.shape[1] + B_block_table, block_num = tuple(page_table.shape) + block_size = tuple(ckv_kpe_cache.shape)[1] if B_q != B_block_table: raise ValueError( f"Expected batch size {B_q} for q_nope_pe and block_table, got {B_q} and {B_block_table}" ) if block_num % (128 / block_size) != 0: raise ValueError( - f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}" + f"Expected block_num % (128 / block_size) == 0, got block_num={block_num!r} and block_size={block_size!r}" ) @@ -76,7 +75,7 @@ def get_batch_mla_module(backend, *args): class BatchMLAPagedAttentionWrapper: - r"""Wrapper class for MLA (`Multi-head Latent Attention `_) + """Wrapper class for MLA (`Multi-head Latent Attention `_) PagedAttention on DeepSeek models. This kernel can be used in decode, and incremental prefill and should be used together with `Matrix Absorption trick `_: @@ -143,15 +142,15 @@ class BatchMLAPagedAttentionWrapper: def __init__( self, - float_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, use_cuda_graph: bool = False, - qo_indptr: Optional[torch.Tensor] = None, - kv_indptr: Optional[torch.Tensor] = None, - kv_indices: Optional[torch.Tensor] = None, - kv_len_arr: Optional[torch.Tensor] = None, + qo_indptr: Optional[paddle.Tensor] = None, + kv_indptr: Optional[paddle.Tensor] = None, + kv_indices: Optional[paddle.Tensor] = None, + kv_len_arr: Optional[paddle.Tensor] = None, backend: str = "auto", ) -> None: - r"""Constructor for BatchMLAPagedAttentionWrapper. + """Constructor for BatchMLAPagedAttentionWrapper. Parameters ---------- @@ -186,21 +185,17 @@ def __init__( other arguments are ignored. """ self._float_workspace_buffer = float_workspace_buffer - self.device = float_workspace_buffer.device - + self.device = float_workspace_buffer.place if backend == "cutlass": self._backend = backend return - - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + self._int_workspace_buffer = paddle.empty( + shape=(8 * 1024 * 1024,), dtype="uint8" ) - self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=tuple(self._int_workspace_buffer.shape), dtype=self._int_workspace_buffer.dtype, - pin_memory=True, - device="cpu", - ) + ).pin_memory() self._use_cuda_graph = use_cuda_graph self._qo_indptr_buf = qo_indptr self._kv_indptr_buf = kv_indptr @@ -213,21 +208,21 @@ def __init__( def plan( self, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, - kv_indices: torch.Tensor, - kv_len_arr: torch.Tensor, + qo_indptr: paddle.Tensor, + kv_indptr: paddle.Tensor, + kv_indices: paddle.Tensor, + kv_len_arr: paddle.Tensor, num_heads: int, head_dim_ckv: int, head_dim_kpe: int, page_size: int, causal: bool, sm_scale: float, - q_data_type: torch.dtype, - kv_data_type: torch.dtype, + q_data_type: paddle.dtype, + kv_data_type: paddle.dtype, use_profiler: bool = False, ) -> None: - r"""Plan the MLA attention computation. + """Plan the MLA attention computation. Parameters ---------- @@ -260,18 +255,16 @@ def plan( use_profiler : bool, optional Whether to enable intra-kernel profiler, default is False. """ - for tensor, name in [ (kv_len_arr, "kv_len_arr"), (kv_indptr, "kv_indptr"), (qo_indptr, "qo_indptr"), (kv_indices, "kv_indices"), ]: - if tensor.dtype != torch.int32: + if tensor.dtype != "int32": raise ValueError( f"Expected {name}.dtype == torch.int32, got {tensor.dtype}" ) - self._cached_module = get_batch_mla_module( self._backend, q_data_type, @@ -285,22 +278,20 @@ def plan( qo_indptr_host = qo_indptr.to("cpu") kv_indptr_host = kv_indptr.to("cpu") kv_len_arr_host = kv_len_arr.to("cpu") - if self._use_cuda_graph: - self._qo_indptr_buf.copy_(qo_indptr, non_blocking=True) - self._kv_indptr_buf.copy_(kv_indptr, non_blocking=True) - self._kv_indices_buf[: len(kv_indices)].copy_(kv_indices, non_blocking=True) - self._kv_len_arr_buf.copy_(kv_len_arr, non_blocking=True) + paddle.assign(qo_indptr, output=self._qo_indptr_buf) + paddle.assign(kv_indptr, output=self._kv_indptr_buf) + paddle.assign(kv_indices, output=self._kv_indices_buf[: len(kv_indices)]) + paddle.assign(kv_len_arr, output=self._kv_len_arr_buf) else: - self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=True) - self._kv_indptr_buf = kv_indptr.to(self.device, non_blocking=True) - self._kv_indices_buf = kv_indices.to(self.device, non_blocking=True) - self._kv_len_arr_buf = kv_len_arr.to(self.device, non_blocking=True) + self._qo_indptr_buf = qo_indptr.to(self.device, blocking=not True) + self._kv_indptr_buf = kv_indptr.to(self.device, blocking=not True) + self._kv_indices_buf = kv_indices.to(self.device, blocking=not True) + self._kv_len_arr_buf = kv_len_arr.to(self.device, blocking=not True) self._causal = causal self._page_size = page_size self._sm_scale = sm_scale self._use_profiler = use_profiler - self._plan_info = self._cached_module.plan.default( self._float_workspace_buffer, self._int_workspace_buffer, @@ -309,54 +300,56 @@ def plan( kv_indptr_host, kv_len_arr_host, num_heads, - head_dim_ckv, # head_dim_o + head_dim_ckv, causal, ) @overload def run( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - ckv_cache: torch.Tensor, - kpe_cache: torch.Tensor, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + q_nope: paddle.Tensor, + q_pe: paddle.Tensor, + ckv_cache: paddle.Tensor, + kpe_cache: paddle.Tensor, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: Literal[False] = False, - profiler_buffer: Optional[torch.Tensor] = None, - kv_len: Optional[torch.Tensor] = None, - page_table: Optional[torch.Tensor] = None, - ) -> torch.Tensor: ... + profiler_buffer: Optional[paddle.Tensor] = None, + kv_len: Optional[paddle.Tensor] = None, + page_table: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + ... @overload def run( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - ckv_cache: torch.Tensor, - kpe_cache: torch.Tensor, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + q_nope: paddle.Tensor, + q_pe: paddle.Tensor, + ckv_cache: paddle.Tensor, + kpe_cache: paddle.Tensor, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: Literal[True] = True, - profiler_buffer: Optional[torch.Tensor] = None, - kv_len: Optional[torch.Tensor] = None, - page_table: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: ... + profiler_buffer: Optional[paddle.Tensor] = None, + kv_len: Optional[paddle.Tensor] = None, + page_table: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + ... def run( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - ckv_cache: torch.Tensor, - kpe_cache: torch.Tensor, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + q_nope: paddle.Tensor, + q_pe: paddle.Tensor, + ckv_cache: paddle.Tensor, + kpe_cache: paddle.Tensor, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: bool = False, - profiler_buffer: Optional[torch.Tensor] = None, - kv_len: Optional[torch.Tensor] = None, - page_table: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - r"""Run the MLA attention computation. + profiler_buffer: Optional[paddle.Tensor] = None, + kv_len: Optional[paddle.Tensor] = None, + page_table: Optional[paddle.Tensor] = None, + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Run the MLA attention computation. Parameters ---------- @@ -392,15 +385,15 @@ def run( ) self._cached_module = get_mla_module() if out is None: - out = torch.empty_like(q_nope) + out = paddle.empty_like(x=q_nope) else: check_shape_dtype_device( - out, q_nope.shape, q_nope.dtype, q_nope.device, "out" + out, tuple(q_nope.shape), q_nope.dtype, q_nope.place, "out" ) - q_nope_pe = torch.cat([q_nope, q_pe], dim=-1) - ckv_kpe_cache = torch.cat([ckv_cache, kpe_cache], dim=-1) + q_nope_pe = paddle.concat(x=[q_nope, q_pe], axis=-1) + ckv_kpe_cache = paddle.concat(x=[ckv_cache, kpe_cache], axis=-1) _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table) - lse = torch.empty(0, dtype=torch.float32, device=self.device) + lse = paddle.empty(shape=[0], dtype="float32") self._cached_module.cutlass_mla_paged_attention.default( self._float_workspace_buffer, out, @@ -411,31 +404,29 @@ def run( page_table, ) return out - if profiler_buffer is None: if self._use_profiler: raise ValueError( "Profiler is enabled, profiler_buffer must be provided" ) - num_heads = q_nope.shape[1] + num_heads = tuple(q_nope.shape)[1] page_size = self._page_size sm_scale = self._sm_scale causal = self._causal mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value device = self.device if out is None: - out = torch.empty_like(q_nope) + out = paddle.empty_like(x=q_nope) else: check_shape_dtype_device( - out, q_nope.shape, q_nope.dtype, q_nope.device, "out" + out, tuple(q_nope.shape), q_nope.dtype, q_nope.place, "out" ) - if return_lse: if lse is None: - lse = torch.empty(q_nope.shape[:2], dtype=torch.float32, device=device) + lse = paddle.empty(shape=tuple(q_nope.shape)[:2], dtype="float32") else: check_shape_dtype_device( - lse, q_nope.shape[:2], torch.float32, q_nope.device, "lse" + lse, tuple(q_nope.shape)[:2], "float32", q_nope.place, "lse" ) profiler_args = (profiler_buffer,) if self._use_profiler else () self._cached_module.run.default( @@ -455,5 +446,4 @@ def run( sm_scale, *profiler_args, ) - return (out, lse) if return_lse else out diff --git a/flashinfer/norm.py b/flashinfer/norm.py index 079e46a42f..ae615583dc 100644 --- a/flashinfer/norm.py +++ b/flashinfer/norm.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,12 +15,9 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools from typing import Optional -import torch - from .jit import JitSpec from .jit import env as jit_env from .jit import gen_jit_spec @@ -41,13 +40,13 @@ def get_norm_module(): def rmsnorm( - input: torch.Tensor, - weight: torch.Tensor, - eps: float = 1e-6, - out: Optional[torch.Tensor] = None, + input: paddle.Tensor, + weight: paddle.Tensor, + eps: float = 1e-06, + out: Optional[paddle.Tensor] = None, enable_pdl: Optional[bool] = None, -) -> torch.Tensor: - r"""Root mean square normalization. +) -> paddle.Tensor: + """Root mean square normalization. ``out[i] = (input[i] / RMS(input)) * weight[i]`` @@ -71,31 +70,31 @@ def rmsnorm( Normalized tensor, shape (batch_size, hidden_size). """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) if out is None: - out = torch.empty_like(input) + out = paddle.empty_like(x=input) _rmsnorm(out, input, weight, eps, enable_pdl) return out @register_custom_op("flashinfer::rmsnorm", mutates_args=("out",)) def _rmsnorm( - out: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, + out: paddle.Tensor, + input: paddle.Tensor, + weight: paddle.Tensor, eps: float, enable_pdl: Optional[bool], ) -> None: if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) get_norm_module().rmsnorm(out, input, weight, eps, enable_pdl) @register_fake_op("flashinfer::rmsnorm") def _rmsnorm_fake( - out: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, + out: paddle.Tensor, + input: paddle.Tensor, + weight: paddle.Tensor, eps: float, enable_pdl: Optional[bool], ) -> None: @@ -104,13 +103,13 @@ def _rmsnorm_fake( @register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual")) def fused_add_rmsnorm( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - eps: float = 1e-6, + input: paddle.Tensor, + residual: paddle.Tensor, + weight: paddle.Tensor, + eps: float = 1e-06, enable_pdl: Optional[bool] = None, ) -> None: - r"""Fused add root mean square normalization. + """Fused add root mean square normalization. Step 1: ``residual[i] += input[i]`` @@ -133,29 +132,29 @@ def fused_add_rmsnorm( `_ """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) get_norm_module().fused_add_rmsnorm(input, residual, weight, eps, enable_pdl) @register_fake_op("flashinfer::fused_add_rmsnorm") def _fused_add_rmsnorm_fake( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - eps: float = 1e-6, + input: paddle.Tensor, + residual: paddle.Tensor, + weight: paddle.Tensor, + eps: float = 1e-06, enable_pdl: Optional[bool] = None, ) -> None: pass def gemma_rmsnorm( - input: torch.Tensor, - weight: torch.Tensor, - eps: float = 1e-6, - out: Optional[torch.Tensor] = None, + input: paddle.Tensor, + weight: paddle.Tensor, + eps: float = 1e-06, + out: Optional[paddle.Tensor] = None, enable_pdl: Optional[bool] = None, -) -> torch.Tensor: - r"""Gemma-style root mean square normalization. +) -> paddle.Tensor: + """Gemma-style root mean square normalization. ``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)`` @@ -179,31 +178,31 @@ def gemma_rmsnorm( Gemma Normalized tensor, shape (batch_size, hidden_size). """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) if out is None: - out = torch.empty_like(input) + out = paddle.empty_like(x=input) _gemma_rmsnorm(out, input, weight, eps, enable_pdl) return out @register_custom_op("flashinfer::gemma_rmsnorm", mutates_args=("out",)) def _gemma_rmsnorm( - out: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, + out: paddle.Tensor, + input: paddle.Tensor, + weight: paddle.Tensor, eps: float, enable_pdl: Optional[bool], ) -> None: if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) get_norm_module().gemma_rmsnorm(out, input, weight, eps, enable_pdl) @register_fake_op("flashinfer::gemma_rmsnorm") def _gemma_rmsnorm_fake( - out: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, + out: paddle.Tensor, + input: paddle.Tensor, + weight: paddle.Tensor, eps: float, enable_pdl: Optional[bool], ) -> None: @@ -214,13 +213,13 @@ def _gemma_rmsnorm_fake( "flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual") ) def gemma_fused_add_rmsnorm( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - eps: float = 1e-6, + input: paddle.Tensor, + residual: paddle.Tensor, + weight: paddle.Tensor, + eps: float = 1e-06, enable_pdl: Optional[bool] = None, ) -> None: - r"""Gemma-style fused add root mean square normalization. + """Gemma-style fused add root mean square normalization. Step 1: ``residual[i] += input[i]`` @@ -243,16 +242,16 @@ def gemma_fused_add_rmsnorm( `_ """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) get_norm_module().gemma_fused_add_rmsnorm(input, residual, weight, eps, enable_pdl) @register_fake_op("flashinfer::gemma_fused_add_rmsnorm") def _gemma_fused_add_rmsnorm_fake( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - eps: float = 1e-6, + input: paddle.Tensor, + residual: paddle.Tensor, + weight: paddle.Tensor, + eps: float = 1e-06, enable_pdl: Optional[bool] = None, ) -> None: pass diff --git a/flashinfer/paddle_utils.py b/flashinfer/paddle_utils.py new file mode 100644 index 0000000000..9a87c83740 --- /dev/null +++ b/flashinfer/paddle_utils.py @@ -0,0 +1,132 @@ + +import paddle + +############################## 相关utils函数,如下 ############################## +############################ PaConvert 自动生成的代码 ########################### + +def device2str(type=None, index=None, *, device=None): + type = device if device else type + if isinstance(type, int): + type = f'gpu:{type}' + elif isinstance(type, str): + if 'cuda' in type: + type = type.replace('cuda', 'gpu') + if 'cpu' in type: + type = 'cpu' + elif index is not None: + type = f'{type}:{index}' + elif isinstance(type, paddle.CPUPlace) or (type is None): + type = 'cpu' + elif isinstance(type, paddle.CUDAPlace): + type = f'gpu:{type.get_device_id()}' + + return type + +def _Tensor_view(self, *args, **kwargs): + if args: + if len(args)==1 and isinstance(args[0], (tuple, list, str)): + return paddle.view(self, args[0]) + else: + return paddle.view(self, list(args)) + elif kwargs: + return paddle.view(self, shape_or_dtype = list(kwargs.values())[0]) + +setattr(paddle.Tensor, 'view', _Tensor_view) + +def _Tensor_reshape(self, *args, **kwargs): + if args: + if len(args) == 1 and isinstance(args[0], (tuple, list)): + return paddle.reshape(self, args[0]) + else: + return paddle.reshape(self, list(args)) + elif kwargs: + assert "shape" in kwargs + return paddle.reshape(self, shape=kwargs["shape"]) + +setattr(paddle.Tensor, "reshape", _Tensor_reshape) + +def dim2perm(ndim, dim0, dim1): + perm = list(range(ndim)) + perm[dim0], perm[dim1] = perm[dim1], perm[dim0] + return perm + +def _Tensor_max(self, *args, **kwargs): + if "other" in kwargs: + kwargs["y"] = kwargs.pop("other") + ret = paddle.maximum(self, *args, **kwargs) + elif len(args) == 1 and isinstance(args[0], paddle.Tensor): + ret = paddle.maximum(self, *args, **kwargs) + else: + if "dim" in kwargs: + kwargs["axis"] = kwargs.pop("dim") + + if "axis" in kwargs or len(args) >= 1: + ret = paddle.max(self, *args, **kwargs), paddle.argmax(self, *args, **kwargs) + else: + ret = paddle.max(self, *args, **kwargs) + + return ret + +setattr(paddle.Tensor, "_max", _Tensor_max) + +def _Tensor_split(self, split_size, dim=0): + if isinstance(split_size, int): + return paddle.split(self, self.shape[dim] // split_size, dim) + else: + return paddle.split(self, split_size, dim) + +setattr(paddle.Tensor, "split", _Tensor_split) + +def _Tensor_add(self, *args, **kwargs): + if "other" in kwargs: + y = kwargs["other"] + elif "y" in kwargs: + y = kwargs["y"] + else: + y = args[0] + if "alpha" in kwargs: + alpha = kwargs["alpha"] + if alpha != 1: + if not isinstance(y, paddle.Tensor): + y = paddle.to_tensor(alpha * y) + else: + y = alpha * y + else: + if not isinstance(y, paddle.Tensor): + y = paddle.to_tensor(y) + return paddle.add(self, y) + +setattr(paddle.Tensor, "add", _Tensor_add) + +def device2int(device): + if isinstance(device, str): + device = device.replace('cuda', 'gpu') + device = device.replace('gpu:', '') + return int(device) + +def paddle_split(x, num_or_sections, axis=0): + if isinstance(num_or_sections, int): + return paddle.split(x, x.shape[axis]//num_or_sections, axis) + else: + return paddle.split(x, num_or_sections, axis) + +def _Tensor_min(self, *args, **kwargs): + if "other" in kwargs: + kwargs["y"] = kwargs.pop("other") + ret = paddle.minimum(self, *args, **kwargs) + elif len(args) == 1 and isinstance(args[0], paddle.Tensor): + ret = paddle.minimum(self, *args, **kwargs) + else: + if "dim" in kwargs: + kwargs["axis"] = kwargs.pop("dim") + + if "axis" in kwargs or len(args) >= 1: + ret = paddle.min(self, *args, **kwargs), paddle.argmin(self, *args, **kwargs) + else: + ret = paddle.min(self, *args, **kwargs) + + return ret + +setattr(paddle.Tensor, "_min", _Tensor_min) +############################## 相关utils函数,如上 ############################## + diff --git a/flashinfer/page.py b/flashinfer/page.py index b141ae4211..720244af7e 100644 --- a/flashinfer/page.py +++ b/flashinfer/page.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2023 by FlashInfer team. @@ -13,22 +15,14 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools from typing import Optional, Tuple, Union -import torch - from .jit import JitSpec from .jit import env as jit_env from .jit import gen_jit_spec -from .utils import ( - TensorLayout, - _check_kv_layout, - _unpack_paged_kv_cache, - register_custom_op, - register_fake_op, -) +from .utils import (TensorLayout, _check_kv_layout, _unpack_paged_kv_cache, + register_custom_op, register_fake_op) def gen_page_module() -> JitSpec: @@ -47,27 +41,26 @@ def get_page_module(): def block_sparse_indices_to_vector_sparse_offsets( - block_sparse_indices: torch.Tensor, - block_sparse_indptr: torch.Tensor, - vector_sparse_offsets: torch.Tensor, - vector_sparse_indptr: torch.Tensor, - kv_lens: torch.Tensor, + block_sparse_indices: paddle.Tensor, + block_sparse_indptr: paddle.Tensor, + vector_sparse_offsets: paddle.Tensor, + vector_sparse_indptr: paddle.Tensor, + kv_lens: paddle.Tensor, stride_block: int, stride_n: int, block_size: int, -) -> torch.Tensor: +) -> paddle.Tensor: if block_size == 1: if stride_block == 1: return block_sparse_indices else: return block_sparse_indices * stride_block - - assert block_sparse_indices.dtype == torch.int32 - assert block_sparse_indptr.dtype == torch.int32 - assert vector_sparse_offsets.dtype == torch.int32 - assert vector_sparse_indptr.dtype == torch.int32 - assert kv_lens.dtype == torch.int32 - batch_size = block_sparse_indptr.size(0) - 1 + assert block_sparse_indices.dtype == "int32" + assert block_sparse_indptr.dtype == "int32" + assert vector_sparse_offsets.dtype == "int32" + assert vector_sparse_indptr.dtype == "int32" + assert kv_lens.dtype == "int32" + batch_size = block_sparse_indptr.shape[0] - 1 get_page_module().block_sparse_indices_to_vector_sparse_offsets( block_sparse_indices, block_sparse_indptr, @@ -83,25 +76,24 @@ def block_sparse_indices_to_vector_sparse_offsets( @register_custom_op( - "flashinfer::append_paged_mla_kv_cache", - mutates_args=("ckv_cache", "kpe_cache"), + "flashinfer::append_paged_mla_kv_cache", mutates_args=("ckv_cache", "kpe_cache") ) def _append_paged_mla_kv_cache_kernel( - append_ckv: torch.Tensor, - append_kpe: torch.Tensor, - batch_indices: torch.Tensor, - positions: torch.Tensor, - ckv_cache: Optional[torch.Tensor], - kpe_cache: Optional[torch.Tensor], - kv_indices: torch.Tensor, - kv_indptr: torch.Tensor, - kv_last_page_len: torch.Tensor, + append_ckv: paddle.Tensor, + append_kpe: paddle.Tensor, + batch_indices: paddle.Tensor, + positions: paddle.Tensor, + ckv_cache: Optional[paddle.Tensor], + kpe_cache: Optional[paddle.Tensor], + kv_indices: paddle.Tensor, + kv_indptr: paddle.Tensor, + kv_last_page_len: paddle.Tensor, ) -> None: - batch_indices = batch_indices.int() - positions = positions.int() - kv_indices = kv_indices.int() - kv_indptr = kv_indptr.int() - kv_last_page_len = kv_last_page_len.int() + batch_indices = batch_indices.astype(dtype="int32") + positions = positions.astype(dtype="int32") + kv_indices = kv_indices.astype(dtype="int32") + kv_indptr = kv_indptr.astype(dtype="int32") + kv_last_page_len = kv_last_page_len.astype(dtype="int32") get_page_module().append_paged_mla_kv_cache( append_ckv, append_kpe, @@ -116,26 +108,25 @@ def _append_paged_mla_kv_cache_kernel( @register_custom_op( - "flashinfer::append_paged_kv_cache", - mutates_args=("paged_k_cache", "paged_v_cache"), + "flashinfer::append_paged_kv_cache", mutates_args=("paged_k_cache", "paged_v_cache") ) def _append_paged_kv_cache_kernel( - append_key: torch.Tensor, - append_value: torch.Tensor, - batch_indices: torch.Tensor, - positions: torch.Tensor, - paged_k_cache: Optional[torch.Tensor], - paged_v_cache: Optional[torch.Tensor], - kv_indices: torch.Tensor, - kv_indptr: torch.Tensor, - kv_last_page_len: torch.Tensor, + append_key: paddle.Tensor, + append_value: paddle.Tensor, + batch_indices: paddle.Tensor, + positions: paddle.Tensor, + paged_k_cache: Optional[paddle.Tensor], + paged_v_cache: Optional[paddle.Tensor], + kv_indices: paddle.Tensor, + kv_indptr: paddle.Tensor, + kv_last_page_len: paddle.Tensor, layout: int, ) -> None: - batch_indices = batch_indices.int() - positions = positions.int() - kv_indices = kv_indices.int() - kv_indptr = kv_indptr.int() - kv_last_page_len = kv_last_page_len.int() + batch_indices = batch_indices.astype(dtype="int32") + positions = positions.astype(dtype="int32") + kv_indices = kv_indices.astype(dtype="int32") + kv_indptr = kv_indptr.astype(dtype="int32") + kv_last_page_len = kv_last_page_len.astype(dtype="int32") get_page_module().append_paged_kv_cache( append_key, append_value, @@ -152,24 +143,24 @@ def _append_paged_kv_cache_kernel( @register_fake_op("flashinfer::append_paged_kv_cache") def _fake_append_paged_kv_cache_kernel( - append_key: torch.Tensor, - append_value: torch.Tensor, - batch_indices: torch.Tensor, - positions: torch.Tensor, - paged_k_cache: Optional[torch.Tensor], - paged_v_cache: Optional[torch.Tensor], - kv_indices: torch.Tensor, - kv_indptr: torch.Tensor, - kv_last_page_len: torch.Tensor, + append_key: paddle.Tensor, + append_value: paddle.Tensor, + batch_indices: paddle.Tensor, + positions: paddle.Tensor, + paged_k_cache: Optional[paddle.Tensor], + paged_v_cache: Optional[paddle.Tensor], + kv_indices: paddle.Tensor, + kv_indptr: paddle.Tensor, + kv_last_page_len: paddle.Tensor, layout: int, ) -> None: pass def get_batch_indices_positions( - append_indptr: torch.Tensor, seq_lens: torch.Tensor, nnz: int -) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Convert append indptr and sequence lengths to batch indices and positions. + append_indptr: paddle.Tensor, seq_lens: paddle.Tensor, nnz: int +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Convert append indptr and sequence lengths to batch indices and positions. Parameters ---------- @@ -210,21 +201,21 @@ def get_batch_indices_positions( -------- append_paged_kv_cache """ - batch_size = append_indptr.size(0) - 1 - batch_indices = torch.empty((nnz,), device=append_indptr.device, dtype=torch.int32) - positions = torch.empty((nnz,), device=append_indptr.device, dtype=torch.int32) + batch_size = append_indptr.shape[0] - 1 + batch_indices = paddle.empty(shape=(nnz,), dtype="int32") + positions = paddle.empty(shape=(nnz,), dtype="int32") from .triton.page import get_batch_indices_positions_kernel - get_batch_indices_positions_kernel[(batch_size,)]( - append_indptr, seq_lens, batch_indices, positions, num_stages=2 - ) + get_batch_indices_positions_kernel[ + batch_size, + ](append_indptr, seq_lens, batch_indices, positions, num_stages=2) return batch_indices, positions def get_seq_lens( - kv_indptr: torch.Tensor, kv_last_page_len: torch.Tensor, page_size: int -) -> torch.Tensor: - r"""Convert KV indptr and last page length to sequence lengths. + kv_indptr: paddle.Tensor, kv_last_page_len: paddle.Tensor, page_size: int +) -> paddle.Tensor: + """Convert KV indptr and last page length to sequence lengths. Parameters ---------- @@ -242,23 +233,23 @@ def get_seq_lens( The sequence lengths of each request in the paged kv-cache, shape: ``[batch_size]``. """ return ( - torch.clamp(kv_indptr[1:] - kv_indptr[:-1] - 1, min=0) * page_size + paddle.clip(x=kv_indptr[1:] - kv_indptr[:-1] - 1, min=0) * page_size + kv_last_page_len ) def append_paged_mla_kv_cache( - append_ckv: torch.Tensor, - append_kpe: torch.Tensor, - batch_indices: torch.Tensor, - positions: torch.Tensor, - ckv_cache: Optional[torch.Tensor], - kpe_cache: Optional[torch.Tensor], - kv_indices: torch.Tensor, - kv_indptr: torch.Tensor, - kv_last_page_len: torch.Tensor, + append_ckv: paddle.Tensor, + append_kpe: paddle.Tensor, + batch_indices: paddle.Tensor, + positions: paddle.Tensor, + ckv_cache: Optional[paddle.Tensor], + kpe_cache: Optional[paddle.Tensor], + kv_indices: paddle.Tensor, + kv_indptr: paddle.Tensor, + kv_last_page_len: paddle.Tensor, ) -> None: - r"""Append a batch of key-value pairs to a paged key-value cache, + """Append a batch of key-value pairs to a paged key-value cache, Note: current only support ckv=512 and kpe=64 Parameters @@ -297,17 +288,17 @@ def append_paged_mla_kv_cache( def append_paged_kv_cache( - append_key: torch.Tensor, - append_value: torch.Tensor, - batch_indices: torch.Tensor, - positions: torch.Tensor, - paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - kv_indices: torch.Tensor, - kv_indptr: torch.Tensor, - kv_last_page_len: torch.Tensor, + append_key: paddle.Tensor, + append_value: paddle.Tensor, + batch_indices: paddle.Tensor, + positions: paddle.Tensor, + paged_kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], + kv_indices: paddle.Tensor, + kv_indptr: paddle.Tensor, + kv_last_page_len: paddle.Tensor, kv_layout: str = "NHD", ) -> None: - r"""Append a batch of key-value pairs to a paged key-value cache. + """Append a batch of key-value pairs to a paged key-value cache. Parameters ---------- @@ -421,5 +412,5 @@ def append_paged_kv_cache( kv_indices, kv_indptr, kv_last_page_len, - TensorLayout[kv_layout].value, + TensorLayout[kv_layout].value ) diff --git a/flashinfer/pod.py b/flashinfer/pod.py index 410077f403..f4fc3c8f88 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2023 by FlashInfer team. @@ -13,32 +19,20 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools import math from types import SimpleNamespace from typing import Any, List, Optional, Tuple, Union -import torch - from .jit import gen_pod_module from .page import get_seq_lens from .prefill import get_batch_prefill_module from .quantization import packbits -from .utils import ( - MaskMode, - PosEncodingMode, - TensorLayout, - _check_cached_qkv_data_type, - _check_kv_layout, - _check_pos_encoding_mode, - _get_cache_alibi_slopes_buf, - _get_cache_buf, - _get_range_buf, - _unpack_paged_kv_cache, - canonicalize_torch_dtype, - device_support_pdl, -) +from .utils import (MaskMode, PosEncodingMode, TensorLayout, + _check_cached_qkv_data_type, _check_kv_layout, + _check_pos_encoding_mode, _get_cache_alibi_slopes_buf, + _get_cache_buf, _get_range_buf, _unpack_paged_kv_cache, + canonicalize_torch_dtype, device_support_pdl) @functools.cache @@ -48,7 +42,7 @@ def get_pod_module(*args): class PODWithPagedKVCacheWrapper: - r"""Wrapper class for POD-Attention with paged kv-cache (first proposed in + """Wrapper class for POD-Attention with paged kv-cache (first proposed in ``_) for batch of requests. Check :ref:`our tutorial` for page table layout. @@ -115,15 +109,15 @@ class PODWithPagedKVCacheWrapper: def __init__( self, - float_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, - paged_kv_indptr_buffer: Optional[torch.Tensor] = None, - paged_kv_indices_buffer: Optional[torch.Tensor] = None, - paged_kv_last_page_len_buffer: Optional[torch.Tensor] = None, + paged_kv_indptr_buffer: Optional[paddle.Tensor] = None, + paged_kv_indices_buffer: Optional[paddle.Tensor] = None, + paged_kv_last_page_len_buffer: Optional[paddle.Tensor] = None, jit_args: Optional[List[Any]] = None, ) -> None: - r"""Constructor of :class:`PODWithPagedKVCacheWrapper`. + """Constructor of :class:`PODWithPagedKVCacheWrapper`. Parameters ---------- @@ -173,33 +167,27 @@ def __init__( ) else: """ - # Override options. Only tensor core version is performant. use_tensor_cores = True self._jit_module: SimpleNamespace = None - self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer - self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + self.device = float_workspace_buffer.place + self._int_workspace_buffer = paddle.empty( + shape=(8 * 1024 * 1024,), dtype="uint8" ) - self._pin_memory_int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), - dtype=torch.uint8, - pin_memory=True, - device="cpu", - ) - + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=(8 * 1024 * 1024,), dtype="uint8" + ).pin_memory() if use_cuda_graph: - if not torch.is_tensor(paged_kv_indptr_buffer): + if not paddle.is_tensor(x=paged_kv_indptr_buffer): raise ValueError( "paged_kv_indptr_buffer should be a torch.Tensor in cudagraph mode" ) - if not torch.is_tensor(paged_kv_indices_buffer): + if not paddle.is_tensor(x=paged_kv_indices_buffer): raise ValueError( "paged_kv_indices_buffer should be a torch.Tensor in cudagraph mode" ) - if not torch.is_tensor(paged_kv_last_page_len_buffer): + if not paddle.is_tensor(x=paged_kv_last_page_len_buffer): raise ValueError( "paged_kv_last_page_len_buffer should be a torch.Tensor in cudagraph mode" ) @@ -210,19 +198,14 @@ def __init__( ) else: self._fixed_batch_size = 0 - self._paged_kv_indptr_buf = paged_kv_indptr_buffer self._paged_kv_indices_buf = paged_kv_indices_buffer self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer self._use_tensor_cores = use_tensor_cores self._use_cuda_graph = use_cuda_graph - if use_cuda_graph: - # NOTE(Zihao): if once created, no need to update it in plan/run - self._qo_indptr_buf = torch.arange( - self._fixed_batch_size + 1, - dtype=torch.int32, - device=float_workspace_buffer.device, + self._qo_indptr_buf = paddle.arange( + dtype="int32", end=self._fixed_batch_size + 1 ) @property @@ -230,9 +213,9 @@ def is_cuda_graph_enabled(self) -> bool: return self._use_cuda_graph def reset_workspace_buffer( - self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor + self, float_workspace_buffer: paddle.Tensor, int_workspace_buffer: paddle.Tensor ) -> None: - r"""Reset the workspace buffer. + """Reset the workspace buffer. Parameters ---------- @@ -246,33 +229,31 @@ def reset_workspace_buffer( """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=tuple(self._int_workspace_buffer.shape), dtype=self._int_workspace_buffer.dtype, - device="cpu", - pin_memory=True, - ) + ).pin_memory() def plan( self, - indptr: torch.Tensor, - indices: torch.Tensor, - last_page_len: torch.Tensor, + indptr: paddle.Tensor, + indices: paddle.Tensor, + last_page_len: paddle.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, pos_encoding_mode: str = "NONE", window_left: int = -1, - q_data_type: Optional[Union[str, torch.dtype]] = "float16", - kv_data_type: Optional[Union[str, torch.dtype]] = None, - data_type: Optional[Union[str, torch.dtype]] = None, + q_data_type: Optional[Union[str, paddle.dtype]] = "float16", + kv_data_type: Optional[Union[str, paddle.dtype]] = None, + data_type: Optional[Union[str, paddle.dtype]] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, non_blocking: bool = True, ) -> None: - r"""Plan POD's batch decode for given problem specification. + """Plan POD's batch decode for given problem specification. Parameters ---------- @@ -322,16 +303,13 @@ def plan( The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``. """ - # Logits soft cap is not supported currently batch_size = len(last_page_len) logits_soft_cap = 0.0 - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") if self.is_cuda_graph_enabled: if batch_size != self._fixed_batch_size: raise ValueError( - "The batch size should be fixed in cudagraph mode, the runtime batch size {} " - " mismatches the batch size set during initialization {}".format( + "The batch size should be fixed in cudagraph mode, the runtime batch size {} mismatches the batch size set during initialization {}".format( batch_size, self._fixed_batch_size ) ) @@ -339,41 +317,33 @@ def plan( raise ValueError( "The size of indices should be less than or equal to the allocated buffer" ) - self._paged_kv_indptr_buf.copy_(indptr, non_blocking=non_blocking) - self._paged_kv_last_page_len_buf.copy_( - last_page_len, non_blocking=non_blocking - ) - self._paged_kv_indices_buf[: len(indices)].copy_( - indices, non_blocking=(indices.device == self.device) and non_blocking - ) + paddle.assign(indptr, output=self._paged_kv_indptr_buf) + paddle.assign(last_page_len, output=self._paged_kv_last_page_len_buf) + paddle.assign(indices, output=self._paged_kv_indices_buf[: len(indices)]) else: self._paged_kv_indptr_buf = indptr.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) self._paged_kv_indices_buf = indices.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) self._paged_kv_last_page_len_buf = last_page_len.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) self._qo_indptr_buf = qo_indptr_host.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) - indptr_host = indptr.to("cpu") last_page_len_host = last_page_len.to("cpu") - if data_type is not None: if q_data_type is None: q_data_type = data_type if kv_data_type is None: kv_data_type = data_type - q_data_type = canonicalize_torch_dtype(q_data_type) if kv_data_type is None: kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) - self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) @@ -386,12 +356,12 @@ def plan( kv_data_type, q_data_type, indptr.dtype, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + head_dim, + head_dim, PosEncodingMode[pos_encoding_mode].value, - window_left != -1, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap - False, # use_fp16_qk_reduction + window_left != -1, + logits_soft_cap > 0, + False, ) self._plan_info = self._cached_module.plan( self._float_workspace_buffer, @@ -400,7 +370,7 @@ def plan( qo_indptr_host, indptr_host, kv_lens_arr_host, - batch_size, # total_num_rows + batch_size, batch_size, num_qo_heads, num_kv_heads, @@ -408,9 +378,8 @@ def plan( self.is_cuda_graph_enabled, head_dim, head_dim, - False, # causal + False, ) - self._indptr_type = indptr.dtype self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left @@ -423,15 +392,13 @@ def plan( def run( self, - # Main params (prefill and decode) - q_p: torch.Tensor, - k_p: torch.Tensor, - v_p: torch.Tensor, - q_d: torch.Tensor, - paged_kv_cache_d: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - # Prefill options - custom_mask_p: Optional[torch.Tensor] = None, - packed_custom_mask_p: Optional[torch.Tensor] = None, + q_p: paddle.Tensor, + k_p: paddle.Tensor, + v_p: paddle.Tensor, + q_d: paddle.Tensor, + paged_kv_cache_d: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], + custom_mask_p: Optional[paddle.Tensor] = None, + packed_custom_mask_p: Optional[paddle.Tensor] = None, causal_p: bool = False, kv_layout_p: str = "NHD", pos_encoding_mode_p: str = "NONE", @@ -440,9 +407,8 @@ def run( rope_scale_p: Optional[float] = None, rope_theta_p: Optional[float] = None, return_lse_p: bool = False, - # Decode options - custom_mask_d: Optional[torch.Tensor] = None, - packed_custom_mask_d: Optional[torch.Tensor] = None, + custom_mask_d: Optional[paddle.Tensor] = None, + packed_custom_mask_d: Optional[paddle.Tensor] = None, causal_d: bool = False, kv_layout_d: str = "NHD", pos_encoding_mode_d: str = "NONE", @@ -456,55 +422,42 @@ def run( return_lse_d: bool = False, use_fp16_qk_reduction: bool = False, enable_pdl: Optional[bool] = None, - *args, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - r"""Compute POD-attention for a batch of requests.""" + *args + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Compute POD-attention for a batch of requests.""" if enable_pdl is None: - enable_pdl = device_support_pdl(q_p.device) - - # Currently unsupported + enable_pdl = device_support_pdl(q_p.place) logits_soft_cap_p = None logits_soft_cap_d = None - # Prefill setup _check_pos_encoding_mode(pos_encoding_mode_p) _check_kv_layout(kv_layout_p) - tmp_p = _get_cache_buf("pod_with_kv_cache_tmp", 32 * 1024 * 1024, q_p.device) + tmp_p = _get_cache_buf("pod_with_kv_cache_tmp", 32 * 1024 * 1024, q_p.place) if logits_soft_cap_p is None: logits_soft_cap_p = 0.0 if sm_scale_p is None: - sm_scale_p = 1.0 / math.sqrt(q_p.size(-1)) + sm_scale_p = 1.0 / math.sqrt(q_p.shape[-1]) if rope_scale_p is None: rope_scale_p = 1.0 if rope_theta_p is None: - rope_theta_p = 1e4 + rope_theta_p = 10000.0 if custom_mask_p is not None and packed_custom_mask_p is None: - # create packed custom mask from custom mask packed_custom_mask_p = packbits( custom_mask_p.contiguous().view(-1), bitorder="little" ) - if packed_custom_mask_p is not None: mask_mode_p = MaskMode.CUSTOM.value + elif causal_p: + mask_mode_p = MaskMode.CAUSAL.value else: - if causal_p: - mask_mode_p = MaskMode.CAUSAL.value - else: - mask_mode_p = MaskMode.NON_CAUSAL.value - + mask_mode_p = MaskMode.NON_CAUSAL.value lse_p = None if return_lse_p: - lse_p = torch.empty( - (q_p.size(0), q_p.size(1)), dtype=torch.float32, device=q_p.device - ) - - out_p = torch.empty_like(q_p) - - # Decode setup + lse_p = paddle.empty(shape=(q_p.shape[0], q_p.shape[1]), dtype="float32") + out_p = paddle.empty_like(x=q_p) k_cache_d, v_cache_d = _unpack_paged_kv_cache(paged_kv_cache_d, self._kv_layout) _check_cached_qkv_data_type( q_d, k_cache_d, self._cached_q_data_type, self._cached_kv_data_type ) - # TODO_AK: Where are these coming from? pos_encoding_mode_d = self._pos_encoding_mode window_left_d = self._window_left logits_soft_cap_d = self._logits_soft_cap @@ -512,11 +465,10 @@ def run( rope_scale_d = self._rope_scale rope_theta_d = self._rope_theta _check_pos_encoding_mode(pos_encoding_mode_d) - # What are the above for and what are the below? if logits_soft_cap_d is None: logits_soft_cap_d = 0.0 if sm_scale_d is None: - head_dim = q_d.shape[-1] + head_dim = tuple(q_d.shape)[-1] sm_scale_d = 1.0 / math.sqrt(head_dim) if q_scale is not None: sm_scale_d *= q_scale @@ -525,38 +477,26 @@ def run( if rope_scale_d is None: rope_scale_d = 1.0 if rope_theta_d is None: - rope_theta_d = 1e4 - + rope_theta_d = 10000.0 lse_d = None if return_lse_d: - lse_d = torch.empty( - (q_d.size(0), q_d.size(1)), dtype=torch.float32, device=q_d.device - ) - out_d = torch.empty_like(q_d) - + lse_d = paddle.empty(shape=(q_d.shape[0], q_d.shape[1]), dtype="float32") + out_d = paddle.empty_like(x=q_d) module_getter = get_pod_module( - # Prefill params q_p.dtype, k_p.dtype, q_p.dtype, - q_p.shape[-1], + tuple(q_p.shape)[-1], PosEncodingMode[pos_encoding_mode_p].value, - window_left_p >= 0, # use_sliding_window - logits_soft_cap_p > 0, # use_logits_soft_cap + window_left_p >= 0, + logits_soft_cap_p > 0, use_fp16_qk_reduction, - # Decode params - # q_d.dtype, - # self._cached_kv_data_type, - # self._cached_q_data_type, self._indptr_type, - # head_dim, # head_dim_qk - # head_dim, # head_dim_vo PosEncodingMode[pos_encoding_mode_d].value, - window_left_d != -1, # use_sliding_window - logits_soft_cap_d > 0, # use_logits_soft_cap + window_left_d != -1, + logits_soft_cap_d > 0, ) module_getter.run_tensor( - # Prefill params q_p, k_p, v_p, @@ -567,12 +507,11 @@ def run( TensorLayout[kv_layout_p].value, window_left_p, packed_custom_mask_p, - _get_cache_alibi_slopes_buf(q_p.shape[1], q_p.device), + _get_cache_alibi_slopes_buf(tuple(q_p.shape)[1], q_p.place), logits_soft_cap_p, sm_scale_p, 1.0 / rope_scale_p, 1.0 / rope_theta_p, - # Decode params self._float_workspace_buffer, self._int_workspace_buffer, self._plan_info, @@ -588,21 +527,19 @@ def run( MaskMode.NON_CAUSAL.value, TensorLayout[self._kv_layout].value, window_left_d, - None, # packed_custom_mask - None, # mask_indptr_buf - _get_cache_alibi_slopes_buf(q_d.shape[1], q_d.device), + None, + None, + _get_cache_alibi_slopes_buf(tuple(q_d.shape)[1], q_d.place), logits_soft_cap_d, sm_scale_d, 1.0 / rope_scale_d, 1.0 / rope_theta_d, enable_pdl, ) - if v_scale is not None: out_d *= v_scale - - return (out_p, out_d) + return out_p, out_d def end_forward(self) -> None: - r"""Warning: this function is deprecated and has no effect.""" + """Warning: this function is deprecated and has no effect.""" pass diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 3805d3a546..d8cc039358 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2023 by FlashInfer team. @@ -13,59 +19,36 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools import logging import math from types import SimpleNamespace from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload -import torch - -from .jit import ( - gen_batch_prefill_module, - gen_customize_batch_prefill_module, - gen_fmha_cutlass_sm100a_module, - gen_single_prefill_module, - get_batch_prefill_uri, - get_single_prefill_uri, - setup_cubin_loader, - trtllm_gen_fmha_module, -) from .cudnn import cudnn_batch_prefill_with_kv_cache +from .jit import (gen_batch_prefill_module, gen_customize_batch_prefill_module, + gen_fmha_cutlass_sm100a_module, gen_single_prefill_module, + get_batch_prefill_uri, get_single_prefill_uri, + setup_cubin_loader, trtllm_gen_fmha_module) from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens from .quantization import packbits, segment_packbits -from .utils import ( - FP4Tensor, - MaskMode, - PosEncodingMode, - TensorLayout, - _check_cached_qkv_data_type, - _check_kv_layout, - _check_pos_encoding_mode, - check_shape_dtype_device, - _get_cache_alibi_slopes_buf, - _get_cache_buf, - _unpack_paged_kv_cache, - canonicalize_torch_dtype, - determine_attention_backend, - device_support_pdl, - get_device_sm_count, - is_float8, - is_sm100a_supported, - register_custom_op, - register_fake_op, - ceil_div, - round_up, -) +from .utils import (FP4Tensor, MaskMode, PosEncodingMode, TensorLayout, + _check_cached_qkv_data_type, _check_kv_layout, + _check_pos_encoding_mode, _get_cache_alibi_slopes_buf, + _get_cache_buf, _unpack_paged_kv_cache, + canonicalize_torch_dtype, ceil_div, + check_shape_dtype_device, determine_attention_backend, + device_support_pdl, get_device_sm_count, is_float8, + is_sm100a_supported, register_custom_op, register_fake_op, + round_up) @functools.cache def get_fmha_module( - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - dtype_idx: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + dtype_idx: paddle.dtype, head_dim_qk: int, head_dim_vo: int, pos_encoding_mode: int, @@ -73,7 +56,7 @@ def get_fmha_module( use_logits_soft_cap: bool, use_fp16_qk_reduction: bool = False, ): - if is_sm100a_supported(torch.device("cuda")): + if is_sm100a_supported(device2str("cuda")): return gen_fmha_cutlass_sm100a_module( dtype_q, dtype_kv, @@ -101,21 +84,18 @@ def cached_wrapper(*args, **kwargs): @functools.wraps(func) def wrapper(*args, **kwargs): - # Convert unhashable arguments to hashable ones hashable_args = [] for arg in args: if isinstance(arg, list): hashable_args.append(tuple(arg)) else: hashable_args.append(arg) - hashable_kwargs = {} for key, value in kwargs.items(): if isinstance(value, list): hashable_kwargs[key] = tuple(value) else: hashable_kwargs[key] = value - return cached_wrapper(*hashable_args, **hashable_kwargs) return wrapper @@ -125,10 +105,10 @@ def wrapper(*args, **kwargs): def get_customize_batch_prefill_module( backend: str, uri: str, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, - dtype_o: torch.dtype, - idtype: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, + dtype_o: paddle.dtype, + idtype: paddle.dtype, head_dim_qk: int, head_dim_vo: int, additional_tensor_names: List[str], @@ -173,30 +153,30 @@ def get_trtllm_gen_prefill_module(): setup_cubin_loader(mod.get_library_path()) def _paged_run( - query: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - workspace_buffer: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, + query: paddle.Tensor, + k_cache: paddle.Tensor, + v_cache: paddle.Tensor, + workspace_buffer: paddle.Tensor, + block_tables: paddle.Tensor, + seq_lens: paddle.Tensor, max_q_len: int, max_kv_len: int, bmm1_scale: float, bmm2_scale: float, batch_size: int, - cum_seq_lens_q: torch.Tensor, - cum_seq_lens_kv: torch.Tensor, + cum_seq_lens_q: paddle.Tensor, + cum_seq_lens_kv: paddle.Tensor, enable_pdl: bool, window_left: int = -1, - out: Optional[torch.Tensor] = None, - sinks: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - sm_count = get_device_sm_count(query.device) + out: Optional[paddle.Tensor] = None, + sinks: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + sm_count = get_device_sm_count(query.place) if out is None: - out = torch.empty_like(query) + out = paddle.empty_like(x=query) op.trtllm_paged_attention_context( out, - None, # fp4 output not supported in wrapper api yet. + None, query, k_cache, v_cache, @@ -207,9 +187,9 @@ def _paged_run( max_kv_len, bmm1_scale, bmm2_scale, - -1, # o_sf_scale - -1, # o_sf_vec_size - 0, # o_sf_start_index + -1, + -1, + 0, batch_size, window_left, cum_seq_lens_q, @@ -221,8 +201,6 @@ def _paged_run( return out def _ragged_run(*args, **kwargs): - # TODO(Zihao): trtllm-gen backend already supports variable length attention, - # but not integrated into flashinfer yet. raise NotImplementedError( "Variable length is not implemented for trtllm-gen backend yet." ) @@ -230,11 +208,7 @@ def _ragged_run(*args, **kwargs): def _plan(*args, **kwargs): pass - return SimpleNamespace( - paged_run=_paged_run, - ragged_run=_ragged_run, - plan=_plan, - ) + return SimpleNamespace(paged_run=_paged_run, ragged_run=_ragged_run, plan=_plan) @functools.cache @@ -243,28 +217,26 @@ def get_single_prefill_module(backend, *args): module = gen_single_prefill_module(backend, *args).build_and_load() run_func = module.run.default - # torch library for single_prefill_with_kv_cache - @register_custom_op( f"flashinfer::{uri}_run", mutates_args=("tmp", "o", "maybe_lse") ) def run_single_prefill( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - tmp: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + tmp: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], mask_mode: int, layout: int, window_left: int, - maybe_packed_custom_mask: Optional[torch.Tensor], - maybe_alibi_slopes: Optional[torch.Tensor], + maybe_packed_custom_mask: Optional[paddle.Tensor], + maybe_alibi_slopes: Optional[paddle.Tensor], logits_soft_cap: float, sm_scale: float, - scale_q: Optional[torch.Tensor], - scale_k: Optional[torch.Tensor], - scale_v: Optional[torch.Tensor], + scale_q: Optional[paddle.Tensor], + scale_k: Optional[paddle.Tensor], + scale_v: Optional[paddle.Tensor], rope_scale: float, rope_theta: float, ) -> None: @@ -284,7 +256,6 @@ def run_single_prefill( sm_scale, ) else: - # FP8 enabled run_func( q, k, @@ -315,24 +286,24 @@ def run_single_prefill( maybe_alibi_slopes, logits_soft_cap, sm_scale, - 1.0 / rope_scale, # rope_rcp_scale - 1.0 / rope_theta, # rope_rcp_theta + 1.0 / rope_scale, + 1.0 / rope_theta, ) return o @register_fake_op(f"flashinfer::{uri}_run") def _fake_run_single_prefill( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - tmp: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + tmp: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], mask_mode: int, layout: int, window_left: int, - maybe_packed_custom_mask: Optional[torch.Tensor], - maybe_alibi_slopes: Optional[torch.Tensor], + maybe_packed_custom_mask: Optional[paddle.Tensor], + maybe_alibi_slopes: Optional[paddle.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, @@ -340,7 +311,6 @@ def _fake_run_single_prefill( ) -> None: pass - # Register the module return SimpleNamespace(run=run_single_prefill) @@ -359,8 +329,6 @@ def get_batch_prefill_module(backend, *args): ragged_run_func = module.ragged_run.default paged_run_func = module.paged_run.default - # torch library for ragged_run - @register_custom_op( f"flashinfer::{uri}_ragged_run", mutates_args=( @@ -371,26 +339,26 @@ def get_batch_prefill_module(backend, *args): ), ) def ragged_run( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qo_indptr: paddle.Tensor, + kv_indptr: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], mask_mode: int, layout: int, window_left: int, enable_pdl: bool, - maybe_custom_mask: Optional[torch.Tensor], - maybe_mask_indptr: Optional[torch.Tensor], - maybe_alibi_slopes: Optional[torch.Tensor], - maybe_prefix_len_ptr: Optional[torch.Tensor], - maybe_token_pos_in_items_ptr: Optional[torch.Tensor], - maybe_max_item_len_ptr: Optional[torch.Tensor], + maybe_custom_mask: Optional[paddle.Tensor], + maybe_mask_indptr: Optional[paddle.Tensor], + maybe_alibi_slopes: Optional[paddle.Tensor], + maybe_prefix_len_ptr: Optional[paddle.Tensor], + maybe_token_pos_in_items_ptr: Optional[paddle.Tensor], + maybe_max_item_len_ptr: Optional[paddle.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, @@ -421,8 +389,8 @@ def ragged_run( maybe_max_item_len_ptr, logits_soft_cap, sm_scale, - 1.0 / rope_scale, # rope_rcp_scale - 1.0 / rope_theta, # rope_rcp_theta + 1.0 / rope_scale, + 1.0 / rope_theta, token_pos_in_items_len, ) else: @@ -448,31 +416,30 @@ def ragged_run( sm_scale, token_pos_in_items_len, ) - return o @register_fake_op(f"flashinfer::{uri}_ragged_run") def _fake_ragged_run( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qo_indptr: paddle.Tensor, + kv_indptr: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], mask_mode: int, layout: int, window_left: int, enable_pdl: bool, - maybe_custom_mask: Optional[torch.Tensor], - maybe_mask_indptr: Optional[torch.Tensor], - maybe_alibi_slopes: Optional[torch.Tensor], - maybe_prefix_len_ptr: Optional[torch.Tensor], - maybe_token_pos_in_items_ptr: Optional[torch.Tensor], - maybe_max_item_len_ptr: Optional[torch.Tensor], + maybe_custom_mask: Optional[paddle.Tensor], + maybe_mask_indptr: Optional[paddle.Tensor], + maybe_alibi_slopes: Optional[paddle.Tensor], + maybe_prefix_len_ptr: Optional[paddle.Tensor], + maybe_token_pos_in_items_ptr: Optional[paddle.Tensor], + maybe_max_item_len_ptr: Optional[paddle.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, @@ -481,8 +448,6 @@ def _fake_ragged_run( ) -> None: pass - # torch library for paged_run - @register_custom_op( f"flashinfer::{uri}_paged_run", mutates_args=( @@ -495,47 +460,47 @@ def _fake_ragged_run( ), ) def paged_run( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - paged_k_cache: torch.Tensor, - paged_v_cache: torch.Tensor, - qo_indptr: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + paged_k_cache: paddle.Tensor, + paged_v_cache: paddle.Tensor, + qo_indptr: paddle.Tensor, + paged_kv_indptr: paddle.Tensor, + paged_kv_indices: paddle.Tensor, + paged_kv_last_page_len: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], mask_mode: int, layout: int, window_left: int, enable_pdl: bool, - maybe_custom_mask: Optional[torch.Tensor], - maybe_mask_indptr: Optional[torch.Tensor], - maybe_alibi_slopes: Optional[torch.Tensor], - maybe_prefix_len_ptr: Optional[torch.Tensor], - maybe_token_pos_in_items_ptr: Optional[torch.Tensor], - maybe_max_item_len_ptr: Optional[torch.Tensor], + maybe_custom_mask: Optional[paddle.Tensor], + maybe_mask_indptr: Optional[paddle.Tensor], + maybe_alibi_slopes: Optional[paddle.Tensor], + maybe_prefix_len_ptr: Optional[paddle.Tensor], + maybe_token_pos_in_items_ptr: Optional[paddle.Tensor], + maybe_max_item_len_ptr: Optional[paddle.Tensor], logits_soft_cap: float, sm_scale: float, - scale_q: Optional[torch.Tensor], - scale_k: Optional[torch.Tensor], - scale_v: Optional[torch.Tensor], + scale_q: Optional[paddle.Tensor], + scale_k: Optional[paddle.Tensor], + scale_v: Optional[paddle.Tensor], rope_scale: float, rope_theta: float, token_pos_in_items_len: int, num_qo_heads: Optional[int] = None, num_kv_heads: Optional[int] = None, - block_tables: Optional[torch.Tensor] = None, - kv_lens_buffer: Optional[torch.Tensor] = None, + block_tables: Optional[paddle.Tensor] = None, + kv_lens_buffer: Optional[paddle.Tensor] = None, page_size: Optional[int] = None, max_q_len: Optional[int] = None, max_kv_len: Optional[int] = None, batch_size: Optional[int] = None, - cum_seq_lens_q: Optional[torch.Tensor] = None, - cum_seq_lens_kv: Optional[torch.Tensor] = None, - sinks: Optional[torch.Tensor] = None, + cum_seq_lens_q: Optional[paddle.Tensor] = None, + cum_seq_lens_kv: Optional[paddle.Tensor] = None, + sinks: Optional[paddle.Tensor] = None, ) -> None: if backend == "trtllm-gen": assert maybe_lse is None @@ -550,7 +515,7 @@ def paged_run( assert cum_seq_lens_kv is not None assert enable_pdl is not None o = paged_run_func( - q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect + q.contiguous(), paged_k_cache, paged_v_cache, int_workspace_buffer, @@ -559,7 +524,7 @@ def paged_run( max_q_len, max_kv_len, sm_scale, - 1.0, # NOTE(Siyuan): update this to expose bmm2 scale + 1.0, batch_size, cum_seq_lens_q, cum_seq_lens_kv, @@ -595,85 +560,84 @@ def paged_run( maybe_max_item_len_ptr, logits_soft_cap, sm_scale, - 1.0 / rope_scale, # rope_rcp_scale - 1.0 / rope_theta, # rope_rcp_theta + 1.0 / rope_scale, + 1.0 / rope_theta, + token_pos_in_items_len, + ) + elif not is_float8(q): + paged_run_func( + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + paged_k_cache, + paged_v_cache, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + o, + maybe_lse, + mask_mode, + layout, + window_left, + enable_pdl, + maybe_prefix_len_ptr, + maybe_token_pos_in_items_ptr, + maybe_max_item_len_ptr, + logits_soft_cap, + sm_scale, token_pos_in_items_len, ) else: - if not is_float8(q): - paged_run_func( - float_workspace_buffer, - int_workspace_buffer, - plan_info_vec, - q, - paged_k_cache, - paged_v_cache, - qo_indptr, - paged_kv_indptr, - paged_kv_indices, - paged_kv_last_page_len, - o, - maybe_lse, - mask_mode, - layout, - window_left, - enable_pdl, - maybe_prefix_len_ptr, - maybe_token_pos_in_items_ptr, - maybe_max_item_len_ptr, - logits_soft_cap, - sm_scale, - token_pos_in_items_len, - ) - else: - paged_run_func( - float_workspace_buffer, - int_workspace_buffer, - plan_info_vec, - q, - paged_k_cache, - paged_v_cache, - qo_indptr, - paged_kv_indptr, - paged_kv_indices, - paged_kv_last_page_len, - o, - maybe_lse, - mask_mode, - layout, - window_left, - enable_pdl, - scale_q, - scale_k, - scale_v, - sm_scale, - ) + paged_run_func( + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + paged_k_cache, + paged_v_cache, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + o, + maybe_lse, + mask_mode, + layout, + window_left, + enable_pdl, + scale_q, + scale_k, + scale_v, + sm_scale, + ) return o @register_fake_op(f"flashinfer::{uri}_paged_run") def _fake_paged_run( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - paged_k_cache: torch.Tensor, - paged_v_cache: torch.Tensor, - qo_indptr: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + paged_k_cache: paddle.Tensor, + paged_v_cache: paddle.Tensor, + qo_indptr: paddle.Tensor, + paged_kv_indptr: paddle.Tensor, + paged_kv_indices: paddle.Tensor, + paged_kv_last_page_len: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], mask_mode: int, layout: int, window_left: int, enable_pdl: bool, - maybe_custom_mask: Optional[torch.Tensor], - maybe_mask_indptr: Optional[torch.Tensor], - maybe_alibi_slopes: Optional[torch.Tensor], - maybe_prefix_len_ptr: Optional[torch.Tensor], - maybe_token_pos_in_items_ptr: Optional[torch.Tensor], - maybe_max_item_len_ptr: Optional[torch.Tensor], + maybe_custom_mask: Optional[paddle.Tensor], + maybe_mask_indptr: Optional[paddle.Tensor], + maybe_alibi_slopes: Optional[paddle.Tensor], + maybe_prefix_len_ptr: Optional[paddle.Tensor], + maybe_token_pos_in_items_ptr: Optional[paddle.Tensor], + maybe_max_item_len_ptr: Optional[paddle.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, @@ -681,26 +645,18 @@ def _fake_paged_run( token_pos_in_items_len: int, num_qo_heads: Optional[int] = None, num_kv_heads: Optional[int] = None, - block_tables: Optional[torch.Tensor] = None, - kv_lens_buffer: Optional[torch.Tensor] = None, + block_tables: Optional[paddle.Tensor] = None, + kv_lens_buffer: Optional[paddle.Tensor] = None, page_size: Optional[int] = None, max_q_len: Optional[int] = None, max_kv_len: Optional[int] = None, batch_size: Optional[int] = None, - cum_seq_lens_q: Optional[torch.Tensor] = None, - cum_seq_lens_kv: Optional[torch.Tensor] = None, + cum_seq_lens_q: Optional[paddle.Tensor] = None, + cum_seq_lens_kv: Optional[paddle.Tensor] = None, ) -> None: pass - # Register the module. - # - # Note that plan is not part of model logic. It should not be included in - # Cuda Graph or torch.compile. So, we don't provide a torch library for plan. - return SimpleNamespace( - plan=plan_func, - ragged_run=ragged_run, - paged_run=paged_run, - ) + return SimpleNamespace(plan=plan_func, ragged_run=ragged_run, paged_run=paged_run) @functools.cache @@ -709,7 +665,6 @@ def get_batch_prefill_jit_module(module_name: str, jit_module: Any): ragged_run_func = jit_module.ragged_run.default paged_run_func = jit_module.paged_run.default - # torch library for ragged_run @register_custom_op( f"flashinfer::{module_name}_ragged_run", mutates_args=( @@ -720,16 +675,16 @@ def get_batch_prefill_jit_module(module_name: str, jit_module: Any): ), ) def ragged_run( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qo_indptr: paddle.Tensor, + kv_indptr: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], mask_mode: int, layout: int, window_left: int, @@ -754,16 +709,16 @@ def ragged_run( @register_fake_op(f"flashinfer::{module_name}_ragged_run") def _fake_ragged_run( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qo_indptr: paddle.Tensor, + kv_indptr: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], mask_mode: int, layout: int, window_left: int, @@ -771,7 +726,6 @@ def _fake_ragged_run( ) -> None: pass - # torch library for paged_run @register_custom_op( f"flashinfer::{module_name}_paged_run", mutates_args=( @@ -784,18 +738,18 @@ def _fake_ragged_run( ), ) def paged_run( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - paged_k_cache: torch.Tensor, - paged_v_cache: torch.Tensor, - qo_indptr: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + paged_k_cache: paddle.Tensor, + paged_v_cache: paddle.Tensor, + qo_indptr: paddle.Tensor, + paged_kv_indptr: paddle.Tensor, + paged_kv_indices: paddle.Tensor, + paged_kv_last_page_len: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], mask_mode: int, layout: int, window_left: int, @@ -822,18 +776,18 @@ def paged_run( @register_fake_op(f"flashinfer::{module_name}_paged_run") def _fake_paged_run( - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, plan_info_vec: List[int], - q: torch.Tensor, - paged_k_cache: torch.Tensor, - paged_v_cache: torch.Tensor, - qo_indptr: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - o: torch.Tensor, - maybe_lse: Optional[torch.Tensor], + q: paddle.Tensor, + paged_k_cache: paddle.Tensor, + paged_v_cache: paddle.Tensor, + qo_indptr: paddle.Tensor, + paged_kv_indptr: paddle.Tensor, + paged_kv_indices: paddle.Tensor, + paged_kv_last_page_len: paddle.Tensor, + o: paddle.Tensor, + maybe_lse: Optional[paddle.Tensor], mask_mode: int, layout: int, window_left: int, @@ -841,36 +795,28 @@ def _fake_paged_run( ) -> None: pass - # Register the module. - # - # Note that plan is not part of model logic. It should not be included in - # Cuda Graph or torch.compile. So, we don't provide a torch library for plan. - return SimpleNamespace( - plan=plan_func, - ragged_run=ragged_run, - paged_run=paged_run, - ) + return SimpleNamespace(plan=plan_func, ragged_run=ragged_run, paged_run=paged_run) def single_prefill_with_kv_cache_with_jit_module( jit_module: Any, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, *args, kv_layout: str = "NHD", mask_mode: int = MaskMode.NON_CAUSAL.value, window_left: int = -1, return_lse: bool = False, -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - device = q.device +) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + device = q.place tmp = _get_cache_buf( "single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, device=device ) - o = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=device) + o = paddle.empty(shape=tuple(q.shape)[:-1] + tuple(v.shape)[-1:], dtype=q.dtype) lse = None if return_lse: - lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=device) + lse = paddle.empty(shape=(q.shape[0], q.shape[1]), dtype="float32") jit_module.run.default( q, k, @@ -888,15 +834,15 @@ def single_prefill_with_kv_cache_with_jit_module( @overload def single_prefill_with_kv_cache( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale_q: Optional[torch.Tensor] = None, - scale_k: Optional[torch.Tensor] = None, - scale_v: Optional[torch.Tensor] = None, - o_dtype: Optional[torch.dtype] = None, - custom_mask: Optional[torch.Tensor] = None, - packed_custom_mask: Optional[torch.Tensor] = None, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + scale_q: Optional[paddle.Tensor] = None, + scale_k: Optional[paddle.Tensor] = None, + scale_v: Optional[paddle.Tensor] = None, + o_dtype: Optional[paddle.dtype] = None, + custom_mask: Optional[paddle.Tensor] = None, + packed_custom_mask: Optional[paddle.Tensor] = None, causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", @@ -908,20 +854,21 @@ def single_prefill_with_kv_cache( rope_theta: Optional[float] = None, backend: str = "auto", return_lse: Literal[False] = False, -) -> torch.Tensor: ... +) -> paddle.Tensor: + ... @overload def single_prefill_with_kv_cache( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale_q: Optional[torch.Tensor] = None, - scale_k: Optional[torch.Tensor] = None, - scale_v: Optional[torch.Tensor] = None, - o_dtype: Optional[torch.dtype] = None, - custom_mask: Optional[torch.Tensor] = None, - packed_custom_mask: Optional[torch.Tensor] = None, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + scale_q: Optional[paddle.Tensor] = None, + scale_k: Optional[paddle.Tensor] = None, + scale_v: Optional[paddle.Tensor] = None, + o_dtype: Optional[paddle.dtype] = None, + custom_mask: Optional[paddle.Tensor] = None, + packed_custom_mask: Optional[paddle.Tensor] = None, causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", @@ -933,19 +880,20 @@ def single_prefill_with_kv_cache( rope_theta: Optional[float] = None, backend: str = "auto", return_lse: Literal[True] = True, -) -> Tuple[torch.Tensor, torch.Tensor]: ... +) -> Tuple[paddle.Tensor, paddle.Tensor]: + ... def single_prefill_with_kv_cache( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale_q: Optional[torch.Tensor] = None, - scale_k: Optional[torch.Tensor] = None, - scale_v: Optional[torch.Tensor] = None, - o_dtype: Optional[torch.dtype] = None, - custom_mask: Optional[torch.Tensor] = None, - packed_custom_mask: Optional[torch.Tensor] = None, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + scale_q: Optional[paddle.Tensor] = None, + scale_k: Optional[paddle.Tensor] = None, + scale_v: Optional[paddle.Tensor] = None, + o_dtype: Optional[paddle.dtype] = None, + custom_mask: Optional[paddle.Tensor] = None, + packed_custom_mask: Optional[paddle.Tensor] = None, causal: bool = False, kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", @@ -957,8 +905,8 @@ def single_prefill_with_kv_cache( rope_theta: Optional[float] = None, backend: str = "auto", return_lse: bool = False, -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - r"""Prefill/Append attention with KV cache for single request, return the attention +) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Prefill/Append attention with KV cache for single request, return the attention output. Parameters @@ -1016,7 +964,7 @@ def single_prefill_with_kv_cache( The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: - :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, + :math:`\\texttt{logits_soft_cap} \\times \\mathrm{tanh}(x / \\texttt{logits_soft_cap})`, where :math:`x` is the input logits. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim_qk)``. @@ -1081,75 +1029,62 @@ def single_prefill_with_kv_cache( """ _check_pos_encoding_mode(pos_encoding_mode) _check_kv_layout(kv_layout) - tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, q.device) + tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, q.place) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: - sm_scale = 1.0 / math.sqrt(q.size(-1)) + sm_scale = 1.0 / math.sqrt(q.shape[-1]) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: - rope_theta = 1e4 + rope_theta = 10000.0 if custom_mask is not None and packed_custom_mask is None: - # create packed custom mask from custom mask packed_custom_mask = packbits( custom_mask.contiguous().view(-1), bitorder="little" ) - if packed_custom_mask is not None: mask_mode = MaskMode.CUSTOM.value + elif causal: + mask_mode = MaskMode.CAUSAL.value else: - if causal: - mask_mode = MaskMode.CAUSAL.value - else: - mask_mode = MaskMode.NON_CAUSAL.value - + mask_mode = MaskMode.NON_CAUSAL.value lse = None if return_lse: - lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.device) - + lse = paddle.empty(shape=(q.shape[0], q.shape[1]), dtype="float32") if is_float8(q): - # FP8 quant enabled, do sanity check: - # 1. unsupported feature - # 2. dtype check assert window_left == -1 assert q.dtype == k.dtype == v.dtype - assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert tuple(q.shape)[-1] == tuple(k.shape)[-1] == tuple(v.shape)[-1] if scale_q is None: - scale_q = torch.ones(q.shape[1], dtype=torch.float32, device=q.device) + scale_q = paddle.ones(shape=tuple(q.shape)[1], dtype="float32") if scale_k is None: - scale_k = torch.ones(k.shape[1], dtype=torch.float32, device=q.device) + scale_k = paddle.ones(shape=tuple(k.shape)[1], dtype="float32") if scale_v is None: - scale_v = torch.ones(v.shape[1], dtype=torch.float32, device=q.device) - + scale_v = paddle.ones(shape=tuple(v.shape)[1], dtype="float32") if backend == "auto": backend = determine_attention_backend( - q.device, + q.place, PosEncodingMode[pos_encoding_mode].value, use_fp16_qk_reduction, - packed_custom_mask is not None, # use_custom_mask + packed_custom_mask is not None, q.dtype, k.dtype, ) - - # o_dtype should be provided for FP8 attention if o_dtype is None: o_dtype = q.dtype - out = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=o_dtype, device=q.device) - + out = paddle.empty(shape=tuple(q.shape)[:-1] + tuple(v.shape)[-1:], dtype=o_dtype) module = get_single_prefill_module( backend, q.dtype, k.dtype, out.dtype, - q.shape[-1], # head_dim_qk - v.shape[-1], # head_dim_vo + tuple(q.shape)[-1], + tuple(v.shape)[-1], PosEncodingMode[pos_encoding_mode].value, - window_left >= 0, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap + window_left >= 0, + logits_soft_cap > 0, use_fp16_qk_reduction, ) - module.run( q, k, @@ -1161,7 +1096,7 @@ def single_prefill_with_kv_cache( TensorLayout[kv_layout].value, window_left, packed_custom_mask, - _get_cache_alibi_slopes_buf(q.shape[1], q.device), + _get_cache_alibi_slopes_buf(tuple(q.shape)[1], q.place), logits_soft_cap, sm_scale, scale_q, @@ -1170,7 +1105,6 @@ def single_prefill_with_kv_cache( rope_scale, rope_theta, ) - return (out, lse) if return_lse else out @@ -1180,30 +1114,30 @@ def single_prefill_with_kv_cache( def _compute_page_mask_indptr( - qo_indptr: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, + qo_indptr: paddle.Tensor, + paged_kv_indptr: paddle.Tensor, + paged_kv_last_page_len: paddle.Tensor, page_size: int, -) -> torch.Tensor: +) -> paddle.Tensor: if len(qo_indptr) != len(paged_kv_indptr): raise ValueError( "The length of qo_indptr and paged_kv_indptr should be the same." ) - mask_indptr = torch.empty_like(qo_indptr) + mask_indptr = paddle.empty_like(x=qo_indptr) mask_indptr[0] = 0 - mask_indptr[1:] = torch.cumsum( - (qo_indptr[1:] - qo_indptr[:-1]) + mask_indptr[1:] = paddle.cumsum( + x=(qo_indptr[1:] - qo_indptr[:-1]) * ( (paged_kv_indptr[1:] - paged_kv_indptr[:-1] - 1) * page_size + paged_kv_last_page_len ), - 0, + axis=0, ) return mask_indptr class BatchPrefillWithPagedKVCacheWrapper: - r"""Wrapper class for prefill/append attention with paged kv-cache for batch of + """Wrapper class for prefill/append attention with paged kv-cache for batch of requests. Check :ref:`our tutorial ` for page table layout. @@ -1306,20 +1240,20 @@ class BatchPrefillWithPagedKVCacheWrapper: def __init__( self, - float_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, - qo_indptr_buf: Optional[torch.Tensor] = None, - paged_kv_indptr_buf: Optional[torch.Tensor] = None, - paged_kv_indices_buf: Optional[torch.Tensor] = None, - paged_kv_last_page_len_buf: Optional[torch.Tensor] = None, - custom_mask_buf: Optional[torch.Tensor] = None, - mask_indptr_buf: Optional[torch.Tensor] = None, + qo_indptr_buf: Optional[paddle.Tensor] = None, + paged_kv_indptr_buf: Optional[paddle.Tensor] = None, + paged_kv_indices_buf: Optional[paddle.Tensor] = None, + paged_kv_last_page_len_buf: Optional[paddle.Tensor] = None, + custom_mask_buf: Optional[paddle.Tensor] = None, + mask_indptr_buf: Optional[paddle.Tensor] = None, backend: str = "auto", jit_args: Optional[List[Any]] = None, jit_kwargs: Optional[Dict[str, Any]] = None, ) -> None: - r"""Constructor of :class:`BatchPrefillWithPagedKVCacheWrapper`. + """Constructor of :class:`BatchPrefillWithPagedKVCacheWrapper`. Parameters ---------- @@ -1382,7 +1316,6 @@ def __init__( The keyword arguments to create the JIT module, defaults to None. """ _check_kv_layout(kv_layout) - if jit_args is not None: if jit_kwargs is None: jit_kwargs = {} @@ -1392,51 +1325,42 @@ def __init__( ) else: self._jit_module = None - self._kv_layout = kv_layout if backend == "cudnn": assert kv_layout == "NHD", "CUDNN backend only supports NHD layout" - self._float_workspace_buffer = float_workspace_buffer - self.device = float_workspace_buffer.device - self._vector_sparse_indptr_buffer: Optional[torch.Tensor] = None + self.device = float_workspace_buffer.place + self._vector_sparse_indptr_buffer: Optional[paddle.Tensor] = None if backend in ["fa3", "auto", "trtllm-gen"]: - # NOTE(Zihao): assume maximum accumulate kv length is 16M - self._vector_sparse_indices_buffer = torch.empty( - (16 * 1024 * 1024,), dtype=torch.int32, device=self.device + self._vector_sparse_indices_buffer = paddle.empty( + shape=(16 * 1024 * 1024,), dtype="int32" ) - # NOTE(Zihao): assume maximum batch size is 32768 - self._vector_sparse_indptr_buffer = torch.empty( - (32768,), dtype=torch.int32, device=self.device + self._vector_sparse_indptr_buffer = paddle.empty( + shape=(32768,), dtype="int32" ) - - self._kv_lens_buffer = torch.empty( - (32768,), dtype=torch.int32, device=self.device - ) - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + self._kv_lens_buffer = paddle.empty(shape=(32768,), dtype="int32") + self._int_workspace_buffer = paddle.empty( + shape=(8 * 1024 * 1024,), dtype="uint8" ) - self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=tuple(self._int_workspace_buffer.shape), dtype=self._int_workspace_buffer.dtype, - device="cpu", - pin_memory=True, - ) + ).pin_memory() self._use_cuda_graph = use_cuda_graph if use_cuda_graph: - if not torch.is_tensor(qo_indptr_buf): + if not paddle.is_tensor(x=qo_indptr_buf): raise ValueError( "qo_indptr_buf should be a torch.Tensor in CUDA graph mode" ) - if not torch.is_tensor(paged_kv_indptr_buf): + if not paddle.is_tensor(x=paged_kv_indptr_buf): raise ValueError( "paged_kv_indptr_buf should be a torch.Tensor in CUDA graph mode" ) - if not torch.is_tensor(paged_kv_indices_buf): + if not paddle.is_tensor(x=paged_kv_indices_buf): raise ValueError( "paged_kv_indices_buf should be a torch.Tensor in CUDA graph mode" ) - if not torch.is_tensor(paged_kv_last_page_len_buf): + if not paddle.is_tensor(x=paged_kv_last_page_len_buf): raise ValueError( "paged_kv_last_page_len_buf should be a torch.Tensor in CUDA graph mode" ) @@ -1449,10 +1373,8 @@ def __init__( raise ValueError( "The length of paged_kv_last_page_len_buf should be batch_size." ) - # NOTE(Zihao): do not check custom_mask_buf and mask_indptr_buf here, as they are optional else: self._fixed_batch_size = 0 - self._qo_indptr_buf = qo_indptr_buf self._paged_kv_indptr_buf = paged_kv_indptr_buf self._paged_kv_indices_buf = paged_kv_indices_buf @@ -1472,9 +1394,9 @@ def is_cuda_graph_enabled(self) -> bool: return self._use_cuda_graph def reset_workspace_buffer( - self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor + self, float_workspace_buffer: paddle.Tensor, int_workspace_buffer: paddle.Tensor ) -> None: - r"""Reset the workspace buffer. + """Reset the workspace buffer. Parameters ---------- @@ -1488,26 +1410,24 @@ def reset_workspace_buffer( """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=tuple(self._int_workspace_buffer.shape), dtype=self._int_workspace_buffer.dtype, - device="cpu", - pin_memory=True, - ) + ).pin_memory() def plan( self, - qo_indptr: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, + qo_indptr: paddle.Tensor, + paged_kv_indptr: paddle.Tensor, + paged_kv_indices: paddle.Tensor, + paged_kv_last_page_len: paddle.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim_qk: int, page_size: int, head_dim_vo: Optional[int] = None, - custom_mask: Optional[torch.Tensor] = None, - packed_custom_mask: Optional[torch.Tensor] = None, + custom_mask: Optional[paddle.Tensor] = None, + packed_custom_mask: Optional[paddle.Tensor] = None, causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, @@ -1516,20 +1436,20 @@ def plan( logits_soft_cap: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - q_data_type: Union[str, torch.dtype] = "float16", - kv_data_type: Optional[Union[str, torch.dtype]] = None, + q_data_type: Union[str, paddle.dtype] = "float16", + kv_data_type: Optional[Union[str, paddle.dtype]] = None, non_blocking: bool = True, - prefix_len_ptr: Optional[torch.Tensor] = None, - token_pos_in_items_ptr: Optional[torch.Tensor] = None, + prefix_len_ptr: Optional[paddle.Tensor] = None, + token_pos_in_items_ptr: Optional[paddle.Tensor] = None, token_pos_in_items_len: int = 0, - max_item_len_ptr: Optional[torch.Tensor] = None, - seq_lens: Optional[torch.Tensor] = None, - seq_lens_q: Optional[torch.Tensor] = None, - block_tables: Optional[torch.Tensor] = None, + max_item_len_ptr: Optional[paddle.Tensor] = None, + seq_lens: Optional[paddle.Tensor] = None, + seq_lens_q: Optional[paddle.Tensor] = None, + block_tables: Optional[paddle.Tensor] = None, max_token_per_sequence: Optional[int] = None, max_sequence_kv: Optional[int] = None, ) -> None: - r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification. + """Plan batch prefill/append attention on Paged KV-Cache for given problem specification. Parameters ---------- @@ -1586,7 +1506,7 @@ def plan( The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: - :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, + :math:`\\texttt{logits_soft_cap} \\times \\mathrm{tanh}(x / \\texttt{logits_soft_cap})`, where :math:`x` is the input logits. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to @@ -1644,44 +1564,32 @@ def plan( if kv_data_type is None: kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) - if logits_soft_cap is None: logits_soft_cap = 0.0 if head_dim_vo is None: head_dim_vo = head_dim_qk - batch_size = len(qo_indptr) - 1 self._batch_size = batch_size self._num_qo_heads = num_qo_heads self._num_kv_heads = num_kv_heads if custom_mask is not None or packed_custom_mask is not None: mask_indptr = _compute_page_mask_indptr( - qo_indptr, - paged_kv_indptr, - paged_kv_last_page_len, - page_size, + qo_indptr, paged_kv_indptr, paged_kv_last_page_len, page_size ) if packed_custom_mask is None and custom_mask is not None: - # create packed custom mask from custom mask packed_custom_mask, mask_indptr = segment_packbits( - custom_mask.contiguous().view(-1), - mask_indptr, - bitorder="little", + custom_mask.contiguous().view(-1), mask_indptr, bitorder="little" ) - self._prefix_len_ptr = prefix_len_ptr self._token_pos_in_items_ptr = token_pos_in_items_ptr self._token_pos_in_items_len = token_pos_in_items_len self._max_item_len_ptr = max_item_len_ptr - - # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors if max_token_per_sequence is not None: self._max_q_len = max_token_per_sequence else: qo_indptr_host = qo_indptr.to("cpu") self._max_q_len = max(qo_indptr_host).item() total_num_rows = qo_indptr_host[-1] - if max_sequence_kv is not None: self._max_kv_len = max_sequence_kv else: @@ -1693,27 +1601,22 @@ def plan( ) else: kv_lens_arr_host = seq_lens.cpu().flatten() - self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( - kv_lens_arr_host, non_blocking=non_blocking + paddle.assign( + kv_lens_arr_host, output=self._kv_lens_buffer[: len(kv_lens_arr_host)] ) self._max_kv_len = max(kv_lens_arr_host).item() - if self.is_cuda_graph_enabled: if self._max_total_num_rows is None: self._max_total_num_rows = total_num_rows elif total_num_rows > self._max_total_num_rows: raise ValueError( - "The total number of rows in qo_indptr {} in cuda graph mode cannot " - "exceed the number of rows set during initialization {}.".format( + "The total number of rows in qo_indptr {} in cuda graph mode cannot exceed the number of rows set during initialization {}.".format( total_num_rows, self._max_total_num_rows ) ) - if batch_size != self._fixed_batch_size: raise ValueError( - "The batch size should be fixed during the lifecycle of the wrapper in " - "cuda graph mode, the runtime batch size {} mismatches the batch size {} " - " set during initialization.".format( + "The batch size should be fixed during the lifecycle of the wrapper in cuda graph mode, the runtime batch size {} mismatches the batch size {} set during initialization.".format( batch_size, self._fixed_batch_size ) ) @@ -1721,58 +1624,52 @@ def plan( raise ValueError( "The length of paged_kv_indices exceeds the allocated buffer size." ) - - self._qo_indptr_buf.copy_(qo_indptr, non_blocking=non_blocking) - self._paged_kv_indptr_buf.copy_(paged_kv_indptr, non_blocking=non_blocking) - self._paged_kv_last_page_len_buf.copy_( - paged_kv_last_page_len, non_blocking=non_blocking + paddle.assign(qo_indptr, output=self._qo_indptr_buf) + paddle.assign(paged_kv_indptr, output=self._paged_kv_indptr_buf) + paddle.assign( + paged_kv_last_page_len, output=self._paged_kv_last_page_len_buf ) - self._paged_kv_indices_buf[: len(paged_kv_indices)].copy_( + paddle.assign( paged_kv_indices, - non_blocking=(paged_kv_indices.device == self.device) and non_blocking, + output=self._paged_kv_indices_buf[: len(paged_kv_indices)], ) - if packed_custom_mask is not None: - if not torch.is_tensor(self._custom_mask_buf): + if not paddle.is_tensor(x=self._custom_mask_buf): raise ValueError( "custom_mask_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation." ) - if not torch.is_tensor(self._mask_indptr_buf): + if not paddle.is_tensor(x=self._mask_indptr_buf): raise ValueError( "mask_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation." ) - self._custom_mask_buf[: len(packed_custom_mask)].copy_( + paddle.assign( packed_custom_mask, - non_blocking=(packed_custom_mask.device == self.device) - and non_blocking, + output=self._custom_mask_buf[: len(packed_custom_mask)], ) - # NOTE(Zihao): mask_indptr has the same length as qo_indptr - self._mask_indptr_buf.copy_(mask_indptr, non_blocking=non_blocking) + paddle.assign(mask_indptr, output=self._mask_indptr_buf) else: - self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=non_blocking) + self._qo_indptr_buf = qo_indptr.to(self.device, blocking=not non_blocking) self._paged_kv_indptr_buf = paged_kv_indptr.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) self._paged_kv_indices_buf = paged_kv_indices.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) if packed_custom_mask is not None: self._custom_mask_buf = packed_custom_mask.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) self._mask_indptr_buf = mask_indptr.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) else: self._custom_mask_buf = None self._mask_indptr_buf = None - self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type - if self._jit_module is not None: self._cached_module = self._jit_module else: @@ -1781,7 +1678,7 @@ def plan( self.device, PosEncodingMode[pos_encoding_mode].value, use_fp16_qk_reduction, - self._custom_mask_buf is not None, # use_custom_mask + self._custom_mask_buf is not None, q_data_type, kv_data_type, ) @@ -1794,57 +1691,54 @@ def plan( head_dim_qk, head_dim_vo, PosEncodingMode[pos_encoding_mode].value, - window_left >= 0, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap + window_left >= 0, + logits_soft_cap > 0, use_fp16_qk_reduction, ) - self._cached_module = get_batch_prefill_module( self._backend, *get_module_args ) - if self._backend == "fa3" or self._backend == "trtllm-gen": if page_size != 1: - vector_sparse_indptr_host = torch.cat( - [ - torch.tensor( - [0], dtype=torch.int32, device=kv_lens_arr_host.device + vector_sparse_indptr_host = paddle.concat( + x=[ + paddle.to_tensor( + data=[0], dtype="int32", place=kv_lens_arr_host.place ), - torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), + paddle.cumsum(x=kv_lens_arr_host, axis=0, dtype="int32"), + ], + axis=0, + ) + paddle.assign( + vector_sparse_indptr_host, + output=self._vector_sparse_indptr_buffer[ + : len(vector_sparse_indptr_host) ], - dim=0, ) - self._vector_sparse_indptr_buffer[ - : len(vector_sparse_indptr_host) - ].copy_(vector_sparse_indptr_host, non_blocking=non_blocking) paged_kv_indptr_host = vector_sparse_indptr_host - self._block_tables = block_tables if self._backend == "trtllm-gen": assert self._kv_layout == "HND" assert logits_soft_cap == 0.0 if self._block_tables is None: blocks_per_seq = [ - (seq_len + page_size - 1) // page_size + ((seq_len + page_size - 1) // page_size) for seq_len in kv_lens_arr_host ] max_num_blocks_per_seq = max(blocks_per_seq) - self._block_tables = torch.zeros( - (batch_size, max_num_blocks_per_seq), - dtype=torch.int, - device=self.device, + self._block_tables = paddle.zeros( + shape=(batch_size, max_num_blocks_per_seq), dtype="int32" ) block_id = paged_kv_indptr_host[0] for i in range(batch_size): num_blocks_needed = blocks_per_seq[i] - assert self._block_tables is not None, ( - "block_tables is not initialized" - ) + assert ( + self._block_tables is not None + ), "block_tables is not initialized" self._block_tables[i, :num_blocks_needed] = paged_kv_indices[ block_id : block_id + num_blocks_needed ] block_id += num_blocks_needed - if self._cached_module is not None: self._plan_info = self._cached_module.plan( self._float_workspace_buffer, @@ -1863,7 +1757,6 @@ def plan( head_dim_vo, causal, ) - self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction @@ -1879,8 +1772,8 @@ def plan( def forward( self, - q: torch.Tensor, - paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + q: paddle.Tensor, + paged_kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, @@ -1891,8 +1784,8 @@ def forward( sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - ) -> torch.Tensor: - r"""Warning: This function is deprecated, please use :meth:`run` instead.""" + ) -> paddle.Tensor: + """Warning: This function is deprecated, please use :meth:`run` instead.""" self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction @@ -1906,49 +1799,51 @@ def forward( @overload def run( self, - q: torch.Tensor, - paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + q: paddle.Tensor, + paged_kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], *args, k_scale: Optional[float] = None, v_scale: Optional[float] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: Literal[False] = False, enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, - ) -> torch.Tensor: ... + ) -> paddle.Tensor: + ... @overload def run( self, - q: torch.Tensor, - paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + q: paddle.Tensor, + paged_kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], *args, k_scale: Optional[float] = None, v_scale: Optional[float] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: Literal[True] = True, enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: ... + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + ... def run( self, - q: torch.Tensor, - paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + q: paddle.Tensor, + paged_kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], *args, q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: bool = False, enable_pdl: Optional[bool] = None, window_left: Optional[int] = None, - sinks: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - r"""Compute batch prefill/append attention between query and paged kv-cache. + sinks: Optional[paddle.Tensor] = None, + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Compute batch prefill/append attention between query and paged kv-cache. Parameters ---------- @@ -1993,22 +1888,20 @@ def run( * The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads]``. """ if enable_pdl is None: - enable_pdl = device_support_pdl(q.device) + enable_pdl = device_support_pdl(q.place) k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) _check_cached_qkv_data_type( q, k_cache, self._cached_q_data_type, self._cached_kv_data_type ) - stride_block = k_cache.stride(0) + stride_block = k_cache.get_strides()[0] if self._kv_layout == "NHD": - page_size = k_cache.shape[1] - stride_n = k_cache.stride(1) + page_size = tuple(k_cache.shape)[1] + stride_n = k_cache.get_strides()[1] else: - page_size = k_cache.shape[2] - stride_n = k_cache.stride(2) + page_size = tuple(k_cache.shape)[2] + stride_n = k_cache.get_strides()[2] window_left = self._window_left if window_left is None else window_left if self._backend != "trtllm-gen": - # NOTE(Siyuan): since window_left is appeared in the plan function, we need to make sure it is the same as the one in the plan function. - # Remove this check if the backend supports dynamic window_left. assert window_left == self._window_left logits_soft_cap = self._logits_soft_cap sm_scale = self._sm_scale @@ -2017,7 +1910,7 @@ def run( if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: - sm_scale = 1.0 / math.sqrt(q.size(-1)) + sm_scale = 1.0 / math.sqrt(q.shape[-1]) if q_scale is not None: sm_scale *= q_scale if k_scale is not None: @@ -2025,66 +1918,58 @@ def run( if rope_scale is None: rope_scale = 1.0 if rope_theta is None: - rope_theta = 1e4 + rope_theta = 10000.0 if return_lse: if lse is None: - lse = torch.empty( - (q.size(0), q.size(1)), dtype=torch.float32, device=q.device - ) + lse = paddle.empty(shape=(q.shape[0], q.shape[1]), dtype="float32") else: check_shape_dtype_device( - lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" + lse, (q.shape[0], q.shape[1]), "float32", q.place, "lse" ) - if out is None: - out = torch.empty( - q.shape[:-1] + v_cache.shape[-1:], dtype=q.dtype, device=q.device + out = paddle.empty( + shape=tuple(q.shape)[:-1] + tuple(v_cache.shape)[-1:], dtype=q.dtype ) else: check_shape_dtype_device( - out, q.shape[:-1] + v_cache.shape[-1:], q.dtype, q.device, "out" + out, + tuple(q.shape)[:-1] + tuple(v_cache.shape)[-1:], + q.dtype, + q.place, + "out", ) - if self._custom_mask_buf is not None: mask_mode = MaskMode.CUSTOM.value + elif self._causal: + mask_mode = MaskMode.CAUSAL.value else: - if self._causal: - mask_mode = MaskMode.CAUSAL.value - else: - mask_mode = MaskMode.NON_CAUSAL.value - + mask_mode = MaskMode.NON_CAUSAL.value if self._prefix_len_ptr is not None: mask_mode = MaskMode.MULTIITEMSCORING.value - if self._backend == "fa3": - # NOTE(Zihao): we divide both stride_block and stride_n by stride_n - # because we will multiply stride_n back in the kernel sparse_indices = block_sparse_indices_to_vector_sparse_offsets( self._paged_kv_indices_buf, self._paged_kv_indptr_buf, - self._vector_sparse_indices_buffer, # output + self._vector_sparse_indices_buffer, self._vector_sparse_indptr_buffer, self._kv_lens_buffer, stride_block // stride_n, - 1, # stride_n // stride_n + 1, page_size, ) sparse_indptr = self._vector_sparse_indptr_buffer else: sparse_indices = self._paged_kv_indices_buf sparse_indptr = self._paged_kv_indptr_buf - if self._backend == "cudnn": if self._seq_lens_q is not None and self._seq_lens_q.dim() == 1: self._seq_lens_q = self._seq_lens_q.reshape(self._batch_size, 1, 1, 1) - if self._seq_lens_kv is not None and self._seq_lens_kv.dim() == 1: self._seq_lens_kv = self._seq_lens_kv.reshape(self._batch_size, 1, 1, 1) - cudnn_batch_prefill_with_kv_cache( q, - k_cache, # Need to be changed - v_cache, # Need to be changed + k_cache, + v_cache, self._sm_scale, self._float_workspace_buffer, actual_seq_lens_q=self._seq_lens_q, @@ -2126,15 +2011,15 @@ def run( run_args += [ self._custom_mask_buf, self._mask_indptr_buf, - _get_cache_alibi_slopes_buf(q.shape[1], q.device), + _get_cache_alibi_slopes_buf(tuple(q.shape)[1], q.place), self._prefix_len_ptr, self._token_pos_in_items_ptr, self._max_item_len_ptr, logits_soft_cap, sm_scale, - None, # scale_q, not supported yet - None, # scale_k - None, # scale_v + None, + None, + None, rope_scale, rope_theta, self._token_pos_in_items_len, @@ -2150,13 +2035,11 @@ def run( self._vector_sparse_indptr_buffer, sinks, ] - assert self._cached_module is not None, "cached module is not initialized" self._cached_module.paged_run(*run_args) if v_scale is not None: - # TODO(Zihao): fused into kernel if is_float8(out): - out = (out.to(torch.float32) * v_scale).to(out.dtype) + out = (out.to("float32") * v_scale).to(out.dtype) else: out *= v_scale return (out, lse) if return_lse else out @@ -2165,8 +2048,8 @@ def run( def forward_return_lse( self, - q: torch.Tensor, - paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + q: paddle.Tensor, + paged_kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, @@ -2177,8 +2060,8 @@ def forward_return_lse( sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Warning: This function is deprecated, please use :meth:`run_return_lse` instead.""" + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Warning: This function is deprecated, please use :meth:`run_return_lse` instead.""" self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction @@ -2190,26 +2073,25 @@ def forward_return_lse( return self.run_return_lse(q, paged_kv_cache, k_scale=k_scale, v_scale=v_scale) def end_forward(self) -> None: - r"""Warning: this function is deprecated and has no effect.""" + """Warning: this function is deprecated and has no effect.""" pass def _compute_mask_indptr( - qo_indptr: torch.Tensor, kv_indptr: torch.Tensor -) -> torch.Tensor: + qo_indptr: paddle.Tensor, kv_indptr: paddle.Tensor +) -> paddle.Tensor: if len(qo_indptr) != len(kv_indptr): raise ValueError("The length of qo_indptr and kv_indptr should be the same.") - mask_indptr = torch.empty_like(qo_indptr) + mask_indptr = paddle.empty_like(x=qo_indptr) mask_indptr[0] = 0 - mask_indptr[1:] = torch.cumsum( - (qo_indptr[1:] - qo_indptr[:-1]) * (kv_indptr[1:] - kv_indptr[:-1]), - 0, + mask_indptr[1:] = paddle.cumsum( + x=(qo_indptr[1:] - qo_indptr[:-1]) * (kv_indptr[1:] - kv_indptr[:-1]), axis=0 ) return mask_indptr class BatchPrefillWithRaggedKVCacheWrapper: - r"""Wrapper class for prefill/append attention with ragged (tensor) kv-cache for + """Wrapper class for prefill/append attention with ragged (tensor) kv-cache for batch of requests. Check :ref:`our tutorial ` for ragged kv-cache layout. @@ -2301,18 +2183,18 @@ class BatchPrefillWithRaggedKVCacheWrapper: def __init__( self, - float_workspace_buffer: torch.Tensor, + float_workspace_buffer: paddle.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, - qo_indptr_buf: Optional[torch.Tensor] = None, - kv_indptr_buf: Optional[torch.Tensor] = None, - custom_mask_buf: Optional[torch.Tensor] = None, - mask_indptr_buf: Optional[torch.Tensor] = None, + qo_indptr_buf: Optional[paddle.Tensor] = None, + kv_indptr_buf: Optional[paddle.Tensor] = None, + custom_mask_buf: Optional[paddle.Tensor] = None, + mask_indptr_buf: Optional[paddle.Tensor] = None, backend: str = "auto", jit_args: Optional[List[Any]] = None, jit_kwargs: Optional[Dict[str, Any]] = None, ) -> None: - r"""Constructor of :class:`BatchPrefillWithRaggedKVCacheWrapper`. + """Constructor of :class:`BatchPrefillWithRaggedKVCacheWrapper`. Parameters ---------- @@ -2373,26 +2255,22 @@ def __init__( ) else: self._jit_module = None - self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer - self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) - self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, - dtype=torch.uint8, - pin_memory=True, - device="cpu", + self.device = float_workspace_buffer.place + self._int_workspace_buffer = paddle.empty( + shape=(8 * 1024 * 1024,), dtype="uint8" ) + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=tuple(self._int_workspace_buffer.shape), dtype="uint8" + ).pin_memory() self._use_cuda_graph = use_cuda_graph if use_cuda_graph: - if not torch.is_tensor(qo_indptr_buf): + if not paddle.is_tensor(x=qo_indptr_buf): raise ValueError( "qo_indptr_buf should be a torch.Tensor in cuda graph mode" ) - if not torch.is_tensor(kv_indptr_buf): + if not paddle.is_tensor(x=kv_indptr_buf): raise ValueError( "kv_indptr_buf should be a torch.Tensor in cuda graph mode" ) @@ -2403,9 +2281,6 @@ def __init__( len(kv_indptr_buf), self._fixed_batch_size ) ) - # NOTE(Zihao): do not check custom_mask_buf and mask_indptr_buf here, - # as they may not be used. - self._qo_indptr_buf = qo_indptr_buf self._kv_indptr_buf = kv_indptr_buf self._custom_mask_buf = custom_mask_buf @@ -2419,9 +2294,9 @@ def is_cuda_graph_enabled(self) -> bool: return self._use_cuda_graph def reset_workspace_buffer( - self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + self, float_workspace_buffer: paddle.Tensor, int_workspace_buffer ) -> None: - r"""Reset the workspace buffer. + """Reset the workspace buffer. Parameters ---------- @@ -2435,23 +2310,21 @@ def reset_workspace_buffer( """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=tuple(self._int_workspace_buffer.shape), dtype=self._int_workspace_buffer.dtype, - device="cpu", - pin_memory=True, - ) + ).pin_memory() def plan( self, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, + qo_indptr: paddle.Tensor, + kv_indptr: paddle.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim_qk: int, head_dim_vo: Optional[int] = None, - custom_mask: Optional[torch.Tensor] = None, - packed_custom_mask: Optional[torch.Tensor] = None, + custom_mask: Optional[paddle.Tensor] = None, + packed_custom_mask: Optional[paddle.Tensor] = None, causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, @@ -2460,15 +2333,15 @@ def plan( sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - q_data_type: Union[str, torch.dtype] = "float16", - kv_data_type: Optional[Union[str, torch.dtype]] = None, + q_data_type: Union[str, paddle.dtype] = "float16", + kv_data_type: Optional[Union[str, paddle.dtype]] = None, non_blocking: bool = True, - prefix_len_ptr: Optional[torch.Tensor] = None, - token_pos_in_items_ptr: Optional[torch.Tensor] = None, + prefix_len_ptr: Optional[paddle.Tensor] = None, + token_pos_in_items_ptr: Optional[paddle.Tensor] = None, token_pos_in_items_len: int = 0, - max_item_len_ptr: Optional[torch.Tensor] = None, + max_item_len_ptr: Optional[paddle.Tensor] = None, ) -> None: - r"""Plan batch prefill/append attention on Ragged KV-Cache for given problem specification. + """Plan batch prefill/append attention on Ragged KV-Cache for given problem specification. Parameters ---------- @@ -2520,7 +2393,7 @@ def plan( The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: - :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, + :math:`\\texttt{logits_soft_cap} \\times \\mathrm{tanh}(x / \\texttt{logits_soft_cap})`, where :math:`x` is the input logits. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to @@ -2569,10 +2442,8 @@ def plan( kv_data_type = canonicalize_torch_dtype(kv_data_type) if head_dim_vo is None: head_dim_vo = head_dim_qk - if logits_soft_cap is None: logits_soft_cap = 0.0 - batch_size = len(qo_indptr) - 1 if len(kv_indptr) != batch_size + 1: raise ValueError( @@ -2581,70 +2452,57 @@ def plan( if custom_mask is not None or packed_custom_mask is not None: mask_indptr = _compute_mask_indptr(qo_indptr, kv_indptr) if packed_custom_mask is None and custom_mask is not None: - # create packed custom mask from custom mask packed_custom_mask, mask_indptr = segment_packbits( - custom_mask.contiguous().view(-1), - mask_indptr, - bitorder="little", + custom_mask.contiguous().view(-1), mask_indptr, bitorder="little" ) - - # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors qo_indptr_host = qo_indptr.to("cpu") kv_indptr_host = kv_indptr.to("cpu") - total_num_rows = qo_indptr_host[-1] - if self.is_cuda_graph_enabled: if self._max_total_num_rows is None: self._max_total_num_rows = total_num_rows elif total_num_rows > self._max_total_num_rows: raise ValueError( - "The total number of rows in qo_indptr {} in cuda graph mode cannot " - "exceed the number of rows set during initialization {}.".format( + "The total number of rows in qo_indptr {} in cuda graph mode cannot exceed the number of rows set during initialization {}.".format( total_num_rows, self._max_total_num_rows ) ) - if batch_size != self._fixed_batch_size: raise ValueError( - "The batch size should be fixed in cudagraph mode, the runtime batch size {} " - " mismatches the batch size set during initialization {}.".format( + "The batch size should be fixed in cudagraph mode, the runtime batch size {} mismatches the batch size set during initialization {}.".format( batch_size, self._fixed_batch_size ) ) - self._qo_indptr_buf.copy_(qo_indptr, non_blocking=non_blocking) - self._kv_indptr_buf.copy_(kv_indptr, non_blocking=non_blocking) + paddle.assign(qo_indptr, output=self._qo_indptr_buf) + paddle.assign(kv_indptr, output=self._kv_indptr_buf) if packed_custom_mask is not None: - if not torch.is_tensor(self._custom_mask_buf): + if not paddle.is_tensor(x=self._custom_mask_buf): raise ValueError( "custom_mask_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation." ) - if not torch.is_tensor(self._mask_indptr_buf): + if not paddle.is_tensor(x=self._mask_indptr_buf): raise ValueError( "mask_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in the attention computation." ) self._custom_mask_buf[: len(packed_custom_mask)] = packed_custom_mask - self._mask_indptr_buf.copy_(mask_indptr, non_blocking=non_blocking) + paddle.assign(mask_indptr, output=self._mask_indptr_buf) else: - self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=non_blocking) - self._kv_indptr_buf = kv_indptr.to(self.device, non_blocking=non_blocking) + self._qo_indptr_buf = qo_indptr.to(self.device, blocking=not non_blocking) + self._kv_indptr_buf = kv_indptr.to(self.device, blocking=not non_blocking) if packed_custom_mask is not None: self._custom_mask_buf = packed_custom_mask.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) self._mask_indptr_buf = mask_indptr.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) - self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type kv_len_arr = kv_indptr_host[1:] - kv_indptr_host[:-1] - self._prefix_len_ptr = prefix_len_ptr self._token_pos_in_items_ptr = token_pos_in_items_ptr self._token_pos_in_items_len = token_pos_in_items_len self._max_item_len_ptr = max_item_len_ptr - if self._jit_module is not None: self._cached_module = self._jit_module else: @@ -2653,11 +2511,10 @@ def plan( self.device, PosEncodingMode[pos_encoding_mode].value, use_fp16_qk_reduction, - self._custom_mask_buf is not None, # use_custom_mask + self._custom_mask_buf is not None, q_data_type, kv_data_type, ) - get_module_args = ( q_data_type, kv_data_type, @@ -2666,8 +2523,8 @@ def plan( head_dim_qk, head_dim_vo, PosEncodingMode[pos_encoding_mode].value, - window_left >= 0, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap + window_left >= 0, + logits_soft_cap > 0, use_fp16_qk_reduction, ) if self._backend == "cutlass": @@ -2676,12 +2533,11 @@ def plan( self._cached_module = get_batch_prefill_module( self._backend, *get_module_args ) - if self._backend == "cutlass": self._plan_info = fmha_varlen_plan( self._cached_module, qo_indptr, kv_indptr, num_qo_heads, causal ) - self._max_qo_len = torch.max(qo_indptr[1:] - qo_indptr[:-1]).item() + self._max_qo_len = paddle.max(x=qo_indptr[1:] - qo_indptr[:-1]).item() else: assert self._cached_module is not None, "cached module is not initialized" self._plan_info = self._cached_module.plan( @@ -2695,13 +2551,12 @@ def plan( batch_size, num_qo_heads, num_kv_heads, - 1, # page_size + 1, self.is_cuda_graph_enabled, head_dim_qk, head_dim_vo, causal, ) - self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction @@ -2715,9 +2570,9 @@ def plan( def forward( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, @@ -2726,8 +2581,8 @@ def forward( sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - ) -> torch.Tensor: - r"""Warning: This function is deprecated, please use :meth:`run` instead.""" + ) -> paddle.Tensor: + """Warning: This function is deprecated, please use :meth:`run` instead.""" self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction @@ -2741,41 +2596,43 @@ def forward( @overload def run( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, *args, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: Literal[False] = False, enable_pdl: Optional[bool] = None, - ) -> torch.Tensor: ... + ) -> paddle.Tensor: + ... @overload def run( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, *args, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: Literal[True] = True, enable_pdl: Optional[bool] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: ... + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + ... def run( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, *args, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: bool = False, enable_pdl: Optional[bool] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - r"""Compute batch prefill/append attention between query and kv-cache stored as + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Compute batch prefill/append attention between query and kv-cache stored as ragged tensor. Parameters @@ -2807,11 +2664,10 @@ def run( * The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads]``. """ if enable_pdl is None: - enable_pdl = device_support_pdl(q.device) + enable_pdl = device_support_pdl(q.place) _check_cached_qkv_data_type( q, k, self._cached_q_data_type, self._cached_kv_data_type ) - window_left = self._window_left logits_soft_cap = self._logits_soft_cap sm_scale = self._sm_scale @@ -2820,27 +2676,25 @@ def run( if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: - sm_scale = 1.0 / math.sqrt(q.size(-1)) + sm_scale = 1.0 / math.sqrt(q.shape[-1]) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: - rope_theta = 1e4 + rope_theta = 10000.0 if return_lse: if lse is None: - lse = torch.empty( - (q.size(0), q.size(1)), dtype=torch.float32, device=q.device - ) + lse = paddle.empty(shape=(q.shape[0], q.shape[1]), dtype="float32") else: check_shape_dtype_device( - lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" + lse, (q.shape[0], q.shape[1]), "float32", q.place, "lse" ) if out is None: - out = torch.empty( - q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.device + out = paddle.empty( + shape=tuple(q.shape)[:-1] + tuple(v.shape)[-1:], dtype=q.dtype ) else: check_shape_dtype_device( - out, q.shape[:-1] + v.shape[-1:], q.dtype, q.device, "out" + out, tuple(q.shape)[:-1] + tuple(v.shape)[-1:], q.dtype, q.place, "out" ) if self._backend == "cutlass": out, lse = fmha_varlen( @@ -2857,24 +2711,19 @@ def run( lse=lse, ) return (out, lse) if return_lse else out - if is_float8(q): logging.warning( - "Our current prefill kernel implementation needs f16 input, the f8 inputs " - " are casted to f16, which could result in performance degradation." + "Our current prefill kernel implementation needs f16 input, the f8 inputs are casted to f16, which could result in performance degradation." ) - q = q.to(torch.float16) - k = k.to(torch.float16) - v = v.to(torch.float16) - + q = q.to("float16") + k = k.to("float16") + v = v.to("float16") if self._custom_mask_buf is not None: mask_mode = MaskMode.CUSTOM.value + elif self._causal: + mask_mode = MaskMode.CAUSAL.value else: - if self._causal: - mask_mode = MaskMode.CAUSAL.value - else: - mask_mode = MaskMode.NON_CAUSAL.value - + mask_mode = MaskMode.NON_CAUSAL.value run_args = [ self._float_workspace_buffer, self._int_workspace_buffer, @@ -2897,7 +2746,7 @@ def run( run_args += [ self._custom_mask_buf, self._mask_indptr_buf, - _get_cache_alibi_slopes_buf(q.shape[1], self.device), + _get_cache_alibi_slopes_buf(tuple(q.shape)[1], self.device), self._prefix_len_ptr, self._token_pos_in_items_ptr, self._max_item_len_ptr, @@ -2907,7 +2756,6 @@ def run( rope_theta, self._token_pos_in_items_len, ] - assert self._cached_module is not None, "cached module is not initialized" self._cached_module.ragged_run(*run_args) return (out, lse) if return_lse else out @@ -2916,9 +2764,9 @@ def run( def forward_return_lse( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, @@ -2927,8 +2775,8 @@ def forward_return_lse( sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Warning: This function is deprecated, please use :meth:`run_return_lse` instead.""" + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Warning: This function is deprecated, please use :meth:`run_return_lse` instead.""" self._causal = causal self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction @@ -2940,32 +2788,24 @@ def forward_return_lse( return self.run_return_lse(q, k, v) def end_forward(self) -> None: - r"""Warning: this function is deprecated and has no effect.""" + """Warning: this function is deprecated and has no effect.""" pass def fmha_varlen_plan( module, - qo_segment_offsets: torch.Tensor, - kv_segment_offsets: torch.Tensor, + qo_segment_offsets: paddle.Tensor, + kv_segment_offsets: paddle.Tensor, num_qo_heads: int, causal: bool, ): - num_ctas = torch.cuda.get_device_properties( - qo_segment_offsets.device + num_ctas = paddle.device.cuda.get_device_properties( + device=device2str(qo_segment_offsets.place) ).multi_processor_count - work_indptr = torch.empty( - num_ctas + 1, device=qo_segment_offsets.device, dtype=torch.int32 - ) - qo_tile_indices = torch.empty( - 131072, device=qo_segment_offsets.device, dtype=torch.int32 - ) - head_indices = torch.empty( - 131072, device=qo_segment_offsets.device, dtype=torch.int32 - ) - batch_indices = torch.empty( - 131072, device=qo_segment_offsets.device, dtype=torch.int32 - ) + work_indptr = paddle.empty(shape=num_ctas + 1, dtype="int32") + qo_tile_indices = paddle.empty(shape=[131072], dtype="int32") + head_indices = paddle.empty(shape=[131072], dtype="int32") + batch_indices = paddle.empty(shape=[131072], dtype="int32") module.plan( qo_segment_offsets, kv_segment_offsets, @@ -2973,119 +2813,100 @@ def fmha_varlen_plan( qo_tile_indices, head_indices, batch_indices, - 256, # qo_tile_size + 256, num_qo_heads, num_ctas, causal, ) - return ( - work_indptr, - qo_tile_indices, - head_indices, - batch_indices, - ) + return work_indptr, qo_tile_indices, head_indices, batch_indices @overload def fmha_varlen( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - qo_segment_offsets: torch.Tensor, - kv_segment_offsets: torch.Tensor, - plan_info: Optional[List[torch.Tensor]] = None, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qo_segment_offsets: paddle.Tensor, + kv_segment_offsets: paddle.Tensor, + plan_info: Optional[List[paddle.Tensor]] = None, max_qo_len: Optional[int] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, causal: bool = False, sm_scale: Optional[float] = None, return_lse: Literal[False] = False, -) -> torch.Tensor: ... +) -> paddle.Tensor: + ... @overload def fmha_varlen( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - qo_segment_offsets: torch.Tensor, - kv_segment_offsets: torch.Tensor, - plan_info: Optional[List[torch.Tensor]] = None, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qo_segment_offsets: paddle.Tensor, + kv_segment_offsets: paddle.Tensor, + plan_info: Optional[List[paddle.Tensor]] = None, max_qo_len: Optional[int] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, causal: bool = False, sm_scale: Optional[float] = None, return_lse: Literal[True] = True, -) -> Tuple[torch.Tensor, torch.Tensor]: ... +) -> Tuple[paddle.Tensor, paddle.Tensor]: + ... def fmha_varlen( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - qo_segment_offsets: torch.Tensor, - kv_segment_offsets: torch.Tensor, - plan_info: Optional[List[torch.Tensor]] = None, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qo_segment_offsets: paddle.Tensor, + kv_segment_offsets: paddle.Tensor, + plan_info: Optional[List[paddle.Tensor]] = None, max_qo_len: Optional[int] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, causal: bool = False, sm_scale: Optional[float] = None, return_lse: bool = False, -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: +) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: workspace_buffer = _get_cache_buf( - "fmha_varlen_cutlass_workspace", 32 * 1024 * 1024, q.device + "fmha_varlen_cutlass_workspace", 32 * 1024 * 1024, q.place ) module = get_fmha_module( q.dtype, k.dtype, v.dtype, - torch.int32, - q.shape[2], - v.shape[2], + "int32", + tuple(q.shape)[2], + tuple(v.shape)[2], PosEncodingMode.NONE.value, - False, # use_sliding_window - False, # use_logits_soft_cap + False, + False, ) - - nnz_qo, num_qo_heads, head_dim_qk = q.shape - nnz_kv, num_kv_heads, head_dim_vo = v.shape - + nnz_qo, num_qo_heads, head_dim_qk = tuple(q.shape) + nnz_kv, num_kv_heads, head_dim_vo = tuple(v.shape) mask_mode_code = 1 if causal else 0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(head_dim_qk) - qo_total_len = nnz_qo if max_qo_len is None: - max_qo_len = torch.max(qo_segment_offsets[1:] - qo_segment_offsets[:-1]).item() - + max_qo_len = paddle.max( + x=qo_segment_offsets[1:] - qo_segment_offsets[:-1] + ).item() if plan_info is None: plan_info = fmha_varlen_plan( module, qo_segment_offsets, kv_segment_offsets, num_qo_heads, causal ) - - ( - work_indptr, - qo_tile_indices, - head_indices, - batch_indices, - ) = plan_info - + work_indptr, qo_tile_indices, head_indices, batch_indices = plan_info if out is None: - out = torch.empty( - qo_total_len + max(max_qo_len, 128), - num_qo_heads, - head_dim_vo, - device=q.device, + out = paddle.empty( + shape=[qo_total_len + max(max_qo_len, 128), num_qo_heads, head_dim_vo], dtype=q.dtype, )[max(max_qo_len, 128) :] - if lse is None and return_lse: - lse = torch.empty( - qo_total_len, num_qo_heads, device=q.device, dtype=torch.float32 - ) - + lse = paddle.empty(shape=[qo_total_len, num_qo_heads], dtype="float32") module.run( workspace_buffer, q, @@ -3107,7 +2928,6 @@ def fmha_varlen( head_dim_vo, max_qo_len, ) - return out, lse @@ -3120,11 +2940,11 @@ def get_trtllm_gen_fmha_module(): def trtllm_ragged_attention_deepseek( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - workspace_buffer: torch.Tensor, - seq_lens: torch.Tensor, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + workspace_buffer: paddle.Tensor, + seq_lens: paddle.Tensor, max_q_len: int, max_kv_len: int, bmm1_scale: float, @@ -3132,15 +2952,15 @@ def trtllm_ragged_attention_deepseek( o_sf_scale: float, batch_size: int, window_left: int, - cum_seq_lens_q: torch.Tensor, - cum_seq_lens_kv: torch.Tensor, + cum_seq_lens_q: paddle.Tensor, + cum_seq_lens_kv: paddle.Tensor, enable_pdl: bool, is_causal: bool, return_lse: bool, - attention_sinks: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + attention_sinks: Optional[paddle.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, +) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: """ Parameters ---------- @@ -3190,31 +3010,24 @@ def trtllm_ragged_attention_deepseek( If return_lse is True, the output will be a tuple of two tensors, the first is the output tensor, the second is the lse tensor. If return_lse is False, the output will be a single tensor. """ - assert query.shape[2] == 192 and key.shape[2] == 192 and value.shape[2] == 128, ( - "currently only support deepseek r1 192 query and 128 value" - ) - + assert ( + tuple(query.shape)[2] == 192 + and tuple(key.shape)[2] == 192 + and tuple(value.shape)[2] == 128 + ), "currently only support deepseek r1 192 query and 128 value" if enable_pdl is None: - enable_pdl = device_support_pdl(query.device) - + enable_pdl = device_support_pdl(query.place) run_func = get_trtllm_gen_fmha_module().trtllm_ragged_attention - sm_count = get_device_sm_count(query.device) + sm_count = get_device_sm_count(query.place) if out is None: - out = torch.empty( - query.shape[0], - query.shape[1], - value.shape[2], - device=query.device, + out = paddle.empty( + shape=[tuple(query.shape)[0], tuple(query.shape)[1], tuple(value.shape)[2]], dtype=query.dtype, ) if return_lse and lse is None: - lse = torch.empty( - query.shape[0], - query.shape[1], - device=query.device, - dtype=torch.float32, + lse = paddle.empty( + shape=[tuple(query.shape)[0], tuple(query.shape)[1]], dtype="float32" ) - run_func( out, query, @@ -3244,26 +3057,26 @@ def trtllm_ragged_attention_deepseek( def trtllm_batch_context_with_kv_cache( - query: torch.Tensor, - kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - workspace_buffer: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, + query: paddle.Tensor, + kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], + workspace_buffer: paddle.Tensor, + block_tables: paddle.Tensor, + seq_lens: paddle.Tensor, max_q_len: int, max_kv_len: int, bmm1_scale: float, bmm2_scale: float, batch_size: int, - cum_seq_lens_q: torch.Tensor, - cum_seq_lens_kv: torch.Tensor, + cum_seq_lens_q: paddle.Tensor, + cum_seq_lens_kv: paddle.Tensor, window_left: int = -1, - out: Optional[Union[torch.Tensor, FP4Tensor]] = None, - out_dtype: Optional[Union[torch.dtype, str]] = None, + out: Optional[Union[paddle.Tensor, FP4Tensor]] = None, + out_dtype: Optional[Union[paddle.dtype, str]] = None, o_sf_scale: Optional[float] = None, o_sf_vec_size: Optional[int] = None, enable_pdl: Optional[bool] = None, - sinks: Optional[List[torch.Tensor]] = None, -) -> Union[torch.Tensor, FP4Tensor]: + sinks: Optional[List[paddle.Tensor]] = None, +) -> Union[paddle.Tensor, FP4Tensor]: """ Parameters ---------- @@ -3311,94 +3124,73 @@ def trtllm_batch_context_with_kv_cache( out: Union[torch.Tensor, FP4Tensor] output torch.Tensor or FP4Tensor. """ - if enable_pdl is None: - enable_pdl = device_support_pdl(query.device) - + enable_pdl = device_support_pdl(query.place) if isinstance(kv_cache, tuple): k_cache, v_cache = kv_cache + elif tuple(kv_cache.shape)[1] == 1: + k_cache, v_cache = kv_cache, kv_cache else: - if kv_cache.shape[1] == 1: - k_cache, v_cache = kv_cache, kv_cache - else: - assert kv_cache.shape[1] == 2, ( - "When kv_cache is a single tensor, the second dimension must be 1 or 2" - ) - # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) - # it doesn't change underlying storage - k_cache, v_cache = kv_cache.unbind(dim=1) - + assert ( + tuple(kv_cache.shape)[1] == 2 + ), "When kv_cache is a single tensor, the second dimension must be 1 or 2" + k_cache, v_cache = kv_cache.unbind(axis=1) run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_context - sm_count = get_device_sm_count(query.device) - - if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)): - assert query.dtype == torch.float8_e4m3fn, ( - "query must be fp8 when out_dtype is nvfp4." - ) + sm_count = get_device_sm_count(query.place) + if out_dtype == "nvfp4" or out_dtype is None and isinstance(out, FP4Tensor): + assert ( + query.dtype == paddle.float8_e4m3fn + ), "query must be fp8 when out_dtype is nvfp4." assert o_sf_scale is not None assert o_sf_vec_size in [None, 16], "only o_sf_vec_size = 16 is supported" o_sf_vec_size = o_sf_vec_size or 16 - - fp4_out_shape = query.shape[:-1] + (ceil_div(query.shape[-1], 2),) - + fp4_out_shape = tuple(query.shape)[:-1] + (ceil_div(tuple(query.shape)[-1], 2),) if isinstance(out, FP4Tensor): - fp4_out_scale_shape = ( - out.scale.shape[0], - round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4), + fp4_out_scale_shape = out.scale.shape[0], round_up( + tuple(query.shape)[1] * tuple(query.shape)[2] // o_sf_vec_size, 4 ) out_scale_factor = out.scale o_sf_start_index = out.scale_start_index out = out.data elif out is None: - fp4_out_scale_shape = ( - round_up(query.shape[0], 128), - round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4), + fp4_out_scale_shape = round_up(tuple(query.shape)[0], 128), round_up( + tuple(query.shape)[1] * tuple(query.shape)[2] // o_sf_vec_size, 4 ) - out_scale_factor = torch.empty( - fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device + out_scale_factor = paddle.empty( + shape=fp4_out_scale_shape, dtype=paddle.float8_e4m3fn ) o_sf_start_index = 0 - out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device) + out = paddle.empty(shape=fp4_out_shape, dtype="uint8") else: raise ValueError(f"Invalid out: {out}") - - assert isinstance(out, torch.Tensor) - - # Use uint8 as the container dtype to compliant with next fp4 gemm. - check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out") - + assert isinstance(out, paddle.Tensor) + check_shape_dtype_device(out, fp4_out_shape, "uint8", query.place, "out") check_shape_dtype_device( out_scale_factor, fp4_out_scale_shape, - torch.float8_e4m3fn, - query.device, + paddle.float8_e4m3fn, + query.place, "out_scale_factor", ) - - # Check o_sf_start_index is valid if ( o_sf_start_index < 0 - or o_sf_start_index + out.shape[0] > out_scale_factor.shape[0] + or o_sf_start_index + tuple(out.shape)[0] > tuple(out_scale_factor.shape)[0] ): raise ValueError( - f"o_sf_start_index is out of the valid range of out_scale_factor. " - f"o_sf_start_index={o_sf_start_index}, out.shape[0]={out.shape[0]}, " - f"out_scale_factor.shape[0]={out_scale_factor.shape[0]}" + f"o_sf_start_index is out of the valid range of out_scale_factor. o_sf_start_index={o_sf_start_index}, out.shape[0]={tuple(out.shape)[0]}, out_scale_factor.shape[0]={tuple(out_scale_factor.shape)[0]}" ) - - elif isinstance(out_dtype, torch.dtype) or out_dtype is None: + elif isinstance(out_dtype, paddle.dtype) or out_dtype is None: assert o_sf_scale is None assert o_sf_vec_size is None out_scale_factor = None o_sf_start_index = 0 out_dtype = out_dtype or query.dtype - if out_dtype not in (query.dtype, torch.float16, torch.bfloat16): + if out_dtype not in (query.dtype, "float16", "bfloat16"): raise ValueError(f"Unsupported out_dtype: {out_dtype}") - out = out if out is not None else torch.empty_like(query, dtype=out_dtype) - check_shape_dtype_device(out, query.shape, out_dtype, query.device, "out") + out = out if out is not None else paddle.empty_like(x=query, dtype=out_dtype) + check_shape_dtype_device(out, tuple(query.shape), out_dtype, query.place, "out") else: raise ValueError(f"Invalid out_dtype: {out_dtype}") - run_func( out, out_scale_factor, @@ -3426,5 +3218,5 @@ def trtllm_batch_context_with_kv_cache( return ( out if out_dtype != "nvfp4" - else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) + else FP4Tensor(out, out_scale_factor, o_sf_start_index, tuple(query.shape)) ) diff --git a/flashinfer/profiler/__init__.py b/flashinfer/profiler/__init__.py index 92e1d76964..831d4d1583 100644 --- a/flashinfer/profiler/__init__.py +++ b/flashinfer/profiler/__init__.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,7 +19,6 @@ See the License for the specific language governing permissions and limitations under the License. """ - import argparse import csv import json @@ -21,7 +26,6 @@ from enum import Enum from typing import Any, Dict, List, Tuple -import torch from tg4perfetto import TraceGenerator @@ -40,62 +44,52 @@ def decode_tag(tag, num_blocks, num_groups): bits 12-23: block_group_idx bits 24-31: sm_id """ - sm_id = (tag >> 24) & 0xFF - block_group_idx = (tag >> 12) & 0xFFF - event_idx = (tag >> 2) & 0x3FF - event_type = tag & 0x3 + sm_id = tag >> 24 & 255 + block_group_idx = tag >> 12 & 4095 + event_idx = tag >> 2 & 1023 + event_type = tag & 3 block_idx = block_group_idx // num_groups group_idx = block_group_idx % num_groups return block_idx, group_idx, event_idx, event_type, sm_id def export_to_perfetto_trace( - profiler_buffer: torch.Tensor, - event_names: List[str], - file_name: str, + profiler_buffer: paddle.Tensor, event_names: List[str], file_name: str ) -> None: - assert profiler_buffer.dtype == torch.uint64 +>>>>>> assert profiler_buffer.dtype == torch.uint64 profiler_buffer_host = profiler_buffer.cpu() - num_blocks, num_groups = profiler_buffer_host[:1].view(dtype=torch.int32) + num_blocks, num_groups = profiler_buffer_host[:1].view(dtype="int32") num_blocks = int(num_blocks) num_groups = int(num_groups) - tgen = TraceGenerator(file_name) - pid_map = {} tid_map = {} track_map: Dict[Tuple[int, int, int], Any] = {} - for i in range(1, len(profiler_buffer_host)): if profiler_buffer_host[i] == 0: continue - tag, timestamp = profiler_buffer_host[i : i + 1].view(dtype=torch.uint32) +>>>>>> tag, timestamp = profiler_buffer_host[i : i + 1].view(dtype=torch.uint32) tag = int(tag) timestamp = int(timestamp) block_idx, group_idx, event_idx, event_type, sm_id = decode_tag( tag, num_blocks, num_groups ) - - # create trackers if block_idx not in pid_map: pid_map[block_idx] = tgen.create_group(f"sm_{sm_id}_block_{block_idx}") pid = pid_map[block_idx] if (block_idx, group_idx) not in tid_map: - tid_map[(block_idx, group_idx)] = pid.create_group(f"group_{group_idx}") - tid = tid_map[(block_idx, group_idx)] + tid_map[block_idx, group_idx] = pid.create_group(f"group_{group_idx}") + tid = tid_map[block_idx, group_idx] event = event_names[event_idx] - if (block_idx, group_idx, event_idx) in track_map: - track = track_map[(block_idx, group_idx, event_idx)] + track = track_map[block_idx, group_idx, event_idx] else: track = tid.create_track() - track_map[(block_idx, group_idx, event_idx)] = track - + track_map[block_idx, group_idx, event_idx] = track if event_type == EventType.kBegin.value: track.open(timestamp, event) elif event_type == EventType.kEnd.value: track.close(timestamp) elif event_type == EventType.kInstant.value: track.instant(timestamp, event) - tgen.flush() diff --git a/flashinfer/quantization.py b/flashinfer/quantization.py index af7f9bde57..bae465ab17 100644 --- a/flashinfer/quantization.py +++ b/flashinfer/quantization.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,12 +15,9 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools from typing import Tuple -import torch - from .jit import JitSpec from .jit import env as jit_env from .jit import gen_jit_spec @@ -41,21 +40,21 @@ def get_quantization_module(): @register_custom_op("flashinfer::packbits", mutates_args=()) -def _packbits(x: torch.Tensor, bitorder: str) -> torch.Tensor: - device = x.device - x = x.to(torch.bool) - y = torch.empty((x.size(0) + 7) // 8, dtype=torch.uint8, device=device) +def _packbits(x: paddle.Tensor, bitorder: str) -> paddle.Tensor: + device = x.place + x = x.to("bool") + y = paddle.empty(shape=(x.shape[0] + 7) // 8, dtype="uint8") get_quantization_module().packbits(x, bitorder, y) return y @register_fake_op("flashinfer::packbits") -def _fake_packbits(x: torch.Tensor, bitorder: str) -> torch.Tensor: - return torch.empty((x.size(0) + 7) // 8, dtype=torch.uint8, device=x.device) +def _fake_packbits(x: paddle.Tensor, bitorder: str) -> paddle.Tensor: + return paddle.empty(shape=(x.shape[0] + 7) // 8, dtype="uint8") -def packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor: - r"""Pack the elements of a binary-valued array into bits in a uint8 array. +def packbits(x: paddle.Tensor, bitorder: str = "big") -> paddle.Tensor: + """Pack the elements of a binary-valued array into bits in a uint8 array. The semantics of this function is the same as `numpy.packbits `_. @@ -89,9 +88,9 @@ def packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor: def segment_packbits( - x: torch.Tensor, indptr: torch.Tensor, bitorder: str = "big" -) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Pack a batch of binary-valued segments into bits in a uint8 array. + x: paddle.Tensor, indptr: paddle.Tensor, bitorder: str = "big" +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Pack a batch of binary-valued segments into bits in a uint8 array. For each segment, the semantics of this function is the same as `numpy.packbits `_. @@ -136,13 +135,12 @@ def segment_packbits( """ seglen = indptr[1:] - indptr[:-1] packed_len = (seglen + 7) // 8 - indptr_new = torch.zeros(len(indptr), dtype=indptr.dtype, device=indptr.device) - indptr_new[1:] = torch.cumsum(packed_len, 0) + indptr_new = paddle.zeros(shape=len(indptr), dtype=indptr.dtype) + indptr_new[1:] = paddle.cumsum(x=packed_len, axis=0) output_nnzs = indptr_new[-1].item() - - device = x.device - indptr = indptr.to(torch.int32) - indptr_new = indptr_new.to(torch.int32) - y = torch.empty(output_nnzs, dtype=torch.uint8, device=device) + device = x.place + indptr = indptr.to("int32") + indptr_new = indptr_new.to("int32") + y = paddle.empty(shape=output_nnzs, dtype="uint8") get_quantization_module().segment_packbits(x, indptr, indptr_new, bitorder, y) return y, indptr_new diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 5b14d610c6..81821c9720 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2024 by FlashInfer team. @@ -13,12 +19,9 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools from typing import Optional, Tuple -import torch - from .jit import JitSpec from .jit import env as jit_env from .jit import gen_jit_spec @@ -42,12 +45,12 @@ def get_rope_module(): @register_custom_op("flashinfer::apply_rope", mutates_args=("q_rope", "k_rope")) def _apply_rope( - q: torch.Tensor, - k: torch.Tensor, - q_rope: torch.Tensor, - k_rope: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + q_rope: paddle.Tensor, + k_rope: paddle.Tensor, + indptr: paddle.Tensor, + offsets: paddle.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, @@ -69,12 +72,12 @@ def _apply_rope( @register_fake_op("flashinfer::apply_rope") def _fake_apply_rope( - q: torch.Tensor, - k: torch.Tensor, - q_rope: torch.Tensor, - k_rope: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + q_rope: paddle.Tensor, + k_rope: paddle.Tensor, + indptr: paddle.Tensor, + offsets: paddle.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, @@ -85,12 +88,12 @@ def _fake_apply_rope( @register_custom_op("flashinfer::apply_llama31_rope", mutates_args=("q_rope", "k_rope")) def _apply_llama31_rope( - q: torch.Tensor, - k: torch.Tensor, - q_rope: torch.Tensor, - k_rope: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + q_rope: paddle.Tensor, + k_rope: paddle.Tensor, + indptr: paddle.Tensor, + offsets: paddle.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, @@ -118,12 +121,12 @@ def _apply_llama31_rope( @register_fake_op("flashinfer::apply_llama31_rope") def _fake_apply_llama31_rope( - q: torch.Tensor, - k: torch.Tensor, - q_rope: torch.Tensor, - k_rope: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + q_rope: paddle.Tensor, + k_rope: paddle.Tensor, + indptr: paddle.Tensor, + offsets: paddle.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, @@ -137,36 +140,28 @@ def _fake_apply_llama31_rope( @register_custom_op("flashinfer::apply_rope_pos_ids", mutates_args=("q_rope", "k_rope")) def _apply_rope_pos_ids( - q: torch.Tensor, - k: torch.Tensor, - q_rope: torch.Tensor, - k_rope: torch.Tensor, - pos_ids: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + q_rope: paddle.Tensor, + k_rope: paddle.Tensor, + pos_ids: paddle.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, ) -> None: get_rope_module().apply_rope_pos_ids( - q, - k, - q_rope, - k_rope, - pos_ids, - rotary_dim, - interleave, - rope_scale, - rope_theta, + q, k, q_rope, k_rope, pos_ids, rotary_dim, interleave, rope_scale, rope_theta ) @register_fake_op("flashinfer::apply_rope_pos_ids") def _fake_apply_rope_pos_ids( - q: torch.Tensor, - k: torch.Tensor, - q_rope: torch.Tensor, - k_rope: torch.Tensor, - pos_ids: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + q_rope: paddle.Tensor, + k_rope: paddle.Tensor, + pos_ids: paddle.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, @@ -180,16 +175,16 @@ def _fake_apply_rope_pos_ids( mutates_args=("q_rope_out", "k_rope_out", "q_nope_out", "k_nope_out"), ) def _mla_rope_quantize( - q_rope_in: torch.Tensor, - k_rope_in: torch.Tensor, - q_nope_in: torch.Tensor, - k_nope_in: torch.Tensor, - cos_sin_cache: torch.Tensor, - pos_ids: torch.Tensor, - q_rope_out: torch.Tensor, - k_rope_out: torch.Tensor, - q_nope_out: torch.Tensor, - k_nope_out: torch.Tensor, + q_rope_in: paddle.Tensor, + k_rope_in: paddle.Tensor, + q_nope_in: paddle.Tensor, + k_nope_in: paddle.Tensor, + cos_sin_cache: paddle.Tensor, + pos_ids: paddle.Tensor, + q_rope_out: paddle.Tensor, + k_rope_out: paddle.Tensor, + q_nope_out: paddle.Tensor, + k_nope_out: paddle.Tensor, quant_scale_q: float, quant_scale_kv: float, interleave: bool, @@ -213,16 +208,16 @@ def _mla_rope_quantize( @register_fake_op("flashinfer::mla_rope_quantize") def _fake_mla_rope_quantize( - q_rope_in: torch.Tensor, - k_rope_in: torch.Tensor, - q_nope_in: torch.Tensor, - k_nope_in: torch.Tensor, - cos_sin_cache: torch.Tensor, - pos_ids: torch.Tensor, - q_rope_out: torch.Tensor, - k_rope_out: torch.Tensor, - q_nope_out: torch.Tensor, - k_nope_out: torch.Tensor, + q_rope_in: paddle.Tensor, + k_rope_in: paddle.Tensor, + q_nope_in: paddle.Tensor, + k_nope_in: paddle.Tensor, + cos_sin_cache: paddle.Tensor, + pos_ids: paddle.Tensor, + q_rope_out: paddle.Tensor, + k_rope_out: paddle.Tensor, + q_nope_out: paddle.Tensor, + k_nope_out: paddle.Tensor, quant_scale_q: float, quant_scale_kv: float, interleave: bool, @@ -234,34 +229,28 @@ def _fake_mla_rope_quantize( "flashinfer::apply_rope_pos_ids_cos_sin_cache", mutates_args=("q_rope", "k_rope") ) def _apply_rope_pos_ids_cos_sin_cache( - q: torch.Tensor, - k: torch.Tensor, - q_rope: torch.Tensor, - k_rope: torch.Tensor, - cos_sin_cache: torch.Tensor, - pos_ids: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + q_rope: paddle.Tensor, + k_rope: paddle.Tensor, + cos_sin_cache: paddle.Tensor, + pos_ids: paddle.Tensor, interleave: bool, ) -> None: get_rope_module().apply_rope_pos_ids_cos_sin_cache( - q, - k, - q_rope, - k_rope, - cos_sin_cache, - pos_ids, - interleave, + q, k, q_rope, k_rope, cos_sin_cache, pos_ids, interleave ) @register_fake_op("flashinfer::apply_rope_pos_ids_cos_sin_cache") def _fake_apply_rope_pos_ids_cos_sin_cache( - q: torch.Tensor, - k: torch.Tensor, - q_rope: torch.Tensor, - k_rope: torch.Tensor, - cos_cache: torch.Tensor, - sin_cache: torch.Tensor, - pos_ids: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + q_rope: paddle.Tensor, + k_rope: paddle.Tensor, + cos_cache: paddle.Tensor, + sin_cache: paddle.Tensor, + pos_ids: paddle.Tensor, interleave: bool, ) -> None: pass @@ -271,11 +260,11 @@ def _fake_apply_rope_pos_ids_cos_sin_cache( "flashinfer::apply_llama31_rope_pos_ids", mutates_args=("q_rope", "k_rope") ) def _apply_llama31_rope_pos_ids( - q: torch.Tensor, - k: torch.Tensor, - q_rope: torch.Tensor, - k_rope: torch.Tensor, - pos_ids: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + q_rope: paddle.Tensor, + k_rope: paddle.Tensor, + pos_ids: paddle.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, @@ -302,11 +291,11 @@ def _apply_llama31_rope_pos_ids( @register_fake_op("flashinfer::apply_llama31_rope_pos_ids") def _fake_apply_llama31_rope_pos_ids( - q: torch.Tensor, - k: torch.Tensor, - q_rope: torch.Tensor, - k_rope: torch.Tensor, - pos_ids: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + q_rope: paddle.Tensor, + k_rope: paddle.Tensor, + pos_ids: paddle.Tensor, rotary_dim: int, interleave: bool, rope_scale: float, @@ -319,16 +308,16 @@ def _fake_apply_llama31_rope_pos_ids( def apply_rope_inplace( - q: torch.Tensor, - k: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + indptr: paddle.Tensor, + offsets: paddle.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 1, - rope_theta: float = 1e4, + rope_theta: float = 10000.0, ) -> None: - r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. + """Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th @@ -399,22 +388,22 @@ def apply_rope_inplace( apply_rope """ if rotary_dim is None: - rotary_dim = q.size(-1) + rotary_dim = q.shape[-1] _apply_rope( q, k, q, k, indptr, offsets, rotary_dim, interleave, rope_scale, rope_theta ) def apply_rope_pos_ids_inplace( - q: torch.Tensor, - k: torch.Tensor, - pos_ids: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + pos_ids: paddle.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 1, - rope_theta: float = 1e4, + rope_theta: float = 10000.0, ) -> None: - r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. + """Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th @@ -457,26 +446,26 @@ def apply_rope_pos_ids_inplace( apply_rope_pos_ids """ if rotary_dim is None: - rotary_dim = q.size(-1) + rotary_dim = q.shape[-1] _apply_rope_pos_ids( q, k, q, k, pos_ids, rotary_dim, interleave, rope_scale, rope_theta ) def apply_llama31_rope_inplace( - q: torch.Tensor, - k: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + indptr: paddle.Tensor, + offsets: paddle.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 8, - rope_theta: float = 5e5, + rope_theta: float = 500000.0, low_freq_factor: float = 1, high_freq_factor: float = 4, old_context_len: int = 8192, ) -> None: - r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as + """Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th @@ -553,7 +542,7 @@ def apply_llama31_rope_inplace( apply_llama31_rope """ if rotary_dim is None: - rotary_dim = q.size(-1) + rotary_dim = q.shape[-1] _apply_llama31_rope( q, k, @@ -572,18 +561,18 @@ def apply_llama31_rope_inplace( def apply_llama31_rope_pos_ids_inplace( - q: torch.Tensor, - k: torch.Tensor, - pos_ids: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + pos_ids: paddle.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 8, - rope_theta: float = 5e5, + rope_theta: float = 500000.0, low_freq_factor: float = 1, high_freq_factor: float = 4, old_context_len: int = 8192, ) -> None: - r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as + """Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th @@ -632,7 +621,7 @@ def apply_llama31_rope_pos_ids_inplace( apply_llama31_rope_pos_ids """ if rotary_dim is None: - rotary_dim = q.size(-1) + rotary_dim = q.shape[-1] _apply_llama31_rope_pos_ids( q, k, @@ -650,16 +639,16 @@ def apply_llama31_rope_pos_ids_inplace( def apply_rope( - q: torch.Tensor, - k: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + indptr: paddle.Tensor, + offsets: paddle.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 1, - rope_theta: float = 1e4, -) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). + rope_theta: float = 10000.0, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th @@ -740,10 +729,10 @@ def apply_rope( -------- apply_rope_inplace """ - q_rope = torch.empty_like(q) - k_rope = torch.empty_like(k) + q_rope = paddle.empty_like(x=q) + k_rope = paddle.empty_like(x=k) if rotary_dim is None: - rotary_dim = q.size(-1) + rotary_dim = q.shape[-1] _apply_rope( q, k, @@ -760,15 +749,15 @@ def apply_rope( def apply_rope_pos_ids( - q: torch.Tensor, - k: torch.Tensor, - pos_ids: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + pos_ids: paddle.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 1, - rope_theta: float = 1e4, -) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). + rope_theta: float = 10000.0, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th @@ -817,10 +806,10 @@ def apply_rope_pos_ids( -------- apply_rope_inplace """ - q_rope = torch.empty_like(q) - k_rope = torch.empty_like(k) + q_rope = paddle.empty_like(x=q) + k_rope = paddle.empty_like(x=k) if rotary_dim is None: - rotary_dim = q.size(-1) + rotary_dim = q.shape[-1] _apply_rope_pos_ids( q, k, q_rope, k_rope, pos_ids, rotary_dim, interleave, rope_scale, rope_theta ) @@ -828,19 +817,19 @@ def apply_rope_pos_ids( def apply_llama31_rope( - q: torch.Tensor, - k: torch.Tensor, - indptr: torch.Tensor, - offsets: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + indptr: paddle.Tensor, + offsets: paddle.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 8, - rope_theta: float = 5e5, + rope_theta: float = 500000.0, low_freq_factor: float = 1, high_freq_factor: float = 4, old_context_len: int = 8192, -) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as RaggedTensor). cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th @@ -927,10 +916,10 @@ def apply_llama31_rope( -------- apply_llama31_rope_inplace """ - q_rope = torch.empty_like(q) - k_rope = torch.empty_like(k) + q_rope = paddle.empty_like(x=q) + k_rope = paddle.empty_like(x=k) if rotary_dim is None: - rotary_dim = q.size(-1) + rotary_dim = q.shape[-1] _apply_llama31_rope( q, k, @@ -950,18 +939,18 @@ def apply_llama31_rope( def apply_llama31_rope_pos_ids( - q: torch.Tensor, - k: torch.Tensor, - pos_ids: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + pos_ids: paddle.Tensor, rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 8, - rope_theta: float = 5e5, + rope_theta: float = 500000.0, low_freq_factor: float = 1, high_freq_factor: float = 4, old_context_len: int = 8192, -) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as RaggedTensor). cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th @@ -1015,10 +1004,10 @@ def apply_llama31_rope_pos_ids( -------- apply_llama31_rope_pos_ids_inplace """ - q_rope = torch.empty_like(q) - k_rope = torch.empty_like(k) + q_rope = paddle.empty_like(x=q) + k_rope = paddle.empty_like(x=k) if rotary_dim is None: - rotary_dim = q.size(-1) + rotary_dim = q.shape[-1] _apply_llama31_rope_pos_ids( q, k, @@ -1037,14 +1026,14 @@ def apply_llama31_rope_pos_ids( def apply_rope_with_cos_sin_cache( - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, + positions: paddle.Tensor, + query: paddle.Tensor, + key: paddle.Tensor, head_size: int, - cos_sin_cache: torch.Tensor, + cos_sin_cache: paddle.Tensor, is_neox: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor]: - r""" +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """ Apply rotary embedding to keys and queries with precomputed cos/sin values. This is designed to be compatible with the SGL/vLLM implementation. @@ -1080,34 +1069,31 @@ def apply_rope_with_cos_sin_cache( ---- The rotary dimension is determined by the cosine cache and sine cache. """ - if cos_sin_cache.dtype != torch.float32: + if cos_sin_cache.dtype != "float32": raise ValueError("cos_sin_cache should be float32") - - query_out = torch.empty_like(query) - key_out = torch.empty_like(key) - + query_out = paddle.empty_like(x=query) + key_out = paddle.empty_like(x=key) _apply_rope_pos_ids_cos_sin_cache( - q=query.view(query.shape[0], -1, head_size), - k=key.view(key.shape[0], -1, head_size), - q_rope=query_out.view(query_out.shape[0], -1, head_size), - k_rope=key_out.view(key_out.shape[0], -1, head_size), + q=query.view(tuple(query.shape)[0], -1, head_size), + k=key.view(tuple(key.shape)[0], -1, head_size), + q_rope=query_out.view(tuple(query_out.shape)[0], -1, head_size), + k_rope=key_out.view(tuple(key_out.shape)[0], -1, head_size), cos_sin_cache=cos_sin_cache, pos_ids=positions, - interleave=(not is_neox), + interleave=not is_neox, ) - return query_out, key_out def apply_rope_with_cos_sin_cache_inplace( - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, + positions: paddle.Tensor, + query: paddle.Tensor, + key: paddle.Tensor, head_size: int, - cos_sin_cache: torch.Tensor, + cos_sin_cache: paddle.Tensor, is_neox: bool = True, ) -> None: - r""" + """ Apply rotary embedding to keys and queries with precomputed cos/sin values. This is designed to be compatible with the SGL/vLLM implementation. The result is inplace applied to the input tensors. @@ -1136,71 +1122,64 @@ def apply_rope_with_cos_sin_cache_inplace( ---- The rotary dimension is determined by the cosine cache and sine cache. """ - if cos_sin_cache.dtype != torch.float32: + if cos_sin_cache.dtype != "float32": raise ValueError("cos_sin_cache should be float32") - - # pass q_rope and k_rope as q and k to perform inplace operation _apply_rope_pos_ids_cos_sin_cache( - q=query.view(query.shape[0], -1, head_size), - k=key.view(key.shape[0], -1, head_size), - q_rope=query.view(query.shape[0], -1, head_size), - k_rope=key.view(key.shape[0], -1, head_size), + q=query.view(tuple(query.shape)[0], -1, head_size), + k=key.view(tuple(key.shape)[0], -1, head_size), + q_rope=query.view(tuple(query.shape)[0], -1, head_size), + k_rope=key.view(tuple(key.shape)[0], -1, head_size), cos_sin_cache=cos_sin_cache, pos_ids=positions, - interleave=(not is_neox), + interleave=not is_neox, ) def mla_rope_quantize_fp8( - q_rope: torch.Tensor, - k_rope: torch.Tensor, - q_nope: torch.Tensor, - k_nope: torch.Tensor, - cos_sin_cache: torch.Tensor, - pos_ids: torch.Tensor, + q_rope: paddle.Tensor, + k_rope: paddle.Tensor, + q_nope: paddle.Tensor, + k_nope: paddle.Tensor, + cos_sin_cache: paddle.Tensor, + pos_ids: paddle.Tensor, is_neox: bool = True, - quantize_dtype: Optional[torch.dtype] = None, + quantize_dtype: Optional[paddle.dtype] = None, quant_scale_q: float = 1.0, quant_scale_kv: float = 1.0, - q_rope_out: Optional[torch.Tensor] = None, - k_rope_out: Optional[torch.Tensor] = None, - q_nope_out: Optional[torch.Tensor] = None, - k_nope_out: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - if cos_sin_cache.dtype != torch.float32: + q_rope_out: Optional[paddle.Tensor] = None, + k_rope_out: Optional[paddle.Tensor] = None, + q_nope_out: Optional[paddle.Tensor] = None, + k_nope_out: Optional[paddle.Tensor] = None, +) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + if cos_sin_cache.dtype != "float32": raise ValueError("cos_sin_cache should be float32") - - # Infer quantize_dtype from output tensors or default to float8_e4m3fn if quantize_dtype is None: for out in (q_rope_out, k_rope_out, q_nope_out, k_nope_out): if out is not None: quantize_dtype = out.dtype break else: - quantize_dtype = torch.float8_e4m3fn - - # Allocate output tensors if not provided + quantize_dtype = paddle.float8_e4m3fn q_rope_out = ( q_rope_out if q_rope_out is not None - else torch.empty_like(q_rope, dtype=quantize_dtype) + else paddle.empty_like(x=q_rope, dtype=quantize_dtype) ) k_rope_out = ( k_rope_out if k_rope_out is not None - else torch.empty_like(k_rope, dtype=quantize_dtype) + else paddle.empty_like(x=k_rope, dtype=quantize_dtype) ) q_nope_out = ( q_nope_out if q_nope_out is not None - else torch.empty_like(q_nope, dtype=quantize_dtype) + else paddle.empty_like(x=q_nope, dtype=quantize_dtype) ) k_nope_out = ( k_nope_out if k_nope_out is not None - else torch.empty_like(k_nope, dtype=quantize_dtype) + else paddle.empty_like(x=k_nope, dtype=quantize_dtype) ) - _mla_rope_quantize( q_rope, k_rope, @@ -1214,7 +1193,6 @@ def mla_rope_quantize_fp8( k_nope_out, quant_scale_q, quant_scale_kv, - not is_neox, # interleave + not is_neox, ) - return q_rope_out, k_rope_out, q_nope_out, k_nope_out diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 4cd7e5bd5a..d82ced83b4 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,22 +17,15 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools from types import SimpleNamespace from typing import Optional, Union -import torch - from .jit import JitSpec from .jit import env as jit_env from .jit import gen_jit_spec -from .utils import ( - _get_cache_buf, - device_support_pdl, - register_custom_op, - register_fake_op, -) +from .utils import (_get_cache_buf, device_support_pdl, register_custom_op, + register_fake_op) def gen_sampling_module() -> JitSpec: @@ -48,16 +45,18 @@ def get_sampling_module(): @register_custom_op("flashinfer::softmax", mutates_args=("workspace_buffer",)) def softmax( - workspace_buffer: torch.Tensor, - logits: torch.Tensor, - maybe_temperature_arr: Optional[torch.Tensor], + workspace_buffer: paddle.Tensor, + logits: paddle.Tensor, + maybe_temperature_arr: Optional[paddle.Tensor], temperature_val: float, enable_pdl: bool, - ) -> torch.Tensor: - logits = logits.float() - probs = torch.empty_like(logits, device=logits.device) + ) -> paddle.Tensor: + logits = logits.astype(dtype="float32") + probs = paddle.empty_like(x=logits) maybe_temperature_arr = ( - maybe_temperature_arr.float() if maybe_temperature_arr is not None else None + maybe_temperature_arr.astype(dtype="float32") + if maybe_temperature_arr is not None + else None ) module.softmax.default( workspace_buffer, @@ -71,99 +70,84 @@ def softmax( @register_fake_op("flashinfer::softmax") def _fake_softmax( - workspace_buffer: torch.Tensor, - logits: torch.Tensor, - maybe_temperature_arr: Optional[torch.Tensor], + workspace_buffer: paddle.Tensor, + logits: paddle.Tensor, + maybe_temperature_arr: Optional[paddle.Tensor], temperature_val: float, enable_pdl: bool, - ) -> torch.Tensor: - return torch.empty_like(logits, device=logits.device, dtype=torch.float32) + ) -> paddle.Tensor: + return paddle.empty_like(x=logits, dtype="float32") - # torch library for sampling_from_logits @register_custom_op("flashinfer::sampling_from_logits", mutates_args=()) def sampling_from_logits( - logits: torch.Tensor, - indices: Optional[torch.Tensor], + logits: paddle.Tensor, + indices: Optional[paddle.Tensor], deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: - device = logits.device - # TODO: support more data types in logits to avoid conversion - # to float32 - logits = logits.float() - batch_size = indices.size(0) if indices is not None else logits.size(0) - samples = torch.empty(batch_size, dtype=torch.int32, device=device) + ) -> paddle.Tensor: + device = logits.place + logits = logits.astype(dtype="float32") + batch_size = indices.shape[0] if indices is not None else logits.shape[0] + samples = paddle.empty(shape=batch_size, dtype="int32") module.sampling_from_logits.default( - logits, - samples, - indices, - deterministic, - generator, + logits, samples, indices, deterministic, generator ) return samples @register_fake_op("flashinfer::sampling_from_logits") def _fake_sampling_from_logits( - logits: torch.Tensor, - indices: Optional[torch.Tensor], + logits: paddle.Tensor, + indices: Optional[paddle.Tensor], deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: - batch_size = indices.size(0) if indices is not None else logits.size(0) - return torch.empty(batch_size, dtype=torch.int32, device=logits.device) - - # torch library for sampling_from_probs + ) -> paddle.Tensor: + batch_size = indices.shape[0] if indices is not None else logits.shape[0] + return paddle.empty(shape=batch_size, dtype="int32") @register_custom_op("flashinfer::sampling_from_probs", mutates_args=()) def sampling_from_probs( - probs: torch.Tensor, - indices: Optional[torch.Tensor], + probs: paddle.Tensor, + indices: Optional[paddle.Tensor], deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: - device = probs.device - probs = probs.float() - batch_size = indices.size(0) if indices is not None else probs.size(0) - samples = torch.empty(batch_size, dtype=torch.int32, device=device) + ) -> paddle.Tensor: + device = probs.place + probs = probs.astype(dtype="float32") + batch_size = indices.shape[0] if indices is not None else probs.shape[0] + samples = paddle.empty(shape=batch_size, dtype="int32") module.sampling_from_probs.default( - probs, - samples, - indices, - deterministic, - generator, + probs, samples, indices, deterministic, generator ) return samples - # torch library for sampling_from_probs - @register_fake_op("flashinfer::sampling_from_probs") def _fake_sampling_from_probs( - probs: torch.Tensor, - indices: Optional[torch.Tensor], + probs: paddle.Tensor, + indices: Optional[paddle.Tensor], deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: - batch_size = indices.size(0) if indices is not None else probs.size(0) - return torch.empty(batch_size, dtype=torch.int32, device=probs.device) - - # torch library for top_p_sampling_from_probs + ) -> paddle.Tensor: + batch_size = indices.shape[0] if indices is not None else probs.shape[0] + return paddle.empty(shape=batch_size, dtype="int32") @register_custom_op("flashinfer::top_p_sampling_from_probs", mutates_args=()) def top_p_sampling_from_probs( - probs: torch.Tensor, - indices: Optional[torch.Tensor], - maybe_top_p_arr: Optional[torch.Tensor], + probs: paddle.Tensor, + indices: Optional[paddle.Tensor], + maybe_top_p_arr: Optional[paddle.Tensor], top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: - device = probs.device - probs = probs.float() + ) -> paddle.Tensor: + device = probs.place + probs = probs.astype(dtype="float32") maybe_top_p_arr = ( - maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + maybe_top_p_arr.astype(dtype="float32") + if maybe_top_p_arr is not None + else None ) - batch_size = indices.size(0) if indices is not None else probs.size(0) - samples = torch.empty(batch_size, dtype=torch.int32, device=device) + batch_size = indices.shape[0] if indices is not None else probs.shape[0] + samples = paddle.empty(shape=batch_size, dtype="int32") module.top_p_sampling_from_probs.default( probs, samples, @@ -177,32 +161,34 @@ def top_p_sampling_from_probs( @register_fake_op("flashinfer::top_p_sampling_from_probs") def _fake_top_p_sampling_from_probs( - probs: torch.Tensor, - indices: Optional[torch.Tensor], - maybe_top_p_arr: Optional[torch.Tensor], + probs: paddle.Tensor, + indices: Optional[paddle.Tensor], + maybe_top_p_arr: Optional[paddle.Tensor], top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: - sample = torch.empty(probs.size(0), dtype=torch.int32, device=probs.device) + ) -> paddle.Tensor: + sample = paddle.empty(shape=probs.shape[0], dtype="int32") return sample - # torch library for top_k_sampling_from_probs - @register_custom_op("flashinfer::top_k_sampling_from_probs", mutates_args=()) def top_k_sampling_from_probs( - probs: torch.Tensor, - indices: Optional[torch.Tensor], - maybe_top_k_arr: Optional[torch.Tensor], + probs: paddle.Tensor, + indices: Optional[paddle.Tensor], + maybe_top_k_arr: Optional[paddle.Tensor], top_k_val: int, deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: - device = probs.device - probs = probs.float() - batch_size = indices.size(0) if indices is not None else probs.size(0) - maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None - samples = torch.empty(batch_size, dtype=torch.int32, device=device) + ) -> paddle.Tensor: + device = probs.place + probs = probs.astype(dtype="float32") + batch_size = indices.shape[0] if indices is not None else probs.shape[0] + maybe_top_k_arr = ( + maybe_top_k_arr.astype(dtype="int32") + if maybe_top_k_arr is not None + else None + ) + samples = paddle.empty(shape=batch_size, dtype="int32") module.top_k_sampling_from_probs.default( probs, samples, @@ -216,35 +202,35 @@ def top_k_sampling_from_probs( @register_fake_op("flashinfer::top_k_sampling_from_probs") def _fake_top_k_sampling_from_probs( - probs: torch.Tensor, - indices: Optional[torch.Tensor], - maybe_top_k_arr: Optional[torch.Tensor], + probs: paddle.Tensor, + indices: Optional[paddle.Tensor], + maybe_top_k_arr: Optional[paddle.Tensor], top_k_val: int, deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: - batch_size = indices.size(0) if indices is not None else probs.size(0) - sample = torch.empty(batch_size, dtype=torch.int32, device=probs.device) + ) -> paddle.Tensor: + batch_size = indices.shape[0] if indices is not None else probs.shape[0] + sample = paddle.empty(shape=batch_size, dtype="int32") return sample - # torch library for min_p_sampling_from_probs - @register_custom_op("flashinfer::min_p_sampling_from_probs", mutates_args=()) def min_p_sampling_from_probs( - probs: torch.Tensor, - indices: Optional[torch.Tensor], - maybe_min_p_arr: Optional[torch.Tensor], + probs: paddle.Tensor, + indices: Optional[paddle.Tensor], + maybe_min_p_arr: Optional[paddle.Tensor], min_p_val: float, deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: - device = probs.device - probs = probs.float() + ) -> paddle.Tensor: + device = probs.place + probs = probs.astype(dtype="float32") maybe_min_p_arr = ( - maybe_min_p_arr.float() if maybe_min_p_arr is not None else None + maybe_min_p_arr.astype(dtype="float32") + if maybe_min_p_arr is not None + else None ) - batch_size = indices.size(0) if indices is not None else probs.size(0) - samples = torch.empty(batch_size, dtype=torch.int32, device=device) + batch_size = indices.shape[0] if indices is not None else probs.shape[0] + samples = paddle.empty(shape=batch_size, dtype="int32") module.min_p_sampling_from_probs.default( probs, samples, @@ -256,27 +242,31 @@ def min_p_sampling_from_probs( ) return samples - # torch library for top_k_top_p_sampling_from_probs - @register_custom_op("flashinfer::top_k_top_p_sampling_from_probs", mutates_args=()) def top_k_top_p_sampling_from_probs( - probs: torch.Tensor, - indices: Optional[torch.Tensor], - maybe_top_k_arr: Optional[torch.Tensor], + probs: paddle.Tensor, + indices: Optional[paddle.Tensor], + maybe_top_k_arr: Optional[paddle.Tensor], top_k_val: int, - maybe_top_p_arr: Optional[torch.Tensor], + maybe_top_p_arr: Optional[paddle.Tensor], top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: - device = probs.device - probs = probs.float() - maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + ) -> paddle.Tensor: + device = probs.place + probs = probs.astype(dtype="float32") + maybe_top_k_arr = ( + maybe_top_k_arr.astype(dtype="int32") + if maybe_top_k_arr is not None + else None + ) maybe_top_p_arr = ( - maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + maybe_top_p_arr.astype(dtype="float32") + if maybe_top_p_arr is not None + else None ) - batch_size = indices.size(0) if indices is not None else probs.size(0) - samples = torch.empty(batch_size, dtype=torch.int32, device=device) + batch_size = indices.shape[0] if indices is not None else probs.shape[0] + samples = paddle.empty(shape=batch_size, dtype="int32") module.top_k_top_p_sampling_from_probs.default( probs, samples, @@ -292,128 +282,108 @@ def top_k_top_p_sampling_from_probs( @register_fake_op("flashinfer::top_k_top_p_sampling_from_probs") def _fake_top_k_top_p_sampling_from_probs( - probs: torch.Tensor, - indices: Optional[torch.Tensor], - maybe_top_k_arr: Optional[torch.Tensor], + probs: paddle.Tensor, + indices: Optional[paddle.Tensor], + maybe_top_k_arr: Optional[paddle.Tensor], top_k_val: int, - maybe_top_p_arr: Optional[torch.Tensor], + maybe_top_p_arr: Optional[paddle.Tensor], top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: - batch_size = indices.size(0) if indices is not None else probs.size(0) - sample = torch.empty(batch_size, dtype=torch.int32, device=probs.device) + ) -> paddle.Tensor: + batch_size = indices.shape[0] if indices is not None else probs.shape[0] + sample = paddle.empty(shape=batch_size, dtype="int32") return sample - # torch library for top_p_renorm_probs - @register_custom_op("flashinfer::top_p_renorm_probs", mutates_args=()) def top_p_renorm_probs( - probs: torch.Tensor, - maybe_top_p_arr: Optional[torch.Tensor], - top_p_val: float, - ) -> torch.Tensor: - probs = probs.float() + probs: paddle.Tensor, maybe_top_p_arr: Optional[paddle.Tensor], top_p_val: float + ) -> paddle.Tensor: + probs = probs.astype(dtype="float32") maybe_top_p_arr = ( - maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + maybe_top_p_arr.astype(dtype="float32") + if maybe_top_p_arr is not None + else None ) - renorm_probs = torch.empty_like(probs) + renorm_probs = paddle.empty_like(x=probs) module.top_p_renorm_probs.default( - probs, - renorm_probs, - maybe_top_p_arr, - top_p_val, + probs, renorm_probs, maybe_top_p_arr, top_p_val ) return renorm_probs @register_fake_op("flashinfer::top_p_renorm_probs") def _fake_top_p_renorm_probs( - probs: torch.Tensor, - maybe_top_p_arr: Optional[torch.Tensor], - top_p_val: float, - ) -> torch.Tensor: - return torch.empty_like(probs) - - # torch library for top_k_renorm_probs + probs: paddle.Tensor, maybe_top_p_arr: Optional[paddle.Tensor], top_p_val: float + ) -> paddle.Tensor: + return paddle.empty_like(x=probs) @register_custom_op("flashinfer::top_k_renorm_probs", mutates_args=()) def top_k_renorm_probs( - probs: torch.Tensor, - maybe_top_k_arr: Optional[torch.Tensor], - top_k_val: int, - ) -> torch.Tensor: - probs = probs.float() - maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None - renorm_probs = torch.empty_like(probs) + probs: paddle.Tensor, maybe_top_k_arr: Optional[paddle.Tensor], top_k_val: int + ) -> paddle.Tensor: + probs = probs.astype(dtype="float32") + maybe_top_k_arr = ( + maybe_top_k_arr.astype(dtype="int32") + if maybe_top_k_arr is not None + else None + ) + renorm_probs = paddle.empty_like(x=probs) module.top_k_renorm_probs.default( - probs, - renorm_probs, - maybe_top_k_arr, - top_k_val, + probs, renorm_probs, maybe_top_k_arr, top_k_val ) return renorm_probs @register_fake_op("flashinfer::top_k_renorm_probs") def _fake_top_k_renorm_probs( - probs: torch.Tensor, - maybe_top_k_arr: Optional[torch.Tensor], - top_k_val: int, - ) -> torch.Tensor: - return torch.empty_like(probs) - - # torch library for top_k_mask_logits + probs: paddle.Tensor, maybe_top_k_arr: Optional[paddle.Tensor], top_k_val: int + ) -> paddle.Tensor: + return paddle.empty_like(x=probs) @register_custom_op("flashinfer::top_k_mask_logits", mutates_args=()) def top_k_mask_logits( - logits: torch.Tensor, - maybe_top_k_arr: Optional[torch.Tensor], - top_k_val: int, - ) -> torch.Tensor: - logits = logits.float() - maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None - mask_logits = torch.empty_like(logits) + logits: paddle.Tensor, maybe_top_k_arr: Optional[paddle.Tensor], top_k_val: int + ) -> paddle.Tensor: + logits = logits.astype(dtype="float32") + maybe_top_k_arr = ( + maybe_top_k_arr.astype(dtype="int32") + if maybe_top_k_arr is not None + else None + ) + mask_logits = paddle.empty_like(x=logits) module.top_k_mask_logits.default( - logits, - mask_logits, - maybe_top_k_arr, - top_k_val, + logits, mask_logits, maybe_top_k_arr, top_k_val ) return mask_logits @register_fake_op("flashinfer::top_k_mask_logits") def _fake_top_k_mask_logits( - logits: torch.Tensor, - maybe_top_k_arr: Optional[torch.Tensor], - top_k_val: int, - ) -> torch.Tensor: - return torch.empty_like(logits) - - # torch library for chain_speculative_sampling + logits: paddle.Tensor, maybe_top_k_arr: Optional[paddle.Tensor], top_k_val: int + ) -> paddle.Tensor: + return paddle.empty_like(x=logits) @register_custom_op( "flashinfer::chain_speculative_sampling", - mutates_args=( - "output_accepted_token_num", - "output_emitted_draft_token_num", - ), + mutates_args=("output_accepted_token_num", "output_emitted_draft_token_num"), ) def chain_speculative_sampling( - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - target_probs: torch.Tensor, - output_accepted_token_num: torch.Tensor, - output_emitted_draft_token_num: torch.Tensor, + draft_probs: paddle.Tensor, + draft_token_ids: paddle.Tensor, + target_probs: paddle.Tensor, + output_accepted_token_num: paddle.Tensor, + output_emitted_draft_token_num: paddle.Tensor, deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: - device = draft_probs.device - draft_probs = draft_probs.float() - draft_token_ids = draft_token_ids.int() - target_probs = target_probs.float() - output_accepted_token_num = output_accepted_token_num.int() - output_emitted_draft_token_num = output_emitted_draft_token_num.int() - b, n = draft_token_ids.shape - output_token_ids = torch.empty((b, n + 1), dtype=torch.int32, device=device) + ) -> paddle.Tensor: + device = draft_probs.place + draft_probs = draft_probs.astype(dtype="float32") + draft_token_ids = draft_token_ids.astype(dtype="int32") + target_probs = target_probs.astype(dtype="float32") + output_accepted_token_num = output_accepted_token_num.astype(dtype="int32") + output_emitted_draft_token_num = output_emitted_draft_token_num.astype( + dtype="int32" + ) + b, n = tuple(draft_token_ids.shape) + output_token_ids = paddle.empty(shape=(b, n + 1), dtype="int32") module.chain_speculative_sampling.default( draft_probs, draft_token_ids, @@ -428,19 +398,18 @@ def chain_speculative_sampling( @register_fake_op("flashinfer::chain_speculative_sampling") def _fake_chain_speculative_sampling( - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - target_probs: torch.Tensor, - output_accepted_token_num: torch.Tensor, - output_emitted_draft_token_num: torch.Tensor, + draft_probs: paddle.Tensor, + draft_token_ids: paddle.Tensor, + target_probs: paddle.Tensor, + output_accepted_token_num: paddle.Tensor, + output_emitted_draft_token_num: paddle.Tensor, deterministic: bool, generator: Optional[torch.Generator], - ) -> torch.Tensor: - b, n = draft_token_ids.shape - device = draft_token_ids.device - return torch.empty((b, n + 1), dtype=torch.int32, device=device) + ) -> paddle.Tensor: + b, n = tuple(draft_token_ids.shape) + device = draft_token_ids.place + return paddle.empty(shape=(b, n + 1), dtype="int32") - # Register the module return SimpleNamespace( softmax=softmax, sampling_from_probs=sampling_from_probs, @@ -457,18 +426,18 @@ def _fake_chain_speculative_sampling( def _to_tensor_scalar_tuple(x): - if isinstance(x, torch.Tensor): - return (x, 0) + if isinstance(x, paddle.Tensor): + return x, 0 else: - return (None, x) + return None, x def softmax( - logits: torch.Tensor, - temperature: Optional[Union[torch.Tensor, float]] = None, + logits: paddle.Tensor, + temperature: Optional[Union[paddle.Tensor, float]] = None, enable_pdl: Optional[bool] = None, -) -> torch.Tensor: - r"""Fused GPU kernel for `online safe softmax `_ with temperature scaling. +) -> paddle.Tensor: + """Fused GPU kernel for `online safe softmax `_ with temperature scaling. Parameters @@ -507,27 +476,25 @@ def softmax( [0.2401, 0.1707, 0.2249, 0.1664, 0.1979], [0.1724, 0.2719, 0.1991, 0.1465, 0.2101]], device='cuda:0') """ - workspace_buffer = _get_cache_buf("softmax_workspace", 1024 * 1024, logits.device) + workspace_buffer = _get_cache_buf("softmax_workspace", 1024 * 1024, logits.place) if temperature is None: temperature = 1.0 - - # Auto-detect PDL support if not specified if enable_pdl is None: - enable_pdl = device_support_pdl(logits.device) - + enable_pdl = device_support_pdl(logits.place) + """Not Support auto convert *.softmax, please judge whether it is Pytorch API and convert by yourself""" return get_sampling_module().softmax( workspace_buffer, logits, *_to_tensor_scalar_tuple(temperature), enable_pdl ) def sampling_from_logits( - logits: torch.Tensor, - indices: Optional[torch.Tensor] = None, + logits: paddle.Tensor, + indices: Optional[paddle.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, -) -> torch.Tensor: - r"""Fused GPU kernel for category sampling from logits. It's equivalent to sampling +) -> paddle.Tensor: + """Fused GPU kernel for category sampling from logits. It's equivalent to sampling from :attr:`logits` after applying softmax. Parameters ---------- @@ -571,7 +538,7 @@ def sampling_from_logits( tensor([0, 1, 1, 1], device='cuda:0', dtype=torch.int32) """ if check_nan: - if torch.any(torch.isnan(logits)): + if paddle.any(x=paddle.isnan(x=logits)): raise ValueError("Input logits contains NaN.") return get_sampling_module().sampling_from_logits( logits, indices, deterministic, generator @@ -579,13 +546,13 @@ def sampling_from_logits( def sampling_from_probs( - probs: torch.Tensor, - indices: Optional[torch.Tensor] = None, + probs: paddle.Tensor, + indices: Optional[paddle.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, -) -> torch.Tensor: - r"""Fused GPU kernel for category sampling from probabilities. +) -> paddle.Tensor: + """Fused GPU kernel for category sampling from probabilities. Parameters ---------- @@ -635,7 +602,7 @@ def sampling_from_probs( This function expects float32 inputs, and the output is int32. """ if check_nan: - if torch.any(torch.isnan(probs)): + if paddle.any(x=paddle.isnan(x=probs)): raise ValueError("Input probs contains NaN.") return get_sampling_module().sampling_from_probs( probs, indices, deterministic, generator @@ -643,14 +610,14 @@ def sampling_from_probs( def top_p_sampling_from_probs( - probs: torch.Tensor, - top_p: Union[torch.Tensor, float], - indices: Optional[torch.Tensor] = None, + probs: paddle.Tensor, + top_p: Union[paddle.Tensor, float], + indices: Optional[paddle.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, -) -> torch.Tensor: - r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities, +) -> paddle.Tensor: + """Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. Check the `blog post `_ for more details. @@ -717,7 +684,7 @@ def top_p_sampling_from_probs( top_p_renorm_probs """ if check_nan: - if torch.any(torch.isnan(probs)): + if paddle.any(x=paddle.isnan(x=probs)): raise ValueError("Input probs contains NaN.") return get_sampling_module().top_p_sampling_from_probs( probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator @@ -725,14 +692,14 @@ def top_p_sampling_from_probs( def top_k_sampling_from_probs( - probs: torch.Tensor, - top_k: Union[torch.Tensor, int], - indices: Optional[torch.Tensor] = None, + probs: paddle.Tensor, + top_k: Union[paddle.Tensor, int], + indices: Optional[paddle.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, -) -> torch.Tensor: - r"""Fused GPU kernel for top-k sampling from probabilities, +) -> paddle.Tensor: + """Fused GPU kernel for top-k sampling from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. Check the `blog post `_ for more details. @@ -799,7 +766,7 @@ def top_k_sampling_from_probs( top_k_renorm_probs """ if check_nan: - if torch.any(torch.isnan(probs)): + if paddle.any(x=paddle.isnan(x=probs)): raise ValueError("Input probs contains NaN.") return get_sampling_module().top_k_sampling_from_probs( probs, indices, *_to_tensor_scalar_tuple(top_k), deterministic, generator @@ -807,14 +774,14 @@ def top_k_sampling_from_probs( def min_p_sampling_from_probs( - probs: torch.Tensor, - min_p: Union[torch.Tensor, float], - indices: Optional[torch.Tensor] = None, + probs: paddle.Tensor, + min_p: Union[paddle.Tensor, float], + indices: Optional[paddle.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, -) -> torch.Tensor: - r"""Fused GPU kernel for `min_p sampling `_ from probabilities, +) -> paddle.Tensor: + """Fused GPU kernel for `min_p sampling `_ from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. Check the `blog post `_ for more details. @@ -875,9 +842,8 @@ def min_p_sampling_from_probs( ---- This function expects float32 inputs, and the output is int32. """ - if check_nan: - if torch.any(torch.isnan(probs)): + if paddle.any(x=paddle.isnan(x=probs)): raise ValueError("Input probs contains NaN.") return get_sampling_module().min_p_sampling_from_probs( probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator @@ -885,16 +851,16 @@ def min_p_sampling_from_probs( def top_k_top_p_sampling_from_logits( - logits: torch.Tensor, - top_k: Union[torch.Tensor, int], - top_p: Union[torch.Tensor, float], - indices: Optional[torch.Tensor] = None, + logits: paddle.Tensor, + top_k: Union[paddle.Tensor, int], + top_p: Union[paddle.Tensor, float], + indices: Optional[paddle.Tensor] = None, filter_apply_order: str = "top_k_first", deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, -) -> torch.Tensor: - r"""Fused GPU kernel for top-k and top-p sampling from pre-softmax logits, +) -> paddle.Tensor: + """Fused GPU kernel for top-k and top-p sampling from pre-softmax logits, this operator implements GPU-based rejection sampling without explicit sorting. Check the `blog post `_ for more details. @@ -978,7 +944,7 @@ def top_k_top_p_sampling_from_logits( """ if filter_apply_order == "top_k_first": masked_logits = top_k_mask_logits(logits, top_k) - probs = torch.softmax(masked_logits, dim=-1) + probs = paddle.nn.functional.softmax(x=masked_logits, axis=-1) return top_p_sampling_from_probs( probs, top_p, @@ -988,9 +954,9 @@ def top_k_top_p_sampling_from_logits( generator=generator, ) elif filter_apply_order == "joint": - probs = torch.softmax(logits, dim=-1) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) if check_nan: - if torch.any(torch.isnan(probs)): + if paddle.any(x=paddle.isnan(x=probs)): raise ValueError("Input probs contains NaN.") return get_sampling_module().top_k_top_p_sampling_from_probs( probs, @@ -1005,16 +971,16 @@ def top_k_top_p_sampling_from_logits( def top_k_top_p_sampling_from_probs( - probs: torch.Tensor, - top_k: Union[torch.Tensor, int], - top_p: Union[torch.Tensor, float], - indices: Optional[torch.Tensor] = None, + probs: paddle.Tensor, + top_k: Union[paddle.Tensor, int], + top_p: Union[paddle.Tensor, float], + indices: Optional[paddle.Tensor] = None, filter_apply_order: str = "top_k_first", deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, -) -> torch.Tensor: - r"""Fused GPU kernel for top-k and top-p sampling from probabilities, +) -> paddle.Tensor: + """Fused GPU kernel for top-k and top-p sampling from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. Check the `blog post `_ for more details. @@ -1103,7 +1069,7 @@ def top_k_top_p_sampling_from_probs( ) elif filter_apply_order == "joint": if check_nan: - if torch.any(torch.isnan(probs)): + if paddle.any(x=paddle.isnan(x=probs)): raise ValueError("Input probs contains NaN.") return get_sampling_module().top_k_top_p_sampling_from_probs( probs, @@ -1118,10 +1084,9 @@ def top_k_top_p_sampling_from_probs( def top_p_renorm_probs( - probs: torch.Tensor, - top_p: Union[torch.Tensor, float], -) -> torch.Tensor: - r"""Fused GPU kernel for renormalizing probabilities by top-p thresholding. + probs: paddle.Tensor, top_p: Union[paddle.Tensor, float] +) -> paddle.Tensor: + """Fused GPU kernel for renormalizing probabilities by top-p thresholding. Parameters ---------- @@ -1183,10 +1148,9 @@ def top_p_renorm_probs( def top_k_renorm_probs( - probs: torch.Tensor, - top_k: Union[torch.Tensor, int], -) -> torch.Tensor: - r"""Fused GPU kernel for renormalizing probabilities by top-k thresholding. + probs: paddle.Tensor, top_k: Union[paddle.Tensor, int] +) -> paddle.Tensor: + """Fused GPU kernel for renormalizing probabilities by top-k thresholding. Parameters ---------- @@ -1247,9 +1211,9 @@ def top_k_renorm_probs( def top_k_mask_logits( - logits: torch.Tensor, top_k: Union[torch.Tensor, int] -) -> torch.Tensor: - r"""Fused GPU kernel for masking logits by top-k thresholding. + logits: paddle.Tensor, top_k: Union[paddle.Tensor, int] +) -> paddle.Tensor: + """Fused GPU kernel for masking logits by top-k thresholding. Parameters ---------- @@ -1306,12 +1270,12 @@ def chain_speculative_sampling( draft_probs, draft_token_ids, target_probs, - maybe_output_accepted_token_num: Optional[torch.Tensor] = None, - maybe_output_emitted_draft_token_num: Optional[torch.Tensor] = None, + maybe_output_accepted_token_num: Optional[paddle.Tensor] = None, + maybe_output_emitted_draft_token_num: Optional[paddle.Tensor] = None, deterministic: bool = True, generator: Optional[torch.Generator] = None, -) -> torch.Tensor: - r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in +) -> paddle.Tensor: + """Fused-GPU kernel for speculative sampling for sequence generation (proposed in paper `Accelerating Large Language Model Decoding with Speculative Sampling `_), where the draft model generates a sequence(chain) of tokens for each request. @@ -1381,7 +1345,7 @@ def chain_speculative_sampling( >>> # token 1 was sampled from draft model for the second token >>> draft_token_ids = torch.tensor([[2, 1]], dtype=torch.int32).to(0) >>> target_probs = torch.tensor([[[0.0, 0.1, 0.6, 0.3], [1.0, 0.0, 0.0, 0.0], [0.7, 0.1, 0.1, 0.1]]]).to(0) - >>> output_token_ids, output_accepted_token_num, output_emitted_draft_token_num =\ + >>> output_token_ids, output_accepted_token_num, output_emitted_draft_token_num =\\ ... flashinfer.sampling.chain_speculative_sampling( ... draft_probs, draft_token_ids, target_probs) >>> # the first token is accepted, the second token is rejected and sampled from the difference @@ -1393,14 +1357,14 @@ def chain_speculative_sampling( >>> output_emitted_draft_token_num tensor([1], device='cuda:0') """ - b = draft_probs.size(0) - dev = draft_probs.device + b = draft_probs.shape[0] + dev = draft_probs.place if maybe_output_accepted_token_num is None: - output_accepted_token_num = torch.zeros(b, dtype=torch.int32, device=dev) + output_accepted_token_num = paddle.zeros(shape=b, dtype="int32") else: output_accepted_token_num = maybe_output_accepted_token_num if maybe_output_emitted_draft_token_num is None: - output_emitted_draft_token_num = torch.zeros(b, dtype=torch.int32, device=dev) + output_emitted_draft_token_num = paddle.zeros(shape=b, dtype="int32") else: output_emitted_draft_token_num = maybe_output_emitted_draft_token_num output_token_ids = get_sampling_module().chain_speculative_sampling( @@ -1412,4 +1376,4 @@ def chain_speculative_sampling( deterministic, generator, ) - return output_token_ids, output_accepted_token_num, output_emitted_draft_token_num + return (output_token_ids, output_accepted_token_num, output_emitted_draft_token_num) diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index c573e17363..421e045202 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -1,3 +1,10 @@ +import sys + + +import einops +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2024 by FlashInfer team. @@ -13,32 +20,23 @@ See the License for the specific language governing permissions and limitations under the License. """ - import math from typing import Optional, Tuple, Union -import torch - from .decode import get_batch_decode_module from .page import block_sparse_indices_to_vector_sparse_offsets from .prefill import _compute_page_mask_indptr, get_batch_prefill_module from .quantization import segment_packbits -from .utils import ( - MaskMode, - PosEncodingMode, - TensorLayout, - _check_pos_encoding_mode, - check_shape_dtype_device, - _get_cache_alibi_slopes_buf, - canonicalize_torch_dtype, - determine_attention_backend, - device_support_pdl, - is_float8, -) +from .utils import (MaskMode, PosEncodingMode, TensorLayout, + _check_pos_encoding_mode, _get_cache_alibi_slopes_buf, + canonicalize_torch_dtype, check_shape_dtype_device, + determine_attention_backend, device_support_pdl, is_float8) -def convert_bsr_mask_layout(mask: torch.Tensor, indptr: torch.Tensor) -> torch.Tensor: - r"""Convert mask from BSR data layout to flashinfer's flattened mask layout. +def convert_bsr_mask_layout( + mask: paddle.Tensor, indptr: paddle.Tensor +) -> paddle.Tensor: + """Convert mask from BSR data layout to flashinfer's flattened mask layout. Parameters ---------- @@ -52,18 +50,20 @@ def convert_bsr_mask_layout(mask: torch.Tensor, indptr: torch.Tensor) -> torch.T flattened_mask : torch.Tensor A flattenedd mask tensor with shape ``(nnz * R * C,)``. """ - nnz, R, C = mask.shape + nnz, R, C = tuple(mask.shape) MB = len(indptr) - 1 - mask_flashinfer = torch.empty((nnz * R * C,), dtype=mask.dtype, device=mask.device) + mask_flashinfer = paddle.empty(shape=(nnz * R * C,), dtype=mask.dtype) for i in range(MB): mask_flashinfer[indptr[i] * R * C : indptr[i + 1] * R * C] = ( - mask[indptr[i] : indptr[i + 1]].transpose(0, 1).reshape(-1) + mask[indptr[i] : indptr[i + 1]] + .transpose(perm=dim2perm(mask[indptr[i] : indptr[i + 1]].ndim, 0, 1)) + .reshape(-1) ) return mask_flashinfer class BlockSparseAttentionWrapper: - r"""Wrapper class for attention computation with a block-sparse matrix as attention mask. + """Wrapper class for attention computation with a block-sparse matrix as attention mask. The definition of block sparse matrix can be found at `bsr_matrix `_ in SciPy. @@ -108,11 +108,9 @@ class BlockSparseAttentionWrapper: """ def __init__( - self, - float_workspace_buffer: torch.Tensor, - backend: str = "auto", + self, float_workspace_buffer: paddle.Tensor, backend: str = "auto" ) -> None: - r"""Constructs of :class:`BlockSparseAttentionWrapper`. + """Constructs of :class:`BlockSparseAttentionWrapper`. Parameters ---------- @@ -126,38 +124,29 @@ def __init__( device architecture and kernel availability. """ self._float_workspace_buffer = float_workspace_buffer - self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + self.device = float_workspace_buffer.place + self._int_workspace_buffer = paddle.empty( + shape=(8 * 1024 * 1024,), dtype="uint8" ) if backend in ["fa3", "auto"]: - # NOTE(Zihao): assume maximum accumulate kv length is 128M - # NOTE(Yilong): 128M is required by video DiT models - self._vector_sparse_indices_buffer = torch.empty( - (128 * 1024 * 1024,), dtype=torch.int32, device=self.device + self._vector_sparse_indices_buffer = paddle.empty( + shape=(128 * 1024 * 1024,), dtype="int32" ) - # NOTE(Zihao): assume maximum batch size is 32768 - self._vector_sparse_indptr_buffer = torch.empty( - (32768,), dtype=torch.int32, device=self.device + self._vector_sparse_indptr_buffer = paddle.empty( + shape=(32768,), dtype="int32" ) - - self._kv_lens_buffer = torch.empty( - (32768,), dtype=torch.int32, device=self.device - ) - self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, - dtype=torch.uint8, - pin_memory=True, - device="cpu", - ) + self._kv_lens_buffer = paddle.empty(shape=(32768,), dtype="int32") + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=tuple(self._int_workspace_buffer.shape), dtype="uint8" + ).pin_memory() self._use_cuda_graph = False self._kv_layout = "NHD" - self._qo_indptr: Optional[torch.Tensor] = None - self._paged_kv_indptr_buf: Optional[torch.Tensor] = None - self._paged_kv_indices_buf: Optional[torch.Tensor] = None - self._paged_kv_last_page_len: Optional[torch.Tensor] = None - self._packed_mask_buf: Optional[torch.Tensor] = None - self._mask_indptr_buf: Optional[torch.Tensor] = None + self._qo_indptr: Optional[paddle.Tensor] = None + self._paged_kv_indptr_buf: Optional[paddle.Tensor] = None + self._paged_kv_indices_buf: Optional[paddle.Tensor] = None + self._paged_kv_last_page_len: Optional[paddle.Tensor] = None + self._packed_mask_buf: Optional[paddle.Tensor] = None + self._mask_indptr_buf: Optional[paddle.Tensor] = None self.R: Optional[int] = None self.C: Optional[int] = None self.M: Optional[int] = None @@ -166,12 +155,12 @@ def __init__( def reset_workspace_buffer( self, - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, - vector_sparse_indices_buffer: Optional[torch.Tensor] = None, - vector_sparse_indptr_buffer: Optional[torch.Tensor] = None, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, + vector_sparse_indices_buffer: Optional[paddle.Tensor] = None, + vector_sparse_indptr_buffer: Optional[paddle.Tensor] = None, ) -> None: - r"""Reset the workspace buffer. + """Reset the workspace buffer. Parameters ---------- @@ -185,13 +174,10 @@ def reset_workspace_buffer( """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=tuple(self._int_workspace_buffer.shape), dtype=self._int_workspace_buffer.dtype, - pin_memory=True, - ) - - # Enable user-defined size + ).pin_memory() if vector_sparse_indices_buffer is not None: self._vector_sparse_indices_buffer = vector_sparse_indices_buffer if vector_sparse_indptr_buffer is not None: @@ -199,8 +185,8 @@ def reset_workspace_buffer( def plan( self, - indptr: torch.Tensor, - indices: torch.Tensor, + indptr: paddle.Tensor, + indices: paddle.Tensor, M: int, N: int, R: int, @@ -208,8 +194,8 @@ def plan( num_qo_heads: int, num_kv_heads: int, head_dim: int, - mask: Optional[torch.Tensor] = None, - packed_mask: Optional[torch.Tensor] = None, + mask: Optional[paddle.Tensor] = None, + packed_mask: Optional[paddle.Tensor] = None, causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, @@ -217,12 +203,12 @@ def plan( sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - q_data_type: Union[str, torch.dtype] = "float16", - kv_data_type: Optional[Union[str, torch.dtype]] = None, - o_data_type: Union[str, torch.dtype] = "float16", + q_data_type: Union[str, paddle.dtype] = "float16", + kv_data_type: Optional[Union[str, paddle.dtype]] = None, + o_data_type: Union[str, paddle.dtype] = "float16", non_blocking: bool = True, ) -> None: - r"""Create auxiliary data structures for block sparse attention. + """Create auxiliary data structures for block sparse attention. Parameters ---------- @@ -268,7 +254,7 @@ def plan( The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: - :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, + :math:`\\texttt{logits_soft_cap} \\times \\mathrm{tanh}(x / \\texttt{logits_soft_cap})`, where :math:`x` is the input logits. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to @@ -302,47 +288,38 @@ def plan( kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) self._o_dtype = canonicalize_torch_dtype(o_data_type) - if logits_soft_cap is None: logits_soft_cap = 0.0 - num_blocks_row = len(indptr) - 1 - qo_indptr_host = R * torch.arange(num_blocks_row + 1, dtype=torch.int32) + qo_indptr_host = R * paddle.arange(dtype="int32", end=num_blocks_row + 1) qo_indptr_host[-1] = M - qo_indptr = qo_indptr_host.to(indptr.device, non_blocking=non_blocking) - if indices.max().item() * C > N: + qo_indptr = qo_indptr_host.to(indptr.place, blocking=not non_blocking) + if indices._max().item() * C > N: raise ValueError("indices out of bound") - last_block_len = torch.full( - (num_blocks_row,), C, dtype=torch.int32, device=indptr.device + last_block_len = paddle.full( + shape=(num_blocks_row,), fill_value=C, dtype="int32" ) - if mask is not None or packed_mask is not None: mask_indptr = _compute_page_mask_indptr( - qo_indptr, - indptr, # paged_kv_indptr - last_block_len, # paged_kv_last_page_len - C, # page_size + qo_indptr, indptr, last_block_len, C ) if packed_mask is None and mask is not None: - # first convert BSR mask to flashinfer layout mask = convert_bsr_mask_layout(mask, indptr) - # create packed mask from mask packed_mask, mask_indptr = segment_packbits( mask.contiguous().view(-1), mask_indptr, bitorder="little" ) - - self._qo_indptr = qo_indptr.to(self.device, non_blocking=non_blocking) - self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=non_blocking) - self._paged_kv_indices_buf = indices.to(self.device, non_blocking=non_blocking) + self._qo_indptr = qo_indptr.to(self.device, blocking=not non_blocking) + self._paged_kv_indptr_buf = indptr.to(self.device, blocking=not non_blocking) + self._paged_kv_indices_buf = indices.to(self.device, blocking=not non_blocking) self._paged_kv_last_page_len = last_block_len.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) if packed_mask is not None: self._packed_mask_buf = packed_mask.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) self._mask_indptr_buf = mask_indptr.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) mask_mode = MaskMode.CUSTOM.value else: @@ -350,23 +327,16 @@ def plan( self._mask_indptr_buf = None mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value self._mask_mode = mask_mode - self.M = M self.N = N self.R = R self.C = C - kv_indptr_host = indptr.to("cpu") - - # NOTE(Zihao): we haven't supported mask in cuda-core implementations but it should - # be easy to add support for it if needed, leave it as a future work. - # at this moment, when mask is provided, we use the tensor-core implementation if ( R * (num_qo_heads // num_kv_heads) < 4 and mask_mode != MaskMode.CUSTOM.value - and q_data_type not in [torch.float8_e4m3fn, torch.float8_e5m2] + and q_data_type not in [paddle.float8_e4m3fn, paddle.float8_e5m2] ): - # If the operation is not compute-bound, we use the cuda-core implementation self._use_tensor_cores = False self._cached_module = get_batch_decode_module( q_data_type, @@ -376,10 +346,9 @@ def plan( head_dim, head_dim, PosEncodingMode[pos_encoding_mode].value, - False, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap + False, + logits_soft_cap > 0, ) - self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, @@ -389,63 +358,60 @@ def plan( num_qo_heads, num_kv_heads, C, - False, # is_cuda_graph_enabled - -1, # window_left - logits_soft_cap, # logits_soft_cap + False, + -1, + logits_soft_cap, head_dim, head_dim, - torch.empty(0, dtype=q_data_type), - torch.empty(0, dtype=kv_data_type), + paddle.empty(shape=[0], dtype=q_data_type), + paddle.empty(shape=[0], dtype=kv_data_type), ) else: - # if the operation is compute-bound, we use the tensor-core implementation self._use_tensor_cores = True - if self._backend == "auto": self._backend = determine_attention_backend( self.device, PosEncodingMode[pos_encoding_mode].value, use_fp16_qk_reduction, - mask_mode == MaskMode.CUSTOM.value, # use_custom_mask + mask_mode == MaskMode.CUSTOM.value, q_data_type, kv_data_type, ) - get_module_args = ( q_data_type, kv_data_type, self._o_dtype, indptr.dtype, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + head_dim, + head_dim, PosEncodingMode[pos_encoding_mode].value, - False, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap + False, + logits_soft_cap > 0, use_fp16_qk_reduction, ) self._cached_module = get_batch_prefill_module( self._backend, *get_module_args ) - kv_lens_arr_host = (kv_indptr_host[1:] - kv_indptr_host[:-1]) * self.C - self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( - kv_lens_arr_host, + paddle.assign( + kv_lens_arr_host, output=self._kv_lens_buffer[: len(kv_lens_arr_host)] ) - if self._backend == "fa3": if self.C != 1: - vector_sparse_indptr_host = torch.cat( - [ - torch.tensor([0], dtype=torch.int32), - torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), + vector_sparse_indptr_host = paddle.concat( + x=[ + paddle.to_tensor(data=[0], dtype="int32"), + paddle.cumsum(x=kv_lens_arr_host, axis=0, dtype="int32"), + ], + axis=0, + ) + paddle.assign( + vector_sparse_indptr_host, + output=self._vector_sparse_indptr_buffer[ + : len(vector_sparse_indptr_host) ], - dim=0, ) - self._vector_sparse_indptr_buffer[ - : len(vector_sparse_indptr_host) - ].copy_(vector_sparse_indptr_host, non_blocking=non_blocking) kv_indptr_host = vector_sparse_indptr_host - self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, @@ -453,17 +419,16 @@ def plan( qo_indptr_host, kv_indptr_host, kv_lens_arr_host, - M, # total_num_rows - num_blocks_row, # batch_size + M, + num_blocks_row, num_qo_heads, num_kv_heads, - self.C, # page_size - False, # is_cuda_graph_enabled, + self.C, + False, head_dim, head_dim, causal, ) - self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._logits_soft_cap = logits_soft_cap @@ -475,20 +440,20 @@ def plan( def forward( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale_q: Optional[torch.Tensor] = None, - scale_k: Optional[torch.Tensor] = None, - scale_v: Optional[torch.Tensor] = None, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + scale_q: Optional[paddle.Tensor] = None, + scale_k: Optional[paddle.Tensor] = None, + scale_v: Optional[paddle.Tensor] = None, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - ) -> torch.Tensor: - r"""Warning: This method is deprecated, please use :meth:`run` instead.""" + ) -> paddle.Tensor: + """Warning: This method is deprecated, please use :meth:`run` instead.""" self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._logits_soft_cap = logits_soft_cap @@ -499,18 +464,18 @@ def forward( def run( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale_q: Optional[torch.Tensor] = None, - scale_k: Optional[torch.Tensor] = None, - scale_v: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + scale_q: Optional[paddle.Tensor] = None, + scale_k: Optional[paddle.Tensor] = None, + scale_v: Optional[paddle.Tensor] = None, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: bool = False, enable_pdl: Optional[bool] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - r"""Compute block-sparse attention between Q/K/V tensors. + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Compute block-sparse attention between Q/K/V tensors. Parameters ---------- @@ -549,8 +514,7 @@ def run( * The logsumexp of attention output, shape: ``[M, num_qo_heads]``. """ if enable_pdl is None: - enable_pdl = device_support_pdl(q.device) - + enable_pdl = device_support_pdl(q.place) pos_encoding_mode = self._pos_encoding_mode logits_soft_cap = self._logits_soft_cap sm_scale = self._sm_scale @@ -560,69 +524,59 @@ def run( if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: - sm_scale = 1.0 / math.sqrt(q.size(-1)) + sm_scale = 1.0 / math.sqrt(q.shape[-1]) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: - rope_theta = 1e4 - k = k.reshape(-1, self.C, *k.shape[-2:]) - v = v.reshape(-1, self.C, *v.shape[-2:]) - - stride_block = k.stride(0) - stride_n = k.stride(1) - + rope_theta = 10000.0 + k = k.reshape(-1, self.C, *tuple(k.shape)[-2:]) + v = v.reshape(-1, self.C, *tuple(v.shape)[-2:]) + stride_block = k.get_strides()[0] + stride_n = k.get_strides()[1] if return_lse: if lse is None: - lse = torch.empty( - (q.size(0), q.size(1)), dtype=torch.float32, device=q.device - ) + lse = paddle.empty(shape=(q.shape[0], q.shape[1]), dtype="float32") else: check_shape_dtype_device( - lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" + lse, (q.shape[0], q.shape[1]), "float32", q.place, "lse" ) - if out is None: - out = torch.empty_like(q, dtype=self._o_dtype) + out = paddle.empty_like(x=q, dtype=self._o_dtype) else: - check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out") - + check_shape_dtype_device(out, tuple(q.shape), self._o_dtype, q.place, "out") if is_float8(q): assert q.dtype == k.dtype == v.dtype - assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert tuple(q.shape)[-1] == tuple(k.shape)[-1] == tuple(v.shape)[-1] assert self._backend == "fa3" and self._use_tensor_cores - if scale_q is None: - scale_q = torch.ones(q.shape[1], dtype=torch.float32, device=q.device) + scale_q = paddle.ones(shape=tuple(q.shape)[1], dtype="float32") if scale_k is None: - scale_k = torch.ones(k.shape[1], dtype=torch.float32, device=q.device) + scale_k = paddle.ones(shape=tuple(k.shape)[1], dtype="float32") if scale_v is None: - scale_v = torch.ones(v.shape[1], dtype=torch.float32, device=q.device) - + scale_v = paddle.ones(shape=tuple(v.shape)[1], dtype="float32") if self._use_tensor_cores: if self._backend == "fa3": if ( - self._vector_sparse_indices_buffer.numel() - <= self._paged_kv_indices_buf.numel() * self.C + self._vector_sparse_indices_buffer.size + <= self._paged_kv_indices_buf.size * self.C ): raise ValueError( "_vector_sparse_indices_buffer is not large enough. Please increase the size." ) - sparse_indices = block_sparse_indices_to_vector_sparse_offsets( self._paged_kv_indices_buf, self._paged_kv_indptr_buf, - self._vector_sparse_indices_buffer, # output + self._vector_sparse_indices_buffer, self._vector_sparse_indptr_buffer, self._kv_lens_buffer, stride_block // stride_n, - 1, # stride_n // stride_n - self.C, # block_size + 1, + self.C, ) sparse_indptr = self._vector_sparse_indptr_buffer else: sparse_indices = self._paged_kv_indices_buf sparse_indptr = self._paged_kv_indptr_buf - self._cached_module.paged_run( self._float_workspace_buffer, self._int_workspace_buffer, @@ -638,15 +592,14 @@ def run( lse, self._mask_mode, TensorLayout[self._kv_layout].value, - -1, # window_left + -1, enable_pdl, - # ADDITIONAL_FUNC_PARAMS self._packed_mask_buf, self._mask_indptr_buf, - _get_cache_alibi_slopes_buf(q.shape[1], self.device), - None, # maybe_prefix_len_ptr - None, # maybe_token_pos_in_items_ptr - None, # maybe_max_item_len_ptr + _get_cache_alibi_slopes_buf(tuple(q.shape)[1], self.device), + None, + None, + None, logits_soft_cap, sm_scale, scale_q, @@ -654,7 +607,7 @@ def run( scale_v, rope_scale, rope_theta, - 0, # token_pos_in_items_len + 0, ) else: self._cached_module.run( @@ -670,25 +623,23 @@ def run( out, lse, TensorLayout[self._kv_layout].value, - -1, # window_left + -1, enable_pdl, - # ADDITIONAL_FUNC_PARAMS - _get_cache_alibi_slopes_buf(q.shape[1], self.device), + _get_cache_alibi_slopes_buf(tuple(q.shape)[1], self.device), logits_soft_cap, sm_scale, rope_scale, rope_theta, ) - return (out, lse) if return_lse else out def end_forward(self) -> None: - r"""Warning: This method is deprecated and has no effect.""" + """Warning: This method is deprecated and has no effect.""" pass class VariableBlockSparseAttentionWrapper: - r"""Wrapper class for attention computation with a block-sparse matrix as attention mask. + """Wrapper class for attention computation with a block-sparse matrix as attention mask. This API supports variable block sizes provided by ``block_row_sz`` and ``block_col_sz``. Besides, each ``kv_head_idx`` can specify its own sparse patterns without using the same mask. @@ -721,11 +672,9 @@ class VariableBlockSparseAttentionWrapper: """ def __init__( - self, - float_workspace_buffer: torch.Tensor, - backend: str = "auto", + self, float_workspace_buffer: paddle.Tensor, backend: str = "auto" ) -> None: - r"""Constructs of :class:`VariableBlockSparseAttentionWrapper`. + """Constructs of :class:`VariableBlockSparseAttentionWrapper`. Parameters ---------- @@ -739,43 +688,37 @@ def __init__( device architecture and kernel availability. """ self._float_workspace_buffer = float_workspace_buffer - self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + self.device = float_workspace_buffer.place + self._int_workspace_buffer = paddle.empty( + shape=(8 * 1024 * 1024,), dtype="uint8" ) if backend in ["fa3", "auto"]: - self._vector_sparse_indices_buffer = torch.empty( - (128 * 1024 * 1024,), dtype=torch.int32, device=self.device + self._vector_sparse_indices_buffer = paddle.empty( + shape=(128 * 1024 * 1024,), dtype="int32" ) - self._vector_sparse_indptr_buffer = torch.empty( - (32768,), dtype=torch.int32, device=self.device + self._vector_sparse_indptr_buffer = paddle.empty( + shape=(32768,), dtype="int32" ) - - self._kv_lens_buffer = torch.empty( - (32768,), dtype=torch.int32, device=self.device - ) - self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, - dtype=torch.uint8, - pin_memory=True, - device="cpu", - ) + self._kv_lens_buffer = paddle.empty(shape=(32768,), dtype="int32") + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=tuple(self._int_workspace_buffer.shape), dtype="uint8" + ).pin_memory() self._use_cuda_graph = False self._kv_layout = "NHD" - self._qo_indptr: Optional[torch.Tensor] = None - self._paged_kv_indptr_buf: Optional[torch.Tensor] = None - self._paged_kv_indices_buf: Optional[torch.Tensor] = None - self._paged_kv_last_page_len: Optional[torch.Tensor] = None + self._qo_indptr: Optional[paddle.Tensor] = None + self._paged_kv_indptr_buf: Optional[paddle.Tensor] = None + self._paged_kv_indices_buf: Optional[paddle.Tensor] = None + self._paged_kv_last_page_len: Optional[paddle.Tensor] = None self._backend = backend def reset_workspace_buffer( self, - float_workspace_buffer: torch.Tensor, - int_workspace_buffer: torch.Tensor, - vector_sparse_indices_buffer: Optional[torch.Tensor] = None, - vector_sparse_indptr_buffer: Optional[torch.Tensor] = None, + float_workspace_buffer: paddle.Tensor, + int_workspace_buffer: paddle.Tensor, + vector_sparse_indices_buffer: Optional[paddle.Tensor] = None, + vector_sparse_indptr_buffer: Optional[paddle.Tensor] = None, ) -> None: - r"""Reset the workspace buffer. + """Reset the workspace buffer. Parameters ---------- @@ -789,13 +732,10 @@ def reset_workspace_buffer( """ self._float_workspace_buffer = float_workspace_buffer self._int_workspace_buffer = int_workspace_buffer - self._pin_memory_int_workspace_buffer = torch.empty( - self._int_workspace_buffer.shape, + self._pin_memory_int_workspace_buffer = paddle.empty( + shape=tuple(self._int_workspace_buffer.shape), dtype=self._int_workspace_buffer.dtype, - pin_memory=True, - ) - - # Enable user-defined size + ).pin_memory() if vector_sparse_indices_buffer is not None: self._vector_sparse_indices_buffer = vector_sparse_indices_buffer if vector_sparse_indptr_buffer is not None: @@ -803,9 +743,9 @@ def reset_workspace_buffer( def plan( self, - block_mask_map: torch.Tensor, - block_row_sz: torch.Tensor, - block_col_sz: torch.Tensor, + block_mask_map: paddle.Tensor, + block_row_sz: paddle.Tensor, + block_col_sz: paddle.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, @@ -817,10 +757,10 @@ def plan( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, non_blocking: bool = True, - q_data_type: Union[str, torch.dtype] = "float16", - kv_data_type: Optional[Union[str, torch.dtype]] = None, + q_data_type: Union[str, paddle.dtype] = "float16", + kv_data_type: Optional[Union[str, paddle.dtype]] = None, ) -> None: - r"""Create auxiliary data structures for block sparse attention. + """Create auxiliary data structures for block sparse attention. Parameters ---------- @@ -850,7 +790,7 @@ def plan( The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to formula: - :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, + :math:`\\texttt{logits_soft_cap} \\times \\mathrm{tanh}(x / \\texttt{logits_soft_cap})`, where :math:`x` is the input logits. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to @@ -877,39 +817,25 @@ def plan( kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) self._o_dtype = q_data_type - if logits_soft_cap is None: logits_soft_cap = 0.0 - - # num_blocks are constant across kv_heads - num_blocks_row = block_row_sz.shape[-1] - num_blocks_col = block_col_sz.shape[-1] - - # q layout: [seq_len, num_kv_heads, gqa_group_size, head_dim] - # padded into: [seq_len * num_kv_heads, 1, gqa_group_size, head_dim] - qo_indptr = torch.cat( - [ - torch.zeros(1, dtype=torch.int32, device=block_row_sz.device), - torch.cumsum(block_row_sz.flatten(), dim=0, dtype=torch.int32), + num_blocks_row = tuple(block_row_sz.shape)[-1] + num_blocks_col = tuple(block_col_sz.shape)[-1] + qo_indptr = paddle.concat( + x=[ + paddle.zeros(shape=[1], dtype="int32"), + paddle.cumsum(x=block_row_sz.flatten(), axis=0, dtype="int32"), ], - dim=0, + axis=0, + ) + qo_indptr_host = qo_indptr.to("cpu", blocking=not non_blocking) + last_block_len = paddle.full( + shape=(num_blocks_row * num_kv_heads,), fill_value=1, dtype="int32" ) - qo_indptr_host = qo_indptr.to("cpu", non_blocking=non_blocking) - last_block_len = torch.full( - (num_blocks_row * num_kv_heads,), - 1, - dtype=torch.int32, - device=block_mask_map.device, - ) # We use page_size == 1 for variable length support - # HND kv layout: [num_kv_heads, num_blocks, block_size, head_dim] - # padded into: [num_kv_heads * num_blocks, block_size, 1, head_dim] - # for customized attention mask for each kv_head - # NOTE(Yilong): This could be perf bottleneck. Consider Triton implementation. def _block_mask_map_to_expanded_indices( - block_mask_map: torch.Tensor, # [H, R, C] bool / {0,1} - block_col_sz: torch.Tensor, # [H, C] int - ) -> Tuple[torch.Tensor, torch.Tensor]: + block_mask_map: paddle.Tensor, block_col_sz: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor]: """ Args: block_mask_map: bool/int [num_kv_heads, num_blocks_row, num_blocks_col] @@ -918,37 +844,30 @@ def _block_mask_map_to_expanded_indices( kv_indptr: [H*R + 1] int32 — CSR indptr kv_indices: [nnz] int32 — token indices per (head, row) """ - device = block_mask_map.device - dtype_i = torch.int32 - - # 1) Calculate the total length of each row (head, row) - row_lengths = (block_mask_map * block_col_sz[:, None, :]).sum(-1) # [H,R] - kv_indptr = torch.cat( - [ - torch.zeros(1, dtype=dtype_i, device=device), - torch.cumsum(row_lengths.flatten(), 0), + device = block_mask_map.place + dtype_i = "int32" + row_lengths = (block_mask_map * block_col_sz[:, None, :]).sum(axis=-1) + kv_indptr = paddle.concat( + x=[ + paddle.zeros(shape=[1], dtype=dtype_i), + paddle.cumsum(x=row_lengths.flatten(), axis=0), ], - dim=0, + axis=0, ) - - # 2) Calculate the offset of each column block within its head col_offset = ( - torch.cumsum(block_col_sz.to(dtype_i), 1) - block_col_sz - ) # [H,C] - head_len = block_col_sz.sum(1, dtype=dtype_i) - head_offset = torch.cumsum(head_len, 0) - head_len - - # 3) Find all selected (h,r,c) + paddle.cumsum(x=block_col_sz.to(dtype_i), axis=1) - block_col_sz + ) + head_len = block_col_sz.sum(axis=1, dtype=dtype_i) + head_offset = paddle.cumsum(x=head_len, axis=0) - head_len h_idx, r_idx, c_idx = block_mask_map.nonzero(as_tuple=True) - lengths = block_col_sz[h_idx, c_idx].to(dtype_i) # [N] - base = head_offset[h_idx] + col_offset[h_idx, c_idx] # [N] - - # 4) Expand variable-length column blocks into token-level indices - cum = torch.cumsum(lengths, 0) - starts = torch.repeat_interleave(cum - lengths, lengths) # [total] - offsets_within = torch.arange(cum[-1], device=device) - starts - kv_indices = torch.repeat_interleave(base, lengths) + offsets_within - + lengths = block_col_sz[h_idx, c_idx].to(dtype_i) + base = head_offset[h_idx] + col_offset[h_idx, c_idx] + cum = paddle.cumsum(x=lengths, axis=0) + starts = paddle.repeat_interleave(x=cum - lengths, repeats=lengths) + offsets_within = paddle.arange(end=cum[-1]) - starts + kv_indices = ( + paddle.repeat_interleave(x=base, repeats=lengths) + offsets_within + ) return kv_indptr.to(dtype=dtype_i, device=device), kv_indices.to( dtype=dtype_i, device=device ) @@ -956,72 +875,64 @@ def _block_mask_map_to_expanded_indices( kv_indptr, kv_indices = _block_mask_map_to_expanded_indices( block_mask_map, block_col_sz ) - kv_indptr_host = kv_indptr.to("cpu", non_blocking=non_blocking) - kv_indices_host = kv_indices.to("cpu", non_blocking=non_blocking) - - self._qo_indptr = qo_indptr.to(self.device, non_blocking=non_blocking) - self._paged_kv_indptr_buf = kv_indptr.to(self.device, non_blocking=non_blocking) + kv_indptr_host = kv_indptr.to("cpu", blocking=not non_blocking) + kv_indices_host = kv_indices.to("cpu", blocking=not non_blocking) + self._qo_indptr = qo_indptr.to(self.device, blocking=not non_blocking) + self._paged_kv_indptr_buf = kv_indptr.to(self.device, blocking=not non_blocking) self._paged_kv_indices_buf = kv_indices.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) self._paged_kv_last_page_len = last_block_len.to( - self.device, non_blocking=non_blocking + self.device, blocking=not non_blocking ) - torch.cuda.synchronize() # for non-blocking copy + paddle.device.synchronize() self._mask_mode = MaskMode.CAUSAL.value if causal else MaskMode.NON_CAUSAL.value - - # Sanity check - assert num_qo_heads % num_kv_heads == 0, ( - "num_qo_heads must be a multiple of num_kv_heads" - ) - assert num_blocks_row * num_kv_heads + 1 == kv_indptr_host.shape[0] - assert kv_indptr_host[-1].item() == kv_indices_host.shape[0], ( - f"{kv_indptr_host[-1].item()} != {kv_indices_host.shape[0]}" - ) - assert num_kv_heads == block_mask_map.shape[0] - assert num_kv_heads == block_row_sz.shape[0] - assert num_kv_heads == block_col_sz.shape[0] - assert num_blocks_row == block_mask_map.shape[1] - assert num_blocks_col == block_mask_map.shape[2] - + assert ( + num_qo_heads % num_kv_heads == 0 + ), "num_qo_heads must be a multiple of num_kv_heads" + assert num_blocks_row * num_kv_heads + 1 == tuple(kv_indptr_host.shape)[0] + assert ( + kv_indptr_host[-1].item() == tuple(kv_indices_host.shape)[0] + ), f"{kv_indptr_host[-1].item()} != {tuple(kv_indices_host.shape)[0]}" + assert num_kv_heads == tuple(block_mask_map.shape)[0] + assert num_kv_heads == tuple(block_row_sz.shape)[0] + assert num_kv_heads == tuple(block_col_sz.shape)[0] + assert num_blocks_row == tuple(block_mask_map.shape)[1] + assert num_blocks_col == tuple(block_mask_map.shape)[2] if self._backend == "auto": self._backend = determine_attention_backend( self.device, PosEncodingMode[pos_encoding_mode].value, use_fp16_qk_reduction, - self._mask_mode == MaskMode.CUSTOM.value, # use_custom_mask + self._mask_mode == MaskMode.CUSTOM.value, q_data_type, kv_data_type, ) - get_module_args = ( q_data_type, kv_data_type, self._o_dtype, kv_indptr_host.dtype, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + head_dim, + head_dim, PosEncodingMode[pos_encoding_mode].value, - False, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap + False, + logits_soft_cap > 0, use_fp16_qk_reduction, ) self._cached_module = get_batch_prefill_module(self._backend, *get_module_args) - - kv_lens_arr_host = kv_indptr_host[1:] - kv_indptr_host[:-1] # page_size == 1 - self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( - kv_lens_arr_host, + kv_lens_arr_host = kv_indptr_host[1:] - kv_indptr_host[:-1] + paddle.assign( + kv_lens_arr_host, output=self._kv_lens_buffer[: len(kv_lens_arr_host)] ) - if self._backend == "fa3": - if self._vector_sparse_indptr_buffer.numel() <= kv_indptr.numel(): + if self._vector_sparse_indptr_buffer.size <= kv_indptr.size: raise ValueError( "_vector_sparse_indptr_buffer is not large enough. Please increase the buffer size." ) - self._vector_sparse_indptr_buffer[: len(kv_indptr)].copy_( - kv_indptr, non_blocking=non_blocking + paddle.assign( + kv_indptr, output=self._vector_sparse_indptr_buffer[: len(kv_indptr)] ) - self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, @@ -1029,17 +940,16 @@ def _block_mask_map_to_expanded_indices( qo_indptr_host, kv_indptr_host, kv_lens_arr_host, - qo_indptr_host[-1].item(), # total_num_rows - num_blocks_row * num_kv_heads, # batch_size - num_qo_heads // num_kv_heads, # num_qo_heads (gqa_group_size) - 1, # num_kv_heads, - 1, # page_size - False, # is_cuda_graph_enabled, + qo_indptr_host[-1].item(), + num_blocks_row * num_kv_heads, + num_qo_heads // num_kv_heads, + 1, + 1, + False, head_dim, head_dim, causal, ) - self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._logits_soft_cap = logits_soft_cap @@ -1051,17 +961,17 @@ def _block_mask_map_to_expanded_indices( def forward( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, - ) -> torch.Tensor: - r"""Warning: This method is deprecated, please use :meth:`run` instead.""" + ) -> paddle.Tensor: + """Warning: This method is deprecated, please use :meth:`run` instead.""" self._pos_encoding_mode = pos_encoding_mode self._use_fp16_qk_reduction = use_fp16_qk_reduction self._logits_soft_cap = logits_soft_cap @@ -1072,15 +982,15 @@ def forward( def run( self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: Optional[torch.Tensor] = None, - lse: Optional[torch.Tensor] = None, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + out: Optional[paddle.Tensor] = None, + lse: Optional[paddle.Tensor] = None, return_lse: bool = False, enable_pdl: Optional[bool] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - r"""Compute block-sparse attention between Q/K/V tensors. + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """Compute block-sparse attention between Q/K/V tensors. Parameters ---------- @@ -1109,12 +1019,8 @@ def run( * The attention output, shape: ``[M, num_qo_heads, head_dim]``. * The logsumexp of attention output, shape: ``[M, num_qo_heads]``. """ - # NOTE(Zihao): defer import of einops - import einops - if enable_pdl is None: - enable_pdl = device_support_pdl(q.device) - + enable_pdl = device_support_pdl(q.place) pos_encoding_mode = self._pos_encoding_mode logits_soft_cap = self._logits_soft_cap sm_scale = self._sm_scale @@ -1124,72 +1030,57 @@ def run( if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: - sm_scale = 1.0 / math.sqrt(q.size(-1)) + sm_scale = 1.0 / math.sqrt(q.shape[-1]) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: - rope_theta = 1e4 - - # reshape to pad num_kv_heads into seq_len - # input [num_qo_heads, qo_len, head_dim] - # kernel layout is NHD -> [qo_len * num_kv_heads, gqa_group_size, head_dim] + rope_theta = 10000.0 q = einops.rearrange( q, "(num_kv_heads gqa_group_size) qo_len head_dim -> (num_kv_heads qo_len) gqa_group_size head_dim", num_kv_heads=self._num_kv_heads, ).contiguous() - # HND -> [kv_len * num_kv_heads (num_pages), 1 (page_size), 1 (new_num_kv_heads), head_dim] k = einops.rearrange( - k, - "num_kv_heads kv_len head_dim -> (num_kv_heads kv_len) 1 1 head_dim", + k, "num_kv_heads kv_len head_dim -> (num_kv_heads kv_len) 1 1 head_dim" ).contiguous() v = einops.rearrange( - v, - "num_kv_heads kv_len head_dim -> (num_kv_heads kv_len) 1 1 head_dim", + v, "num_kv_heads kv_len head_dim -> (num_kv_heads kv_len) 1 1 head_dim" ).contiguous() - - stride_block = k.stride(0) - stride_n = k.stride(1) - + stride_block = k.get_strides()[0] + stride_n = k.get_strides()[1] if return_lse: if lse is None: - lse = torch.empty( - (q.size(0), q.size(1)), dtype=torch.float32, device=q.device - ) + lse = paddle.empty(shape=(q.shape[0], q.shape[1]), dtype="float32") else: check_shape_dtype_device( - lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" + lse, (q.shape[0], q.shape[1]), "float32", q.place, "lse" ) - if out is None: - out = torch.empty_like(q, dtype=self._o_dtype) + out = paddle.empty_like(x=q, dtype=self._o_dtype) else: - check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out") - + check_shape_dtype_device(out, tuple(q.shape), self._o_dtype, q.place, "out") if self._backend == "fa3": if ( - self._vector_sparse_indices_buffer.numel() - <= self._paged_kv_indices_buf.numel() + self._vector_sparse_indices_buffer.size + <= self._paged_kv_indices_buf.size ): raise ValueError( "_vector_sparse_indices_buffer is not large enough. Please increase the buffer size." ) - sparse_indices = block_sparse_indices_to_vector_sparse_offsets( self._paged_kv_indices_buf, self._paged_kv_indptr_buf, - self._vector_sparse_indices_buffer, # output + self._vector_sparse_indices_buffer, self._vector_sparse_indptr_buffer, self._kv_lens_buffer, stride_block // stride_n, - 1, # stride_n // stride_n - 1, # block_size + 1, + 1, ) sparse_indptr = self._vector_sparse_indptr_buffer else: sparse_indices = self._paged_kv_indices_buf sparse_indptr = self._paged_kv_indptr_buf - self._cached_module.paged_run( self._float_workspace_buffer, self._int_workspace_buffer, @@ -1205,38 +1096,32 @@ def run( lse, self._mask_mode, TensorLayout[self._kv_layout].value, - -1, # window_left + -1, enable_pdl, - # ADDITIONAL_FUNC_PARAMS - # Not supported yet - None, # packed_mask_buf - None, # mask_indptr_buf - None, # alibi_slopes_buf + None, + None, + None, None, None, None, logits_soft_cap, sm_scale, - None, # scale_q - None, # scale_k - None, # scale_v + None, + None, + None, rope_scale, rope_theta, - 0, # token_pos_in_items_len + 0, ) - - # [qo_len * num_kv_heads, gqa_group_size, head_dim] -> HND out = einops.rearrange( out, "(num_kv_heads qo_len) gqa_group_size head_dim -> (num_kv_heads gqa_group_size) qo_len head_dim", num_kv_heads=self._num_kv_heads, ).contiguous() - if return_lse: lse = einops.rearrange( lse, "(num_kv_heads qo_len) gqa_group_size -> (num_kv_heads gqa_group_size) qo_len", num_kv_heads=self._num_kv_heads, ).contiguous() - return (out, lse) if return_lse else out diff --git a/flashinfer/testing/__init__.py b/flashinfer/testing/__init__.py index c034f0784f..577869663d 100644 --- a/flashinfer/testing/__init__.py +++ b/flashinfer/testing/__init__.py @@ -13,16 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. """ - -from .utils import ( - attention_flops, - attention_flops_with_actual_seq_lens, - attention_tb_per_sec, - attention_tb_per_sec_with_actual_seq_lens, - attention_tflops_per_sec, - attention_tflops_per_sec_with_actual_seq_lens, - bench_gpu_time, - bench_gpu_time_with_cudagraph, - set_seed, - sleep_after_kernel_run, -) +from .utils import (attention_flops, attention_flops_with_actual_seq_lens, + attention_tb_per_sec, + attention_tb_per_sec_with_actual_seq_lens, + attention_tflops_per_sec, + attention_tflops_per_sec_with_actual_seq_lens, + bench_gpu_time, bench_gpu_time_with_cudagraph, set_seed, + sleep_after_kernel_run) diff --git a/flashinfer/testing/utils.py b/flashinfer/testing/utils.py index f85b435eba..a2700f19c7 100644 --- a/flashinfer/testing/utils.py +++ b/flashinfer/testing/utils.py @@ -1,3 +1,12 @@ +import sys + + +import os + +import einops +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2023 by FlashInfer team. @@ -13,52 +22,54 @@ See the License for the specific language governing permissions and limitations under the License. """ - import math import random -import time -from typing import Tuple, Any - -import os import sys +import time +from typing import Any, Tuple import numpy as np -import torch -from einops import rearrange, reduce, repeat from flashinfer.utils import round_up -def _ceil_to_ue8m0(x: torch.Tensor): +def _ceil_to_ue8m0(x: paddle.Tensor): """imported from DeepGEMM""" assert x.view(-1).amax().item() > 0 - return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + return paddle.pow(x=2.0, y=paddle.ceil(x=paddle.log2(x=x.abs()))) -def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def per_token_cast_to_fp8(x: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]: """imported from DeepGEMM""" - assert x.dim() == 2 and x.size(1) % 128 == 0 - m, n = x.shape + assert x.dim() == 2 and x.shape[1] % 128 == 0 + m, n = tuple(x.shape) x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + x_amax = ( + x_view.abs().astype(dtype="float32").amax(axis=2).view(m, -1).clip(min=0.0001) + ) sf = _ceil_to_ue8m0(x_amax / 448.0) - return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), sf + return (x_view * (1.0 / sf.unsqueeze(axis=2))).to(paddle.float8_e4m3fn).view( + m, n + ), sf -def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def per_block_cast_to_fp8(x: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]: """imported from DeepGEMM""" assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - (round_up(m, 128), round_up(n, 128)), dtype=x.dtype, device=x.device - ) + m, n = tuple(x.shape) + x_padded = paddle.zeros(shape=(round_up(m, 128), round_up(n, 128)), dtype=x.dtype) x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_view = x_padded.view(-1, 128, x_padded.shape[1] // 128, 128) + x_amax = ( + x_view.abs() + .astype(dtype="float32") + .amax(axis=(1, 3), keepdim=True) + .clip(min=0.0001) + ) sf = _ceil_to_ue8m0(x_amax / 448.0) - x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( - x_view.size(0), x_view.size(2) + x_scaled = (x_view * (1.0 / sf)).to(paddle.float8_e4m3fn) + return x_scaled.view_as(other=x_padded)[:m, :n].contiguous(), sf.view( + x_view.shape[0], x_view.shape[2] ) @@ -77,86 +88,69 @@ def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode): tuple: A tuple containing the quantized FP8 tensor and the calculated float32 scales. """ - # 1. Assertions and Initial Setup ndim = x.ndim assert ndim in [2, 3], f"x.ndim must be 2 or 3, but got {ndim}" assert ndim == len(scale_shape) == len(tile_shape) - - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_amax = torch.tensor(fp8_info.max, device=x.device, dtype=torch.float32) - - # 2. Tiling and Scale Calculation + fp8_info = paddle.finfo(dtype=paddle.float8_e4m3fn) + fp8_amax = paddle.to_tensor(data=fp8_info.max, dtype="float32", place=x.place) if ndim == 2: s0, s1 = scale_shape t0, t1 = tile_shape if scale_major_mode == "K": - # Tile x and find the max absolute value in each tile - x_tiled = rearrange(x, "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) - abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4) + x_tiled = einops.rearrange( + x, "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1 + ) + abs_max = einops.reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clip( + min=0.0001 + ) x_scale = abs_max / fp8_amax - x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) - - # Broadcast scales back to the original tensor shape - scales_repeated = repeat(x_scale, "s0 s1 -> (s0 t0) (s1 t1)", t0=t0, t1=t1) + x_scale = paddle.pow(x=2.0, y=paddle.ceil(x=paddle.log2(x=x_scale.abs()))) + scales_repeated = einops.tile( + repeat_times=[x_scale, "s0 s1 -> (s0 t0) (s1 t1)"] + ) else: - # Handle column-major tiling - x_tiled = rearrange(x, "(s1 t0) (s0 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) - abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4) + x_tiled = einops.rearrange( + x, "(s1 t0) (s0 t1) -> s0 s1 t0 t1", s0=s0, s1=s1 + ) + abs_max = einops.reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clip( + min=0.0001 + ) x_scale = abs_max / fp8_amax - x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) - - # Permute scale axes before repeating to match layout - scales_permuted = rearrange(x_scale, "s0 s1 -> s1 s0") - scales_repeated = repeat( - scales_permuted, "s1 s0 -> (s1 t0) (s0 t1)", t0=t0, t1=t1 + x_scale = paddle.pow(x=2.0, y=paddle.ceil(x=paddle.log2(x=x_scale.abs()))) + scales_permuted = einops.rearrange(x_scale, "s0 s1 -> s1 s0") + scales_repeated = einops.tile( + repeat_times=[scales_permuted, "s1 s0 -> (s1 t0) (s0 t1)"] ) - elif ndim == 3: s0, s1, s2 = scale_shape t0, t1, t2 = tile_shape if scale_major_mode == "K": - # Tile x and find the max absolute value in each tile - x_tiled = rearrange( + x_tiled = einops.rearrange( x, "(s0 t0) (s1 t1) (s2 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 ) - abs_max = reduce( + abs_max = einops.reduce( x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max" - ).clamp(1e-4) + ).clip(min=0.0001) x_scale = abs_max / fp8_amax - x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) - - # Broadcast scales back to the original tensor shape - scales_repeated = repeat( - x_scale, "s0 s1 s2 -> (s0 t0) (s1 t1) (s2 t2)", t0=t0, t1=t1, t2=t2 + x_scale = paddle.pow(x=2.0, y=paddle.ceil(x=paddle.log2(x=x_scale.abs()))) + scales_repeated = einops.tile( + repeat_times=[x_scale, "s0 s1 s2 -> (s0 t0) (s1 t1) (s2 t2)"] ) else: - # Handle layout where the last two axes are swapped - x_tiled = rearrange( + x_tiled = einops.rearrange( x, "(s0 t0) (s2 t1) (s1 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 ) - abs_max = reduce( + abs_max = einops.reduce( x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max" - ).clamp(1e-4) + ).clip(min=0.0001) x_scale = abs_max / fp8_amax - x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) - - # Permute scale axes before repeating to match layout - scales_permuted = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1") - scales_repeated = repeat( - scales_permuted, - "s0 s2 s1 -> (s0 t0) (s2 t1) (s1 t2)", - t0=t0, - t1=t1, - t2=t2, + x_scale = paddle.pow(x=2.0, y=paddle.ceil(x=paddle.log2(x=x_scale.abs()))) + scales_permuted = einops.rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1") + scales_repeated = einops.tile( + repeat_times=[scales_permuted, "s0 s2 s1 -> (s0 t0) (s2 t1) (s1 t2)"] ) - - # 3. Final Quantization - # Divide the original tensor by the broadcasted scales - x_fp32 = x / (scales_repeated + 1e-8) - - # Convert the result to the target FP8 format - x_fp8 = x_fp32.to(torch.float8_e4m3fn) - + x_fp32 = x / (scales_repeated + 1e-08) + x_fp8 = x_fp32.to(paddle.float8_e4m3fn) return x_fp8, x_scale @@ -175,43 +169,41 @@ def dequantize_fp8(x, x_scale, scale_major_mode): tuple: A tuple containing the quantized FP8 tensor and the calculated float32 scales. """ - # 1. Assertions and Initial Setup ndim = x.ndim assert ndim in [2, 3], f"x.ndim must be 2 or 3, but got {ndim}" - assert ndim == len(x_scale.shape) - - # 2. Tiling and Scale Calculation + assert ndim == len(tuple(x_scale.shape)) if ndim == 2: if scale_major_mode == "K": - s0, s1 = x_scale.shape + s0, s1 = tuple(x_scale.shape) else: - s1, s0 = x_scale.shape - x = rearrange( - x.to(torch.float32), "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1 + s1, s0 = tuple(x_scale.shape) + x = einops.rearrange( + x.to("float32"), "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1 ) if scale_major_mode == "K": - x_scale = rearrange(x_scale, "s0 s1 -> s0 s1 1 1") + x_scale = einops.rearrange(x_scale, "s0 s1 -> s0 s1 1 1") else: - x_scale = rearrange(x_scale, "s0 s1 -> s1 s0 1 1") - out = rearrange(x * x_scale, "s0 s1 t0 t1 -> (s0 t0) (s1 t1)") - + x_scale = einops.rearrange(x_scale, "s0 s1 -> s1 s0 1 1") + out = einops.rearrange(x * x_scale, "s0 s1 t0 t1 -> (s0 t0) (s1 t1)") elif ndim == 3: if scale_major_mode == "K": - s0, s1, s2 = x_scale.shape + s0, s1, s2 = tuple(x_scale.shape) else: - s0, s2, s1 = x_scale.shape - x = rearrange( - x.to(torch.float32), + s0, s2, s1 = tuple(x_scale.shape) + x = einops.rearrange( + x.to("float32"), "(s0 t0) (s1 t1) (s2 t2)-> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2, ) if scale_major_mode == "K": - x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s1 s2 1 1 1") + x_scale = einops.rearrange(x_scale, "s0 s1 s2 -> s0 s1 s2 1 1 1") else: - x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1 1 1 1") - out = rearrange(x * x_scale, "s0 s1 s2 t0 t1 t2 -> (s0 t0) (s1 t1) (s2 t2)") + x_scale = einops.rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1 1 1 1") + out = einops.rearrange( + x * x_scale, "s0 s1 s2 t0 t1 t2 -> (s0 t0) (s1 t1) (s2 t2)" + ) return out @@ -225,13 +217,12 @@ def set_seed(random_seed): Returns: None """ - torch.manual_seed(random_seed) + paddle.seed(seed=random_seed) random.seed(random_seed) np.random.seed(random_seed) - - if torch.cuda.is_available(): - torch.cuda.manual_seed(random_seed) - torch.cuda.manual_seed_all(random_seed) + if paddle.device.cuda.device_count() >= 1: + paddle.seed(seed=random_seed) + paddle.seed(seed=random_seed) def sleep_after_kernel_run(execution_time): @@ -253,13 +244,7 @@ def sleep_after_kernel_run(execution_time): def attention_flops( - batch_size, - qo_seqlen, - kv_seqlen, - head_dim_qk, - head_dim_vo, - num_qo_heads, - causal, + batch_size, qo_seqlen, kv_seqlen, head_dim_qk, head_dim_vo, num_qo_heads, causal ): """ Calculate FLOPs for a given attention layer. Assumes all sequence lengths are the same within the batch @@ -324,44 +309,40 @@ def attention_flops_with_actual_seq_lens( """ if causal: bmm1_flops = ( - torch.dot( - 2 * actual_seq_lens_kv.to(torch.float32) - - actual_seq_lens_q.to(torch.float32), - actual_seq_lens_q.to(torch.float32), + paddle.dot( + x=2 * actual_seq_lens_kv.to("float32") + - actual_seq_lens_q.to("float32"), + y=actual_seq_lens_q.to("float32"), ) * num_qo_heads * head_dim_qk ) bmm2_flops = ( - torch.dot( - 2 * actual_seq_lens_kv.to(torch.float32) - - actual_seq_lens_q.to(torch.float32), - actual_seq_lens_q.to(torch.float32), + paddle.dot( + x=2 * actual_seq_lens_kv.to("float32") + - actual_seq_lens_q.to("float32"), + y=actual_seq_lens_q.to("float32"), ) * num_qo_heads * head_dim_vo ) - else: bmm1_flops = ( 2 - * torch.dot( - actual_seq_lens_kv.to(torch.float32), - actual_seq_lens_q.to(torch.float32), + * paddle.dot( + x=actual_seq_lens_kv.to("float32"), y=actual_seq_lens_q.to("float32") ) * num_qo_heads * head_dim_qk ) bmm2_flops = ( 2 - * torch.dot( - actual_seq_lens_kv.to(torch.float32), - actual_seq_lens_q.to(torch.float32), + * paddle.dot( + x=actual_seq_lens_kv.to("float32"), y=actual_seq_lens_q.to("float32") ) * num_qo_heads * head_dim_vo ) - total_flops = bmm1_flops + bmm2_flops return total_flops @@ -393,15 +374,9 @@ def attention_tflops_per_sec( tflops_per_sec (float): TFLOPS per second for the layer. """ f = attention_flops( - batch_size, - qo_seqlen, - kv_seqlen, - head_dim_qk, - head_dim_vo, - num_qo_heads, - causal, + batch_size, qo_seqlen, kv_seqlen, head_dim_qk, head_dim_vo, num_qo_heads, causal ) - return f / time / 1e9 if not math.isnan(time) else 0.0 + return f / time / 1000000000.0 if not math.isnan(time) else 0.0 def attention_tflops_per_sec_with_actual_seq_lens( @@ -437,7 +412,7 @@ def attention_tflops_per_sec_with_actual_seq_lens( num_qo_heads, causal, ) - return f.item() / time / 1e9 if not math.isnan(time) else 0.0 + return f.item() / time / 1000000000.0 if not math.isnan(time) else 0.0 def attention_tb_per_sec( @@ -449,9 +424,9 @@ def attention_tb_per_sec( num_qo_heads, num_kv_heads, time, - q_dtype=torch.bfloat16, - kv_dtype=torch.bfloat16, - o_dtype=torch.bfloat16, + q_dtype="bfloat16", + kv_dtype="bfloat16", + o_dtype="bfloat16", ): """ Calculate TB per second perf achieved for a given attention layer. Assumes all sequence lengths are the same within the batch. @@ -472,14 +447,21 @@ def attention_tb_per_sec( Returns: tb_per_sec (float): TB per second for the layer. """ - q_bytes = batch_size * qo_seqlen * num_qo_heads * head_dim_qk * q_dtype.itemsize - k_bytes = batch_size * kv_seqlen * num_kv_heads * head_dim_qk * kv_dtype.itemsize - v_bytes = batch_size * kv_seqlen * num_kv_heads * head_dim_vo * kv_dtype.itemsize - o_bytes = batch_size * qo_seqlen * num_qo_heads * head_dim_vo * o_dtype.itemsize + q_bytes = ( + batch_size * qo_seqlen * num_qo_heads * head_dim_qk * q_dtype.element_size() + ) + k_bytes = ( + batch_size * kv_seqlen * num_kv_heads * head_dim_qk * kv_dtype.element_size() + ) + v_bytes = ( + batch_size * kv_seqlen * num_kv_heads * head_dim_vo * kv_dtype.element_size() + ) + o_bytes = ( + batch_size * qo_seqlen * num_qo_heads * head_dim_vo * o_dtype.element_size() + ) total_bytes = q_bytes + k_bytes + v_bytes + o_bytes - - time_in_sec = time / 1e3 - bytes_in_tb = total_bytes / 1e12 # TB not TiB + time_in_sec = time / 1000.0 + bytes_in_tb = total_bytes / 1000000000000.0 return bytes_in_tb / time_in_sec if not math.isnan(time) else 0.0 @@ -491,9 +473,9 @@ def attention_tb_per_sec_with_actual_seq_lens( num_qo_heads, num_kv_heads, time, - q_dtype=torch.bfloat16, - kv_dtype=torch.bfloat16, - o_dtype=torch.bfloat16, + q_dtype="bfloat16", + kv_dtype="bfloat16", + o_dtype="bfloat16", ): """ Calculate TB per second perf achieved for a given attention layer with actual sequence lengths. @@ -515,22 +497,32 @@ def attention_tb_per_sec_with_actual_seq_lens( tb_per_sec (float): TB per second for the layer. """ q_bytes = ( - torch.sum(actual_seq_lens_q) * num_qo_heads * head_dim_qk * q_dtype.itemsize + paddle.sum(x=actual_seq_lens_q) + * num_qo_heads + * head_dim_qk + * q_dtype.element_size() ) k_bytes = ( - torch.sum(actual_seq_lens_kv) * num_kv_heads * head_dim_qk * kv_dtype.itemsize + paddle.sum(x=actual_seq_lens_kv) + * num_kv_heads + * head_dim_qk + * kv_dtype.element_size() ) v_bytes = ( - torch.sum(actual_seq_lens_kv) * num_kv_heads * head_dim_vo * kv_dtype.itemsize + paddle.sum(x=actual_seq_lens_kv) + * num_kv_heads + * head_dim_vo + * kv_dtype.element_size() ) o_bytes = ( - torch.sum(actual_seq_lens_q) * num_qo_heads * head_dim_vo * o_dtype.itemsize + paddle.sum(x=actual_seq_lens_q) + * num_qo_heads + * head_dim_vo + * o_dtype.element_size() ) - total_bytes = (q_bytes + k_bytes + v_bytes + o_bytes).item() - - time_in_sec = time / 1e3 - bytes_in_tb = total_bytes / 1e12 # TB not TiB + time_in_sec = time / 1000.0 + bytes_in_tb = total_bytes / 1000000000000.0 return bytes_in_tb / time_in_sec if not math.isnan(time) else 0.0 @@ -570,59 +562,51 @@ def bench_gpu_time( Returns: measured_times: List of measured times. """ - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + start_event = paddle.device.cuda.Event(enable_timing=True) + end_event = paddle.device.cuda.Event(enable_timing=True) if l2_flush: l2_flush_size = int(l2_flush_size_mb) * 1024 * 1024 - buffer = torch.empty(l2_flush_size, device=l2_flush_device, dtype=torch.int8) - - ## Estimate kernel execution time by running the kernel 5 times + buffer = paddle.empty(shape=l2_flush_size, dtype="int8") measurement_iters = 5 - torch.cuda.synchronize() - fn() # Call once to exclude initial overhead - torch.cuda.synchronize() + paddle.device.synchronize() + fn() + paddle.device.synchronize() start_event.record() for _ in range(measurement_iters): if l2_flush: buffer.zero_() fn() end_event.record() - torch.cuda.synchronize() + paddle.device.synchronize() estimated_kernel_execution_time = ( start_event.elapsed_time(end_event) / measurement_iters ) - - ## Set dry run and repeat iterations if dry_run_iters is None: dry_run_iters = max(1, int(dry_run_time_ms / estimated_kernel_execution_time)) if repeat_iters is None: repeat_iters = max(1, int(repeat_time_ms / estimated_kernel_execution_time)) - - # Dry runs - torch.cuda.synchronize() + paddle.device.synchronize() for _ in range(dry_run_iters): if l2_flush: buffer.zero_() fn() - torch.cuda.synchronize() - - # Actual run - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat_iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat_iters)] - torch.cuda.synchronize() + paddle.device.synchronize() + start_events = [ + paddle.device.cuda.Event(enable_timing=True) for _ in range(repeat_iters) + ] + end_events = [ + paddle.device.cuda.Event(enable_timing=True) for _ in range(repeat_iters) + ] + paddle.device.synchronize() for iter_idx in range(repeat_iters): if l2_flush: buffer.zero_() start_events[iter_idx].record() fn() end_events[iter_idx].record() - if sleep_after_run: sleep_after_kernel_run(estimated_kernel_execution_time) - - # Synchronize once outside of the loop to avoid synchronization overhead - torch.cuda.synchronize() + paddle.device.synchronize() measured_times = [] for iter_idx in range(repeat_iters): measured_times.append(start_events[iter_idx].elapsed_time(end_events[iter_idx])) @@ -671,30 +655,23 @@ def bench_gpu_time_with_cudagraph( Returns: measured_times: List of measured times. """ - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + start_event = paddle.device.cuda.Event(enable_timing=True) + end_event = paddle.device.cuda.Event(enable_timing=True) if l2_flush: l2_flush_size = int(l2_flush_size_mb) * 1024 * 1024 - buffer = torch.empty(l2_flush_size, device=l2_flush_device, dtype=torch.int8) - - # Warmup run - torch.cuda.synchronize() - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): + buffer = paddle.empty(shape=l2_flush_size, dtype="int8") + paddle.device.synchronize() + s = paddle.device.Stream() + s.wait_stream(paddle.device.current_stream()) + with paddle.device.stream_guard(stream=s): for _ in range(3): fn() - torch.cuda.current_stream().wait_stream(s) - - # Capture kernel in graph - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): + paddle.device.current_stream().wait_stream(s) +>>>>>> g = torch.cuda.CUDAGraph() +>>>>>> with torch.cuda.graph(g): for _ in range(num_iters_within_graph): fn() - torch.cuda.synchronize() - - ## Estimate kernel execution time by running the kernel 5 times + paddle.device.synchronize() measurement_iters = 5 start_event.record() for _ in range(measurement_iters): @@ -702,41 +679,36 @@ def bench_gpu_time_with_cudagraph( buffer.zero_() g.replay() end_event.record() - torch.cuda.synchronize() + paddle.device.synchronize() estimated_kernel_execution_time = ( start_event.elapsed_time(end_event) / measurement_iters ) - - ## Set dry run and repeat iterations if dry_run_iters is None: dry_run_iters = max(1, int(dry_run_time_ms / estimated_kernel_execution_time)) if repeat_iters is None: repeat_iters = max(1, int(repeat_time_ms / estimated_kernel_execution_time)) - - # Dry run - torch.cuda.synchronize() + paddle.device.synchronize() for _ in range(dry_run_iters): if l2_flush: buffer.zero_() g.replay() - torch.cuda.synchronize() - - # Actual run - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat_iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat_iters)] - torch.cuda.synchronize() + paddle.device.synchronize() + start_events = [ + paddle.device.cuda.Event(enable_timing=True) for _ in range(repeat_iters) + ] + end_events = [ + paddle.device.cuda.Event(enable_timing=True) for _ in range(repeat_iters) + ] + paddle.device.synchronize() for iter_idx in range(repeat_iters): if l2_flush: buffer.zero_() start_events[iter_idx].record() g.replay() end_events[iter_idx].record() - if sleep_after_run: sleep_after_kernel_run(estimated_kernel_execution_time) - - # Synchronize once outside of the loop to avoid synchronization overhead - torch.cuda.synchronize() + paddle.device.synchronize() measured_times = [] for iter_idx in range(repeat_iters): measured_times.append( @@ -758,19 +730,14 @@ class suppress_stdout_stderr: def __enter__(self): self.outnull_file = open(os.devnull, "w") self.errnull_file = open(os.devnull, "w") - self.old_stdout_fileno_undup = sys.stdout.fileno() self.old_stderr_fileno_undup = sys.stderr.fileno() - self.old_stdout_fileno = os.dup(sys.stdout.fileno()) self.old_stderr_fileno = os.dup(sys.stderr.fileno()) - self.old_stdout = sys.stdout self.old_stderr = sys.stderr - os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) - sys.stdout = self.outnull_file sys.stderr = self.errnull_file return self @@ -778,18 +745,14 @@ def __enter__(self): def __exit__(self, *_): sys.stdout = self.old_stdout sys.stderr = self.old_stderr - os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) - os.close(self.old_stdout_fileno) os.close(self.old_stderr_fileno) - self.outnull_file.close() self.errnull_file.close() -# copied from DeepGEMM def bench_kineto( fn, kernel_names, @@ -799,16 +762,9 @@ def bench_kineto( flush_l2: bool = True, with_multiple_kernels: bool = False, ): - # Conflict with Nsight Systems using_nsys = int(os.environ.get("DG_NSYS_PROFILING", 0)) - - # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle - flush_l2_size = int(8e9 // 4) - - # For some auto-tuning kernels with prints + flush_l2_size = int(8000000000.0 // 4) fn() - - # Profile suppress = ( suppress_stdout_stderr if suppress_kineto_output and not using_nsys @@ -816,13 +772,13 @@ def bench_kineto( ) with suppress(): schedule = ( - torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + paddle.profiler.make_scheduler(closed=0, ready=1, record=1, repeat=1) if not using_nsys else None ) profiler: Any = ( - torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule +>>>>>> torch.profiler.profile( +>>>>>> activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule ) if not using_nsys else empty_suppress() @@ -831,22 +787,17 @@ def bench_kineto( for _i in range(2): for _ in range(num_tests): if flush_l2: - torch.empty( - flush_l2_size, dtype=torch.int, device="cuda" - ).zero_() + paddle.empty(shape=flush_l2_size, dtype="int32").zero_() fn() - if not using_nsys: profiler.step() - - # Return 1 if using Nsight Systems if using_nsys: return 1 - - # Parse the profiling table assert isinstance(kernel_names, (str, tuple)) is_tuple = isinstance(kernel_names, tuple) - prof_lines = ( + """Not Support auto convert *.key_averages, please judge whether it is Pytorch API and convert by yourself""" + """Not Support auto convert *.table, please judge whether it is Pytorch API and convert by yourself""" +>>>>>> prof_lines = ( profiler.key_averages() .table(sort_by="cuda_time_total", max_name_column_width=100) .split("\n") @@ -855,16 +806,12 @@ def bench_kineto( assert all([isinstance(name, str) for name in kernel_names]) if not with_multiple_kernels: for name in kernel_names: - assert sum([name in line for line in prof_lines]) == 1, ( - f"Errors of the kernel {name} in the profiling table" - ) - - # Save chrome traces + assert ( + sum([(name in line) for line in prof_lines]) == 1 + ), f"Errors of the kernel {name} in the profiling table" if trace_path is not None: - profiler.export_chrome_trace(trace_path) - - # Return average kernel times - units = {"ms": 1e3, "us": 1e6} + paddle.profiler.export_chrome_tracing(dir_name=trace_path) + units = {"ms": 1000.0, "us": 1000000.0} kernel_times = [] for name in kernel_names: total_time = 0.0 @@ -881,7 +828,6 @@ def bench_kineto( total_num += int(num_str) break kernel_times.append(total_time / total_num) - return tuple(kernel_times) if is_tuple else kernel_times[0] @@ -891,5 +837,5 @@ def count_bytes(*tensors): if isinstance(t, (tuple, list)): total += count_bytes(*t) elif t is not None: - total += t.numel() * t.element_size() + total += t.size * t.element_size() return total diff --git a/flashinfer/triton/__init__.py b/flashinfer/triton/__init__.py index 6247c071fd..5659e83542 100644 --- a/flashinfer/triton/__init__.py +++ b/flashinfer/triton/__init__.py @@ -1,2 +1 @@ -from . import cascade # noqa: F401 -from . import sm_constraint_gemm # noqa: F401 +from . import cascade, sm_constraint_gemm diff --git a/flashinfer/triton/activation.py b/flashinfer/triton/activation.py index 1def6817b7..8683c0a034 100644 --- a/flashinfer/triton/activation.py +++ b/flashinfer/triton/activation.py @@ -1,18 +1,18 @@ from collections.abc import Mapping from typing import Optional -import torch -import triton # type: ignore[import] +import paddle +import triton from flashinfer.triton.kernels.activation import silu_and_mul_kernel def silu_and_mul( - x: torch.Tensor, - x_scale: Optional[torch.Tensor] = None, - o_scale: Optional[torch.Tensor] = None, - dtype: Optional[torch.dtype] = None, -) -> torch.Tensor: + x: paddle.Tensor, + x_scale: Optional[paddle.Tensor] = None, + o_scale: Optional[paddle.Tensor] = None, + dtype: Optional[paddle.dtype] = None, +) -> paddle.Tensor: """Sigmoid Linear Unit and Multiplication Computes `silu(x[:,:d]) * x[:, d:]`, where `d = x.shape[-1] // 2. @@ -29,29 +29,25 @@ def silu_and_mul( Returns: The output activation, of shape `(b, d)`. """ - - b, n = x.shape - + b, n = tuple(x.shape) assert n % 2 == 0 d = n // 2 - o_dtype = dtype or x.dtype - o = torch.empty((b, d), dtype=o_dtype, device=x.device) + o = paddle.empty(shape=(b, d), dtype=o_dtype) def grid(meta: Mapping[str, int]) -> tuple[int, int]: - return (b, triton.cdiv(d, meta["BLOCK_SIZE"])) + return b, triton.cdiv(d, meta["BLOCK_SIZE"]) silu_and_mul_kernel[grid]( o_ptr=o, - o_stride=o.stride(0), + o_stride=o.get_strides()[0], o_scale_ptr=o_scale, x_ptr=x, - x_stride=x.stride(0), + x_stride=x.get_strides()[0], x_scale_ptr=x_scale, d=d, BLOCK_SIZE=1024, HAS_X_SCALE=x_scale is not None, HAS_O_SCALE=o_scale is not None, ) - return o diff --git a/flashinfer/triton/cascade.py b/flashinfer/triton/cascade.py index e8eac75c61..36062bb822 100644 --- a/flashinfer/triton/cascade.py +++ b/flashinfer/triton/cascade.py @@ -1,18 +1,15 @@ from typing import Optional -import torch +import paddle -from .kernels.cascade import ( - merge_state_in_place_kernel, - merge_state_kernel, - merge_states_kernel, - variable_length_merge_states_kernel, -) +from .kernels.cascade import (merge_state_in_place_kernel, merge_state_kernel, + merge_states_kernel, + variable_length_merge_states_kernel) from .utils import check_device, check_dim, check_input, check_shape def merge_state( - v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor + v_a: paddle.Tensor, s_a: paddle.Tensor, v_b: paddle.Tensor, s_b: paddle.Tensor ): check_input(v_a) check_input(s_a) @@ -25,31 +22,29 @@ def merge_state( check_dim(2, s_b) check_shape(v_a, v_b) check_shape(s_a, s_b) - assert v_a.size(0) == s_a.size(0) - assert v_a.size(1) == s_b.size(1) - s_a = s_a.to(torch.float32) - s_b = s_b.to(torch.float32) - seq_len = v_a.size(0) - num_heads = v_a.size(1) - head_dim = v_a.size(2) - v_merged = torch.empty_like(v_a).to(s_a.device) - s_merged = torch.empty((seq_len, num_heads)).to(s_a.device) + assert v_a.shape[0] == s_a.shape[0] + assert v_a.shape[1] == s_b.shape[1] + s_a = s_a.to("float32") + s_b = s_b.to("float32") + seq_len = v_a.shape[0] + num_heads = v_a.shape[1] + head_dim = v_a.shape[2] + v_merged = paddle.empty_like(x=v_a).to(s_a.place) + s_merged = paddle.empty(shape=(seq_len, num_heads)).to(s_a.place) bdx = head_dim bdy = num_heads - merge_state_kernel[lambda meta: (seq_len,)]( v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy ) - return v_merged, s_merged def merge_state_in_place( - v: torch.Tensor, - s: torch.Tensor, - v_other: torch.Tensor, - s_other: torch.Tensor, - mask: Optional[torch.Tensor] = None, + v: paddle.Tensor, + s: paddle.Tensor, + v_other: paddle.Tensor, + s_other: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, ): check_input(v) check_input(s) @@ -62,91 +57,68 @@ def merge_state_in_place( check_dim(2, s_other) check_shape(v, v_other) check_shape(s, s_other) - assert v.size(0) == s.size(0) - assert v.size(1) == s.size(1) - assert s.dtype == torch.float32 - assert s_other.dtype == torch.float32 + assert v.shape[0] == s.shape[0] + assert v.shape[1] == s.shape[1] + assert s.dtype == "float32" + assert s_other.dtype == "float32" if mask is not None: check_dim(1, mask) - assert v.size(0) == mask.size(0) - assert mask.device == v.device - seq_len = v.size(0) - num_heads = v.size(1) - head_dim = v.size(2) - + assert v.shape[0] == mask.shape[0] + assert mask.place == v.place + seq_len = v.shape[0] + num_heads = v.shape[1] + head_dim = v.shape[2] bdx = head_dim bdy = num_heads - merge_state_in_place_kernel[(seq_len,)]( - v, s, v_other, s_other, num_heads, head_dim, mask, bdx=bdx, bdy=bdy - ) + merge_state_in_place_kernel[ + seq_len, + ](v, s, v_other, s_other, num_heads, head_dim, mask, bdx=bdx, bdy=bdy) -def merge_states(v: torch.Tensor, s: torch.Tensor): +def merge_states(v: paddle.Tensor, s: paddle.Tensor): check_input(v) check_input(s) check_device([v, s]) check_dim(4, v) check_dim(3, s) - assert v.size(0) == s.size(0) - assert v.size(1) == s.size(1) - assert v.size(2) == s.size(2) - seq_len = v.size(0) - num_index_sets = v.size(1) - num_heads = v.size(2) - head_dim = v.size(3) - s = s.to(torch.float32) - v_merged = torch.empty( - (seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device - ) - s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device) - + assert v.shape[0] == s.shape[0] + assert v.shape[1] == s.shape[1] + assert v.shape[2] == s.shape[2] + seq_len = v.shape[0] + num_index_sets = v.shape[1] + num_heads = v.shape[2] + head_dim = v.shape[3] + s = s.to("float32") + v_merged = paddle.empty(shape=(seq_len, num_heads, head_dim), dtype=v.dtype) + s_merged = paddle.empty(shape=(seq_len, num_heads), dtype=s.dtype) bdx = head_dim bdy = num_heads - merge_states_kernel[(seq_len,)]( - v, - s, - v_merged, - s_merged, - num_index_sets, - num_heads, - head_dim, - bdx=bdx, - bdy=bdy, - ) + merge_states_kernel[ + seq_len, + ](v, s, v_merged, s_merged, num_index_sets, num_heads, head_dim, bdx=bdx, bdy=bdy) return v_merged, s_merged def variable_length_merge_states( - v: torch.Tensor, s: torch.Tensor, indptr: torch.Tensor + v: paddle.Tensor, s: paddle.Tensor, indptr: paddle.Tensor ): check_input(v) check_input(s) check_device([v, s]) check_dim(3, v) check_dim(2, s) - assert v.size(0) == s.size(0) - assert v.size(1) == s.size(1) - seq_len = indptr.size(0) - 1 - num_heads = v.size(1) - head_dim = v.size(2) - s = s.to(torch.float32) - indptr = indptr.to(torch.int32) - v_merged = torch.empty( - (seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device - ) - s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device) - + assert v.shape[0] == s.shape[0] + assert v.shape[1] == s.shape[1] + seq_len = indptr.shape[0] - 1 + num_heads = v.shape[1] + head_dim = v.shape[2] + s = s.to("float32") + indptr = indptr.to("int32") + v_merged = paddle.empty(shape=(seq_len, num_heads, head_dim), dtype=v.dtype) + s_merged = paddle.empty(shape=(seq_len, num_heads), dtype=s.dtype) bdx = head_dim bdy = num_heads - variable_length_merge_states_kernel[(seq_len,)]( - v, - s, - indptr, - v_merged, - s_merged, - num_heads, - head_dim, - bdx=bdx, - bdy=bdy, - ) + variable_length_merge_states_kernel[ + seq_len, + ](v, s, indptr, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy) return v_merged, s_merged diff --git a/flashinfer/triton/gemm.py b/flashinfer/triton/gemm.py index 70f320e57a..436378c728 100644 --- a/flashinfer/triton/gemm.py +++ b/flashinfer/triton/gemm.py @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ - import triton import triton.language as tl @@ -37,24 +36,18 @@ def compute_sm80_group_gemm_args( w_column_major, ): pid = tl.program_id(0) - m = tl.load(xy_indptr + pid + 1) - tl.load(xy_indptr + pid) k, n = d_in, d_out - tl.store(all_problems_ptr + pid * 3, m) tl.store(all_problems_ptr + pid * 3 + 1, n) tl.store(all_problems_ptr + pid * 3 + 2, k) - w_i = tl.load(w_indices + pid) if w_indices else tl.cast(pid, tl.int64) w_curr_ptr = w + w_i * k * n tl.store(w_ptr + pid, w_curr_ptr) - x_curr_ptr = x + tl.load(xy_indptr + pid) * k tl.store(x_ptr + pid, x_curr_ptr) - y_curr_ptr = y + tl.load(xy_indptr + pid) * n tl.store(y_ptr + pid, y_curr_ptr) - tl.store(x_ld_ptr + pid, k) tl.store(w_ld_ptr + pid, k if w_column_major else n) tl.store(y_ld_ptr + pid, n) @@ -79,36 +72,25 @@ def compute_sm90_group_gemm_args( w_column_major, ): pid = tl.program_id(0) - m = tl.load(xy_indptr + pid + 1) - tl.load(xy_indptr + pid) k, n = d_in, d_out - tl.store(all_problems_ptr + pid * 3, m) tl.store(all_problems_ptr + pid * 3 + 1, n) tl.store(all_problems_ptr + pid * 3 + 2, k) - w_i = tl.load(w_indices + pid) if w_indices else tl.cast(pid, tl.int64) w_curr_ptr = w + w_i * k * n tl.store(w_ptr + pid, w_curr_ptr) - x_curr_ptr = x + tl.load(xy_indptr + pid) * k tl.store(x_ptr + pid, x_curr_ptr) - y_curr_ptr = y + tl.load(xy_indptr + pid) * n tl.store(y_ptr + pid, y_curr_ptr) - tl.store(x_stride_ptr + pid, k) tl.store(w_stride_ptr + pid, k if w_column_major else n) tl.store(y_stride_ptr + pid, n) @triton.jit -def compute_padding_mapping( - m_indptr, - padded_m_indptr, - m_rank, - padded_m_rank, -): +def compute_padding_mapping(m_indptr, padded_m_indptr, m_rank, padded_m_rank): pid = tl.program_id(0) m_start = tl.load(m_indptr + pid) m_end = tl.load(m_indptr + pid + 1) diff --git a/flashinfer/triton/kernels/activation.py b/flashinfer/triton/kernels/activation.py index ebb7fe6877..7fede83939 100644 --- a/flashinfer/triton/kernels/activation.py +++ b/flashinfer/triton/kernels/activation.py @@ -1,5 +1,5 @@ -import triton # type: ignore[import] -import triton.language as tl # type: ignore[import] +import triton +import triton.language as tl from flashinfer.triton.kernels.quant import scale_and_clamp @@ -37,28 +37,20 @@ def silu_and_mul_kernel( If scales are provided, the input and output tensors are scaled. """ - i = tl.program_id(axis=0).to(tl.int64) j = tl.program_id(axis=1) - o_row_ptr = o_ptr + o_stride * i x_row_ptr = x_ptr + x_stride * i - offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < d - a = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32) b = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32) - if HAS_X_SCALE: x_scale = tl.load(x_scale_ptr) a *= x_scale b *= x_scale - result = tl.sigmoid(a) * a * b - if HAS_O_SCALE: o_scale = tl.load(o_scale_ptr) result = scale_and_clamp(result, o_scale, o_ptr.dtype.element_ty) - tl.store(o_row_ptr + offsets, result, mask=mask) diff --git a/flashinfer/triton/kernels/cascade.py b/flashinfer/triton/kernels/cascade.py index 0439dc0440..b6d1a1c20d 100644 --- a/flashinfer/triton/kernels/cascade.py +++ b/flashinfer/triton/kernels/cascade.py @@ -1,5 +1,5 @@ -import triton # type: ignore[import] -import triton.language as tl # type: ignore[import] +import triton +import triton.language as tl @triton.jit @@ -39,23 +39,17 @@ def merge_state_kernel( for head_idx in tl.range(bdy): s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx) s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx) - offsets = (pos * num_heads + head_idx) * head_dim + tx v_a = tl.load(v_a_ptr + offsets) v_b = tl.load(v_b_ptr + offsets) - v_merged, s_max, d = state_merge( o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1 ) v_merged, s_max, d = state_normalize(v_merged, s_max, d) v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx tl.store(v_merged_ptr + v_merged_offset, v_merged) - if s_merged_ptr: - tl.store( - s_merged_ptr + pos * num_heads + head_idx, - tl.log2(d) + s_max, - ) + tl.store(s_merged_ptr + pos * num_heads + head_idx, tl.log2(d) + s_max) @triton.jit @@ -74,7 +68,6 @@ def merge_state_in_place_kernel( if mask_ptr: if tl.load(mask_ptr + pos) == 0: return - for head_idx in tl.range(bdy): s_val = tl.load(s_ptr + pos * num_heads + head_idx) s_other_val = tl.load(s_other_ptr + pos * num_heads + head_idx) @@ -91,8 +84,7 @@ def merge_state_in_place_kernel( tl.store(v_ptr + offset, v_vec) if s_ptr: tl.store( - s_ptr + pos * num_heads + head_idx, - tl.log2(s_val + s_other_val) + s_max, + s_ptr + pos * num_heads + head_idx, tl.log2(s_val + s_other_val) + s_max ) @@ -109,10 +101,9 @@ def merge_states_kernel( bdy: tl.constexpr, ): pos = tl.program_id(axis=0) - for tx in tl.range(bdx): for head_idx in tl.range(bdy): - o, m, d = 0.0, -5e4, 1.0 + o, m, d = 0.0, -50000.0, 1.0 for iter in tl.range(num_index_sets): s = tl.load( s_ptr + (pos * num_index_sets + iter) * num_heads + head_idx @@ -146,7 +137,7 @@ def variable_length_merge_states_kernel( pos = tl.program_id(axis=0) for tx in tl.range(bdx): for head_idx in tl.range(bdy): - o, m, d = 0.0, -5e4, 1.0 + o, m, d = 0.0, -50000.0, 1.0 for iter in tl.range(tl.load(indptr + pos), tl.load(indptr + pos + 1)): s = tl.load(s_ptr + iter * num_heads + head_idx) v = tl.load(v_ptr + (iter * num_heads + head_idx) * head_dim + tx) diff --git a/flashinfer/triton/kernels/norm.py b/flashinfer/triton/kernels/norm.py index a7042999b0..55d98bc441 100644 --- a/flashinfer/triton/kernels/norm.py +++ b/flashinfer/triton/kernels/norm.py @@ -1,5 +1,5 @@ -import triton # type: ignore[import] -import triton.language as tl # type: ignore[import] +import triton +import triton.language as tl from flashinfer.triton.kernels.quant import scale_and_clamp @@ -25,52 +25,35 @@ def rms_norm_kernel( HAS_RESIDUAL: tl.constexpr, ) -> None: i = tl.program_id(axis=0).to(tl.int64) - - # If r_ptr is present, the input to norm is x + r. x_row = x_ptr + i * x_stride o_row = o_ptr + i * o_stride if HAS_OUTPUT else x_row r_row = r_ptr + i * r_stride if HAS_RESIDUAL else None - x_scale = tl.load(x_scale_ptr) if HAS_IN_SCALE else None o_scale = tl.load(o_scale_ptr) if HAS_OUT_SCALE else None - - # Find the root mean square for the given row. square_sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, n, BLOCK_SIZE): offsets = off + tl.arange(0, BLOCK_SIZE) mask = offsets < n - x = tl.load(x_row + offsets, mask=mask, other=0.0).to(tl.float32) if HAS_IN_SCALE: x *= x_scale - if HAS_RESIDUAL: r = tl.load(r_row + offsets, mask=mask, other=0.0).to(tl.float32) x += r tl.store(r_row + offsets, x, mask=mask) - square_sum += x * x - - # Compute the norm. rms = tl.rsqrt(tl.sum(square_sum) / n + EPS) - - # x[i] = r[i] + x[i] / rms * weight[i] output_dtype = o_row.dtype.element_ty for off in range(0, n, BLOCK_SIZE): offsets = off + tl.arange(0, BLOCK_SIZE) mask = offsets < n - if HAS_RESIDUAL: x = tl.load(r_row + offsets, mask=mask).to(tl.float32) else: x = tl.load(x_row + offsets, mask=mask).to(tl.float32) if HAS_IN_SCALE: x *= x_scale - w = tl.load(w_ptr + offsets, mask=mask).to(tl.float32) - - # Multiply x with RMS on float32, but cast to the narrower type before - # multiplying with the weights to replicate the HF behaviour precisely. result = w * (x * rms) if HAS_OUT_SCALE: result = scale_and_clamp(result, o_scale, output_dtype) diff --git a/flashinfer/triton/kernels/quant.py b/flashinfer/triton/kernels/quant.py index bab4198f34..392e02400c 100644 --- a/flashinfer/triton/kernels/quant.py +++ b/flashinfer/triton/kernels/quant.py @@ -1,5 +1,5 @@ -import triton # type: ignore[import] -import triton.language as tl # type: ignore[import] +import triton +import triton.language as tl @triton.jit @@ -23,5 +23,4 @@ def scale_and_clamp(x, scale, dtype): clamp_max = 3.3895313892515355e38 else: tl.static_assert(False, f"Unsupported dtype: {dtype}") - return tl.clamp(x.to(tl.float32) * scale, clamp_min, clamp_max).to(dtype) diff --git a/flashinfer/triton/kernels/sm_constraint_gemm.py b/flashinfer/triton/kernels/sm_constraint_gemm.py index 5f156c4510..cd86b68e2b 100644 --- a/flashinfer/triton/kernels/sm_constraint_gemm.py +++ b/flashinfer/triton/kernels/sm_constraint_gemm.py @@ -1,5 +1,5 @@ -import triton # type: ignore[import] -import triton.language as tl # type: ignore[import] +import triton +import triton.language as tl def matmul_get_configs(): @@ -17,7 +17,7 @@ def matmul_get_configs(): for BM in [128] for BN in [128] for BK in [64] - for s in ([3]) + for s in [3] for w in [4] ] @@ -40,15 +40,12 @@ def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m + pid_m = first_pid_m + tile_id % group_size_m + pid_n = tile_id % num_pid_in_group // group_size_m return pid_m, pid_n -@triton.autotune( - configs=matmul_get_configs(), - key=["M", "N", "K"], -) +@triton.autotune(configs=matmul_get_configs(), key=["M", "N", "K"]) @triton.jit(launch_metadata=_matmul_launch_metadata) def gemm_kernel_persistent( a_ptr, @@ -76,14 +73,9 @@ def gemm_kernel_persistent( num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - - # NOTE: There is currently a bug in blackwell pipelining that means it can't handle a value being - # used in both the prologue and epilogue, so we duplicate the counters as a work-around. tile_id_c = start_pid - NUM_SMS - offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) num_pid_in_group = GROUP_SIZE_M * num_pid_n - for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): pid_m, pid_n = _compute_pid( tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS @@ -96,7 +88,6 @@ def gemm_kernel_persistent( offs_bn = tl.where(offs_bn < N, offs_bn, 0) offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for ki in range(k_tiles): offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) @@ -106,7 +97,6 @@ def gemm_kernel_persistent( b_ptrs = b_ptr + ( offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn ) - a = tl.load( a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0 ) @@ -114,7 +104,6 @@ def gemm_kernel_persistent( b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0 ) accumulator = tl.dot(a, b, accumulator) - tile_id_c += NUM_SMS pid_m, pid_n = _compute_pid( tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS @@ -124,7 +113,6 @@ def gemm_kernel_persistent( c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) c = accumulator.to(c_ptr.dtype.element_ty) - c = tl.fma(c, alpha, beta * tl.load(c_ptrs, mask=c_mask)) tl.store(c_ptrs, c, mask=c_mask) @@ -133,37 +121,30 @@ def gemm_kernel_persistent( def gemm_kernel_descriptor_persistent( a_ptr, b_ptr, - c_ptr, # + c_ptr, M, N, - K, # + K, alpha, beta, - BLOCK_SIZE_M: tl.constexpr, # - BLOCK_SIZE_N: tl.constexpr, # - BLOCK_SIZE_K: tl.constexpr, # - GROUP_SIZE_M: tl.constexpr, # - EPILOGUE_SUBTILE: tl.constexpr, # + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + EPILOGUE_SUBTILE: tl.constexpr, NUM_SMS: tl.constexpr, -): # +): dtype = c_ptr.dtype.element_ty start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - a_desc = tl.make_tensor_descriptor( - a_ptr, - shape=[M, K], - strides=[K, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + a_ptr, shape=[M, K], strides=[K, 1], block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K] ) b_desc = tl.make_tensor_descriptor( - b_ptr, - shape=[N, K], - strides=[K, 1], - block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + b_ptr, shape=[N, K], strides=[K, 1], block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K] ) c_desc = tl.make_tensor_descriptor( c_ptr, @@ -174,33 +155,26 @@ def gemm_kernel_descriptor_persistent( BLOCK_SIZE_N if not EPILOGUE_SUBTILE else BLOCK_SIZE_N // 2, ], ) - - # tile_id_c is used in the epilogue to break the dependency between - # the prologue and the epilogue tile_id_c = start_pid - NUM_SMS num_pid_in_group = GROUP_SIZE_M * num_pid_n - for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): pid_m, pid_n = _compute_pid( tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS ) offs_am = pid_m * BLOCK_SIZE_M offs_bn = pid_n * BLOCK_SIZE_N - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for ki in range(k_tiles): offs_k = ki * BLOCK_SIZE_K a = a_desc.load([offs_am, offs_k]) b = b_desc.load([offs_bn, offs_k]) accumulator = tl.dot(a, b.T, accumulator) - tile_id_c += NUM_SMS pid_m, pid_n = _compute_pid( tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS ) offs_cm = pid_m * BLOCK_SIZE_M offs_cn = pid_n * BLOCK_SIZE_N - if EPILOGUE_SUBTILE: acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) acc = tl.permute(acc, (0, 2, 1)) @@ -221,31 +195,27 @@ def gemm_kernel_descriptor_persistent( c_desc.store([offs_cm, offs_cn], c) -# only for testing -@triton.autotune( - configs=matmul_get_configs(), - key=["M", "N", "K"], -) +@triton.autotune(configs=matmul_get_configs(), key=["M", "N", "K"]) @triton.jit(launch_metadata=_matmul_launch_metadata) def gemm_kernel( a_ptr, b_ptr, - c_ptr, # + c_ptr, M, N, - K, # + K, stride_am, - stride_ak, # + stride_ak, stride_bk, - stride_bn, # + stride_bn, stride_cm, - stride_cn, # + stride_cn, alpha, beta, - BLOCK_SIZE_M: tl.constexpr, # - BLOCK_SIZE_N: tl.constexpr, # - BLOCK_SIZE_K: tl.constexpr, # - GROUP_SIZE_M: tl.constexpr, # + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -254,34 +224,27 @@ def gemm_kernel( group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - + pid_m = first_pid_m + pid % group_size_m + pid_n = pid % num_pid_in_group // group_size_m start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N - offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) offs_am = tl.where(offs_am < M, offs_am, 0) offs_bn = tl.where(offs_bn < N, offs_bn, 0) - offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) accumulator = tl.dot(a, b, accumulator) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - c = accumulator.to(c_ptr.dtype.element_ty) - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] diff --git a/flashinfer/triton/norm.py b/flashinfer/triton/norm.py index 9ddc2cdbc0..f4b8046534 100644 --- a/flashinfer/triton/norm.py +++ b/flashinfer/triton/norm.py @@ -1,40 +1,37 @@ from typing import Optional -import torch -import triton # type: ignore[import] +import paddle +import triton from flashinfer.triton.kernels.norm import rms_norm_kernel def rms_norm( - x: torch.Tensor, - weight: torch.Tensor, - out: torch.Tensor, + x: paddle.Tensor, + weight: paddle.Tensor, + out: paddle.Tensor, eps: float, - in_scale: Optional[torch.Tensor] = None, - out_scale: Optional[torch.Tensor] = None, + in_scale: Optional[paddle.Tensor] = None, + out_scale: Optional[paddle.Tensor] = None, ) -> None: """RMS norm. Computes `out[i,j] = x[i,j] * weight[j] / sqrt(eps + sum(x[i]^2) / n)`. """ - - b, n = x.shape - + b, n = tuple(x.shape) block_size = triton.next_power_of_2(n) num_warps = max(8, min(32, block_size // 256)) - - rms_norm_kernel[(b,)]( + rms_norm_kernel[b,]( n=n, b=b, x_ptr=x, - x_stride=x.stride(0), + x_stride=x.get_strides()[0], x_scale_ptr=in_scale, r_ptr=None, r_stride=0, w_ptr=weight, o_ptr=out, - o_stride=out.stride(0), + o_stride=out.get_strides()[0], o_scale_ptr=out_scale, EPS=eps, BLOCK_SIZE=block_size, @@ -47,38 +44,34 @@ def rms_norm( def rms_norm_add_residual( - x: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, + x: paddle.Tensor, + residual: paddle.Tensor, + weight: paddle.Tensor, eps: float, - x_out: Optional[torch.Tensor] = None, - x_in_scale: Optional[torch.Tensor] = None, - x_out_scale: Optional[torch.Tensor] = None, + x_out: Optional[paddle.Tensor] = None, + x_in_scale: Optional[paddle.Tensor] = None, + x_out_scale: Optional[paddle.Tensor] = None, ) -> None: """In-place RMS norm with fused residual addition. Computes `r = r + x`, followed by `x = rmsnorm(r)`. """ - - b, n = x.shape - - assert x.shape == residual.shape - assert x.stride(0) == residual.stride(0) - + b, n = tuple(x.shape) + assert tuple(x.shape) == tuple(residual.shape) + assert x.get_strides()[0] == residual.get_strides()[0] block_size = triton.next_power_of_2(n) num_warps = min(32, triton.cdiv(block_size, 32)) - - rms_norm_kernel[(b,)]( + rms_norm_kernel[b,]( n=n, b=b, x_ptr=x, - x_stride=x.stride(0), + x_stride=x.get_strides()[0], x_scale_ptr=x_in_scale, r_ptr=residual, - r_stride=residual.stride(0), + r_stride=residual.get_strides()[0], w_ptr=weight, o_ptr=x_out, - o_stride=x_out.stride(0) if x_out is not None else 0, + o_stride=x_out.get_strides()[0] if x_out is not None else 0, o_scale_ptr=x_out_scale, EPS=eps, BLOCK_SIZE=block_size, diff --git a/flashinfer/triton/page.py b/flashinfer/triton/page.py index 440bbe2f7f..50b371bb15 100644 --- a/flashinfer/triton/page.py +++ b/flashinfer/triton/page.py @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ - import triton import triton.language as tl @@ -27,11 +26,9 @@ def get_batch_indices_positions_kernel( num_stages: tl.constexpr, ): batch_idx = tl.program_id(0) - batch_start = tl.load(append_indptr + batch_idx) batch_end = tl.load(append_indptr + batch_idx + 1) seq_len = tl.load(seq_lens_ptr + batch_idx) - for i in tl.range(batch_start, batch_end, 128, num_stages=num_stages): offsets = tl.arange(0, 128) + i mask = offsets < batch_end diff --git a/flashinfer/triton/sm_constraint_gemm.py b/flashinfer/triton/sm_constraint_gemm.py index 2a07b8e10d..595d3240a2 100644 --- a/flashinfer/triton/sm_constraint_gemm.py +++ b/flashinfer/triton/sm_constraint_gemm.py @@ -1,13 +1,11 @@ from typing import Optional -import torch +import paddle import triton -from .kernels.sm_constraint_gemm import ( - gemm_kernel, - gemm_kernel_descriptor_persistent, - gemm_kernel_persistent, -) +from .kernels.sm_constraint_gemm import (gemm_kernel, + gemm_kernel_descriptor_persistent, + gemm_kernel_persistent) from .utils import check_device, check_dim, check_input @@ -25,56 +23,49 @@ def gemm_persistent(a, b, c=None, alpha=1.0, beta=0.0, out_dtype=None, num_sms=N out_dtype: The dtype of the output matrix. Default: fp8 --> bf16. Otherwise, same as a.dtype. num_sms: The number of SMs to use for the computation. """ - - # Check inputs. check_input(a) - # b can be non-contiguous check_device([a, b]) check_dim(2, a) check_dim(2, b) - if c is not None: check_input(c) check_device([c]) check_dim(2, c) - - assert a.shape[1] == b.shape[0], "Incompatible dimensions between a and b" + assert ( + tuple(a.shape)[1] == tuple(b.shape)[0] + ), "Incompatible dimensions between a and b" assert a.dtype == b.dtype, "Incompatible dtypes between a and b" - if c is not None: - assert a.shape[0] == c.shape[0], "Incompatible dimensions between a and c" - assert b.shape[1] == c.shape[1], "Incompatible dimensions between b and c" - - M, K = a.shape - K, N = b.shape + assert ( + tuple(a.shape)[0] == tuple(c.shape)[0] + ), "Incompatible dimensions between a and c" + assert ( + tuple(b.shape)[1] == tuple(c.shape)[1] + ), "Incompatible dimensions between b and c" + M, K = tuple(a.shape) + K, N = tuple(b.shape) dtype = a.dtype out_dtype = ( out_dtype if out_dtype else dtype - if dtype != torch.float8_e4m3fn - else torch.bfloat16 - ) - - assert c is None or c.dtype == out_dtype, ( - "Incompatible dtypes between c and out_dtype" + if dtype != paddle.float8_e4m3fn + else "bfloat16" ) - - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=out_dtype) if c is None else c - - # Set num_sms to be 100% of the available SMs - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + assert ( + c is None or c.dtype == out_dtype + ), "Incompatible dtypes between c and out_dtype" + c = paddle.empty(shape=(M, N), dtype=out_dtype) if c is None else c + NUM_SMS = paddle.device.cuda.get_device_properties( + device="gpu" + ).multi_processor_count num_sms = NUM_SMS if num_sms is None else min(NUM_SMS, num_sms) - - # 1D launch kernel where each block gets its own program. grid = lambda META: ( min( num_sms, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ), ) - gemm_kernel_persistent[grid]( a, b, @@ -82,12 +73,12 @@ def gemm_persistent(a, b, c=None, alpha=1.0, beta=0.0, out_dtype=None, num_sms=N M, N, K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), + a.get_strides()[0], + a.get_strides()[1], + b.get_strides()[0], + b.get_strides()[1], + c.get_strides()[0], + c.get_strides()[1], alpha=alpha, beta=beta, NUM_SMS=num_sms, @@ -109,48 +100,42 @@ def gemm(a, b, c=None, alpha=1.0, beta=0.0, out_dtype=None): out_dtype: The dtype of the output matrix. Default: fp8 --> bf16. Otherwise, same as a.dtype. num_sms: The number of SMs to use for the computation. """ - # Check inputs. check_input(a) - # b can be non-contiguous check_device([a, b]) check_dim(2, a) check_dim(2, b) - if c is not None: check_input(c) check_device([c]) check_dim(2, c) - - assert a.shape[1] == b.shape[0], "Incompatible dimensions between a and b" + assert ( + tuple(a.shape)[1] == tuple(b.shape)[0] + ), "Incompatible dimensions between a and b" assert a.dtype == b.dtype, "Incompatible dtypes between a and b" - if c is not None: - assert a.shape[0] == c.shape[0], "Incompatible dimensions between a and c" - assert b.shape[1] == c.shape[1], "Incompatible dimensions between b and c" - - M, K = a.shape - K, N = b.shape + assert ( + tuple(a.shape)[0] == tuple(c.shape)[0] + ), "Incompatible dimensions between a and c" + assert ( + tuple(b.shape)[1] == tuple(c.shape)[1] + ), "Incompatible dimensions between b and c" + M, K = tuple(a.shape) + K, N = tuple(b.shape) dtype = a.dtype out_dtype = ( out_dtype if out_dtype else dtype - if dtype != torch.float8_e4m3fn - else torch.bfloat16 - ) - - assert c is None or c.dtype == out_dtype, ( - "Incompatible dtypes between c and out_dtype" + if dtype != paddle.float8_e4m3fn + else "bfloat16" ) - - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=out_dtype) if c is None else c - - # 1D launch kernel where each block gets its own program. + assert ( + c is None or c.dtype == out_dtype + ), "Incompatible dtypes between c and out_dtype" + c = paddle.empty(shape=(M, N), dtype=out_dtype) if c is None else c grid = lambda META: ( triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - gemm_kernel[grid]( a, b, @@ -158,12 +143,12 @@ def gemm(a, b, c=None, alpha=1.0, beta=0.0, out_dtype=None): M, N, K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), + a.get_strides()[0], + a.get_strides()[1], + b.get_strides()[0], + b.get_strides()[1], + c.get_strides()[0], + c.get_strides()[1], alpha=alpha, beta=beta, ) @@ -200,75 +185,70 @@ def gemm_descriptor_persistent( num_sms: The number of SMs to use for the computation. EPILOGUE_SUBTILE: Whether to use the epilogue subtile optimization. """ - # Check inputs. check_input(a) check_input(b) check_device([a, b]) check_dim(2, a) check_dim(2, b) - if c is not None: check_input(c) check_device([c]) check_dim(2, c) - - assert a.shape[1] == b.shape[1], "Incompatible dimensions between a and b" + assert ( + tuple(a.shape)[1] == tuple(b.shape)[1] + ), "Incompatible dimensions between a and b" assert a.dtype == b.dtype, "Incompatible dtypes between a and b" - if c is not None: - assert a.shape[0] == c.shape[0], "Incompatible dimensions between a and c" - assert b.shape[0] == c.shape[1], "Incompatible dimensions between b and c" - - M, K = a.shape - N, K = b.shape + assert ( + tuple(a.shape)[0] == tuple(c.shape)[0] + ), "Incompatible dimensions between a and c" + assert ( + tuple(b.shape)[0] == tuple(c.shape)[1] + ), "Incompatible dimensions between b and c" + M, K = tuple(a.shape) + N, K = tuple(b.shape) dtype = a.dtype out_dtype = ( out_dtype if out_dtype else dtype - if dtype != torch.float8_e4m3fn - else torch.bfloat16 + if dtype != paddle.float8_e4m3fn + else "bfloat16" ) - - # check on TMA tensor map swizzling granularity - # Swizzle 16B chunks within at least 32B span - if dtype == torch.float8_e4m3fn: + if dtype == paddle.float8_e4m3fn: assert K >= 16, "Least chunk size must be 16B" assert N >= 16, "Least chunk size must be 16B" else: assert K >= 8, "Least chunk size must be 16B" assert N >= 8, "Least chunk size must be 16B" - - c = torch.empty((M, N), device=a.device, dtype=out_dtype) if c is None else c - - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + c = paddle.empty(shape=(M, N), dtype=out_dtype) if c is None else c + NUM_SMS = paddle.device.cuda.get_device_properties( + device="gpu" + ).multi_processor_count num_sms = NUM_SMS if num_sms is None else min(NUM_SMS, num_sms) - # TMA descriptors require a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): - return torch.empty(size, device="cuda", dtype=torch.int8) + return paddle.empty(shape=size, dtype="int8") triton.set_allocator(alloc_fn) - grid = lambda META: ( min( num_sms, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ), ) - gemm_kernel_descriptor_persistent[grid]( a, b, - c, # + c, M, N, - K, # + K, alpha, beta, - NUM_SMS=num_sms, # + NUM_SMS=num_sms, BLOCK_SIZE_M=128, - BLOCK_SIZE_N=128 if dtype != torch.float32 else 64, + BLOCK_SIZE_N=128 if dtype != "float32" else 64, BLOCK_SIZE_K=64, GROUP_SIZE_M=8, num_stages=3, diff --git a/flashinfer/triton/utils.py b/flashinfer/triton/utils.py index 6cd6bd34f2..bd8d475824 100644 --- a/flashinfer/triton/utils.py +++ b/flashinfer/triton/utils.py @@ -1,28 +1,28 @@ from typing import List -import torch +import paddle -def check_input(x: torch.Tensor): - assert x.is_cuda, f"{str(x)} must be a CUDA Tensor" +def check_input(x: paddle.Tensor): + assert x.place.is_gpu_place(), f"{str(x)} must be a CUDA Tensor" assert x.is_contiguous(), f"{str(x)} must be contiguous" -def check_dim(d, x: torch.Tensor): +def check_dim(d, x: paddle.Tensor): assert x.dim() == d, f"{str(x)} must be a {d}D tensor" -def check_shape(a: torch.Tensor, b: torch.Tensor): +def check_shape(a: paddle.Tensor, b: paddle.Tensor): assert a.dim() == b.dim(), "tensors should have same dim" for i in range(a.dim()): - assert a.size(i) == b.size(i), ( - f"tensors shape mismatch, {a.size()} and {b.size()}" - ) + assert ( + a.shape[i] == b.shape[i] + ), f"tensors shape mismatch, {tuple(a.shape)} and {tuple(b.shape)}" -def check_device(tensors: List[torch.Tensor]): - device = tensors[0].device +def check_device(tensors: List[paddle.Tensor]): + device = tensors[0].place for t in tensors: - assert t.device == device, ( - f"All tensors should be on the same device, but got {device} and {t.device}" - ) + assert ( + t.place == device + ), f"All tensors should be on the same device, but got {device} and {t.place}" diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 1c716d1e0e..6c8df7a7c1 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -1,3 +1,11 @@ +import sys + + +import os + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2023 by FlashInfer team. @@ -13,19 +21,13 @@ See the License for the specific language governing permissions and limitations under the License. """ - import functools import math -import os from enum import Enum from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union -import torch -import torch.version -from torch.torch_version import TorchVersion -from torch.torch_version import __version__ as torch_version - -from .jit import gen_jit_spec, env as jit_env +from .jit import env as jit_env +from .jit import gen_jit_spec IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1" @@ -48,40 +50,30 @@ class TensorLayout(Enum): HND = 1 -log2e = 1.44269504088896340736 +log2e = 1.4426950408889634 -def _expand_5d(x: torch.Tensor, kv_layout: str) -> torch.Tensor: +def _expand_5d(x: paddle.Tensor, kv_layout: str) -> paddle.Tensor: if x.ndim not in [4, 5]: raise ValueError("x must be 4D or 5D") if x.ndim == 4: - # page_size == 1 if kv_layout == "NHD": - # (num_pages, 2, num_heads, head_dim) -> (num_pages, 2, page_size=1, num_heads, head_dim) - # expand to 5D on the 3nd last dimension - return x.unsqueeze(-3) + return x.unsqueeze(axis=-3) elif kv_layout == "HND": - # (num_pages, 2, num_heads, head_dim) -> (num_pages, 2, num_heads, page_size=1, head_dim) - # expand to 5D on the 2nd last dimension - return x.unsqueeze(-2) + return x.unsqueeze(axis=-2) else: raise KeyError("Invalid kv_layout {}".format(kv_layout)) return x -def _expand_4d(x: torch.Tensor, kv_layout: str) -> torch.Tensor: +def _expand_4d(x: paddle.Tensor, kv_layout: str) -> paddle.Tensor: if x.ndim not in [3, 4]: raise ValueError("x must be 3D or 4D") if x.ndim == 3: - # page_size == 1 if kv_layout == "NHD": - # (num_pages, num_heads, head_dim) -> (num_pages, page_size=1, num_heads, head_dim) - # expand to 4D on the 3nd last dimension - return x.unsqueeze(-3) + return x.unsqueeze(axis=-3) elif kv_layout == "HND": - # (num_pages, num_heads, head_dim) -> (num_pages, num_heads, page_size=1, head_dim) - # expand to 5D on the 2nd last dimension - return x.unsqueeze(-2) + return x.unsqueeze(axis=-2) else: raise KeyError("Invalid kv_layout {}".format(kv_layout)) return x @@ -90,10 +82,6 @@ def _expand_4d(x: torch.Tensor, kv_layout: str) -> torch.Tensor: def next_positive_power_of_2(x: int) -> int: if x < 1: return 1 - - # Following code is equivalent to 1 << (x - 1).bit_length() - # But this impl does not contain bit_length() so can be used by torch compile. - # It can correctly handle 64bit number which should be enough for now. n = x - 1 n |= n >> 1 n |= n >> 2 @@ -114,31 +102,29 @@ def _check_kv_layout(kv_layout: str) -> None: raise KeyError("Invalid kv_layout {}".format(kv_layout)) -def is_float8(x: torch.Tensor) -> bool: - return x.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] +def is_float8(x: paddle.Tensor) -> bool: + return x.dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2] -def get_indptr(x: torch.Tensor) -> torch.Tensor: - x = x.to(torch.int64) - ret = torch.zeros(x.shape[0] + 1, dtype=x.dtype, device=x.device) - ret[1:] = x.cumsum(0) +def get_indptr(x: paddle.Tensor) -> paddle.Tensor: + x = x.to("int64") + ret = paddle.zeros(shape=tuple(x.shape)[0] + 1, dtype=x.dtype) + ret[1:] = x.cumsum(axis=0) return ret def _unpack_paged_kv_cache( - paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + paged_kv_cache: Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]], kv_layout: str, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[paddle.Tensor, paddle.Tensor]: if isinstance(paged_kv_cache, tuple): paged_k_cache, paged_v_cache = paged_kv_cache - return ( - _expand_4d(paged_k_cache, kv_layout), - _expand_4d(paged_v_cache, kv_layout), + return _expand_4d(paged_k_cache, kv_layout), _expand_4d( + paged_v_cache, kv_layout ) - elif torch.is_tensor(paged_kv_cache): - # NOTE(Zihao): split on the second dimension + elif paddle.is_tensor(x=paged_kv_cache): paged_kv_cache = _expand_5d(paged_kv_cache, kv_layout) - paged_k_cache, paged_v_cache = paged_kv_cache.unbind(dim=1) + paged_k_cache, paged_v_cache = paged_kv_cache.unbind(axis=1) return paged_k_cache, paged_v_cache else: raise KeyError( @@ -148,48 +134,47 @@ def _unpack_paged_kv_cache( ) -def get_alibi_slopes(n_heads: int) -> torch.Tensor: +def get_alibi_slopes(n_heads: int) -> paddle.Tensor: n = 2 ** math.floor(math.log2(n_heads)) m_0 = 2.0 ** (-8.0 / n) - m = torch.pow(m_0, torch.arange(1, 1 + n)) + m = paddle.pow(x=m_0, y=paddle.arange(start=1, end=1 + n)) if n < n_heads: m_hat_0 = 2.0 ** (-4.0 / n) - m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2)) - m = torch.cat([m, m_hat]) - return m.float() + m_hat = paddle.pow( + x=m_hat_0, y=paddle.arange(start=1, end=1 + 2 * (n_heads - n), step=2) + ) + m = paddle.concat(x=[m, m_hat]) + return m.astype(dtype="float32") -_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {} +_cache_buf: Dict[Tuple[str, str], paddle.Tensor] = {} -def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: - key = (name, device) +def _get_cache_buf(name: str, bytes: int, device: str) -> paddle.Tensor: + key = name, device buf = _cache_buf.get(key) if buf is None: - buf = torch.empty(bytes, dtype=torch.uint8, device=device) + buf = paddle.empty(shape=bytes, dtype="uint8") _cache_buf[key] = buf return buf -# find the least power of 2 that is greater than or equal to x def _ceil_pow2(x: int) -> int: return 1 << (x - 1).bit_length() -def _get_range_buf(seq_len: int, device: torch.device) -> torch.Tensor: +def _get_range_buf(seq_len: int, device: str) -> paddle.Tensor: seq_len_pow2 = _ceil_pow2(seq_len) - key = (f"range_{seq_len_pow2}", device) + key = f"range_{seq_len_pow2}", device buf = _cache_buf.get(key) if buf is None: - buf = torch.arange(seq_len_pow2, device=device, dtype=torch.int32) + buf = paddle.arange(dtype="int32", end=seq_len_pow2) _cache_buf[key] = buf return buf[:seq_len] -def _get_cache_alibi_slopes_buf( - num_qo_heads: int, device: torch.device -) -> torch.Tensor: - key = (f"alibi_slopes_{num_qo_heads}", device) +def _get_cache_alibi_slopes_buf(num_qo_heads: int, device: str) -> paddle.Tensor: + key = f"alibi_slopes_{num_qo_heads}", device buf = _cache_buf.get(key) if buf is None: buf = get_alibi_slopes(num_qo_heads).to(device) @@ -197,10 +182,10 @@ def _get_cache_alibi_slopes_buf( return buf -def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: +def canonicalize_torch_dtype(dtype: Union[paddle.dtype, str]) -> paddle.dtype: if isinstance(dtype, str): return getattr(torch, dtype) - elif isinstance(dtype, torch.dtype): + elif isinstance(dtype, paddle.dtype): return dtype else: raise TypeError( @@ -209,14 +194,14 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: @functools.cache -def get_compute_capability(device: torch.device) -> Tuple[int, int]: - if device.type != "cuda": - raise ValueError("device must be a cuda device") - return torch.cuda.get_device_capability(device.index) +def get_compute_capability(device: str) -> Tuple[int, int]: + # if device.type != "cuda": + # raise ValueError("device must be a cuda device") + return paddle.device.cuda.get_device_capability(device.gpu_device_id()) def _check_cached_qkv_data_type( - q: torch.Tensor, k: torch.Tensor, dtype_q: torch.dtype, dtype_kv: torch.dtype + q: paddle.Tensor, k: paddle.Tensor, dtype_q: paddle.dtype, dtype_kv: paddle.dtype ) -> None: if q.dtype != dtype_q: raise ValueError( @@ -227,60 +212,24 @@ def _check_cached_qkv_data_type( f"The dtype of k {k.dtype} does not match the kv_data_type {dtype_kv} specified in plan function." ) +def register_custom_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: Optional[Union[str, Sequence[str]]] = None, + schema: Optional[str] = None, +) -> Callable: + return lambda x: x + +def register_fake_op(name: str, fn: Optional[Callable] = None) -> Callable: + return lambda x: x -if IS_BUILDING_DOCS or TorchVersion(torch_version) < TorchVersion("2.4"): - - def register_custom_op( - name: str, - fn: Optional[Callable] = None, - /, - *, - mutates_args: Union[str, Iterable[str]], - device_types: Optional[Union[str, Sequence[str]]] = None, - schema: Optional[str] = None, - ) -> Callable: - return lambda x: x - - def register_fake_op( - name: str, - fn: Optional[Callable] = None, - ) -> Callable: - return lambda x: x - -else: - - def register_custom_op( - name: str, - fn: Optional[Callable] = None, - /, - *, - mutates_args: Union[str, Iterable[str]], - device_types: Optional[Union[str, Sequence[str]]] = None, - schema: Optional[str] = None, - ) -> Callable: - # NOTE(Zihao): torch.library.custom_op has significant overhead as mentioned in the following link - # https://github.com/vllm-project/vllm/blob/36e76700453924c8d421db99af70a88a1df835cd/vllm/utils.py#L1660-L1674 - - # return torch.library.custom_op( - # name, - # fn, - # mutates_args=mutates_args, - # device_types=device_types, - # schema=schema, - # ) - return lambda x: x - - def register_fake_op( - name: str, - fn: Optional[Callable] = None, - ) -> Callable: - # return torch.library.register_fake(name, fn) - return lambda x: x - - -def determine_gemm_backend(device: torch.device) -> str: + +def determine_gemm_backend(device: str) -> str: major, _ = get_compute_capability(device) - if major == 9 and torch.version.cuda >= "12.3": + if major == 9: return "sm90" else: return "sm80" @@ -290,8 +239,8 @@ def is_fa3_backend_supported( pos_encoding_mode: int, use_fp16_qk_reductions: bool, use_custom_mask: bool, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, ) -> bool: """ Check if the FA3 backend is supported based on the given parameters. @@ -329,8 +278,8 @@ def is_cutlass_backend_supported( pos_encoding_mode: int, use_fp16_qk_reductions: bool, use_custom_mask: bool, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, ) -> bool: """ Check if the cutlass backend is supported based on the given parameters. @@ -359,20 +308,20 @@ def is_cutlass_backend_supported( return False if use_fp16_qk_reductions: return False - if dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2]: + if dtype_q in [paddle.float8_e4m3fn, paddle.float8_e5m2]: return False - if dtype_kv in [torch.float8_e4m3fn, torch.float8_e5m2]: + if dtype_kv in [paddle.float8_e4m3fn, paddle.float8_e5m2]: return False return True def determine_attention_backend( - device: torch.device, + device: str, pos_encoding_mode: int, use_fp16_qk_reductions: bool, use_custom_mask: bool, - dtype_q: torch.dtype, - dtype_kv: torch.dtype, + dtype_q: paddle.dtype, + dtype_kv: paddle.dtype, ) -> str: """ Determine the appropriate attention backend based on the device and parameters. @@ -400,11 +349,7 @@ def determine_attention_backend( The name of the attention backend to be used. """ if is_sm90a_supported(device) and is_fa3_backend_supported( - pos_encoding_mode, - use_fp16_qk_reductions, - use_custom_mask, - dtype_q, - dtype_kv, + pos_encoding_mode, use_fp16_qk_reductions, use_custom_mask, dtype_q, dtype_kv ): return "fa3" else: @@ -429,47 +374,45 @@ def has_cuda_cudart() -> bool: return importlib.util.find_spec("cuda.cudart") is not None -def is_sm90a_supported(device: torch.device) -> bool: +def is_sm90a_supported(device: str) -> bool: major, _ = get_compute_capability(device) - return major == 9 and version_at_least(torch.version.cuda, "12.3") + return major == 9 -def is_sm100a_supported(device: torch.device) -> bool: +def is_sm100a_supported(device: str) -> bool: major, _ = get_compute_capability(device) - return major == 10 and version_at_least(torch.version.cuda, "12.8") + return major == 10 -def determine_mla_backend(device: torch.device) -> str: +def determine_mla_backend(device: str) -> str: return "fa3" if is_sm90a_supported(device) else "fa2" def check_shape_dtype_device( - x: torch.Tensor, + x: paddle.Tensor, expected_shape: Optional[Sequence[int]], - expected_dtype: Optional[torch.dtype], - expected_device: Optional[torch.device], + expected_dtype: Optional[paddle.dtype], + expected_device: Optional[str], name: str, ) -> None: - if expected_shape and x.shape != torch.Size(expected_shape): + if expected_shape and tuple(x.shape) != tuple(expected_shape): raise ValueError( - f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}" + f"Invalid shape of {name}: expected {expected_shape}, got {tuple(x.shape)}" ) if expected_dtype and x.dtype != expected_dtype: raise ValueError( f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}" ) - if expected_device and x.device != expected_device: + if expected_device and x.place != expected_device: raise ValueError( - f"Invalid device of {name}: expected {expected_device}, got {x.device}" + f"Invalid device of {name}: expected {expected_device}, got {x.place}" ) def get_logging_module(): return gen_jit_spec( "logging", - [ - jit_env.FLASHINFER_CSRC_DIR / "logging.cc", - ], + [jit_env.FLASHINFER_CSRC_DIR / "logging.cc"], extra_include_paths=[ jit_env.SPDLOG_INCLUDE_DIR, jit_env.FLASHINFER_INCLUDE_DIR, @@ -500,7 +443,7 @@ def set_log_level(lvl_str: str) -> None: get_logging_module().set_log_level(log_level_map[lvl_str].value) -def device_support_pdl(device: torch.device) -> bool: +def device_support_pdl(device: str) -> bool: major, _ = get_compute_capability(device) return major >= 9 @@ -524,8 +467,10 @@ def round_up(x: int, y: int) -> int: return ceil_div(x, y) * y -def get_device_sm_count(device: torch.device) -> int: - return torch.cuda.get_device_properties(device).multi_processor_count +def get_device_sm_count(device: str) -> int: + return paddle.device.cuda.get_device_properties( + device=device2str(device) + ).multi_processor_count class FP4Tensor: @@ -538,8 +483,8 @@ class FP4Tensor: def __init__( self, - data: torch.Tensor, - scale: torch.Tensor, + data: paddle.Tensor, + scale: paddle.Tensor, scale_start_index: int = 0, original_shape: Optional[Tuple[int, ...]] = None, ): @@ -556,44 +501,32 @@ def __init__( original_shape : Optional[Tuple[int, ...]] The original shape before compression. """ - if data.dtype != torch.uint8: + if data.dtype != "uint8": raise ValueError(f"data must be uint8 tensor, got {data.dtype}") - - # Validate scale factor tensor and scale start index - if scale.dtype != torch.float8_e4m3fn: + if scale.dtype != paddle.float8_e4m3fn: raise ValueError(f"scale must be float8_e4m3fn tensor, got {scale.dtype}") - if scale.shape[0] % 128 != 0: + if tuple(scale.shape)[0] % 128 != 0: raise ValueError( - f"scale.shape[0] must be a multiple of 128, got {scale.shape[0]}" + f"scale.shape[0] must be a multiple of 128, got {tuple(scale.shape)[0]}" ) - if scale_start_index < 0 or scale_start_index >= scale.shape[0]: + if scale_start_index < 0 or scale_start_index >= tuple(scale.shape)[0]: raise ValueError( - f"scale start index must be in the range [0, scale.shape[0]). " - f"scale_start_index={scale_start_index}, scale.shape[0]={scale.shape[0]}" + f"scale start index must be in the range [0, scale.shape[0]). scale_start_index={scale_start_index}, scale.shape[0]={tuple(scale.shape)[0]}" ) - if scale_start_index + data.shape[0] > scale.shape[0]: + if scale_start_index + tuple(data.shape)[0] > tuple(scale.shape)[0]: raise ValueError( - f"scale start index + data.shape[0] must not exceed scale.shape[0]. " - f"scale_start_index={scale_start_index}, data.shape[0]={data.shape[0]}, scale.shape[0]={scale.shape[0]}" + f"scale start index + data.shape[0] must not exceed scale.shape[0]. scale_start_index={scale_start_index}, data.shape[0]={tuple(data.shape)[0]}, scale.shape[0]={tuple(scale.shape)[0]}" ) - - # Validate shape relationship if original_shape is provided if original_shape is not None: - if data.shape[:-1] != original_shape[:-1]: + if tuple(data.shape)[:-1] != original_shape[:-1]: raise ValueError( - f"data and original_shape must have the same dimensions except the last one. " - f"data.shape={data.shape}, original_shape={original_shape}" + f"data and original_shape must have the same dimensions except the last one. data.shape={tuple(data.shape)}, original_shape={original_shape}" ) - - # Check the last dimension relationship: data_dim = ceil(original_dim / 2) expected_data_dim = math.ceil(original_shape[-1] / 2) - if data.shape[-1] != expected_data_dim: + if tuple(data.shape)[-1] != expected_data_dim: raise ValueError( - f"data last dimension must be ceil(original_shape[-1] / 2). " - f"data.shape[-1]={data.shape[-1]}, original_shape[-1]={original_shape[-1]}, " - f"expected={expected_data_dim}" + f"data last dimension must be ceil(original_shape[-1] / 2). data.shape[-1]={tuple(data.shape)[-1]}, original_shape[-1]={original_shape[-1]}, expected={expected_data_dim}" ) - self.data = data self.scale = scale self.scale_start_index = scale_start_index @@ -601,29 +534,41 @@ def __init__( self.dtype = "nvfp4" -# yapf: disable -srcToDstBlk16RowMap = [ - 0, 8, - 1, 9, - 2, 10, - 3, 11, - 4, 12, - 5, 13, - 6, 14, - 7, 15 -] - +srcToDstBlk16RowMap = [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] srcToDstBlk32RowMap = [ - 0, 8, 16, 24, - 1, 9, 17, 25, - 2, 10, 18, 26, - 3, 11, 19, 27, - 4, 12, 20, 28, - 5, 13, 21, 29, - 6, 14, 22, 30, - 7, 15, 23, 31 + 0, + 8, + 16, + 24, + 1, + 9, + 17, + 25, + 2, + 10, + 18, + 26, + 3, + 11, + 19, + 27, + 4, + 12, + 20, + 28, + 5, + 13, + 21, + 29, + 6, + 14, + 22, + 30, + 7, + 15, + 23, + 31, ] -# yapf: enable def get_shuffle_block_size(epilogue_tile_m: int) -> int: @@ -634,60 +579,42 @@ def get_shuffle_block_size(epilogue_tile_m: int) -> int: def get_shuffle_matrix_a_row_indices( - input_tensor: torch.Tensor, epilogue_tile_m: int -) -> torch.Tensor: + input_tensor: paddle.Tensor, epilogue_tile_m: int +) -> paddle.Tensor: """ Higher-level PyTorch approach to reorder the rows in blocks of size 16 or 32. - We do NOT try to handle custom e2m1 memory usage (i.e. no 'K/2' bytes). - Instead, we purely reorder rows in a standard PyTorch shape [M, K]. """ - assert input_tensor.dim() == 2, ( - f"input_tensor should be a 2D tensor, not {input_tensor.dim()}" - ) - - # M, K from the input - M, K = input_tensor.shape - - # Choose block size 16 or 32 + assert ( + input_tensor.dim() == 2 + ), f"input_tensor should be a 2D tensor, not {input_tensor.dim()}" + M, K = tuple(input_tensor.shape) shuffle_block_size = get_shuffle_block_size(epilogue_tile_m) row_map = srcToDstBlk16RowMap if shuffle_block_size == 16 else srcToDstBlk32RowMap - - assert M % shuffle_block_size == 0, ( - f"input_tensor.shape[0] must be multiples of {shuffle_block_size}" - ) - - # row_indices[new_row] = old_row - # so row_indices is an array of size M telling us from which old_row - # the new_row should be taken. - row_indices = torch.empty(M, dtype=torch.long) - + assert ( + M % shuffle_block_size == 0 + ), f"input_tensor.shape[0] must be multiples of {shuffle_block_size}" + row_indices = paddle.empty(shape=M, dtype="int64") for old_row in range(M): block_idx = old_row // shuffle_block_size row_in_block = old_row % shuffle_block_size mapped_row_in_block = row_map[row_in_block] - new_row = block_idx * shuffle_block_size + mapped_row_in_block - row_indices[new_row] = old_row - return row_indices def get_shuffle_matrix_sf_a_row_indices( - input_tensor: torch.Tensor, epilogue_tile_m: int, num_elts_per_sf: int = 16 -) -> torch.Tensor: - assert input_tensor.dtype == torch.uint8 + input_tensor: paddle.Tensor, epilogue_tile_m: int, num_elts_per_sf: int = 16 +) -> paddle.Tensor: + assert input_tensor.dtype == "uint8" assert num_elts_per_sf == 16 - - assert input_tensor.dim() == 2, ( - f"input_tensor should be a 2D tensor, not {input_tensor.dim()}" - ) - - # M, K from the input - M, K = input_tensor.shape + assert ( + input_tensor.dim() == 2 + ), f"input_tensor should be a 2D tensor, not {input_tensor.dim()}" + M, K = tuple(input_tensor.shape) assert M % 128 == 0 assert K % 4 == 0 - row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) - return row_indices diff --git a/include/flashinfer/comm/trtllm_allreduce_fusion.cuh b/include/flashinfer/comm/trtllm_allreduce_fusion.cuh index 72d89300d8..8ec84c4a03 100644 --- a/include/flashinfer/comm/trtllm_allreduce_fusion.cuh +++ b/include/flashinfer/comm/trtllm_allreduce_fusion.cuh @@ -3,7 +3,7 @@ #include #include -#if CUDA_VERSION >= 120800 +#if CUDA_VERSION >= 12080 #include #endif @@ -531,7 +531,7 @@ __forceinline__ __device__ uint32_t pack_bytes(uint8_t c0, uint8_t c1, uint8_t c return (val3 << 24) | (val2 << 16) | (val1 << 8) | val0; } -#if CUDA_VERSION >= 120800 +#if CUDA_VERSION >= 12080 // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). // NOTE: bypass sm_100 requirement by __nv_cvt_float2_to_fp4x2 inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { @@ -623,7 +623,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t& vec, float SFScaleV uint8_t fp8SFVal; // Write the SF to global memory (STG.8). if constexpr (UE8M0_SF) { -#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) +#if (__CUDACC_VER_MAJOR__ * 1000 + __CUDACC_VER_MINOR__ * 10 >= 12080) __nv_fp8_e8m0 tmp; tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); SFValue = static_cast(tmp); @@ -945,7 +945,7 @@ class FusedOp { } } -#if CUDA_VERSION >= 120800 +#if CUDA_VERSION >= 12080 if constexpr (GetQuantType == QuantType::kFP4) { // NOTE(Yingyi): might update later auto sf_out = utils::cvt_quant_to_fp4_get_sf_out_offset( @@ -1469,24 +1469,26 @@ cudaError_t allreduce_fusion_op(AllReduceFusionParams const& params, bool lau DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormFP8Quant, NRanks); \ break; \ case AllReduceFusionPattern::kARResidualRMSNormFP4Quant: \ - if constexpr (!std::is_same_v && CUDA_VERSION >= 120800) { \ + if constexpr (!std::is_same_v && CUDA_VERSION >= 12080) { \ DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormFP4Quant, NRanks); \ } else { \ - FLASHINFER_CHECK(false, "FP4Quant pattern cannot work with DType=float!"); \ + FLASHINFER_CHECK(CUDA_VERSION >= 12080, "FP4Quant requires CUDA 12.8 or higher"); \ + FLASHINFER_CHECK(false, "FP4Quant pattern cannot work with DType=float"); \ } \ break; \ case AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant: \ DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant, NRanks); \ break; \ case AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant: \ - if constexpr (!std::is_same_v && CUDA_VERSION >= 120800) { \ + if constexpr (!std::is_same_v && CUDA_VERSION >= 12080) { \ DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant, NRanks); \ } else { \ - FLASHINFER_CHECK(false, "OutFP4Quant pattern cannot work with DType=float!"); \ + FLASHINFER_CHECK(CUDA_VERSION >= 12080, "OutFP4Quant requires CUDA 12.8 or higher"); \ + FLASHINFER_CHECK(false, "OutFP4Quant pattern cannot work with DType=float"); \ } \ break; \ default: \ - FLASHINFER_CHECK(false, "Unsupported allreduce fusion pattern!"); \ + FLASHINFER_CHECK(false, "Unsupported allreduce fusion pattern"); \ } switch (params.nranks) { diff --git a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh index 4853685c1b..143e25de9c 100644 --- a/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh +++ b/include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh @@ -3,7 +3,7 @@ #include #include -#if CUDA_VERSION >= 120800 +#if CUDA_VERSION >= 12080 #include #endif @@ -509,7 +509,7 @@ __forceinline__ __device__ uint32_t pack_bytes(uint8_t c0, uint8_t c1, uint8_t c return (val3 << 24) | (val2 << 16) | (val1 << 8) | val0; } -#if CUDA_VERSION >= 120800 +#if CUDA_VERSION >= 12080 // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). // NOTE:bypass sm_100 requirement by __nv_cvt_float2_to_fp4x2 inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { @@ -601,7 +601,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t& vec, float SFScaleV uint8_t fp8SFVal; // Write the SF to global memory (STG.8). if constexpr (UE8M0_SF) { -#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) +#if (__CUDACC_VER_MAJOR__ * 1000 + __CUDACC_VER_MINOR__ * 10 >= 12080) __nv_fp8_e8m0 tmp; tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); SFValue = static_cast(tmp); @@ -820,7 +820,7 @@ __device__ __forceinline__ void fused_op(vec_t const& val, int acce if constexpr (NormOut) { norm_val.store(reinterpret_cast(params.norm_out) + access_id * VEC_SIZE); } -#if CUDA_VERSION >= 120800 +#if CUDA_VERSION >= 12080 if constexpr (QuantOut) { constexpr int SF_VEC_SIZE = 16; auto sf_out = utils::cvt_quant_to_fp4_get_sf_out_offset( @@ -1478,7 +1478,7 @@ cudaError_t moefinalize_allreduce_fusion_op(MoeFinalizeAllReduceFusionParams auto status = DISPATCH_MOEFINALIZEREDUCTION( params.nranks, params.residual_out, params.rms_gamma, params.quant_out, N_RANKS, RES, RMS, QUANT, [&]() -> cudaError_t { - if constexpr (CUDA_VERSION < 120800 && QUANT) { + if constexpr (CUDA_VERSION < 12080 && QUANT) { FLASHINFER_CHECK(false, "cuda version should be greater equal than 12.8 with " "trtllm_moe_allreduce_fusion quant"); diff --git a/include/flashinfer/cutlass_utils.cuh b/include/flashinfer/cutlass_utils.cuh index 0e2033a5b9..7a7a00789b 100644 --- a/include/flashinfer/cutlass_utils.cuh +++ b/include/flashinfer/cutlass_utils.cuh @@ -42,7 +42,7 @@ #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/tensor_view_io.h" #if defined(FLASHINFER_ENABLE_FP4_E2M1) -#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) +#if (__CUDACC_VER_MAJOR__ * 1000 + __CUDACC_VER_MINOR__ * 10 >= 12080) #include #endif #endif diff --git a/profiler/batch_attention.py b/profiler/batch_attention.py index 7e10c4740e..d1b52a6ad4 100644 --- a/profiler/batch_attention.py +++ b/profiler/batch_attention.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,11 +15,8 @@ See the License for the specific language governing permissions and limitations under the License. """ - import argparse -import torch - import flashinfer from flashinfer.profiler import export_to_perfetto_trace @@ -35,47 +34,32 @@ def profile_persistent_batch_attention( profiler_buffer_size, device="cuda", ): - seq_lens = torch.tensor(kv_lens, dtype=torch.int32) - q_lens = torch.tensor(qo_lens, dtype=torch.int32) - - seq_lens_blocks = torch.ceil(seq_lens / page_size).int() - - q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int() - kv_indptr = torch.cat( - [torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0 - ).int() - + seq_lens = paddle.to_tensor(data=kv_lens, dtype="int32") + q_lens = paddle.to_tensor(data=qo_lens, dtype="int32") + seq_lens_blocks = paddle.ceil(x=seq_lens / page_size).astype(dtype="int32") + q_indptr = paddle.concat( + x=[paddle.to_tensor(data=[0]), paddle.cumsum(x=q_lens, axis=0)], axis=0 + ).astype(dtype="int32") + kv_indptr = paddle.concat( + x=[paddle.to_tensor(data=[0]), paddle.cumsum(x=seq_lens_blocks, axis=0)], axis=0 + ).astype(dtype="int32") num_blocks = kv_indptr[-1].item() - - q = torch.rand( - q_indptr[-1].item(), num_qo_heads, head_dim, device=device, dtype=test_dtype + q = paddle.rand( + shape=[q_indptr[-1].item(), num_qo_heads, head_dim], dtype=test_dtype ) if layout == "NHD": - kv_data = torch.randn( - num_blocks, - 2, - page_size, - num_kv_heads, - head_dim, - dtype=test_dtype, - device=device, + kv_data = paddle.randn( + shape=[num_blocks, 2, page_size, num_kv_heads, head_dim], dtype=test_dtype ) elif layout == "HND": - kv_data = torch.randn( - num_blocks, - 2, - num_kv_heads, - page_size, - head_dim, - dtype=test_dtype, - device=device, + kv_data = paddle.randn( + shape=[num_blocks, 2, num_kv_heads, page_size, head_dim], dtype=test_dtype ) - wrapper = flashinfer.BatchAttention(kv_layout=layout) wrapper.plan( q_indptr.to(device), kv_indptr.to(device), - torch.arange(num_blocks).int().to(device), + paddle.arange(end=num_blocks).astype(dtype="int32").to(device), seq_lens.to(device), num_qo_heads, num_kv_heads, @@ -87,21 +71,13 @@ def profile_persistent_batch_attention( kv_data_type=test_dtype, use_profiler=True, ) - - profiler_buffer = torch.zeros( - (profiler_buffer_size,), dtype=torch.uint64, device=device - ) - - # warmup +>>>>>> profiler_buffer = paddle.zeros(shape=(profiler_buffer_size,), dtype=torch.uint64) wrapper.run(q, kv_data, profiler_buffer=profiler_buffer) profiler_buffer.zero_() - wrapper.run(q, kv_data, profiler_buffer=profiler_buffer) - trace_name = "batch_attention.perfetto-trace" events = ["prefill", "decode", "reduction"] export_to_perfetto_trace(profiler_buffer, events, trace_name) - print(f"Profile trace exported to {trace_name}") @@ -117,47 +93,32 @@ def persistent_batch_attention( causal, device="cuda", ): - seq_lens = torch.tensor(kv_lens, dtype=torch.int32) - q_lens = torch.tensor(qo_lens, dtype=torch.int32) - - seq_lens_blocks = torch.ceil(seq_lens / page_size).int() - - q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int() - kv_indptr = torch.cat( - [torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0 - ).int() - + seq_lens = paddle.to_tensor(data=kv_lens, dtype="int32") + q_lens = paddle.to_tensor(data=qo_lens, dtype="int32") + seq_lens_blocks = paddle.ceil(x=seq_lens / page_size).astype(dtype="int32") + q_indptr = paddle.concat( + x=[paddle.to_tensor(data=[0]), paddle.cumsum(x=q_lens, axis=0)], axis=0 + ).astype(dtype="int32") + kv_indptr = paddle.concat( + x=[paddle.to_tensor(data=[0]), paddle.cumsum(x=seq_lens_blocks, axis=0)], axis=0 + ).astype(dtype="int32") num_blocks = kv_indptr[-1].item() - - q = torch.rand( - q_indptr[-1].item(), num_qo_heads, head_dim, device=device, dtype=test_dtype + q = paddle.rand( + shape=[q_indptr[-1].item(), num_qo_heads, head_dim], dtype=test_dtype ) if layout == "NHD": - kv_data = torch.randn( - num_blocks, - 2, - page_size, - num_kv_heads, - head_dim, - dtype=test_dtype, - device=device, + kv_data = paddle.randn( + shape=[num_blocks, 2, page_size, num_kv_heads, head_dim], dtype=test_dtype ) elif layout == "HND": - kv_data = torch.randn( - num_blocks, - 2, - num_kv_heads, - page_size, - head_dim, - dtype=test_dtype, - device=device, + kv_data = paddle.randn( + shape=[num_blocks, 2, num_kv_heads, page_size, head_dim], dtype=test_dtype ) - wrapper = flashinfer.BatchAttention(kv_layout=layout) wrapper.plan( q_indptr.to(device), kv_indptr.to(device), - torch.arange(num_blocks).int().to(device), + paddle.arange(end=num_blocks).astype(dtype="int32").to(device), seq_lens.to(device), num_qo_heads, num_kv_heads, @@ -176,20 +137,16 @@ def persistent_batch_attention( parser.add_argument("--profiler-buffer-size", type=int, default=1048576) parser.add_argument("--use-profiler", action="store_true") args = parser.parse_args() - seq_len_config = [(600, 1)] * 122 + [(10000, 17)] * 8 - kv_lens = [p[0] for p in seq_len_config] qo_lens = [p[1] for p in seq_len_config] - page_size = 1 num_kv_heads = 4 num_qo_heads = 28 head_dim = 128 layout = "NHD" - test_dtype = torch.bfloat16 + test_dtype = "bfloat16" causal = True - if args.use_profiler: profile_persistent_batch_attention( kv_lens=kv_lens, diff --git a/profiler/mla.py b/profiler/mla.py index 7c9867d11e..d3119ec428 100644 --- a/profiler/mla.py +++ b/profiler/mla.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,11 +15,8 @@ See the License for the specific language governing permissions and limitations under the License. """ - import argparse -import torch - import flashinfer from flashinfer.profiler import export_to_perfetto_trace @@ -28,27 +27,27 @@ def profile_deepseek_mla_decode( head_dim_ckv = 512 head_dim_kpe = 64 page_size = 1 - q_nope = torch.randn( - batch_size * 1, num_heads, head_dim_ckv, dtype=torch.half, device="cuda" - ) - q_pe = torch.zeros( - batch_size * 1, num_heads, head_dim_kpe, dtype=torch.half, device="cuda" - ) - ckv = torch.randn( - batch_size * seq_len, 1, head_dim_ckv, dtype=torch.half, device="cuda" + q_nope = paddle.randn( + shape=[batch_size * 1, num_heads, head_dim_ckv], dtype="float16" ) - kpe = torch.zeros( - batch_size * seq_len, 1, head_dim_kpe, dtype=torch.half, device="cuda" + q_pe = paddle.zeros( + shape=[batch_size * 1, num_heads, head_dim_kpe], dtype="float16" ) - sm_scale = 1.0 / ((head_dim_ckv + head_dim_kpe) ** 0.5) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + ckv = paddle.randn(shape=[batch_size * seq_len, 1, head_dim_ckv], dtype="float16") + kpe = paddle.zeros(shape=[batch_size * seq_len, 1, head_dim_kpe], dtype="float16") + sm_scale = 1.0 / (head_dim_ckv + head_dim_kpe) ** 0.5 + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8").to(0) wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( workspace_buffer, backend=backend ) - q_indptr = torch.arange(0, batch_size + 1).to(0).int() - kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * seq_len - kv_indices = torch.arange(0, batch_size * seq_len).to(0).int() - kv_lens = torch.full((batch_size,), seq_len, dtype=torch.int32).to(0) + q_indptr = paddle.arange(start=0, end=batch_size + 1).to(0).astype(dtype="int32") + kv_indptr = ( + paddle.arange(start=0, end=batch_size + 1).to(0).astype(dtype="int32") * seq_len + ) + kv_indices = ( + paddle.arange(start=0, end=batch_size * seq_len).to(0).astype(dtype="int32") + ) + kv_lens = paddle.full(shape=(batch_size,), fill_value=seq_len, dtype="int32").to(0) wrapper.plan( q_indptr, kv_indptr, @@ -58,26 +57,20 @@ def profile_deepseek_mla_decode( head_dim_ckv, head_dim_kpe, page_size, - False, # causal + False, sm_scale, q_nope.dtype, ckv.dtype, use_profiler=True, ) - profiler_buffer = torch.zeros( - (profiler_buffer_size,), dtype=torch.uint64, device="cuda" - ) - # warmup run +>>>>>> profiler_buffer = paddle.zeros(shape=(profiler_buffer_size,), dtype=torch.uint64) _o = wrapper.run( q_nope, q_pe, ckv, kpe, return_lse=False, profiler_buffer=profiler_buffer ) profiler_buffer.zero_() - - # run wrapper.run( q_nope, q_pe, ckv, kpe, return_lse=False, profiler_buffer=profiler_buffer ) - export_to_perfetto_trace( profiler_buffer, [ diff --git a/scripts/update_whl_index.py b/scripts/update_whl_index.py index 4902d0cc7f..134c730c66 100644 --- a/scripts/update_whl_index.py +++ b/scripts/update_whl_index.py @@ -6,7 +6,7 @@ with open(path, "rb") as f: sha256 = hashlib.sha256(f.read()).hexdigest() ver, cu, torch = re.findall( - r"flashinfer_python-([0-9.]+(?:\.post[0-9]+)?)\+cu(\d+)torch([0-9.]+)-", + "flashinfer_python-([0-9.]+(?:\\.post[0-9]+)?)\\+cu(\\d+)torch([0-9.]+)-", path.name, )[0] index_dir = pathlib.Path(f"flashinfer-whl/cu{cu}/torch{torch}/flashinfer-python") diff --git a/setup.py b/setup.py index 9b62a0ef33..398abc8d6c 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,8 @@ +import os + +import paddle +import setuptools + """ Copyright (c) 2023 by FlashInfer team. @@ -13,17 +18,12 @@ See the License for the specific language governing permissions and limitations under the License. """ - -import os import platform import re import subprocess from pathlib import Path from typing import List, Mapping -import setuptools -from setuptools.dist import Distribution - root = Path(__file__).parent.resolve() aot_ops_package_dir = root / "build" / "aot-ops-package-dir" enable_aot = aot_ops_package_dir.is_dir() and any(aot_ops_package_dir.iterdir()) @@ -55,7 +55,7 @@ def generate_build_meta(aot_build_meta: dict) -> None: cmdclass: Mapping[str, type[setuptools.Command]] = {} install_requires = [ "numpy", - "torch", + # "paddlepaddle", # Manually installed by user "ninja", "requests", "cuda-python<=12.9", @@ -65,26 +65,25 @@ def generate_build_meta(aot_build_meta: dict) -> None: "nvidia-cudnn-frontend>=1.13.0", ] generate_build_meta({}) - if enable_aot: - import torch - import torch.utils.cpp_extension as torch_cpp_ext + pass from packaging.version import Version def get_cuda_version() -> Version: - if torch_cpp_ext.CUDA_HOME is None: + if paddle.utils.cpp_extension.cpp_extension.CUDA_HOME is None: nvcc = "nvcc" else: - nvcc = os.path.join(torch_cpp_ext.CUDA_HOME, "bin/nvcc") + nvcc = os.path.join( + paddle.utils.cpp_extension.cpp_extension.CUDA_HOME, "bin/nvcc" + ) txt = subprocess.check_output([nvcc, "--version"], text=True) - return Version(re.findall(r"release (\d+\.\d+),", txt)[0]) + return Version(re.findall("release (\\d+\\.\\d+),", txt)[0]) cuda_version = get_cuda_version() - torch_full_version = Version(torch.__version__) + torch_full_version = Version(paddle.__version__) torch_version = f"{torch_full_version.major}.{torch_full_version.minor}" install_requires = [req for req in install_requires if not req.startswith("torch ")] install_requires.append(f"torch == {torch_version}.*") - aot_build_meta = {} aot_build_meta["cuda_major"] = cuda_version.major aot_build_meta["cuda_minor"] = cuda_version.minor @@ -94,7 +93,7 @@ def get_cuda_version() -> Version: generate_build_meta(aot_build_meta) -class AotDistribution(Distribution): +class AotDistribution(setuptools.dist.Distribution): def has_ext_modules(self) -> bool: return enable_aot diff --git a/tests/alibi_reference.py b/tests/alibi_reference.py index dd03a359de..07c6f0a1b2 100644 --- a/tests/alibi_reference.py +++ b/tests/alibi_reference.py @@ -1,3 +1,5 @@ +import paddle + """ Attention with Linear Biases (ALiBi) reference implementation. @@ -11,55 +13,39 @@ - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/285cb3735bde02fbc8c19ddeb24d0ae7e77135c1/labml_nn/transformers/mha.py - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/285cb3735bde02fbc8c19ddeb24d0ae7e77135c1/labml_nn/transformers/alibi/__init__.py """ - import math from typing import Optional -import torch - def get_slopes(n_heads: int): - r""" + """ ## Get head-specific slope $m$ for each head * `n_heads` is the number of heads in the attention layer $n$ The slope for first head is - $$\frac{1}{2^{\frac{8}{n}}} = 2^{-\frac{8}{n}}$$ + $$\\frac{1}{2^{\\frac{8}{n}}} = 2^{-\\frac{8}{n}}$$ The slopes for the rest of the heads are in a geometric series with a ratio same as above. For instance when the number of heads is $8$ the slopes are - $$\frac{1}{2^1}, \frac{1}{2^2}, \dots, \frac{1}{2^8}$$ + $$\\frac{1}{2^1}, \\frac{1}{2^2}, \\dots, \\frac{1}{2^8}$$ """ - - # Get the closest power of 2 to `n_heads`. - # If `n_heads` is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2, - # and then add the remaining slopes. n = 2 ** math.floor(math.log2(n_heads)) - # $2^{-\frac{8}{n}}$ m_0 = 2.0 ** (-8.0 / n) - # $2^{-1\frac{8}{n}}, 2^{-2 \frac{8}{n}}, 2^{-3 \frac{8}{n}}, \dots$ - m = torch.pow(m_0, torch.arange(1, 1 + n)) - - # If `n_heads` is not a power of 2, then we add the remaining slopes. - # We calculate the remaining slopes for $n * 2$ (avoiding slopes added previously). - # And pick the slopes upto `n_heads`. + m = paddle.pow(x=m_0, y=paddle.arange(start=1, end=1 + n)) if n < n_heads: - # $2^{-\frac{8}{2n}}$ m_hat_0 = 2.0 ** (-4.0 / n) - # $2^{-1\frac{8}{2n}}, 2^{-3 \frac{8}{2n}}, 2^{-5 \frac{8}{2n}}, \dots$ - # Note that we take steps by $2$ to avoid slopes added previously. - m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2)) - # Concatenate the slopes with the remaining slopes. - m = torch.cat([m, m_hat]) - + m_hat = paddle.pow( + x=m_hat_0, y=paddle.arange(start=1, end=1 + 2 * (n_heads - n), step=2) + ) + m = paddle.concat(x=[m, m_hat]) return m -@torch.no_grad() -def get_alibi_biases(n_heads: int, mask: torch.Tensor): +@paddle.no_grad() +def get_alibi_biases(n_heads: int, mask: paddle.Tensor): """ ## Calculate the attention biases matrix @@ -68,27 +54,16 @@ def get_alibi_biases(n_heads: int, mask: torch.Tensor): This returns a matrix of shape `[seq_len_q, seq_len_k, n_heads, ]` with ALiBi attention biases. """ - - # Get slopes $m$ for each head - m = get_slopes(n_heads).to(mask.device) - - # Calculate distances $[0, 1, \dots, N]$ - # Here we calculate the distances using the mask. - # - # Since it's causal mask we can just use $[0, 1, \dots, N]$ too. - distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[ - None, : - ] - - # Multiply them pair-wise to get the AliBi bias matrix + m = get_slopes(n_heads).to(mask.place) + distance = paddle.arange(dtype="int64", end=tuple(mask.shape)[1])[None, :] return distance[:, :, None] * m[None, None, :] def alibi_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor] = None, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, ): """ query: [q_len, num_heads, head_dim] @@ -96,27 +71,13 @@ def alibi_attention( value: [kv_len, num_heads, head_dim] mask: [q_len, kv_len] """ - q_len, num_heads, head_dim = query.shape - - scores = torch.einsum("qhd,khd->qkh", query.float(), key.float()) - # Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$ + q_len, num_heads, head_dim = tuple(query.shape) + scores = paddle.einsum( + "qhd,khd->qkh", query.astype(dtype="float32"), key.astype(dtype="float32") + ) scores *= 1.0 / math.sqrt(head_dim) - - # Create AliBi biases if it's not cached alibi_biases = get_alibi_biases(num_heads, mask) - - # Add AliBi biases to attention scores. - # ALiBi biases has shape `[seq_len, seq_len, n_heads]` - # and `scores` has shape `[seq_len, seq_len, batch_size, n_heads]` scores += alibi_biases - - # Apply mask - scores = scores.masked_fill(mask.unsqueeze(-1) == 0, float("-inf")) - - # $softmax$ attention along the key sequence dimension - # $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$ - attn = torch.softmax(scores, dim=1) - - # Multiply by values - # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ - return torch.einsum("ovh,vhd->ohd", attn, value.float()).to(query) + scores = scores.masked_fill(mask=mask.unsqueeze(axis=-1) == 0, value=float("-inf")) + attn = paddle.nn.functional.softmax(x=scores, axis=1) + return paddle.einsum("ovh,vhd->ohd", attn, value.astype(dtype="float32")).to(query) diff --git a/tests/conftest.py b/tests/conftest.py index 97ef03e381..b1bbc40c00 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,15 @@ +import sys + + import functools import gc import os import types from typing import Any, Dict +import paddle import pytest -import torch -from torch.torch_version import TorchVersion -from torch.torch_version import __version__ as torch_version +from flashinfer.paddle_utils import * import flashinfer @@ -53,30 +55,22 @@ flashinfer.sampling.top_k_mask_logits, flashinfer.sampling.chain_speculative_sampling, ] - _TORCH_COMPILE_CACHE: Dict[str, Any] = dict() def _set_torch_compile_options(): - import torch._dynamo.config - - torch._dynamo.config.cache_size_limit = 128 +# >>>>>> torch._dynamo.config.cache_size_limit = 128 + pass def _monkeypatch_add_torch_compile(func): """ Replace the given function with its torch.compile version. """ - - from torch._library.custom_ops import CustomOpDef - if type(func) is types.FunctionType: fn = func - elif isinstance(func, CustomOpDef): - fn = func._init_fn else: - raise ValueError(f"Unsupported fn type {type(func)}") - + return fullname = fn.__module__ + "." + fn.__qualname__ components = fullname.split(".") assert components[0] == "flashinfer" @@ -87,30 +81,7 @@ def _monkeypatch_add_torch_compile(func): raise ValueError(f"Failed to monkeypatch: {fullname}") def wrapper(*args, **kwargs): - compiled = _TORCH_COMPILE_CACHE.get(fullname) - if compiled is None: - # Warmup -- JIT compile / import the kernels. - # - # From user side, users also need to warmup the model beforehand, - # as suggested by PyTorch Cuda Graph docs (not sure if it's also - # recommended for torch.compile as well.) - # - # For the convenience of FlashInfer testing, we do the warmup here, - # on the first run of the function. The caveat is that the first - # call will run twice: once to warmup, and another through the - # compiled version. - func(*args, **kwargs) - - # Compile - compiled = torch.compile( - func, - fullgraph=True, - backend="inductor", - mode="max-autotune-no-cudagraphs", - ) - _TORCH_COMPILE_CACHE[fn.__name__] = compiled - - return compiled(*args, **kwargs) + return fn(*args, **kwargs) setattr(module, fn.__name__, wrapper) print("Applied torch.compile to", fullname) @@ -118,8 +89,6 @@ def wrapper(*args, **kwargs): def pytest_configure(config): if os.environ.get("FLASHINFER_TEST_TORCH_COMPILE", "0") == "1": - if torch_version < TorchVersion("2.4"): - pytest.skip("torch.compile requires torch >= 2.4") _set_torch_compile_options() for fn in TORCH_COMPILE_FNS: _monkeypatch_add_torch_compile(fn) @@ -131,1752 +100,178 @@ def is_cuda_oom_error_str(e: str) -> bool: @pytest.hookimpl(tryfirst=True) def pytest_runtest_call(item): - # skip OOM error try: item.runtest() - except (torch.cuda.OutOfMemoryError, RuntimeError) as e: - if isinstance(e, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(str(e)): + except RuntimeError as e: + if is_cuda_oom_error_str(str(e)): pytest.skip("Skipping due to OOM") else: raise - @functools.cache -def get_device_properties(device: torch.device): - return torch.cuda.get_device_properties(device) +def get_device_properties(device: str): + return paddle.device.cuda.get_device_properties(device=device2str(device)) -def clear_cuda_cache(device: torch.device) -> None: +def clear_cuda_cache(device: str) -> None: total_memory = get_device_properties(device).total_memory - reserved_memory = torch.cuda.memory_reserved() - - # FLASHINFER_TEST_MEMORY_THRESHOLD: threshold for PyTorch reserved memory usage (default: 0.9) + reserved_memory = paddle.device.cuda.memory_reserved() threshold = float(os.environ.get("FLASHINFER_TEST_MEMORY_THRESHOLD", "0.9")) - if reserved_memory > threshold * total_memory: gc.collect() - torch.cuda.empty_cache() + paddle.device.cuda.empty_cache() -# collected from gsk8k trace in sglang VARLEN_INDPTR_PARAMS = [ - [ - 0, - 1276, - 2551, - 3838, - 5115, - 6428, - 7705, - 8985, - 10293, - 11607, - 12909, - 14216, - 15508, - ], - [ - 0, - 1320, - 2637, - 3926, - 5208, - 6494, - 7795, - 9130, - 10415, - 11704, - 12995, - 14270, - 15551, - ], - [ - 0, - 1333, - 2762, - 4068, - 5345, - 6635, - 7936, - 9237, - 10518, - 11874, - 13146, - 14466, - 15785, - ], - [ - 0, - 1310, - 2603, - 3937, - 5231, - 6527, - 7799, - 9107, - 10411, - 11698, - 12978, - 14277, - 15559, - ], - [ - 0, - 1286, - 2561, - 3857, - 5163, - 6459, - 7763, - 9057, - 10380, - 11679, - 12968, - 14284, - 15610, - ], + [0, 1276, 2551, 3838, 5115, 6428, 7705, 8985, 10293, 11607, 12909, 14216, 15508], + [0, 1320, 2637, 3926, 5208, 6494, 7795, 9130, 10415, 11704, 12995, 14270, 15551], + [0, 1333, 2762, 4068, 5345, 6635, 7936, 9237, 10518, 11874, 13146, 14466, 15785], + [0, 1310, 2603, 3937, 5231, 6527, 7799, 9107, 10411, 11698, 12978, 14277, 15559], + [0, 1286, 2561, 3857, 5163, 6459, 7763, 9057, 10380, 11679, 12968, 14284, 15610], [0, 1350, 2667, 4003, 5347, 6631, 7919, 9208, 10524], - [ - 0, - 1293, - 2609, - 3902, - 5196, - 6495, - 7807, - 9086, - 10382, - 11700, - 12989, - 14271, - 15578, - ], - [ - 0, - 1276, - 2551, - 3838, - 5115, - 6428, - 7705, - 8985, - 10293, - 11607, - 12909, - 14216, - 15508, - ], - [ - 0, - 1280, - 2559, - 3874, - 5197, - 6540, - 7850, - 9167, - 10536, - 11820, - 13111, - 14444, - 15756, - ], - [ - 0, - 1306, - 2598, - 3895, - 5181, - 6458, - 7750, - 9047, - 10333, - 11635, - 12906, - 14207, - 15514, - ], - [ - 0, - 1300, - 2620, - 3912, - 5219, - 6500, - 7781, - 9069, - 10386, - 11665, - 12976, - 14262, - 15545, - ], + [0, 1293, 2609, 3902, 5196, 6495, 7807, 9086, 10382, 11700, 12989, 14271, 15578], + [0, 1276, 2551, 3838, 5115, 6428, 7705, 8985, 10293, 11607, 12909, 14216, 15508], + [0, 1280, 2559, 3874, 5197, 6540, 7850, 9167, 10536, 11820, 13111, 14444, 15756], + [0, 1306, 2598, 3895, 5181, 6458, 7750, 9047, 10333, 11635, 12906, 14207, 15514], + [0, 1300, 2620, 3912, 5219, 6500, 7781, 9069, 10386, 11665, 12976, 14262, 15545], [0, 1800, 3600, 5400, 7200, 9000, 10799, 12620, 14441, 16261], [0, 1298, 2638], [0, 1284], [0, 1297, 2604], [0, 1276], - [ - 0, - 1286, - 2614, - 3909, - 5207, - 6490, - 7785, - 9067, - 10356, - 11633, - 12915, - 14231, - 15511, - ], - [ - 0, - 1312, - 2613, - 3899, - 5203, - 6492, - 7811, - 9151, - 10439, - 11757, - 13052, - 14364, - 15646, - ], + [0, 1286, 2614, 3909, 5207, 6490, 7785, 9067, 10356, 11633, 12915, 14231, 15511], + [0, 1312, 2613, 3899, 5203, 6492, 7811, 9151, 10439, 11757, 13052, 14364, 15646], [0, 1287], - [ - 0, - 1353, - 2684, - 4039, - 5326, - 6615, - 7932, - 9217, - 10528, - 11862, - 13207, - 14490, - 15785, - ], + [0, 1353, 2684, 4039, 5326, 6615, 7932, 9217, 10528, 11862, 13207, 14490, 15785], [0, 1307], - [ - 0, - 1301, - 2587, - 3949, - 5263, - 6620, - 7933, - 9226, - 10521, - 11895, - 13179, - 14531, - 15822, - ], - [ - 0, - 1334, - 2618, - 3914, - 5198, - 6476, - 7771, - 9076, - 10362, - 11675, - 12974, - 14264, - 15540, - ], - [ - 0, - 1323, - 2606, - 3900, - 5209, - 6487, - 7770, - 9074, - 10397, - 11694, - 13047, - 14411, - 15719, - ], - [ - 0, - 1286, - 2563, - 3845, - 5181, - 6471, - 7789, - 9087, - 10394, - 11750, - 13021, - 14335, - 15616, - ], - [ - 0, - 1310, - 2605, - 3889, - 5214, - 6518, - 7794, - 9143, - 10465, - 11802, - 13134, - 14438, - 15713, - ], + [0, 1301, 2587, 3949, 5263, 6620, 7933, 9226, 10521, 11895, 13179, 14531, 15822], + [0, 1334, 2618, 3914, 5198, 6476, 7771, 9076, 10362, 11675, 12974, 14264, 15540], + [0, 1323, 2606, 3900, 5209, 6487, 7770, 9074, 10397, 11694, 13047, 14411, 15719], + [0, 1286, 2563, 3845, 5181, 6471, 7789, 9087, 10394, 11750, 13021, 14335, 15616], + [0, 1310, 2605, 3889, 5214, 6518, 7794, 9143, 10465, 11802, 13134, 14438, 15713], [0, 1285, 2596, 3877, 5163, 6487, 7782, 9104, 10403], - [ - 0, - 1299, - 2594, - 3921, - 5222, - 6494, - 7777, - 9098, - 10406, - 11736, - 13026, - 14317, - 15621, - ], + [0, 1299, 2594, 3921, 5222, 6494, 7777, 9098, 10406, 11736, 13026, 14317, 15621], [0, 1268, 2560], [0, 1536, 3061, 4578, 6177, 7774, 9378, 10958, 12636, 14292, 15954], [0, 1240], - [ - 0, - 1362, - 2653, - 3930, - 5201, - 6505, - 7808, - 9094, - 10421, - 11720, - 12994, - 14285, - 15584, - ], + [0, 1362, 2653, 3930, 5201, 6505, 7808, 9094, 10421, 11720, 12994, 14285, 15584], [0, 1676, 3342, 5094, 6842, 8582, 10342, 12102, 13861, 15618], - [ - 0, - 1329, - 2656, - 3977, - 5253, - 6532, - 7831, - 9150, - 10444, - 11749, - 13090, - 14388, - 15675, - ], - [ - 0, - 1284, - 2578, - 3854, - 5189, - 6513, - 7809, - 9144, - 10463, - 11772, - 13062, - 14368, - 15641, - ], - [ - 0, - 1286, - 2651, - 3960, - 5286, - 6573, - 7857, - 9146, - 10445, - 11723, - 12995, - 14270, - 15537, - ], + [0, 1329, 2656, 3977, 5253, 6532, 7831, 9150, 10444, 11749, 13090, 14388, 15675], + [0, 1284, 2578, 3854, 5189, 6513, 7809, 9144, 10463, 11772, 13062, 14368, 15641], + [0, 1286, 2651, 3960, 5286, 6573, 7857, 9146, 10445, 11723, 12995, 14270, 15537], [0, 1716, 3555, 5298, 7041, 8781], [0, 1281, 2574, 3866, 5176], - [ - 0, - 1331, - 2655, - 3965, - 5318, - 6606, - 7954, - 9238, - 10526, - 11837, - 13134, - 14436, - 15728, - ], + [0, 1331, 2655, 3965, 5318, 6606, 7954, 9238, 10526, 11837, 13134, 14436, 15728], [0, 1754, 3508, 5259, 7028, 8796, 10564, 12332, 14097, 15862], - [ - 0, - 1293, - 2584, - 3871, - 5157, - 6466, - 7831, - 9118, - 10444, - 11728, - 13017, - 14295, - 15594, - ], - [ - 0, - 1279, - 2586, - 3876, - 5170, - 6440, - 7768, - 9067, - 10351, - 11651, - 12936, - 14239, - 15542, - ], - [ - 0, - 1293, - 2563, - 3861, - 5139, - 6491, - 7776, - 9121, - 10422, - 11731, - 13033, - 14338, - 15639, - ], - [ - 0, - 1292, - 2632, - 3933, - 5257, - 6576, - 7881, - 9178, - 10455, - 11796, - 13095, - 14385, - 15685, - ], + [0, 1293, 2584, 3871, 5157, 6466, 7831, 9118, 10444, 11728, 13017, 14295, 15594], + [0, 1279, 2586, 3876, 5170, 6440, 7768, 9067, 10351, 11651, 12936, 14239, 15542], + [0, 1293, 2563, 3861, 5139, 6491, 7776, 9121, 10422, 11731, 13033, 14338, 15639], + [0, 1292, 2632, 3933, 5257, 6576, 7881, 9178, 10455, 11796, 13095, 14385, 15685], [0, 1307], - [ - 0, - 1307, - 2590, - 3897, - 5206, - 6527, - 7826, - 9104, - 10400, - 11696, - 13022, - 14326, - 15615, - ], - [ - 0, - 1287, - 2597, - 3933, - 5275, - 6555, - 7835, - 9153, - 10445, - 11729, - 13019, - 14303, - 15608, - ], - [ - 0, - 1294, - 2589, - 3904, - 5205, - 6504, - 7803, - 9087, - 10375, - 11671, - 12970, - 14279, - 15615, - ], - [ - 0, - 1312, - 2624, - 3957, - 5243, - 6533, - 7817, - 9095, - 10377, - 11729, - 13053, - 14332, - 15643, - ], - [ - 0, - 1278, - 2579, - 3858, - 5147, - 6461, - 7745, - 9038, - 10376, - 11654, - 12941, - 14265, - 15592, - ], - [ - 0, - 1284, - 2578, - 3855, - 5181, - 6475, - 7787, - 9090, - 10386, - 11661, - 13010, - 14291, - 15595, - ], - [ - 0, - 1293, - 2604, - 3893, - 5211, - 6526, - 7859, - 9139, - 10439, - 11723, - 13071, - 14369, - 15669, - ], - [ - 0, - 1333, - 2618, - 3892, - 5196, - 6478, - 7778, - 9088, - 10381, - 11677, - 12986, - 14276, - 15552, - ], - [ - 0, - 1287, - 2569, - 3876, - 5156, - 6463, - 7754, - 9053, - 10363, - 11642, - 12946, - 14230, - 15501, - ], - [ - 0, - 1293, - 2585, - 3900, - 5183, - 6502, - 7882, - 9185, - 10466, - 11732, - 13017, - 14324, - 15612, - ], - [ - 0, - 1323, - 2597, - 3877, - 5183, - 6483, - 7793, - 9084, - 10417, - 11719, - 12999, - 14294, - 15587, - ], - [ - 0, - 1289, - 2626, - 3900, - 5217, - 6550, - 7820, - 9140, - 10431, - 11717, - 13028, - 14312, - 15615, - ], - [ - 0, - 1306, - 2588, - 3890, - 5192, - 6540, - 7858, - 9170, - 10492, - 11772, - 13051, - 14348, - 15636, - ], - [ - 0, - 1279, - 2583, - 3892, - 5193, - 6481, - 7788, - 9099, - 10394, - 11701, - 13026, - 14348, - 15710, - ], - [ - 0, - 1287, - 2599, - 3939, - 5223, - 6523, - 7822, - 9102, - 10435, - 11714, - 13006, - 14294, - 15622, - ], - [ - 0, - 1302, - 2631, - 3913, - 5192, - 6503, - 7804, - 9121, - 10429, - 11757, - 13064, - 14379, - 15656, - ], - [ - 0, - 1278, - 2569, - 3914, - 5211, - 6480, - 7805, - 9089, - 10383, - 11687, - 12971, - 14281, - 15605, - ], - [ - 0, - 1278, - 2559, - 3834, - 5144, - 6434, - 7754, - 9033, - 10330, - 11607, - 12925, - 14218, - 15510, - ], + [0, 1307, 2590, 3897, 5206, 6527, 7826, 9104, 10400, 11696, 13022, 14326, 15615], + [0, 1287, 2597, 3933, 5275, 6555, 7835, 9153, 10445, 11729, 13019, 14303, 15608], + [0, 1294, 2589, 3904, 5205, 6504, 7803, 9087, 10375, 11671, 12970, 14279, 15615], + [0, 1312, 2624, 3957, 5243, 6533, 7817, 9095, 10377, 11729, 13053, 14332, 15643], + [0, 1278, 2579, 3858, 5147, 6461, 7745, 9038, 10376, 11654, 12941, 14265, 15592], + [0, 1284, 2578, 3855, 5181, 6475, 7787, 9090, 10386, 11661, 13010, 14291, 15595], + [0, 1293, 2604, 3893, 5211, 6526, 7859, 9139, 10439, 11723, 13071, 14369, 15669], + [0, 1333, 2618, 3892, 5196, 6478, 7778, 9088, 10381, 11677, 12986, 14276, 15552], + [0, 1287, 2569, 3876, 5156, 6463, 7754, 9053, 10363, 11642, 12946, 14230, 15501], + [0, 1293, 2585, 3900, 5183, 6502, 7882, 9185, 10466, 11732, 13017, 14324, 15612], + [0, 1323, 2597, 3877, 5183, 6483, 7793, 9084, 10417, 11719, 12999, 14294, 15587], + [0, 1289, 2626, 3900, 5217, 6550, 7820, 9140, 10431, 11717, 13028, 14312, 15615], + [0, 1306, 2588, 3890, 5192, 6540, 7858, 9170, 10492, 11772, 13051, 14348, 15636], + [0, 1279, 2583, 3892, 5193, 6481, 7788, 9099, 10394, 11701, 13026, 14348, 15710], + [0, 1287, 2599, 3939, 5223, 6523, 7822, 9102, 10435, 11714, 13006, 14294, 15622], + [0, 1302, 2631, 3913, 5192, 6503, 7804, 9121, 10429, 11757, 13064, 14379, 15656], + [0, 1278, 2569, 3914, 5211, 6480, 7805, 9089, 10383, 11687, 12971, 14281, 15605], + [0, 1278, 2559, 3834, 5144, 6434, 7754, 9033, 10330, 11607, 12925, 14218, 15510], [0, 1319], - [ - 0, - 1269, - 2564, - 3849, - 5130, - 6430, - 7740, - 9060, - 10409, - 11698, - 13001, - 14286, - 15557, - ], - [ - 0, - 1288, - 2592, - 3867, - 5214, - 6491, - 7793, - 9110, - 10416, - 11729, - 13020, - 14318, - 15625, - ], - [ - 0, - 1326, - 2643, - 3972, - 5270, - 6591, - 7872, - 9139, - 10437, - 11731, - 13031, - 14327, - 15633, - ], - [ - 0, - 1284, - 2560, - 3857, - 5134, - 6436, - 7728, - 9041, - 10345, - 11625, - 12940, - 14242, - 15530, - ], + [0, 1269, 2564, 3849, 5130, 6430, 7740, 9060, 10409, 11698, 13001, 14286, 15557], + [0, 1288, 2592, 3867, 5214, 6491, 7793, 9110, 10416, 11729, 13020, 14318, 15625], + [0, 1326, 2643, 3972, 5270, 6591, 7872, 9139, 10437, 11731, 13031, 14327, 15633], + [0, 1284, 2560, 3857, 5134, 6436, 7728, 9041, 10345, 11625, 12940, 14242, 15530], [0, 1299, 2576], - [ - 0, - 1296, - 2574, - 3866, - 5162, - 6448, - 7745, - 9020, - 10294, - 11588, - 12895, - 14218, - 15525, - ], - [ - 0, - 1279, - 2563, - 3875, - 5161, - 6461, - 7741, - 9023, - 10305, - 11613, - 12897, - 14204, - 15536, - ], - [ - 0, - 1273, - 2553, - 3848, - 5210, - 6493, - 7775, - 9058, - 10375, - 11695, - 12984, - 14278, - 15588, - ], - [ - 0, - 1283, - 2584, - 3863, - 5160, - 6444, - 7740, - 9061, - 10377, - 11698, - 12994, - 14274, - 15545, - ], - [ - 0, - 1329, - 2648, - 3962, - 5309, - 6622, - 7930, - 9242, - 10544, - 11828, - 13183, - 14476, - 15809, - ], - [ - 0, - 1290, - 2591, - 3891, - 5175, - 6460, - 7766, - 9112, - 10402, - 11701, - 13019, - 14330, - 15633, - ], - [ - 0, - 1333, - 2673, - 3958, - 5270, - 6589, - 7911, - 9203, - 10549, - 11841, - 13146, - 14471, - 15776, - ], - [ - 0, - 1288, - 2643, - 3945, - 5266, - 6595, - 7907, - 9213, - 10486, - 11807, - 13138, - 14430, - 15703, - ], - [ - 0, - 1306, - 2620, - 3944, - 5260, - 6569, - 7852, - 9144, - 10460, - 11785, - 13075, - 14368, - 15672, - ], - [ - 0, - 1294, - 2572, - 3851, - 5164, - 6464, - 7755, - 9090, - 10398, - 11688, - 13002, - 14313, - 15593, - ], - [ - 0, - 1340, - 2651, - 3959, - 5258, - 6545, - 7836, - 9157, - 10465, - 11772, - 13065, - 14368, - 15747, - ], - [ - 0, - 1325, - 2657, - 3935, - 5255, - 6583, - 7874, - 9154, - 10448, - 11732, - 13026, - 14344, - 15620, - ], + [0, 1296, 2574, 3866, 5162, 6448, 7745, 9020, 10294, 11588, 12895, 14218, 15525], + [0, 1279, 2563, 3875, 5161, 6461, 7741, 9023, 10305, 11613, 12897, 14204, 15536], + [0, 1273, 2553, 3848, 5210, 6493, 7775, 9058, 10375, 11695, 12984, 14278, 15588], + [0, 1283, 2584, 3863, 5160, 6444, 7740, 9061, 10377, 11698, 12994, 14274, 15545], + [0, 1329, 2648, 3962, 5309, 6622, 7930, 9242, 10544, 11828, 13183, 14476, 15809], + [0, 1290, 2591, 3891, 5175, 6460, 7766, 9112, 10402, 11701, 13019, 14330, 15633], + [0, 1333, 2673, 3958, 5270, 6589, 7911, 9203, 10549, 11841, 13146, 14471, 15776], + [0, 1288, 2643, 3945, 5266, 6595, 7907, 9213, 10486, 11807, 13138, 14430, 15703], + [0, 1306, 2620, 3944, 5260, 6569, 7852, 9144, 10460, 11785, 13075, 14368, 15672], + [0, 1294, 2572, 3851, 5164, 6464, 7755, 9090, 10398, 11688, 13002, 14313, 15593], + [0, 1340, 2651, 3959, 5258, 6545, 7836, 9157, 10465, 11772, 13065, 14368, 15747], + [0, 1325, 2657, 3935, 5255, 6583, 7874, 9154, 10448, 11732, 13026, 14344, 15620], [0, 1764, 3551, 5336, 7121, 8905, 10688, 12471, 14252, 16054], - [ - 0, - 1280, - 2590, - 3896, - 5187, - 6520, - 7822, - 9117, - 10397, - 11690, - 12977, - 14270, - 15561, - ], - [ - 0, - 1285, - 2577, - 3862, - 5198, - 6477, - 7762, - 9130, - 10412, - 11694, - 13049, - 14358, - 15666, - ], - [ - 0, - 1287, - 2617, - 3942, - 5240, - 6510, - 7807, - 9090, - 10390, - 11743, - 13031, - 14325, - 15615, - ], + [0, 1280, 2590, 3896, 5187, 6520, 7822, 9117, 10397, 11690, 12977, 14270, 15561], + [0, 1285, 2577, 3862, 5198, 6477, 7762, 9130, 10412, 11694, 13049, 14358, 15666], + [0, 1287, 2617, 3942, 5240, 6510, 7807, 9090, 10390, 11743, 13031, 14325, 15615], [0, 1310, 2584, 3990, 5291, 6598, 7908, 9192], - [ - 0, - 1304, - 2626, - 3930, - 5209, - 6499, - 7810, - 9109, - 10435, - 11731, - 13007, - 14307, - 15593, - ], - [ - 0, - 1308, - 2612, - 3927, - 5227, - 6515, - 7812, - 9146, - 10447, - 11731, - 13017, - 14317, - 15602, - ], + [0, 1304, 2626, 3930, 5209, 6499, 7810, 9109, 10435, 11731, 13007, 14307, 15593], + [0, 1308, 2612, 3927, 5227, 6515, 7812, 9146, 10447, 11731, 13017, 14317, 15602], [0, 1820, 3640, 5460, 7277, 9115, 10953, 12791, 14628], - [ - 0, - 1289, - 2594, - 3903, - 5196, - 6499, - 7799, - 9077, - 10386, - 11662, - 12959, - 14243, - 15543, - ], - [ - 0, - 1300, - 2601, - 3876, - 5165, - 6436, - 7725, - 9039, - 10352, - 11639, - 12927, - 14209, - 15490, - ], + [0, 1289, 2594, 3903, 5196, 6499, 7799, 9077, 10386, 11662, 12959, 14243, 15543], + [0, 1300, 2601, 3876, 5165, 6436, 7725, 9039, 10352, 11639, 12927, 14209, 15490], [0, 1837, 3674, 5206, 6693, 8229, 9790, 11329, 12910, 14474, 16037], - [ - 0, - 1292, - 2604, - 3878, - 5151, - 6453, - 7749, - 9033, - 10363, - 11703, - 13014, - 14301, - 15617, - ], - [ - 0, - 1275, - 2556, - 3843, - 5147, - 6427, - 7712, - 9003, - 10311, - 11600, - 12970, - 14264, - 15545, - ], - [ - 0, - 1285, - 2590, - 3878, - 5169, - 6527, - 7863, - 9161, - 10451, - 11745, - 13066, - 14382, - 15695, - ], + [0, 1292, 2604, 3878, 5151, 6453, 7749, 9033, 10363, 11703, 13014, 14301, 15617], + [0, 1275, 2556, 3843, 5147, 6427, 7712, 9003, 10311, 11600, 12970, 14264, 15545], + [0, 1285, 2590, 3878, 5169, 6527, 7863, 9161, 10451, 11745, 13066, 14382, 15695], [0, 1340, 2635], - [ - 0, - 1314, - 2600, - 3894, - 5194, - 6490, - 7797, - 9105, - 10385, - 11667, - 12967, - 14255, - 15550, - ], - [ - 0, - 1308, - 2605, - 3956, - 5254, - 6582, - 7865, - 9160, - 10459, - 11758, - 13045, - 14341, - 15623, - ], - [ - 0, - 1282, - 2576, - 3882, - 5190, - 6510, - 7819, - 9142, - 10427, - 11736, - 13041, - 14359, - 15683, - ], + [0, 1314, 2600, 3894, 5194, 6490, 7797, 9105, 10385, 11667, 12967, 14255, 15550], + [0, 1308, 2605, 3956, 5254, 6582, 7865, 9160, 10459, 11758, 13045, 14341, 15623], + [0, 1282, 2576, 3882, 5190, 6510, 7819, 9142, 10427, 11736, 13041, 14359, 15683], [0, 1300, 2614, 3924], - [ - 0, - 1282, - 2600, - 3923, - 5229, - 6580, - 7952, - 9295, - 10593, - 11873, - 13161, - 14458, - 15756, - ], - [ - 0, - 1286, - 2578, - 3884, - 5184, - 6494, - 7779, - 9078, - 10356, - 11677, - 12976, - 14256, - 15560, - ], - [ - 0, - 1303, - 2575, - 3848, - 5119, - 6417, - 7714, - 9020, - 10362, - 11668, - 12983, - 14314, - 15599, - ], + [0, 1282, 2600, 3923, 5229, 6580, 7952, 9295, 10593, 11873, 13161, 14458, 15756], + [0, 1286, 2578, 3884, 5184, 6494, 7779, 9078, 10356, 11677, 12976, 14256, 15560], + [0, 1303, 2575, 3848, 5119, 6417, 7714, 9020, 10362, 11668, 12983, 14314, 15599], [0, 1291, 2584], - [ - 0, - 1299, - 2617, - 3938, - 5328, - 6600, - 7885, - 9163, - 10489, - 11771, - 13053, - 14332, - 15691, - ], + [0, 1299, 2617, 3938, 5328, 6600, 7885, 9163, 10489, 11771, 13053, 14332, 15691], [0, 1305, 2617], [0, 1573, 3257, 4935, 6605, 8256, 9906, 11529, 13171, 14809], - [ - 0, - 1299, - 2591, - 3885, - 5165, - 6445, - 7744, - 9111, - 10413, - 11725, - 13000, - 14304, - 15614, - ], + [0, 1299, 2591, 3885, 5165, 6445, 7744, 9111, 10413, 11725, 13000, 14304, 15614], [0, 1296], - [ - 0, - 1295, - 2570, - 3912, - 5252, - 6527, - 7806, - 9121, - 10408, - 11710, - 12988, - 14270, - 15585, - ], - [ - 0, - 1285, - 2621, - 3937, - 5235, - 6506, - 7790, - 9085, - 10352, - 11630, - 12949, - 14247, - 15528, - ], - [ - 0, - 1297, - 2575, - 3868, - 5146, - 6436, - 7775, - 9066, - 10376, - 11708, - 13005, - 14365, - 15649, - ], - [ - 0, - 1322, - 2638, - 3920, - 5217, - 6522, - 7801, - 9113, - 10472, - 11769, - 13046, - 14372, - 15668, - ], - [ - 0, - 1272, - 2539, - 3871, - 5146, - 6471, - 7791, - 9069, - 10360, - 11688, - 12968, - 14262, - 15580, - ], - [ - 0, - 1322, - 2642, - 3933, - 5229, - 6538, - 7823, - 9126, - 10432, - 11734, - 13089, - 14372, - 15678, - ], - [ - 0, - 1310, - 2658, - 3987, - 5316, - 6608, - 7878, - 9171, - 10463, - 11757, - 13060, - 14356, - 15660, - ], - [ - 0, - 1318, - 2640, - 3924, - 5237, - 6546, - 7832, - 9138, - 10462, - 11762, - 13046, - 14341, - 15609, - ], - [ - 0, - 1280, - 2558, - 3850, - 5191, - 6495, - 7820, - 9113, - 10401, - 11717, - 13040, - 14314, - 15614, - ], - [ - 0, - 1313, - 2596, - 3908, - 5249, - 6542, - 7843, - 9141, - 10456, - 11739, - 13039, - 14348, - 15699, - ], + [0, 1295, 2570, 3912, 5252, 6527, 7806, 9121, 10408, 11710, 12988, 14270, 15585], + [0, 1285, 2621, 3937, 5235, 6506, 7790, 9085, 10352, 11630, 12949, 14247, 15528], + [0, 1297, 2575, 3868, 5146, 6436, 7775, 9066, 10376, 11708, 13005, 14365, 15649], + [0, 1322, 2638, 3920, 5217, 6522, 7801, 9113, 10472, 11769, 13046, 14372, 15668], + [0, 1272, 2539, 3871, 5146, 6471, 7791, 9069, 10360, 11688, 12968, 14262, 15580], + [0, 1322, 2642, 3933, 5229, 6538, 7823, 9126, 10432, 11734, 13089, 14372, 15678], + [0, 1310, 2658, 3987, 5316, 6608, 7878, 9171, 10463, 11757, 13060, 14356, 15660], + [0, 1318, 2640, 3924, 5237, 6546, 7832, 9138, 10462, 11762, 13046, 14341, 15609], + [0, 1280, 2558, 3850, 5191, 6495, 7820, 9113, 10401, 11717, 13040, 14314, 15614], + [0, 1313, 2596, 3908, 5249, 6542, 7843, 9141, 10456, 11739, 13039, 14348, 15699], [0, 1309], [0, 1400, 2689], - [ - 0, - 1362, - 2646, - 3947, - 5228, - 6517, - 7824, - 9116, - 10402, - 11683, - 12976, - 14271, - 15583, - ], - [ - 0, - 1303, - 2653, - 3937, - 5234, - 6541, - 7861, - 9224, - 10606, - 11897, - 13213, - 14544, - 15851, - ], - [ - 0, - 1309, - 2636, - 3924, - 5216, - 6500, - 7775, - 9085, - 10380, - 11696, - 12999, - 14337, - 15613, - ], + [0, 1362, 2646, 3947, 5228, 6517, 7824, 9116, 10402, 11683, 12976, 14271, 15583], + [0, 1303, 2653, 3937, 5234, 6541, 7861, 9224, 10606, 11897, 13213, 14544, 15851], + [0, 1309, 2636, 3924, 5216, 6500, 7775, 9085, 10380, 11696, 12999, 14337, 15613], [0, 1310, 2611, 3904, 5238, 6532, 7804, 9100, 10408, 11707, 13011], - [ - 0, - 1313, - 2646, - 3956, - 5263, - 6587, - 7949, - 9257, - 10555, - 11837, - 13104, - 14394, - 15724, - ], - [ - 0, - 1321, - 2612, - 3915, - 5231, - 6551, - 7838, - 9128, - 10440, - 11759, - 13099, - 14416, - 15700, - ], - [ - 0, - 1283, - 2592, - 3872, - 5194, - 6467, - 7751, - 9040, - 10321, - 11673, - 13010, - 14304, - 15602, - ], + [0, 1313, 2646, 3956, 5263, 6587, 7949, 9257, 10555, 11837, 13104, 14394, 15724], + [0, 1321, 2612, 3915, 5231, 6551, 7838, 9128, 10440, 11759, 13099, 14416, 15700], + [0, 1283, 2592, 3872, 5194, 6467, 7751, 9040, 10321, 11673, 13010, 14304, 15602], [0, 1270, 2622, 3915, 5193, 6478, 7776, 9085, 10430, 11732, 13033, 14338], [0, 1296, 2631, 3955], - [ - 0, - 1315, - 2622, - 3949, - 5243, - 6592, - 7894, - 9216, - 10533, - 11830, - 13123, - 14419, - 15722, - ], - [ - 0, - 1296, - 2590, - 3913, - 5221, - 6504, - 7778, - 9125, - 10426, - 11782, - 13051, - 14328, - 15637, - ], - [ - 0, - 1294, - 2579, - 3886, - 5160, - 6456, - 7746, - 9047, - 10347, - 11638, - 12962, - 14261, - 15550, - ], + [0, 1315, 2622, 3949, 5243, 6592, 7894, 9216, 10533, 11830, 13123, 14419, 15722], + [0, 1296, 2590, 3913, 5221, 6504, 7778, 9125, 10426, 11782, 13051, 14328, 15637], + [0, 1294, 2579, 3886, 5160, 6456, 7746, 9047, 10347, 11638, 12962, 14261, 15550], [0, 7], - [ - 0, - 1298, - 2599, - 3887, - 5201, - 6506, - 7843, - 9158, - 10456, - 11749, - 13058, - 14337, - 15630, - ], - [ - 0, - 1290, - 2598, - 3876, - 5177, - 6473, - 7790, - 9065, - 10362, - 11640, - 12943, - 14287, - 15582, - ], - [ - 0, - 1333, - 2623, - 3903, - 5189, - 6467, - 7759, - 9063, - 10388, - 11729, - 13022, - 14310, - 15626, - ], - [ - 0, - 1322, - 2615, - 3921, - 5206, - 6491, - 7811, - 9109, - 10394, - 11691, - 12969, - 14256, - 15532, - ], - [ - 0, - 1302, - 2610, - 3942, - 5267, - 6545, - 7859, - 9154, - 10460, - 11733, - 13053, - 14326, - 15661, - ], + [0, 1298, 2599, 3887, 5201, 6506, 7843, 9158, 10456, 11749, 13058, 14337, 15630], + [0, 1290, 2598, 3876, 5177, 6473, 7790, 9065, 10362, 11640, 12943, 14287, 15582], + [0, 1333, 2623, 3903, 5189, 6467, 7759, 9063, 10388, 11729, 13022, 14310, 15626], + [0, 1322, 2615, 3921, 5206, 6491, 7811, 9109, 10394, 11691, 12969, 14256, 15532], + [0, 1302, 2610, 3942, 5267, 6545, 7859, 9154, 10460, 11733, 13053, 14326, 15661], [0, 1289, 2616], - [ - 0, - 1291, - 2640, - 3932, - 5229, - 6547, - 7903, - 9205, - 10547, - 11857, - 13171, - 14484, - 15771, - ], + [0, 1291, 2640, 3932, 5229, 6547, 7903, 9205, 10547, 11857, 13171, 14484, 15771], [0, 1240], - [ - 0, - 1289, - 2665, - 3954, - 5276, - 6576, - 7883, - 9167, - 10535, - 11868, - 13215, - 14548, - 15862, - ], - [ - 0, - 1299, - 2606, - 3913, - 5223, - 6514, - 7793, - 9097, - 10381, - 11652, - 12936, - 14228, - 15513, - ], - [ - 0, - 1334, - 2615, - 3932, - 5214, - 6511, - 7818, - 9109, - 10403, - 11701, - 13036, - 14306, - 15648, - ], - [ - 0, - 1315, - 2613, - 3889, - 5215, - 6490, - 7799, - 9110, - 10407, - 11684, - 13016, - 14333, - 15639, - ], - [ - 0, - 1304, - 2591, - 3907, - 5275, - 6563, - 7887, - 9203, - 10539, - 11836, - 13169, - 14459, - 15745, - ], - [ - 0, - 1279, - 2548, - 3860, - 5216, - 6529, - 7833, - 9102, - 10400, - 11697, - 13002, - 14313, - 15638, - ], - [ - 0, - 1284, - 2569, - 3861, - 5165, - 6452, - 7768, - 9056, - 10424, - 11748, - 13064, - 14361, - 15697, - ], + [0, 1289, 2665, 3954, 5276, 6576, 7883, 9167, 10535, 11868, 13215, 14548, 15862], + [0, 1299, 2606, 3913, 5223, 6514, 7793, 9097, 10381, 11652, 12936, 14228, 15513], + [0, 1334, 2615, 3932, 5214, 6511, 7818, 9109, 10403, 11701, 13036, 14306, 15648], + [0, 1315, 2613, 3889, 5215, 6490, 7799, 9110, 10407, 11684, 13016, 14333, 15639], + [0, 1304, 2591, 3907, 5275, 6563, 7887, 9203, 10539, 11836, 13169, 14459, 15745], + [0, 1279, 2548, 3860, 5216, 6529, 7833, 9102, 10400, 11697, 13002, 14313, 15638], + [0, 1284, 2569, 3861, 5165, 6452, 7768, 9056, 10424, 11748, 13064, 14361, 15697], [0, 1302, 2600], [0, 1289, 2586], [0, 1287, 2577, 3855], diff --git a/tests/jit_utils.py b/tests/jit_utils.py index 6a462bc366..bfbab2e77d 100644 --- a/tests/jit_utils.py +++ b/tests/jit_utils.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2023 by FlashInfer team. @@ -13,11 +19,8 @@ See the License for the specific language governing permissions and limitations under the License. """ - import itertools -import torch - import flashinfer from flashinfer.jit import JitSpec from flashinfer.utils import is_fa3_backend_supported, is_sm90a_supported @@ -32,7 +35,6 @@ def gen_decode_attention_modules( use_logits_soft_cap_options, ) -> list[JitSpec]: jit_specs: list[JitSpec] = [] - for ( q_dtype, kv_dtype, @@ -49,16 +51,15 @@ def gen_decode_attention_modules( use_logits_soft_cap_options, ): if q_dtype != kv_dtype: - if kv_dtype.itemsize > 1: - continue # skip fp16/bf16 mixed precision - + if kv_dtype.element_size() > 1: + continue jit_specs.append( flashinfer.decode.gen_single_decode_module( q_dtype, kv_dtype, q_dtype, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + head_dim, + head_dim, pos_encoding_mode, use_sliding_window, use_logits_soft_cap, @@ -69,55 +70,40 @@ def gen_decode_attention_modules( q_dtype, kv_dtype, q_dtype, - torch.int32, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + "int32", + head_dim, + head_dim, pos_encoding_mode, use_sliding_window, use_logits_soft_cap, ) ) - return jit_specs def gen_persistent_batch_attention_modules( - q_dtypes, - kv_dtypes, - head_dims, - use_logits_soft_cap_options, + q_dtypes, kv_dtypes, head_dims, use_logits_soft_cap_options ) -> list[JitSpec]: jit_specs: list[JitSpec] = [] - - for ( - q_dtype, - kv_dtype, - head_dim, - use_logits_soft_cap, - ) in itertools.product( - q_dtypes, - kv_dtypes, - head_dims, - use_logits_soft_cap_options, + for q_dtype, kv_dtype, head_dim, use_logits_soft_cap in itertools.product( + q_dtypes, kv_dtypes, head_dims, use_logits_soft_cap_options ): if q_dtype != kv_dtype: - if kv_dtype.itemsize > 1: - continue # skip fp16/bf16 mixed precision - + if kv_dtype.element_size() > 1: + continue jit_specs.append( flashinfer.attention.gen_batch_attention_module( q_dtype, kv_dtype, q_dtype, - torch.int32, - head_dim, # head_dim_qk - head_dim, # head_dim_vo - 0, # pos_encoding_mode + "int32", + head_dim, + head_dim, + 0, use_logits_soft_cap, - False, # use_profiler + False, ) ) - return jit_specs @@ -131,7 +117,6 @@ def gen_prefill_attention_modules( use_fp16_qk_reduction_options, ) -> list[JitSpec]: jit_specs: list[JitSpec] = [] - for ( q_dtype, kv_dtype, @@ -150,10 +135,9 @@ def gen_prefill_attention_modules( use_fp16_qk_reduction_options, ): if q_dtype != kv_dtype: - if kv_dtype.itemsize > 1: - continue # skip fp16/bf16 mixed precision - - if is_sm90a_supported(torch.device("cuda")) and is_fa3_backend_supported( + if kv_dtype.element_size() > 1: + continue + if is_sm90a_supported(device2str("cuda")) and is_fa3_backend_supported( pos_encoding_mode, use_fp16_qk_reduction, use_custom_mask=False, @@ -166,24 +150,23 @@ def gen_prefill_attention_modules( q_dtype, kv_dtype, q_dtype, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + head_dim, + head_dim, pos_encoding_mode, use_sliding_window, use_logits_soft_cap, use_fp16_qk_reduction, ) ) - jit_specs.append( flashinfer.prefill.gen_batch_prefill_module( "fa3", q_dtype, kv_dtype, q_dtype, - torch.int32, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + "int32", + head_dim, + head_dim, pos_encoding_mode, use_sliding_window, use_logits_soft_cap, @@ -196,8 +179,8 @@ def gen_prefill_attention_modules( q_dtype, kv_dtype, q_dtype, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + head_dim, + head_dim, pos_encoding_mode, use_sliding_window, use_logits_soft_cap, @@ -210,19 +193,15 @@ def gen_prefill_attention_modules( q_dtype, kv_dtype, q_dtype, - torch.int32, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + "int32", + head_dim, + head_dim, pos_encoding_mode, use_sliding_window, use_logits_soft_cap, use_fp16_qk_reduction, ) ) - - # required for attention with custom mask jit_specs.append(flashinfer.quantization.gen_quantization_module()) - jit_specs.append(flashinfer.page.gen_page_module()) - return jit_specs diff --git a/tests/rope_reference.py b/tests/rope_reference.py index 82df2e8489..0d3c497598 100644 --- a/tests/rope_reference.py +++ b/tests/rope_reference.py @@ -1,25 +1,18 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. +import sys + import math from typing import Optional, Tuple, Union -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. -import torch +import paddle +from flashinfer.paddle_utils import * -def apply_scaling(freqs: torch.Tensor): - # Values obtained from grid search +def apply_scaling(freqs: paddle.Tensor): scale_factor = 8 low_freq_factor = 1 high_freq_factor = 4 - old_context_len = 8192 # original llama3 length - + old_context_len = 8192 low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor new_freqs = [] @@ -35,7 +28,7 @@ def apply_scaling(freqs: torch.Tensor): high_freq_factor - low_freq_factor ) new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + return paddle.to_tensor(data=new_freqs, dtype=freqs.dtype, place=freqs.place) def precompute_freqs_cis( @@ -45,76 +38,78 @@ def precompute_freqs_cis( use_scaled: bool = False, device: str = "cuda:0", ): - freqs = 1.0 / ( - theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim) + freqs = 1.0 / theta ** ( + paddle.arange(start=0, end=dim, step=2)[: dim // 2].astype(dtype="float32") + / dim ) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) + t = paddle.arange(dtype="float32", end=end) if use_scaled: freqs = apply_scaling(freqs) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + freqs = paddle.outer(x=t, y=freqs) + freqs_cis = paddle.polar(abs=paddle.ones_like(x=freqs), angle=freqs) return freqs_cis -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): +def reshape_for_broadcast(freqs_cis: paddle.Tensor, x: paddle.Tensor): ndim = x.ndim assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + assert tuple(freqs_cis.shape) == (tuple(x.shape)[1], tuple(x.shape)[-1]) + shape = [ + (d if i == 1 or i == ndim - 1 else 1) for i, d in enumerate(tuple(x.shape)) + ] return freqs_cis.view(*shape) def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq: paddle.Tensor, xk: paddle.Tensor, freqs_cis: paddle.Tensor +) -> Tuple[paddle.Tensor, paddle.Tensor]: + xq_ = paddle.as_complex( + x=xq.astype(dtype="float32").reshape(*tuple(xq.shape)[:-1], -1, 2) + ) + xk_ = paddle.as_complex( + x=xk.astype(dtype="float32").reshape(*tuple(xk.shape)[:-1], -1, 2) + ) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) + xq_out = paddle.as_real(x=xq_ * freqs_cis).flatten(start_axis=3) + xk_out = paddle.as_real(x=xk_ * freqs_cis).flatten(start_axis=3) + return xq_out.astype(dtype=xq.dtype), xk_out.astype(dtype=xk.dtype) def apply_rotary_pos_emb(q, k, cos, sin): - cos = cos.unsqueeze(1) - sin = sin.unsqueeze(1) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + cos = cos.unsqueeze(axis=1) + sin = sin.unsqueeze(axis=1) + q_embed = q * cos + rotate_half(q) * sin + k_embed = k * cos + rotate_half(k) * sin return q_embed.to(q.dtype), k_embed.to(k.dtype) def rotate_half(x): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) + x1 = x[..., : tuple(x.shape)[-1] // 2] + x2 = x[..., tuple(x.shape)[-1] // 2 :] + return paddle.concat(x=(-x2, x1), axis=-1) def generate_cos_sin_f32_cache( - max_seq_len, head_dim, theta=1e4, use_scaled: bool = False, device: str = "cuda:0" + max_seq_len, + head_dim, + theta=10000.0, + use_scaled: bool = False, + device: str = "cuda:0", ): - position = torch.arange(max_seq_len, device=device, dtype=torch.float32).unsqueeze( - 1 - ) - freqs = 1.0 / ( - theta - ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim) + position = paddle.arange(dtype="float32", end=max_seq_len).unsqueeze(axis=1) + freqs = 1.0 / theta ** ( + paddle.arange(start=0, end=head_dim, step=2, dtype="float32") / head_dim ) - freqs = torch.cat([freqs, freqs], dim=-1).contiguous() + freqs = paddle.concat(x=[freqs, freqs], axis=-1).contiguous() if use_scaled: freqs = apply_scaling(freqs) args = position * freqs - sin_cache = torch.sin(args) - cos_cache = torch.cos(args) + sin_cache = paddle.sin(x=args) + cos_cache = paddle.cos(x=args) return cos_cache, sin_cache -# The following code is from the vLLM's implementation of RoPE. -# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py - - -class RotaryEmbedding(torch.nn.Module): +class RotaryEmbedding(paddle.nn.Layer): def __init__( self, head_size: int, @@ -122,7 +117,7 @@ def __init__( max_position_embeddings: int, base: int, is_neox_style: bool, - dtype: torch.dtype, + dtype: paddle.dtype, device: str = "cuda:0", ) -> None: super().__init__() @@ -134,41 +129,33 @@ def __init__( self.dtype = dtype self.device = device cache = self._compute_cos_sin_cache() - self.cos_sin_cache: torch.Tensor - self.register_buffer("cos_sin_cache", cache, persistent=False) - - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: - inv_freq = 1.0 / ( - base - ** ( - torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float, device=self.device - ) - / self.rotary_dim - ) + self.cos_sin_cache: paddle.Tensor + self.register_buffer(name="cos_sin_cache", tensor=cache, persistable=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> paddle.Tensor: + inv_freq = 1.0 / base ** ( + paddle.arange(start=0, end=self.rotary_dim, step=2, dtype="float32") + / self.rotary_dim ) return inv_freq - def _compute_cos_sin_cache(self) -> torch.Tensor: + def _compute_cos_sin_cache(self) -> paddle.Tensor: """Compute the cos and sin cache.""" inv_freq = self._compute_inv_freq(self.base) - t = torch.arange( - self.max_position_embeddings, dtype=torch.float, device=self.device - ) - - freqs = torch.einsum("i,j -> ij", t, inv_freq) + t = paddle.arange(dtype="float32", end=self.max_position_embeddings) + freqs = paddle.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) + cache = paddle.concat(x=(cos, sin), axis=-1) return cache def _apply_rotary_emb( self, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, + x: paddle.Tensor, + cos: paddle.Tensor, + sin: paddle.Tensor, is_neox_style: bool, - ) -> torch.Tensor: + ) -> paddle.Tensor: """ Args: x: [num_tokens, num_heads, head_size] @@ -177,56 +164,48 @@ def _apply_rotary_emb( is_neox_style: Whether to use the Neox-style or GPT-J-style rotary positional embeddings. """ - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) + cos = cos.unsqueeze(axis=-2).to(x.dtype) + sin = sin.unsqueeze(axis=-2).to(x.dtype) if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) + x1, x2 = paddle.chunk(x=x, chunks=2, axis=-1) else: x1 = x[..., ::2] x2 = x[..., 1::2] o1 = x1 * cos - x2 * sin o2 = x2 * cos + x1 * sin if is_neox_style: - return torch.cat((o1, o2), dim=-1) + return paddle.concat(x=(o1, o2), axis=-1) else: - return torch.stack((o1, o2), dim=-1).flatten(-2) + return paddle.stack(x=(o1, o2), axis=-1).flatten(start_axis=-2) def forward_native( self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + positions: paddle.Tensor, + query: paddle.Tensor, + key: paddle.Tensor, + offsets: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: """A PyTorch-native implementation of forward().""" if offsets is not None: positions = positions + offsets - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - - # Note: the is different from the vLLM's implementation, - # We added float32 conversion because float32 is required for the rotary embedding to work correctly for long contexts - query = query.to(torch.float32) - key = key.to(torch.float32) - - cos, sin = cos_sin.chunk(2, dim=-1) - - query_shape = query.shape + num_tokens = tuple(positions.shape)[0] + cos_sin = self.cos_sin_cache.index_select(axis=0, index=positions) + query = query.to("float32") + key = key.to("float32") + cos, sin = cos_sin.chunk(chunks=2, axis=-1) + query_shape = tuple(query.shape) query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] query_rot = self._apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) - query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - key_shape = key.shape + query = paddle.concat(x=(query_rot, query_pass), axis=-1).reshape(query_shape) + key_shape = tuple(key.shape) key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] key_rot = self._apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - + key = paddle.concat(x=(key_rot, key_pass), axis=-1).reshape(key_shape) query = query.to(self.dtype) key = key.to(self.dtype) return query, key diff --git a/tests/sink_attention_reference.py b/tests/sink_attention_reference.py index e26707c157..caba2cd923 100644 --- a/tests/sink_attention_reference.py +++ b/tests/sink_attention_reference.py @@ -1,3 +1,10 @@ +import sys + + +import einops +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,42 +20,29 @@ See the License for the specific language governing permissions and limitations under the License. """ - from typing import Optional -import einops -import torch - def sink_softmax(logits, sink): - sink = einops.repeat(sink, "h -> b h m 1", b=logits.shape[0], m=logits.shape[2]) - # (b, h, m, (n + 1)) - logits = torch.cat([logits, sink], dim=-1) - # (s_1, s_2, ..., s_n) - # (s_1, s_2, ..., s_n, log(sink)) - # (exp(s_1), exp(s_2), ..., exp(s_n), sink) - # (exp(s_1) / (exp(s_1) + exp(s_2) + ... + exp(s_n) + sink), - # exp(s_2) / (exp(s_1) + exp(s_2) + ... + exp(s_n) + sink), - # ..., - # exp(s_n) / (exp(s_1) + exp(s_2) + ... + exp(s_n) + sink)) - # sink / (exp(s_1) + exp(s_2) + ... + exp(s_n) + sink) - score = torch.softmax(logits, dim=-1)[..., :-1].contiguous() + sink = einops.tile(repeat_times=[sink, "h -> b h m 1"]) + logits = paddle.concat(x=[logits, sink], axis=-1) + score = paddle.nn.functional.softmax(x=logits, axis=-1)[..., :-1].contiguous() return score def sink_attention_unified( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - sink: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + sink: paddle.Tensor, window_left: int, causal: bool, sm_scale: float, batch_size: Optional[int] = None, mode: str = "auto", - qo_indptr: Optional[torch.Tensor] = None, - kv_indptr: Optional[torch.Tensor] = None, -) -> torch.Tensor: + qo_indptr: Optional[paddle.Tensor] = None, + kv_indptr: Optional[paddle.Tensor] = None, +) -> paddle.Tensor: """ Unified sink attention implementation supporting prefill, incremental, chunk prefill, and variable-length scenarios. @@ -89,30 +83,22 @@ def sink_attention_unified( - Chunk Prefill: [total_q_len, num_qo_heads, head_dim] - Variable Length: [total_q_len, num_qo_heads, head_dim] """ - - # Auto-detect mode if not specified if mode == "auto": - # Check if variable length mode is indicated by presence of indptr if qo_indptr is not None or kv_indptr is not None: mode = "varlen" - elif len(q.shape) == 3 and len(k.shape) == 4: - # q: [batch_size, num_heads, head_dim], k: [batch_size, kv_len, num_heads, head_dim] - # This is incremental mode + elif len(tuple(q.shape)) == 3 and len(tuple(k.shape)) == 4: mode = "incremental" - elif len(q.shape) == 3 and len(k.shape) == 3: - # Both q and k are flattened: [total_len, num_heads, head_dim] + elif len(tuple(q.shape)) == 3 and len(tuple(k.shape)) == 3: if batch_size is None: raise ValueError( "batch_size is required for auto-detection in prefill/chunk modes" ) - - qo_len = q.shape[0] // batch_size - kv_len = k.shape[0] // batch_size - + qo_len = tuple(q.shape)[0] // batch_size + kv_len = tuple(k.shape)[0] // batch_size if qo_len == kv_len: mode = "prefill" elif qo_len == 1: - mode = "incremental" # Special case: single token with flattened format + mode = "incremental" elif qo_len > 1 and qo_len != kv_len: mode = "chunk" else: @@ -121,282 +107,184 @@ def sink_attention_unified( ) else: raise ValueError( - f"Cannot auto-detect mode from tensor shapes: q={q.shape}, k={k.shape}" + f"Cannot auto-detect mode from tensor shapes: q={tuple(q.shape)}, k={tuple(k.shape)}" ) - - # Process based on detected/specified mode if mode == "incremental": - # Incremental generation mode: q_len=1, kv_len from cache - batch_size = q.shape[0] + batch_size = tuple(q.shape)[0] qo_len = 1 - kv_len = k.shape[1] - num_qo_heads = q.shape[1] - num_kv_heads = k.shape[2] - - # Handle GQA + kv_len = tuple(k.shape)[1] + num_qo_heads = tuple(q.shape)[1] + num_kv_heads = tuple(k.shape)[2] if num_qo_heads != num_kv_heads: - k = torch.repeat_interleave( - k, num_qo_heads // num_kv_heads, dim=2 + k = paddle.repeat_interleave( + x=k, repeats=num_qo_heads // num_kv_heads, axis=2 ).contiguous() - v = torch.repeat_interleave( - v, num_qo_heads // num_kv_heads, dim=2 + v = paddle.repeat_interleave( + x=v, repeats=num_qo_heads // num_kv_heads, axis=2 ).contiguous() num_kv_heads = num_qo_heads - - head_dim_qk = q.shape[2] - head_dim_vo = v.shape[3] - - # Compute logits: [batch_size, num_heads, 1, kv_len] + head_dim_qk = tuple(q.shape)[2] + head_dim_vo = tuple(v.shape)[3] logits = ( - torch.einsum( - "bhd,blhd->bhl", - q.float(), - k.float(), - ).unsqueeze(2) # Add seq_len=1 dimension + paddle.einsum( + "bhd,blhd->bhl", q.astype(dtype="float32"), k.astype(dtype="float32") + ).unsqueeze(axis=2) * sm_scale ) - elif mode in ["prefill", "chunk"]: - # Prefill or Chunk prefill mode: q and k are flattened tensors if batch_size is None: raise ValueError(f"batch_size is required for {mode} mode") - - qo_len = q.shape[0] // batch_size - kv_len = k.shape[0] // batch_size - num_qo_heads = q.shape[1] - num_kv_heads = k.shape[1] - - # Handle GQA + qo_len = tuple(q.shape)[0] // batch_size + kv_len = tuple(k.shape)[0] // batch_size + num_qo_heads = tuple(q.shape)[1] + num_kv_heads = tuple(k.shape)[1] if num_qo_heads != num_kv_heads: - k = torch.repeat_interleave( - k, num_qo_heads // num_kv_heads, dim=1 + k = paddle.repeat_interleave( + x=k, repeats=num_qo_heads // num_kv_heads, axis=1 ).contiguous() - v = torch.repeat_interleave( - v, num_qo_heads // num_kv_heads, dim=1 + v = paddle.repeat_interleave( + x=v, repeats=num_qo_heads // num_kv_heads, axis=1 ).contiguous() - - head_dim_qk = q.shape[2] - head_dim_vo = v.shape[2] - - # Compute logits: [batch_size, num_heads, qo_len, kv_len] + head_dim_qk = tuple(q.shape)[2] + head_dim_vo = tuple(v.shape)[2] logits = ( - torch.einsum( + paddle.einsum( "bmhd,bnhd->bhmn", - q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), - k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), + q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).astype( + dtype="float32" + ), + k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).astype( + dtype="float32" + ), ) * sm_scale ) - elif mode == "varlen": - # Variable length sequences mode if qo_indptr is None or kv_indptr is None: raise ValueError("qo_indptr and kv_indptr are required for varlen mode") - - batch_size = qo_indptr.shape[0] - 1 - num_qo_heads = q.shape[1] - num_kv_heads = k.shape[1] - head_dim_qk = q.shape[2] - head_dim_vo = v.shape[2] - - # Handle GQA + batch_size = tuple(qo_indptr.shape)[0] - 1 + num_qo_heads = tuple(q.shape)[1] + num_kv_heads = tuple(k.shape)[1] + head_dim_qk = tuple(q.shape)[2] + head_dim_vo = tuple(v.shape)[2] if num_qo_heads != num_kv_heads: - k = torch.repeat_interleave( - k, num_qo_heads // num_kv_heads, dim=1 + k = paddle.repeat_interleave( + x=k, repeats=num_qo_heads // num_kv_heads, axis=1 ).contiguous() - v = torch.repeat_interleave( - v, num_qo_heads // num_kv_heads, dim=1 + v = paddle.repeat_interleave( + x=v, repeats=num_qo_heads // num_kv_heads, axis=1 ).contiguous() num_kv_heads = num_qo_heads - - # Process each request in the batch separately output_list = [] - for i in range(batch_size): - # Extract tensors for current request qo_start, qo_end = qo_indptr[i].item(), qo_indptr[i + 1].item() kv_start, kv_end = kv_indptr[i].item(), kv_indptr[i + 1].item() - - q_i = q[qo_start:qo_end] # [qo_len_i, num_heads, head_dim] - k_i = k[kv_start:kv_end] # [kv_len_i, num_heads, head_dim] - v_i = v[kv_start:kv_end] # [kv_len_i, num_heads, head_dim] - + q_i = q[qo_start:qo_end] + k_i = k[kv_start:kv_end] + v_i = v[kv_start:kv_end] qo_len_i = qo_end - qo_start kv_len_i = kv_end - kv_start - - # Compute logits for current request: [1, num_heads, qo_len_i, kv_len_i] logits_i = ( - torch.einsum( + paddle.einsum( "qhd,khd->hqk", - q_i.float(), - k_i.float(), - ).unsqueeze(0) # Add batch dimension + q_i.astype(dtype="float32"), + k_i.astype(dtype="float32"), + ).unsqueeze(axis=0) * sm_scale ) - - # Build attention mask for current request if causal: - # Create causal mask for this specific request - row_idx = torch.arange(qo_len_i, dtype=torch.int32, device=q.device)[ - :, None - ] - col_idx = torch.arange(kv_len_i, dtype=torch.int32, device=q.device)[ - None, : - ] - - # Default causal mask: position i can attend to positions 0 to i in the kv sequence - # Assuming queries correspond to the last qo_len_i positions in the kv sequence + row_idx = paddle.arange(dtype="int32", end=qo_len_i)[:, None] + col_idx = paddle.arange(dtype="int32", end=kv_len_i)[None, :] query_positions = kv_len_i - qo_len_i + row_idx mask_i = query_positions >= col_idx - if window_left >= 0: mask_i &= query_positions - window_left <= col_idx else: - # Non-causal mask - mask_i = torch.ones( - qo_len_i, kv_len_i, device=q.device, dtype=torch.bool - ) + mask_i = paddle.ones(shape=[qo_len_i, kv_len_i], dtype="bool") if window_left >= 0: - row_idx = torch.arange( - qo_len_i, dtype=torch.int32, device=q.device - )[:, None] - col_idx = torch.arange( - kv_len_i, dtype=torch.int32, device=q.device - )[None, :] + row_idx = paddle.arange(dtype="int32", end=qo_len_i)[:, None] + col_idx = paddle.arange(dtype="int32", end=kv_len_i)[None, :] query_positions = kv_len_i - qo_len_i + row_idx mask_i = query_positions - window_left <= col_idx - - # Apply mask logits_i = logits_i.masked_fill( - mask_i.unsqueeze(0).unsqueeze(0) == 0, float("-inf") + mask=mask_i.unsqueeze(axis=0).unsqueeze(axis=0) == 0, + value=float("-inf"), ) - - # Apply sink softmax - p_i = sink_softmax(logits_i, sink) # [1, num_heads, qo_len_i, kv_len_i] - - # Compute output for current request + p_i = sink_softmax(logits_i, sink) o_i = ( - torch.einsum( - "bhmn,nhd->bmhd", - p_i, # [1, num_heads, qo_len_i, kv_len_i] - v_i.float(), # [kv_len_i, num_heads, head_dim] - ) + paddle.einsum("bhmn,nhd->bmhd", p_i, v_i.astype(dtype="float32")) .contiguous() .view(qo_len_i, num_qo_heads, head_dim_vo) .to(q) ) - output_list.append(o_i) - - # Concatenate outputs from all requests - o_ref = torch.cat(output_list, dim=0) - + o_ref = paddle.concat(x=output_list, axis=0) return o_ref - else: raise ValueError( f"Unknown mode: {mode}. Supported modes: 'auto', 'prefill', 'incremental', 'chunk', 'varlen'" ) - - # Build attention mask (unified for all modes) if causal: if mode == "incremental": - # For incremental: new token can attend to all previous tokens - mask = torch.ones(1, kv_len, device=q.device, dtype=torch.bool) + mask = paddle.ones(shape=[1, kv_len], dtype="bool") if window_left >= 0: - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device) - mask = (kv_len - 1 - window_left) <= col_idx + col_idx = paddle.arange(dtype="int32", end=kv_len) + mask = kv_len - 1 - window_left <= col_idx elif mode == "prefill": - # For regular prefill: standard causal mask - mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( - 1 - ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) + mask = paddle.arange(start=kv_len - qo_len, end=kv_len).unsqueeze( + axis=1 + ) >= paddle.arange(start=0, end=kv_len).unsqueeze(axis=0) if window_left >= 0: - row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[ - :, None - ] - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[ - None, : - ] + row_idx = paddle.arange(dtype="int32", end=qo_len)[:, None] + col_idx = paddle.arange(dtype="int32", end=kv_len)[None, :] mask &= row_idx - window_left <= col_idx elif mode == "chunk": - # For chunk prefill: each query position can attend to all previous KV positions - # Current chunk positions are at the end: [kv_len - qo_len : kv_len] current_chunk_start = kv_len - qo_len - row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[ - :, None - ] # Positions within chunk - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[ - None, : - ] # All KV positions - - # Each position can attend to: all historical + positions up to itself in current chunk - abs_row_positions = ( - current_chunk_start + row_idx - ) # Absolute positions in full sequence - mask = abs_row_positions >= col_idx # Standard causal mask - + row_idx = paddle.arange(dtype="int32", end=qo_len)[:, None] + col_idx = paddle.arange(dtype="int32", end=kv_len)[None, :] + abs_row_positions = current_chunk_start + row_idx + mask = abs_row_positions >= col_idx if window_left >= 0: mask &= abs_row_positions - window_left <= col_idx + elif mode == "incremental": + mask = paddle.ones(shape=[1, kv_len], dtype="bool") + if window_left >= 0: + col_idx = paddle.arange(dtype="int32", end=kv_len) + mask = kv_len - 1 - window_left <= col_idx else: - # Non-causal mask - if mode == "incremental": - mask = torch.ones(1, kv_len, device=q.device, dtype=torch.bool) - if window_left >= 0: - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device) - mask = (kv_len - 1 - window_left) <= col_idx - else: # prefill or chunk - mask = torch.ones(qo_len, kv_len, device=q.device, dtype=torch.bool) - if window_left >= 0: - if mode == "chunk": - # For chunk mode, apply window relative to absolute positions - current_chunk_start = kv_len - qo_len - row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[ - :, None - ] - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[ - None, : - ] - abs_row_positions = current_chunk_start + row_idx - mask = abs_row_positions - window_left <= col_idx - else: # prefill - row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[ - :, None - ] - col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device)[ - None, : - ] - mask = row_idx - window_left <= col_idx - - # Apply mask - logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) - - # Apply sink softmax + mask = paddle.ones(shape=[qo_len, kv_len], dtype="bool") + if window_left >= 0: + if mode == "chunk": + current_chunk_start = kv_len - qo_len + row_idx = paddle.arange(dtype="int32", end=qo_len)[:, None] + col_idx = paddle.arange(dtype="int32", end=kv_len)[None, :] + abs_row_positions = current_chunk_start + row_idx + mask = abs_row_positions - window_left <= col_idx + else: + row_idx = paddle.arange(dtype="int32", end=qo_len)[:, None] + col_idx = paddle.arange(dtype="int32", end=kv_len)[None, :] + mask = row_idx - window_left <= col_idx + logits = logits.masked_fill( + mask=mask.unsqueeze(axis=0).unsqueeze(axis=0) == 0, value=float("-inf") + ) p = sink_softmax(logits, sink) - - # Compute output if mode == "incremental": - # Incremental mode output o_ref = ( - torch.einsum( - "bhml,blhd->bhd", - p, # [batch_size, num_heads, 1, kv_len] - v.float(), # [batch_size, kv_len, num_heads, head_dim] - ) + paddle.einsum("bhml,blhd->bhd", p, v.astype(dtype="float32")) .contiguous() .to(q) ) - else: # prefill or chunk mode - # Prefill/Chunk mode output + else: o_ref = ( - torch.einsum( + paddle.einsum( "bhmn,bnhd->bmhd", - p, # [batch_size, num_heads, qo_len, kv_len] - v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), + p, + v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).astype( + dtype="float32" + ), ) .contiguous() .view(batch_size * qo_len, num_qo_heads, head_dim_vo) .to(q) ) - return o_ref diff --git a/tests/test_activation.py b/tests/test_activation.py index 3854d7f576..bb67512cb5 100644 --- a/tests/test_activation.py +++ b/tests/test_activation.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,9 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch import flashinfer from flashinfer.utils import get_compute_capability @@ -39,13 +39,13 @@ def warmup_jit(): @pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) @pytest.mark.parametrize("enable_pdl", [True, False]) def test_fused_silu_mul(dim, batch_size, seq_len, enable_pdl): - x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) - major, _ = get_compute_capability(x.device) + x = paddle.randn(shape=[batch_size, seq_len, 2 * dim]).to(0).to("float16") + major, _ = get_compute_capability(x.place) if major < 9 and enable_pdl: pytest.skip("PDL is only available for Hopper and later GPUs") - y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim]) + y_ref = x[..., dim:] * paddle.nn.functional.silu(x=x[..., :dim]) y = flashinfer.activation.silu_and_mul(x, enable_pdl=enable_pdl) - torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=y_ref, y=y, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) @@ -53,13 +53,13 @@ def test_fused_silu_mul(dim, batch_size, seq_len, enable_pdl): @pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) @pytest.mark.parametrize("enable_pdl", [True, False]) def test_fused_gelu_tanh_mul(dim, batch_size, seq_len, enable_pdl): - x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) - major, _ = get_compute_capability(x.device) + x = paddle.randn(shape=[batch_size, seq_len, 2 * dim]).to(0).to("float16") + major, _ = get_compute_capability(x.place) if major < 9 and enable_pdl: pytest.skip("PDL is only available for Hopper and later GPUs") - y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh") + y_ref = x[..., dim:] * paddle.nn.functional.gelu(x=x[..., :dim], approximate=True) y = flashinfer.activation.gelu_tanh_and_mul(x, enable_pdl=enable_pdl) - torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=y_ref, y=y, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) @@ -67,13 +67,13 @@ def test_fused_gelu_tanh_mul(dim, batch_size, seq_len, enable_pdl): @pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) @pytest.mark.parametrize("enable_pdl", [True, False]) def test_fused_gelu_mul(dim, batch_size, seq_len, enable_pdl): - x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) - major, _ = get_compute_capability(x.device) + x = paddle.randn(shape=[batch_size, seq_len, 2 * dim]).to(0).to("float16") + major, _ = get_compute_capability(x.place) if major < 9 and enable_pdl: pytest.skip("PDL is only available for Hopper and later GPUs") - y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none") + y_ref = x[..., dim:] * paddle.nn.functional.gelu(x=x[..., :dim], approximate=False) y = flashinfer.activation.gelu_and_mul(x, enable_pdl=enable_pdl) - torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=y_ref, y=y, rtol=0.001, atol=0.001).item(), "" if __name__ == "__main__": diff --git a/tests/test_alibi.py b/tests/test_alibi.py index 06ea3a769f..2e8a6342e7 100644 --- a/tests/test_alibi.py +++ b/tests/test_alibi.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2023 by FlashInfer team. @@ -13,11 +15,10 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch from alibi_reference import alibi_attention -from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules +from jit_utils import (gen_decode_attention_modules, + gen_prefill_attention_modules) import flashinfer @@ -26,21 +27,10 @@ def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0, 2], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps + ["float16"], ["float16"], [128, 256], [0, 2], [False], [False] ) + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0, 2], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions + ["float16"], ["float16"], [128, 256], [0, 2], [False], [False], [False] ), verbose=False, ) @@ -50,19 +40,14 @@ def warmup_jit(): @pytest.mark.parametrize("seq_len", [1, 9, 81, 729]) @pytest.mark.parametrize("num_heads", [4, 8, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) -def test_single_decode_alibi( - seq_len, - num_heads, - head_dim, -): - q = torch.randn(num_heads, head_dim, device="cuda:0", dtype=torch.float16) - k = torch.randn(seq_len, num_heads, head_dim, device="cuda:0", dtype=torch.float16) - v = torch.randn(seq_len, num_heads, head_dim, device="cuda:0", dtype=torch.float16) - +def test_single_decode_alibi(seq_len, num_heads, head_dim): + q = paddle.randn(shape=[num_heads, head_dim], dtype="float16") + k = paddle.randn(shape=[seq_len, num_heads, head_dim], dtype="float16") + v = paddle.randn(shape=[seq_len, num_heads, head_dim], dtype="float16") o = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="ALIBI") - mask = torch.ones(1, seq_len, dtype=torch.bool, device="cuda:0") - o_ref = alibi_attention(q.unsqueeze(0), k, v, mask).squeeze(0) - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + mask = paddle.ones(shape=[1, seq_len], dtype="bool") + o_ref = alibi_attention(q.unsqueeze(axis=0), k, v, mask).squeeze(0) + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("q_len", [1, 17, 81, 987]) @@ -70,27 +55,20 @@ def test_single_decode_alibi( @pytest.mark.parametrize("num_heads", [4, 8, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("causal", [False, True]) -def test_single_prefill_alibi( - q_len, - kv_len, - num_heads, - head_dim, - causal, -): +def test_single_prefill_alibi(q_len, kv_len, num_heads, head_dim, causal): if causal and q_len > kv_len: pytest.skip("Causal attention requires q_len <= kv_len") - q = torch.randn(q_len, num_heads, head_dim, device="cuda:0", dtype=torch.float16) - k = torch.randn(kv_len, num_heads, head_dim, device="cuda:0", dtype=torch.float16) - v = torch.randn(kv_len, num_heads, head_dim, device="cuda:0", dtype=torch.float16) - + q = paddle.randn(shape=[q_len, num_heads, head_dim], dtype="float16") + k = paddle.randn(shape=[kv_len, num_heads, head_dim], dtype="float16") + v = paddle.randn(shape=[kv_len, num_heads, head_dim], dtype="float16") o = flashinfer.single_prefill_with_kv_cache( q, k, v, causal=causal, pos_encoding_mode="ALIBI" ) - mask = torch.ones(q_len, kv_len, dtype=torch.bool, device="cuda:0") + mask = paddle.ones(shape=[q_len, kv_len], dtype="bool") if causal: - mask = torch.tril(mask, diagonal=kv_len - q_len) + mask = paddle.tril(x=mask, diagonal=kv_len - q_len) o_ref = alibi_attention(q, k, v, mask) - torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2) + assert paddle.allclose(x=o, y=o_ref, rtol=0.01, atol=0.01).item(), "" if __name__ == "__main__": diff --git a/tests/test_attention_sink.py b/tests/test_attention_sink.py index 254e4af231..277deb4caf 100644 --- a/tests/test_attention_sink.py +++ b/tests/test_attention_sink.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,24 +19,22 @@ See the License for the specific language governing permissions and limitations under the License. """ - import math import pytest -import torch from sink_attention_reference import sink_attention_unified import flashinfer -from flashinfer.jit.utils import filename_safe_dtype_map from flashinfer.jit.attention import gen_batch_prefill_attention_sink_module from flashinfer.jit.attention.variants import attention_sink_decl +from flashinfer.jit.utils import filename_safe_dtype_map from flashinfer.utils import is_sm90a_supported @pytest.fixture(autouse=True, scope="module") def warmup_jit(): jit_specs = [] - for dtype in [torch.float16, torch.bfloat16]: + for dtype in ["float16", "bfloat16"]: for backend in ["fa2", "fa3"]: for use_swa in [True, False]: for head_dim in [128]: @@ -40,29 +44,27 @@ def warmup_jit(): dtype_q=dtype, dtype_kv=dtype, dtype_o=dtype, - dtype_idx=torch.int32, + dtype_idx="int32", head_dim_qk=head_dim, head_dim_vo=head_dim, pos_encoding_mode=0, use_sliding_window=use_swa, ) ) - flashinfer.jit.build_jit_specs(jit_specs) yield -# Wrapper functions for backward compatibility def sink_attention_ref( batch_size: int, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - sink: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + sink: paddle.Tensor, window_left: int, causal: bool, sm_scale: float, -) -> torch.Tensor: +) -> paddle.Tensor: """Backward compatible wrapper for prefill mode.""" return sink_attention_unified( q, @@ -78,14 +80,14 @@ def sink_attention_ref( def sink_attention_incremental_ref( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - sink: torch.Tensor, + q: paddle.Tensor, + k_cache: paddle.Tensor, + v_cache: paddle.Tensor, + sink: paddle.Tensor, window_left: int, causal: bool, sm_scale: float, -) -> torch.Tensor: +) -> paddle.Tensor: """Backward compatible wrapper for incremental mode.""" return sink_attention_unified( q, k_cache, v_cache, sink, window_left, causal, sm_scale, mode="incremental" @@ -94,14 +96,14 @@ def sink_attention_incremental_ref( def sink_attention_chunk_ref( batch_size: int, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - sink: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + sink: paddle.Tensor, window_left: int, causal: bool, sm_scale: float, -) -> torch.Tensor: +) -> paddle.Tensor: """Wrapper for chunk prefill mode.""" return sink_attention_unified( q, @@ -117,16 +119,16 @@ def sink_attention_chunk_ref( def sink_attention_varlen_ref( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - sink: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + sink: paddle.Tensor, window_left: int, causal: bool, sm_scale: float, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, -) -> torch.Tensor: + qo_indptr: paddle.Tensor, + kv_indptr: paddle.Tensor, +) -> paddle.Tensor: """Wrapper for variable length sequences mode.""" return sink_attention_unified( q, @@ -142,7 +144,7 @@ def sink_attention_varlen_ref( ) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("batch_size", [1, 4, 16]) @pytest.mark.parametrize("seq_len", [1, 4, 16, 128]) @pytest.mark.parametrize("num_qo_heads", [32]) @@ -153,33 +155,29 @@ def sink_attention_varlen_ref( def test_attention_sink( dtype, batch_size, seq_len, num_qo_heads, num_kv_heads, window_left, causal, backend ): - torch.manual_seed(42) - device = torch.device("cuda:0") + paddle.seed(seed=42) + device = device2str("cuda:0") if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") jit_args = ( - f"batch_prefill_attention_sink_{filename_safe_dtype_map[dtype]}_swa_{window_left >= 0}_{backend}", # uri - dtype, # dtype_q - dtype, # dtype_kv - dtype, # dtype_o - torch.int32, # idtype - 128, # hidden_dim_qk - 128, # hidden_dim_vo - ["sink"], # additional_tensor_names - ["float"], # additional_tensor_dtypes - ["sm_scale"], # additional_scalar_names - ["double"], # additional_scalar_dtypes + f"batch_prefill_attention_sink_{filename_safe_dtype_map[dtype]}_swa_{window_left >= 0}_{backend}", + dtype, + dtype, + dtype, + "int32", + 128, + 128, + ["sink"], + ["float"], + ["sm_scale"], + ["double"], "AttentionSink", attention_sink_decl[backend], ) - jit_kwargs = { - "use_sliding_window": window_left >= 0, - } + jit_kwargs = {"use_sliding_window": window_left >= 0} sm_scale = 1.0 / math.sqrt(128) - torch.manual_seed(42) - float_workspace_buffer = torch.empty( - 128 * 1024 * 1024, dtype=torch.uint8, device=device - ) + paddle.seed(seed=42) + float_workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="uint8") wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", @@ -187,15 +185,13 @@ def test_attention_sink( jit_args=jit_args, jit_kwargs=jit_kwargs, ) - qo_indptr_host = torch.arange( - 0, batch_size * seq_len + 1, seq_len, dtype=torch.int32 + qo_indptr_host = paddle.arange( + start=0, end=batch_size * seq_len + 1, step=seq_len, dtype="int32" ) - kv_indptr_host = torch.arange( - 0, batch_size * seq_len + 1, seq_len, dtype=torch.int32 + kv_indptr_host = paddle.arange( + start=0, end=batch_size * seq_len + 1, step=seq_len, dtype="int32" ) - head_dim = 128 - wrapper.plan( qo_indptr_host, kv_indptr_host, @@ -207,38 +203,16 @@ def test_attention_sink( q_data_type=dtype, kv_data_type=dtype, ) - - q = torch.randn( - batch_size * seq_len, - num_qo_heads, - head_dim, - dtype=dtype, - device=device, - ) - k = torch.randn( - batch_size * seq_len, - num_kv_heads, - head_dim, - dtype=dtype, - device=device, - ) - v = torch.randn( - batch_size * seq_len, - num_kv_heads, - head_dim, - dtype=dtype, - device=device, - ) - - sink = torch.rand(num_qo_heads, device=device, dtype=torch.float32) * 5 - + q = paddle.randn(shape=[batch_size * seq_len, num_qo_heads, head_dim], dtype=dtype) + k = paddle.randn(shape=[batch_size * seq_len, num_kv_heads, head_dim], dtype=dtype) + v = paddle.randn(shape=[batch_size * seq_len, num_kv_heads, head_dim], dtype=dtype) + sink = paddle.rand(shape=num_qo_heads, dtype="float32") * 5 o = wrapper.run(q, k, v, sink, sm_scale) o_ref = sink_attention_ref(batch_size, q, k, v, sink, window_left, causal, sm_scale) - if dtype == torch.float16: - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" else: - torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2) - + assert paddle.allclose(x=o, y=o_ref, rtol=0.01, atol=0.01).item(), "" wrapper_paged = flashinfer.BatchPrefillWithPagedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", @@ -246,12 +220,10 @@ def test_attention_sink( jit_args=jit_args, jit_kwargs=jit_kwargs, ) - kv_indices_host = torch.arange( - 0, - batch_size * seq_len, - dtype=torch.int32, + kv_indices_host = paddle.arange(start=0, end=batch_size * seq_len, dtype="int32") + paged_kv_last_page_len_host = paddle.full( + shape=(batch_size,), fill_value=1, dtype="int32" ) - paged_kv_last_page_len_host = torch.full((batch_size,), 1, dtype=torch.int32) wrapper_paged.plan( qo_indptr_host, kv_indptr_host, @@ -268,43 +240,32 @@ def test_attention_sink( non_blocking=True, ) o_paged = wrapper_paged.run(q, (k, v), sink, sm_scale) - if dtype == torch.float16: - torch.testing.assert_close(o_paged, o_ref, rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose(x=o_paged, y=o_ref, rtol=0.001, atol=0.001).item(), "" else: - torch.testing.assert_close(o_paged, o_ref, rtol=1e-2, atol=1e-2) - - # Test with non-contiguous KV indices (production scenario) + assert paddle.allclose(x=o_paged, y=o_ref, rtol=0.01, atol=0.01).item(), "" total_pages = batch_size * seq_len - if total_pages > 1: # Only test fragmentation when we have multiple pages - # Create a fragmented page allocation pattern + if total_pages > 1: import random - random.seed(42 + total_pages) # Deterministic but varied seed - all_pages = list(range(0, total_pages * 2)) # Larger page pool + random.seed(42 + total_pages) + all_pages = list(range(0, total_pages * 2)) occupied_pages = set( random.sample(all_pages, min(total_pages, len(all_pages) // 2)) ) available_pages = [p for p in all_pages if p not in occupied_pages] - - # Allocate non-contiguous pages - kv_indices_fragmented = torch.tensor( - available_pages[:total_pages], dtype=torch.int32, device=device + kv_indices_fragmented = paddle.to_tensor( + data=available_pages[:total_pages], dtype="int32", place=device ) - - # Create new paged KV cache with larger capacity - k_paged_frag = torch.randn( - total_pages * 2, 1, num_kv_heads, head_dim, dtype=dtype, device=device + k_paged_frag = paddle.randn( + shape=[total_pages * 2, 1, num_kv_heads, head_dim], dtype=dtype ) - v_paged_frag = torch.randn( - total_pages * 2, 1, num_kv_heads, head_dim, dtype=dtype, device=device + v_paged_frag = paddle.randn( + shape=[total_pages * 2, 1, num_kv_heads, head_dim], dtype=dtype ) - - # Copy K,V data to fragmented pages for i, page_idx in enumerate(kv_indices_fragmented): k_paged_frag[page_idx, 0] = k[i] v_paged_frag[page_idx, 0] = v[i] - - # Test with fragmented indices wrapper_paged_frag = flashinfer.BatchPrefillWithPagedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", @@ -330,15 +291,17 @@ def test_attention_sink( o_paged_frag = wrapper_paged_frag.run( q, (k_paged_frag, v_paged_frag), sink, sm_scale ) - - # Verify fragmented result matches reference - if dtype == torch.float16: - torch.testing.assert_close(o_paged_frag, o_ref, rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose( + x=o_paged_frag, y=o_ref, rtol=0.001, atol=0.001 + ).item(), "" else: - torch.testing.assert_close(o_paged_frag, o_ref, rtol=1e-2, atol=1e-2) + assert paddle.allclose( + x=o_paged_frag, y=o_ref, rtol=0.01, atol=0.01 + ).item(), "" -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("batch_size", [1, 4, 16]) @pytest.mark.parametrize("initial_seq_len", [32, 128]) @pytest.mark.parametrize("num_generation_steps", [1, 2, 4]) @@ -362,22 +325,19 @@ def test_attention_sink_incremental_generation( Test incremental generation scenario: q_len=1, kv_len grows gradually Simulate the token-by-token generation process in real large model inference """ - torch.manual_seed(42) - device = torch.device("cuda:0") + paddle.seed(seed=42) + device = device2str("cuda:0") if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") head_dim = 128 sm_scale = 1.0 / math.sqrt(head_dim) - - torch.manual_seed(42) - - # Create JIT arguments + paddle.seed(seed=42) jit_args = ( f"batch_prefill_attention_sink_{filename_safe_dtype_map[dtype]}_swa_{window_left >= 0}_{backend}", dtype, dtype, dtype, - torch.int32, + "int32", head_dim, head_dim, ["sink"], @@ -387,56 +347,30 @@ def test_attention_sink_incremental_generation( "AttentionSink", attention_sink_decl[backend], ) - jit_kwargs = { - "use_sliding_window": window_left >= 0, - } - - float_workspace_buffer = torch.empty( - 128 * 1024 * 1024, dtype=torch.uint8, device=device - ) - - # Initialize KV cache - simulate state after prefill phase - k_cache = torch.randn( - batch_size, initial_seq_len, num_kv_heads, head_dim, dtype=dtype, device=device + jit_kwargs = {"use_sliding_window": window_left >= 0} + float_workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="uint8") + k_cache = paddle.randn( + shape=[batch_size, initial_seq_len, num_kv_heads, head_dim], dtype=dtype ) - v_cache = torch.randn( - batch_size, initial_seq_len, num_kv_heads, head_dim, dtype=dtype, device=device + v_cache = paddle.randn( + shape=[batch_size, initial_seq_len, num_kv_heads, head_dim], dtype=dtype ) - - sink = torch.rand(num_qo_heads, device=device, dtype=torch.float32) * 5 - + sink = paddle.rand(shape=num_qo_heads, dtype="float32") * 5 k_accumulated = v_accumulated = None - # Simulate incremental generation process for step in range(num_generation_steps): current_kv_len = initial_seq_len + step - - # Current generated new token (q_len=1) - q_new = torch.randn( - batch_size, num_qo_heads, head_dim, dtype=dtype, device=device - ) - - # K,V for newly generated token - k_new = torch.randn( - batch_size, 1, num_kv_heads, head_dim, dtype=dtype, device=device - ) - v_new = torch.randn( - batch_size, 1, num_kv_heads, head_dim, dtype=dtype, device=device - ) - - # Update KV cache + q_new = paddle.randn(shape=[batch_size, num_qo_heads, head_dim], dtype=dtype) + k_new = paddle.randn(shape=[batch_size, 1, num_kv_heads, head_dim], dtype=dtype) + v_new = paddle.randn(shape=[batch_size, 1, num_kv_heads, head_dim], dtype=dtype) if step == 0: k_cache_current = k_cache v_cache_current = v_cache else: - k_cache_current = torch.cat([k_cache, k_accumulated], dim=1) - v_cache_current = torch.cat([v_cache, v_accumulated], dim=1) - - # Calculate reference result + k_cache_current = paddle.concat(x=[k_cache, k_accumulated], axis=1) + v_cache_current = paddle.concat(x=[v_cache, v_accumulated], axis=1) o_ref = sink_attention_incremental_ref( q_new, k_cache_current, v_cache_current, sink, window_left, causal, sm_scale ) - - # Use flashinfer to calculate result (need format conversion to adapt to existing API) wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", @@ -444,15 +378,13 @@ def test_attention_sink_incremental_generation( jit_args=jit_args, jit_kwargs=jit_kwargs, ) - - # Set correct indptr: q_len=1 for each batch, kv_len=current_kv_len for each batch - qo_indptr_host = torch.arange( - 0, batch_size + 1, dtype=torch.int32 - ) # [0, 1, 2, ..., batch_size] - kv_indptr_host = torch.arange( - 0, batch_size * current_kv_len + 1, current_kv_len, dtype=torch.int32 + qo_indptr_host = paddle.arange(start=0, end=batch_size + 1, dtype="int32") + kv_indptr_host = paddle.arange( + start=0, + end=batch_size * current_kv_len + 1, + step=current_kv_len, + dtype="int32", ) - wrapper.plan( qo_indptr_host, kv_indptr_host, @@ -464,27 +396,18 @@ def test_attention_sink_incremental_generation( q_data_type=dtype, kv_data_type=dtype, ) - - # Convert to format expected by flashinfer [total_q_len, num_heads, head_dim] - q_flashinfer = q_new.view( - batch_size, num_qo_heads, head_dim - ) # [batch_size, num_heads, head_dim] + q_flashinfer = q_new.view(batch_size, num_qo_heads, head_dim) k_flashinfer = k_cache_current.view( batch_size * current_kv_len, num_kv_heads, head_dim ) v_flashinfer = v_cache_current.view( batch_size * current_kv_len, num_kv_heads, head_dim ) - o = wrapper.run(q_flashinfer, k_flashinfer, v_flashinfer, sink, sm_scale) - - # Verify results - if dtype == torch.float16: - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" else: - torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2) - - # Also test with BatchPrefillWithPagedKVCacheWrapper + assert paddle.allclose(x=o, y=o_ref, rtol=0.01, atol=0.01).item(), "" wrapper_paged = flashinfer.BatchPrefillWithPagedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", @@ -492,12 +415,12 @@ def test_attention_sink_incremental_generation( jit_args=jit_args, jit_kwargs=jit_kwargs, ) - kv_indices_host = torch.arange( - 0, - batch_size * current_kv_len, - dtype=torch.int32, + kv_indices_host = paddle.arange( + start=0, end=batch_size * current_kv_len, dtype="int32" + ) + paged_kv_last_page_len_host = paddle.full( + shape=(batch_size,), fill_value=1, dtype="int32" ) - paged_kv_last_page_len_host = torch.full((batch_size,), 1, dtype=torch.int32) wrapper_paged.plan( qo_indptr_host, kv_indptr_host, @@ -516,43 +439,34 @@ def test_attention_sink_incremental_generation( o_paged = wrapper_paged.run( q_flashinfer, (k_flashinfer, v_flashinfer), sink, sm_scale ) - if dtype == torch.float16: - torch.testing.assert_close(o_paged, o_ref, rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose( + x=o_paged, y=o_ref, rtol=0.001, atol=0.001 + ).item(), "" else: - torch.testing.assert_close(o_paged, o_ref, rtol=1e-2, atol=1e-2) - - # Test with non-contiguous KV indices for incremental generation + assert paddle.allclose(x=o_paged, y=o_ref, rtol=0.01, atol=0.01).item(), "" total_pages = batch_size * current_kv_len - if total_pages > 1: # Only test fragmentation when we have multiple pages - # Create fragmented page allocation pattern + if total_pages > 1: import random - random.seed(42 + step + current_kv_len) # Vary seed with step and length + random.seed(42 + step + current_kv_len) all_pages = list(range(0, total_pages * 2)) occupied_pages = set( random.sample(all_pages, min(total_pages, len(all_pages) // 2)) ) available_pages = [p for p in all_pages if p not in occupied_pages] - - # Allocate non-contiguous pages - kv_indices_fragmented = torch.tensor( - available_pages[:total_pages], dtype=torch.int32, device=device + kv_indices_fragmented = paddle.to_tensor( + data=available_pages[:total_pages], dtype="int32", place=device ) - - # Create fragmented paged KV cache - k_paged_frag = torch.randn( - total_pages * 2, 1, num_kv_heads, head_dim, dtype=dtype, device=device + k_paged_frag = paddle.randn( + shape=[total_pages * 2, 1, num_kv_heads, head_dim], dtype=dtype ) - v_paged_frag = torch.randn( - total_pages * 2, 1, num_kv_heads, head_dim, dtype=dtype, device=device + v_paged_frag = paddle.randn( + shape=[total_pages * 2, 1, num_kv_heads, head_dim], dtype=dtype ) - - # Copy K,V data to fragmented pages for i, page_idx in enumerate(kv_indices_fragmented): k_paged_frag[page_idx, 0] = k_flashinfer[i] v_paged_frag[page_idx, 0] = v_flashinfer[i] - - # Test with fragmented indices wrapper_paged_frag = flashinfer.BatchPrefillWithPagedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", @@ -578,27 +492,26 @@ def test_attention_sink_incremental_generation( o_paged_frag = wrapper_paged_frag.run( q_flashinfer, (k_paged_frag, v_paged_frag), sink, sm_scale ) - - # Verify fragmented result matches reference - if dtype == torch.float16: - torch.testing.assert_close(o_paged_frag, o_ref, rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose( + x=o_paged_frag, y=o_ref, rtol=0.001, atol=0.001 + ).item(), "" else: - torch.testing.assert_close(o_paged_frag, o_ref, rtol=1e-2, atol=1e-2) - - # Accumulate new K,V for next step + assert paddle.allclose( + x=o_paged_frag, y=o_ref, rtol=0.01, atol=0.01 + ).item(), "" if step == 0: k_accumulated = k_new v_accumulated = v_new else: - k_accumulated = torch.cat([k_accumulated, k_new], dim=1) - v_accumulated = torch.cat([v_accumulated, v_new], dim=1) - + k_accumulated = paddle.concat(x=[k_accumulated, k_new], axis=1) + v_accumulated = paddle.concat(x=[v_accumulated, v_new], axis=1) print( f"Step {step}: q_len=1, kv_len={current_kv_len}, both RaggedKV and PagedKV wrappers passed!" ) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("batch_size", [1, 4, 16]) @pytest.mark.parametrize("chunk_size", [128, 256]) @pytest.mark.parametrize("historical_len", [256, 512]) @@ -623,28 +536,24 @@ def test_attention_sink_chunk_prefill( Simulate chunk-based processing of long sequences where current chunk attends to all historical tokens plus current chunk tokens """ - torch.manual_seed(42) - device = torch.device("cuda:0") + paddle.seed(seed=42) + device = device2str("cuda:0") if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") - # Skip invalid combinations if chunk_size >= historical_len: pytest.skip( "chunk_size should be smaller than historical_len for meaningful chunk prefill test" ) - head_dim = 128 sm_scale = 1.0 / math.sqrt(head_dim) - torch.manual_seed(42) + paddle.seed(seed=42) total_kv_len = historical_len + chunk_size - - # Create JIT arguments jit_args = ( f"batch_prefill_attention_sink_{filename_safe_dtype_map[dtype]}_swa_{window_left >= 0}_{backend}", dtype, dtype, dtype, - torch.int32, + "int32", head_dim, head_dim, ["sink"], @@ -654,36 +563,21 @@ def test_attention_sink_chunk_prefill( "AttentionSink", attention_sink_decl[backend], ) - jit_kwargs = { - "use_sliding_window": window_left >= 0, - } - - float_workspace_buffer = torch.empty( - 128 * 1024 * 1024, dtype=torch.uint8, device=device - ) - - # Create input tensors for chunk prefill scenario - # q represents current chunk: [batch_size * chunk_size, num_heads, head_dim] - q_chunk = torch.randn( - batch_size * chunk_size, num_qo_heads, head_dim, dtype=dtype, device=device + jit_kwargs = {"use_sliding_window": window_left >= 0} + float_workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="uint8") + q_chunk = paddle.randn( + shape=[batch_size * chunk_size, num_qo_heads, head_dim], dtype=dtype ) - - # k, v represent all tokens (historical + current chunk) - k_all = torch.randn( - batch_size * total_kv_len, num_kv_heads, head_dim, dtype=dtype, device=device + k_all = paddle.randn( + shape=[batch_size * total_kv_len, num_kv_heads, head_dim], dtype=dtype ) - v_all = torch.randn( - batch_size * total_kv_len, num_kv_heads, head_dim, dtype=dtype, device=device + v_all = paddle.randn( + shape=[batch_size * total_kv_len, num_kv_heads, head_dim], dtype=dtype ) - - sink = torch.rand(num_qo_heads, device=device, dtype=torch.float32) * 5 - - # Calculate reference result using chunk prefill mode + sink = paddle.rand(shape=num_qo_heads, dtype="float32") * 5 o_ref = sink_attention_chunk_ref( batch_size, q_chunk, k_all, v_all, sink, window_left, causal, sm_scale ) - - # Test with flashinfer wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", @@ -691,15 +585,12 @@ def test_attention_sink_chunk_prefill( jit_args=jit_args, jit_kwargs=jit_kwargs, ) - - # Set up indices for chunk prefill - qo_indptr_host = torch.arange( - 0, batch_size * chunk_size + 1, chunk_size, dtype=torch.int32 + qo_indptr_host = paddle.arange( + start=0, end=batch_size * chunk_size + 1, step=chunk_size, dtype="int32" ) - kv_indptr_host = torch.arange( - 0, batch_size * total_kv_len + 1, total_kv_len, dtype=torch.int32 + kv_indptr_host = paddle.arange( + start=0, end=batch_size * total_kv_len + 1, step=total_kv_len, dtype="int32" ) - wrapper.plan( qo_indptr_host, kv_indptr_host, @@ -711,16 +602,11 @@ def test_attention_sink_chunk_prefill( q_data_type=dtype, kv_data_type=dtype, ) - o = wrapper.run(q_chunk, k_all, v_all, sink, sm_scale) - - # Verify results - if dtype == torch.float16: - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" else: - torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2) - - # Also test with BatchPrefillWithPagedKVCacheWrapper + assert paddle.allclose(x=o, y=o_ref, rtol=0.01, atol=0.01).item(), "" wrapper_paged = flashinfer.BatchPrefillWithPagedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", @@ -728,12 +614,12 @@ def test_attention_sink_chunk_prefill( jit_args=jit_args, jit_kwargs=jit_kwargs, ) - kv_indices_host = torch.arange( - 0, - batch_size * total_kv_len, - dtype=torch.int32, + kv_indices_host = paddle.arange( + start=0, end=batch_size * total_kv_len, dtype="int32" + ) + paged_kv_last_page_len_host = paddle.full( + shape=(batch_size,), fill_value=1, dtype="int32" ) - paged_kv_last_page_len_host = torch.full((batch_size,), 1, dtype=torch.int32) wrapper_paged.plan( qo_indptr_host, kv_indptr_host, @@ -750,45 +636,32 @@ def test_attention_sink_chunk_prefill( non_blocking=True, ) o_paged = wrapper_paged.run(q_chunk, (k_all, v_all), sink, sm_scale) - if dtype == torch.float16: - torch.testing.assert_close(o_paged, o_ref, rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose(x=o_paged, y=o_ref, rtol=0.001, atol=0.001).item(), "" else: - torch.testing.assert_close(o_paged, o_ref, rtol=1e-2, atol=1e-2) - - # Test with non-contiguous KV indices for chunk prefill + assert paddle.allclose(x=o_paged, y=o_ref, rtol=0.01, atol=0.01).item(), "" total_pages = batch_size * total_kv_len - if total_pages > 1: # Only test fragmentation when we have multiple pages - # Create fragmented page allocation pattern + if total_pages > 1: import random - random.seed( - 42 + batch_size + total_kv_len - ) # Vary seed with batch and total length + random.seed(42 + batch_size + total_kv_len) all_pages = list(range(0, total_pages * 2)) occupied_pages = set( random.sample(all_pages, min(total_pages, len(all_pages) // 2)) ) available_pages = [p for p in all_pages if p not in occupied_pages] - - # Allocate non-contiguous pages - kv_indices_fragmented = torch.tensor( - available_pages[:total_pages], dtype=torch.int32, device=device + kv_indices_fragmented = paddle.to_tensor( + data=available_pages[:total_pages], dtype="int32", place=device ) - - # Create fragmented paged KV cache - k_paged_frag = torch.randn( - total_pages * 2, 1, num_kv_heads, head_dim, dtype=dtype, device=device + k_paged_frag = paddle.randn( + shape=[total_pages * 2, 1, num_kv_heads, head_dim], dtype=dtype ) - v_paged_frag = torch.randn( - total_pages * 2, 1, num_kv_heads, head_dim, dtype=dtype, device=device + v_paged_frag = paddle.randn( + shape=[total_pages * 2, 1, num_kv_heads, head_dim], dtype=dtype ) - - # Copy K,V data to fragmented pages for i, page_idx in enumerate(kv_indices_fragmented): k_paged_frag[page_idx, 0] = k_all[i] v_paged_frag[page_idx, 0] = v_all[i] - - # Test with fragmented indices wrapper_paged_frag = flashinfer.BatchPrefillWithPagedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", @@ -814,24 +687,23 @@ def test_attention_sink_chunk_prefill( o_paged_frag = wrapper_paged_frag.run( q_chunk, (k_paged_frag, v_paged_frag), sink, sm_scale ) - - # Verify fragmented result matches reference - if dtype == torch.float16: - torch.testing.assert_close(o_paged_frag, o_ref, rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose( + x=o_paged_frag, y=o_ref, rtol=0.001, atol=0.001 + ).item(), "" else: - torch.testing.assert_close(o_paged_frag, o_ref, rtol=1e-2, atol=1e-2) - + assert paddle.allclose( + x=o_paged_frag, y=o_ref, rtol=0.01, atol=0.01 + ).item(), "" print( - f"Chunk prefill test passed: q_len={chunk_size}, kv_len={total_kv_len}, " - f"batch_size={batch_size}, causal={causal}" + f"Chunk prefill test passed: q_len={chunk_size}, kv_len={total_kv_len}, batch_size={batch_size}, causal={causal}" ) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize( "indptr_config", [ - # (qo_indptr, kv_indptr, description) ( [0, 32, 64, 128, 256], [0, 128, 256, 512, 1024], @@ -867,27 +739,21 @@ def test_attention_sink_varlen( Test variable length sequences within a batch. Each request in the batch can have different query and key/value lengths. """ - torch.manual_seed(42) - device = torch.device("cuda:0") + paddle.seed(seed=42) + device = device2str("cuda:0") if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") - # Unpack the indptr configuration qo_indptr, kv_indptr, description = indptr_config - - # Validate that qo_indptr and kv_indptr have same batch size if len(qo_indptr) != len(kv_indptr): pytest.skip( f"qo_indptr and kv_indptr must have same batch size for {description}" ) - batch_size = len(qo_indptr) - 1 total_qo_len = qo_indptr[-1] total_kv_len = kv_indptr[-1] head_dim = 128 sm_scale = 1.0 / math.sqrt(head_dim) - torch.manual_seed(42) - - # Check if any request has qo_len > kv_len for causal case + paddle.seed(seed=42) if causal: for i in range(batch_size): qo_len_i = qo_indptr[i + 1] - qo_indptr[i] @@ -896,56 +762,37 @@ def test_attention_sink_varlen( pytest.skip( "qo_len > kv_len not supported for causal attention in varlen mode" ) - - # Create input tensors - q = torch.randn(total_qo_len, num_qo_heads, head_dim, dtype=dtype, device=device) - k = torch.randn(total_kv_len, num_kv_heads, head_dim, dtype=dtype, device=device) - v = torch.randn(total_kv_len, num_kv_heads, head_dim, dtype=dtype, device=device) - - qo_indptr_tensor = torch.tensor(qo_indptr, dtype=torch.int32, device=device) - kv_indptr_tensor = torch.tensor(kv_indptr, dtype=torch.int32, device=device) - - sink = torch.rand(num_qo_heads, device=device, dtype=torch.float32) * 5 - - # Test the variable length reference implementation + q = paddle.randn(shape=[total_qo_len, num_qo_heads, head_dim], dtype=dtype) + k = paddle.randn(shape=[total_kv_len, num_kv_heads, head_dim], dtype=dtype) + v = paddle.randn(shape=[total_kv_len, num_kv_heads, head_dim], dtype=dtype) + qo_indptr_tensor = paddle.to_tensor(data=qo_indptr, dtype="int32", place=device) + kv_indptr_tensor = paddle.to_tensor(data=kv_indptr, dtype="int32", place=device) + sink = paddle.rand(shape=num_qo_heads, dtype="float32") * 5 o_ref = sink_attention_varlen_ref( q, k, v, sink, window_left, causal, sm_scale, qo_indptr_tensor, kv_indptr_tensor ) - - # Verify output shape - assert o_ref.shape == ( + assert tuple(o_ref.shape) == ( total_qo_len, num_qo_heads, head_dim, - ), f"Expected shape ({total_qo_len}, {num_qo_heads}, {head_dim}), got {o_ref.shape}" - - # Test against FlashInfer kernel for verification - # Create JIT arguments for attention sink + ), f"Expected shape ({total_qo_len}, {num_qo_heads}, {head_dim}), got {tuple(o_ref.shape)}" jit_args = ( - f"batch_prefill_attention_sink_{filename_safe_dtype_map[dtype]}_swa_{window_left >= 0}_{backend}", # uri - dtype, # dtype_q - dtype, # dtype_kv - dtype, # dtype_o - torch.int32, # idtype - head_dim, # hidden_dim_qk - head_dim, # hidden_dim_vo - ["sink"], # additional_tensor_names - ["float"], # additional_tensor_dtypes - ["sm_scale"], # additional_scalar_names - ["double"], # additional_scalar_dtypes + f"batch_prefill_attention_sink_{filename_safe_dtype_map[dtype]}_swa_{window_left >= 0}_{backend}", + dtype, + dtype, + dtype, + "int32", + head_dim, + head_dim, + ["sink"], + ["float"], + ["sm_scale"], + ["double"], "AttentionSink", attention_sink_decl[backend], ) - jit_kwargs = { - "use_sliding_window": window_left >= 0, - } - - # Create workspace buffer - float_workspace_buffer = torch.empty( - 128 * 1024 * 1024, dtype=torch.uint8, device=device - ) - - # Test with BatchPrefillWithRaggedKVCacheWrapper + jit_kwargs = {"use_sliding_window": window_left >= 0} + float_workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="uint8") wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", @@ -953,7 +800,6 @@ def test_attention_sink_varlen( jit_args=jit_args, jit_kwargs=jit_kwargs, ) - wrapper.plan( qo_indptr_tensor, kv_indptr_tensor, @@ -965,16 +811,11 @@ def test_attention_sink_varlen( q_data_type=dtype, kv_data_type=dtype, ) - o = wrapper.run(q, k, v, sink, sm_scale) - - # Compare varlen reference result with FlashInfer kernel result - if dtype == torch.float16: - torch.testing.assert_close(o_ref, o, rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose(x=o_ref, y=o, rtol=0.001, atol=0.001).item(), "" else: - torch.testing.assert_close(o_ref, o, rtol=1e-2, atol=1e-2) - - # Also test with BatchPrefillWithPagedKVCacheWrapper + assert paddle.allclose(x=o_ref, y=o, rtol=0.01, atol=0.01).item(), "" wrapper_paged = flashinfer.BatchPrefillWithPagedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", @@ -982,9 +823,9 @@ def test_attention_sink_varlen( jit_args=jit_args, jit_kwargs=jit_kwargs, ) - kv_indices_host = torch.arange(0, total_kv_len, dtype=torch.int32, device=device) - paged_kv_last_page_len_host = torch.full( - (batch_size,), 1, dtype=torch.int32, device=device + kv_indices_host = paddle.arange(start=0, end=total_kv_len, dtype="int32") + paged_kv_last_page_len_host = paddle.full( + shape=(batch_size,), fill_value=1, dtype="int32" ) wrapper_paged.plan( qo_indptr_tensor, @@ -1002,45 +843,32 @@ def test_attention_sink_varlen( non_blocking=True, ) o_paged = wrapper_paged.run(q, (k, v), sink, sm_scale) - if dtype == torch.float16: - torch.testing.assert_close(o_ref, o_paged, rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose(x=o_ref, y=o_paged, rtol=0.001, atol=0.001).item(), "" else: - torch.testing.assert_close(o_ref, o_paged, rtol=1e-2, atol=1e-2) - - # Test with non-contiguous KV indices for variable length sequences + assert paddle.allclose(x=o_ref, y=o_paged, rtol=0.01, atol=0.01).item(), "" total_pages = total_kv_len - if total_pages > 1: # Only test fragmentation when we have multiple pages - # Create fragmented page allocation pattern + if total_pages > 1: import random - random.seed( - 42 + batch_size + total_kv_len - ) # Vary seed with batch and total length + random.seed(42 + batch_size + total_kv_len) all_pages = list(range(0, total_pages * 2)) occupied_pages = set( random.sample(all_pages, min(total_pages, len(all_pages) // 2)) ) available_pages = [p for p in all_pages if p not in occupied_pages] - - # Allocate non-contiguous pages - kv_indices_fragmented = torch.tensor( - available_pages[:total_pages], dtype=torch.int32, device=device + kv_indices_fragmented = paddle.to_tensor( + data=available_pages[:total_pages], dtype="int32", place=device ) - - # Create fragmented paged KV cache - k_paged_frag = torch.randn( - total_pages * 2, 1, num_kv_heads, head_dim, dtype=dtype, device=device + k_paged_frag = paddle.randn( + shape=[total_pages * 2, 1, num_kv_heads, head_dim], dtype=dtype ) - v_paged_frag = torch.randn( - total_pages * 2, 1, num_kv_heads, head_dim, dtype=dtype, device=device + v_paged_frag = paddle.randn( + shape=[total_pages * 2, 1, num_kv_heads, head_dim], dtype=dtype ) - - # Copy K,V data to fragmented pages for i, page_idx in enumerate(kv_indices_fragmented): k_paged_frag[page_idx, 0] = k[i] v_paged_frag[page_idx, 0] = v[i] - - # Test with fragmented indices wrapper_paged_frag = flashinfer.BatchPrefillWithPagedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", backend=backend, jit_args=jit_args ) @@ -1062,24 +890,22 @@ def test_attention_sink_varlen( o_paged_frag = wrapper_paged_frag.run( q, (k_paged_frag, v_paged_frag), sink, sm_scale ) - - # Verify fragmented result matches reference - if dtype == torch.float16: - torch.testing.assert_close(o_ref, o_paged_frag, rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose( + x=o_ref, y=o_paged_frag, rtol=0.001, atol=0.001 + ).item(), "" else: - torch.testing.assert_close(o_ref, o_paged_frag, rtol=1e-2, atol=1e-2) - + assert paddle.allclose( + x=o_ref, y=o_paged_frag, rtol=0.01, atol=0.01 + ).item(), "" print( - f"Variable length test passed: {description}, batch_size={batch_size}, " - f"qo_lens={[qo_indptr[i + 1] - qo_indptr[i] for i in range(batch_size)]}, " - f"kv_lens={[kv_indptr[i + 1] - kv_indptr[i] for i in range(batch_size)]}, " - f"causal={causal}" + f"Variable length test passed: {description}, batch_size={batch_size}, qo_lens={[(qo_indptr[i + 1] - qo_indptr[i]) for i in range(batch_size)]}, kv_lens={[(kv_indptr[i + 1] - kv_indptr[i]) for i in range(batch_size)]}, causal={causal}" ) if __name__ == "__main__": test_attention_sink( - torch.float16, + "float16", batch_size=128, seq_len=1024, num_qo_heads=32, diff --git a/tests/test_attention_sink_blackwell.py b/tests/test_attention_sink_blackwell.py index 3dfca6ea0b..7a4b5742d7 100644 --- a/tests/test_attention_sink_blackwell.py +++ b/tests/test_attention_sink_blackwell.py @@ -1,3 +1,10 @@ +import sys + + +import einops +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,16 +20,13 @@ See the License for the specific language governing permissions and limitations under the License. """ - -import einops import pytest -import torch from sink_attention_reference import sink_attention_unified import flashinfer -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("batch_size", [1, 4, 16]) @pytest.mark.parametrize("page_size", [32]) @pytest.mark.parametrize("seq_len", [32, 128, 1024]) @@ -30,50 +34,28 @@ @pytest.mark.parametrize("num_kv_heads", [8, 32]) @pytest.mark.parametrize("head_dim", [64, 128]) def test_blackwell_trtllm_gen_decode_attention_sink( - dtype, - batch_size, - page_size, - seq_len, - num_qo_heads, - num_kv_heads, - head_dim, + dtype, batch_size, page_size, seq_len, num_qo_heads, num_kv_heads, head_dim ): seed = 0 - torch.manual_seed(seed) + paddle.seed(seed=seed) device = "cuda:0" - - seq_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device=device) - + seq_lens = paddle.full(shape=(batch_size,), fill_value=seq_len, dtype="int32") blocks_per_seq = (seq_lens + page_size - 1) // page_size - max_num_blocks_per_seq = torch.max(blocks_per_seq).item() - - # Generate unique block IDs for all sequences - block_tables = torch.arange( - (batch_size * max_num_blocks_per_seq), dtype=torch.int32, device=device + max_num_blocks_per_seq = paddle.max(x=blocks_per_seq).item() + block_tables = paddle.arange( + dtype="int32", end=batch_size * max_num_blocks_per_seq ).reshape(batch_size, max_num_blocks_per_seq) - - # Create separate K and V caches num_tokens = seq_len * batch_size num_blocks = (num_tokens + page_size - 1) // page_size - q = torch.randn( - batch_size, - num_qo_heads, - head_dim, - dtype=dtype, - device=device, + q = paddle.randn(shape=[batch_size, num_qo_heads, head_dim], dtype=dtype) + k_cache = paddle.randn( + shape=[num_blocks, num_kv_heads, page_size, head_dim], dtype=dtype ) - - k_cache = torch.randn( - num_blocks, num_kv_heads, page_size, head_dim, dtype=dtype, device=device - ) - v_cache = torch.randn( - num_blocks, num_kv_heads, page_size, head_dim, dtype=dtype, device=device + v_cache = paddle.randn( + shape=[num_blocks, num_kv_heads, page_size, head_dim], dtype=dtype ) - - sink = torch.rand(num_qo_heads, device=device, dtype=torch.float32) * 5 - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - + sink = paddle.rand(shape=num_qo_heads, dtype="float32") * 5 + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") output = flashinfer.decode.trtllm_batch_decode_with_kv_cache( q.contiguous(), (k_cache, v_cache), @@ -81,13 +63,12 @@ def test_blackwell_trtllm_gen_decode_attention_sink( block_tables, seq_lens, seq_len, - 1.0, # bmm1_scale - 1.0, # bmm2_scale - -1, # window_left + 1.0, + 1.0, + -1, out_dtype=dtype, sinks=sink, ) - k = einops.rearrange( k_cache, "(b num_pages_per_b) h p d -> b (num_pages_per_b p) h d", @@ -98,29 +79,17 @@ def test_blackwell_trtllm_gen_decode_attention_sink( "(b num_pages_per_b) h p d -> b (num_pages_per_b p) h d", num_pages_per_b=max_num_blocks_per_seq, ) - - o_ref = sink_attention_unified( - q, - k, - v, - sink, - -1, - False, - 1.0, - mode="incremental", - ) - - if dtype == torch.float16: - atol, rtol = 1e-3, 1e-3 - elif dtype == torch.bfloat16: - atol, rtol = 1e-2, 1e-2 + o_ref = sink_attention_unified(q, k, v, sink, -1, False, 1.0, mode="incremental") + if dtype == "float16": + atol, rtol = 0.001, 0.001 + elif dtype == "bfloat16": + atol, rtol = 0.01, 0.01 else: raise ValueError(f"Unsupported dtype: {dtype}") - - torch.testing.assert_close(o_ref, output, atol=atol, rtol=rtol) + assert paddle.allclose(x=o_ref, y=output, atol=atol, rtol=rtol).item(), "" -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("batch_size", [1, 4, 16]) @pytest.mark.parametrize("page_size", [32]) @pytest.mark.parametrize("seq_len", [32, 128, 1024]) @@ -128,56 +97,30 @@ def test_blackwell_trtllm_gen_decode_attention_sink( @pytest.mark.parametrize("num_kv_heads", [8, 32]) @pytest.mark.parametrize("head_dim", [64, 128]) def test_blackwell_trtllm_gen_context_attention_sink( - dtype, - batch_size, - page_size, - seq_len, - num_qo_heads, - num_kv_heads, - head_dim, + dtype, batch_size, page_size, seq_len, num_qo_heads, num_kv_heads, head_dim ): seed = 0 - torch.manual_seed(seed) + paddle.seed(seed=seed) device = "cuda:0" - - seq_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device=device) - + seq_lens = paddle.full(shape=(batch_size,), fill_value=seq_len, dtype="int32") blocks_per_seq = (seq_lens + page_size - 1) // page_size - max_num_blocks_per_seq = torch.max(blocks_per_seq).item() - - # Generate unique block IDs for all sequences - block_tables = torch.arange( - (batch_size * max_num_blocks_per_seq), dtype=torch.int32, device=device + max_num_blocks_per_seq = paddle.max(x=blocks_per_seq).item() + block_tables = paddle.arange( + dtype="int32", end=batch_size * max_num_blocks_per_seq ).reshape(batch_size, max_num_blocks_per_seq) - - # Create separate K and V caches num_tokens = seq_len * batch_size num_blocks = (num_tokens + page_size - 1) // page_size - q = torch.randn( - num_tokens, - num_qo_heads, - head_dim, - dtype=dtype, - device=device, - ) - - k_cache = torch.randn( - num_blocks, num_kv_heads, page_size, head_dim, dtype=dtype, device=device - ) - v_cache = torch.randn( - num_blocks, num_kv_heads, page_size, head_dim, dtype=dtype, device=device + q = paddle.randn(shape=[num_tokens, num_qo_heads, head_dim], dtype=dtype) + k_cache = paddle.randn( + shape=[num_blocks, num_kv_heads, page_size, head_dim], dtype=dtype ) - - sink = torch.rand(num_qo_heads, device=device, dtype=torch.float32) * 5 - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - q_indptr = ( - torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) * seq_len + v_cache = paddle.randn( + shape=[num_blocks, num_kv_heads, page_size, head_dim], dtype=dtype ) - kv_indptr = ( - torch.arange(0, num_blocks + 1, dtype=torch.int32, device=device) * page_size - ) - + sink = paddle.rand(shape=num_qo_heads, dtype="float32") * 5 + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") + q_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * seq_len + kv_indptr = paddle.arange(start=0, end=num_blocks + 1, dtype="int32") * page_size output = flashinfer.prefill.trtllm_batch_context_with_kv_cache( q.contiguous(), (k_cache, v_cache), @@ -186,44 +129,25 @@ def test_blackwell_trtllm_gen_context_attention_sink( seq_lens, seq_len, seq_len, - 1.0, # bmm1_scale - 1.0, # bmm2_scale + 1.0, + 1.0, batch_size, q_indptr, kv_indptr, - -1, # window_left + -1, out_dtype=dtype, sinks=sink, ) - - k = einops.rearrange( - k_cache, - "num_pages h p d -> (num_pages p) h d", - ) - v = einops.rearrange( - v_cache, - "num_pages h p d -> (num_pages p) h d", - ) - - print(q.shape, k.shape, v.shape) - + k = einops.rearrange(k_cache, "num_pages h p d -> (num_pages p) h d") + v = einops.rearrange(v_cache, "num_pages h p d -> (num_pages p) h d") + print(tuple(q.shape), tuple(k.shape), tuple(v.shape)) o_ref = sink_attention_unified( - q, - k, - v, - sink, - -1, - True, - 1.0, - mode="prefill", - batch_size=batch_size, + q, k, v, sink, -1, True, 1.0, mode="prefill", batch_size=batch_size ) - - if dtype == torch.float16: - atol, rtol = 1e-3, 1e-3 - elif dtype == torch.bfloat16: - atol, rtol = 1e-2, 1e-2 + if dtype == "float16": + atol, rtol = 0.001, 0.001 + elif dtype == "bfloat16": + atol, rtol = 0.01, 0.01 else: raise ValueError(f"Unsupported dtype: {dtype}") - - torch.testing.assert_close(o_ref, output, atol=atol, rtol=rtol) + assert paddle.allclose(x=o_ref, y=output, atol=atol, rtol=rtol).item(), "" diff --git a/tests/test_batch_attention.py b/tests/test_batch_attention.py index c4ff93e561..66e05f20ce 100644 --- a/tests/test_batch_attention.py +++ b/tests/test_batch_attention.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,64 +19,55 @@ See the License for the specific language governing permissions and limitations under the License. """ - import numpy as np import pytest -import torch +from jit_utils import (gen_persistent_batch_attention_modules, + gen_prefill_attention_modules) import flashinfer -from jit_utils import ( - gen_persistent_batch_attention_modules, - gen_prefill_attention_modules, -) @pytest.fixture(autouse=True, scope="module") def warmup_jit(): flashinfer.jit.build_jit_specs( gen_persistent_batch_attention_modules( - [torch.float16, torch.bfloat16], # q_dtypes - [torch.float16, torch.bfloat16], # kv_dtypes - [64, 128, 256], # head_dims - [False, True], # use_logits_soft_cap + ["float16", "bfloat16"], + ["float16", "bfloat16"], + [64, 128, 256], + [False, True], ) + gen_prefill_attention_modules( - [torch.float16, torch.bfloat16], # q_dtypes - [torch.float16, torch.bfloat16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False, True], # use_logits_soft_caps - [False], # use_fp16_qk_reductions + ["float16", "bfloat16"], + ["float16", "bfloat16"], + [64, 128, 256], + [0], + [False], + [False, True], + [False], ), verbose=False, ) -# ------------------------- Configuration generation function ----------------------------- # def _build_seq_len_configs(): """ Reproduce the sequence length configurations from the original benchmark (including random cases). Returns: List[List[Tuple[int,int]]] -> Each element is a list of (kv_len, qo_len) pairs. """ np.random.seed(42) - torch.manual_seed(42) - + paddle.seed(seed=42) seq_len_configs = [ [(8190, 7939)], - [(2, 235)] - + [(1, 13353)], # corner case with a large number of masked out tokens + [(2, 235)] + [(1, 13353)], [(67, 1)], [(182, 1)], [(2011, 1)], - [(2048, 1)] * 77, # decode-only - [(4099, 129)] * 2, # prefill-only + [(2048, 1)] * 77, + [(4099, 129)] * 2, [(600, 1)] * 132 * 2 + [(5000, 3)] * 128, - [(1024, 1)] * 100 + [(8192, 17)] * 8, # speculative decode - [(766, 2)] * 99 + [(1024, 512)] * 1, # chunked prefill + [(1024, 1)] * 100 + [(8192, 17)] * 8, + [(766, 2)] * 99 + [(1024, 512)] * 1, ] - - # Construct random seqlen tests bsz, stride, sparsity = 256, 16, 0.05 full_kv_len = np.random.randint(1000, 11000, size=bsz) seq_len = [] @@ -81,7 +78,6 @@ def _build_seq_len_configs(): kv_len, qo_len = int(full_kv_len[i] * sparsity), 1 seq_len.append((kv_len, qo_len)) seq_len_configs.append(seq_len) - return seq_len_configs @@ -93,7 +89,7 @@ def _run_attention( num_qo_heads=1, head_dim=128, layout="NHD", - test_dtype=torch.bfloat16, + test_dtype="bfloat16", logits_soft_cap=0.0, device="cuda", causal=True, @@ -101,48 +97,37 @@ def _run_attention( """ Run both implementations and return (output_old, lse_old, output_new, lse_new) """ - dev = torch.device(device) - seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=dev) - q_lens = torch.tensor(qo_lens, dtype=torch.int32, device=dev) - - seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int() - - q_indptr = torch.cat( - [torch.tensor([0], device=dev), torch.cumsum(q_lens, 0)], dim=0 - ).int() - kv_indptr = torch.cat( - [torch.tensor([0], device=dev), torch.cumsum(seq_lens_blocks, 0)], dim=0 - ).int() - + dev = device2str(device) + seq_lens = paddle.to_tensor(data=kv_lens, dtype="int32", place=dev) + q_lens = paddle.to_tensor(data=qo_lens, dtype="int32", place=dev) + seq_lens_blocks = paddle.ceil(x=seq_lens / page_block_size).astype(dtype="int32") + q_indptr = paddle.concat( + x=[paddle.to_tensor(data=[0], place=dev), paddle.cumsum(x=q_lens, axis=0)], + axis=0, + ).astype(dtype="int32") + kv_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=dev), + paddle.cumsum(x=seq_lens_blocks, axis=0), + ], + axis=0, + ).astype(dtype="int32") num_blocks = kv_indptr[-1].item() - - q = torch.rand( - q_indptr[-1].item(), num_qo_heads, head_dim, dtype=test_dtype, device=dev + q = paddle.rand( + shape=[q_indptr[-1].item(), num_qo_heads, head_dim], dtype=test_dtype ) if layout == "NHD": - kv_data = torch.randn( - num_blocks, - 2, - page_block_size, - num_kv_heads, - head_dim, + kv_data = paddle.randn( + shape=[num_blocks, 2, page_block_size, num_kv_heads, head_dim], dtype=test_dtype, - device=dev, ) elif layout == "HND": - kv_data = torch.randn( - num_blocks, - 2, - num_kv_heads, - page_block_size, - head_dim, + kv_data = paddle.randn( + shape=[num_blocks, 2, num_kv_heads, page_block_size, head_dim], dtype=test_dtype, - device=dev, ) - - # --------- old scheduler --------- # wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=dev), + paddle.empty(shape=128 * 1024 * 1024, dtype="uint8"), kv_layout=layout, backend="fa2", ) @@ -150,7 +135,7 @@ def _run_attention( wrapper_old.plan( q_indptr, kv_indptr, - torch.arange(num_blocks, device=dev).int(), + paddle.arange(end=num_blocks).astype(dtype="int32"), last_page_len, num_qo_heads, num_kv_heads, @@ -162,13 +147,11 @@ def _run_attention( logits_soft_cap=logits_soft_cap, ) out_old, lse_old = wrapper_old.run(q, kv_data, return_lse=True) - - # --------- new / mixed scheduler --------- # wrapper = flashinfer.BatchAttention(kv_layout=layout) wrapper.plan( q_indptr, kv_indptr, - torch.arange(num_blocks, device=dev).int(), + paddle.arange(end=num_blocks).astype(dtype="int32"), seq_lens, num_qo_heads, num_kv_heads, @@ -181,13 +164,11 @@ def _run_attention( logits_soft_cap=logits_soft_cap, ) out_new, lse_new = wrapper.run(q, kv_data, logits_soft_cap=logits_soft_cap) - - torch.cuda.synchronize() - torch.testing.assert_close(out_old, out_new, rtol=1e-2, atol=1e-2) - torch.testing.assert_close(lse_old, lse_new, rtol=1e-2, atol=1e-2) + paddle.device.synchronize() + assert paddle.allclose(x=out_old, y=out_new, rtol=0.01, atol=0.01).item(), "" + assert paddle.allclose(x=lse_old, y=lse_new, rtol=0.01, atol=0.01).item(), "" -# ------------------------- PyTest test case ----------------------------- # @pytest.mark.parametrize("seq_len_pairs", _build_seq_len_configs()) @pytest.mark.parametrize("page_block_size", [1, 8, 16]) @pytest.mark.parametrize("num_kv_heads", [8, 1, 4]) @@ -195,7 +176,7 @@ def _run_attention( @pytest.mark.parametrize("head_dim", [64, 128, 256]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("layout", ["HND", "NHD"]) -@pytest.mark.parametrize("test_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("test_dtype", ["bfloat16", "float16"]) @pytest.mark.parametrize("logits_soft_cap", [0.0, 50.0]) def test_batch_attention_correctness( seq_len_pairs, @@ -211,7 +192,6 @@ def test_batch_attention_correctness( num_qo_heads = num_kv_heads * gqa_group_size kv_lens = [p[0] for p in seq_len_pairs] qo_lens = [p[1] for p in seq_len_pairs] - _run_attention( kv_lens=kv_lens, qo_lens=qo_lens, diff --git a/tests/test_batch_decode_kernels.py b/tests/test_batch_decode_kernels.py index e5124e43fa..6601a5d3e2 100644 --- a/tests/test_batch_decode_kernels.py +++ b/tests/test_batch_decode_kernels.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2023 by FlashInfer team. @@ -13,10 +19,9 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch -from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules +from jit_utils import (gen_decode_attention_modules, + gen_prefill_attention_modules) import flashinfer @@ -25,27 +30,21 @@ def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( - [torch.float16], # q_dtypes - [ - torch.float16, - torch.float8_e4m3fn, - ], # kv_dtypes - [128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps + ["float16"], + ["float16", paddle.float8_e4m3fn], + [128, 256], + [0, 1], + [False], + [False], ) + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [ - torch.float16, - torch.float8_e4m3fn, - ], # kv_dtypes - [128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions + ["float16"], + ["float16", paddle.float8_e4m3fn], + [128, 256], + [0, 1], + [False], + [False], + [False], ), verbose=False, ) @@ -62,8 +61,8 @@ def warmup_jit(): @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) @pytest.mark.parametrize("logits_soft_cap", [0.0]) @pytest.mark.parametrize("return_lse", [True]) -@pytest.mark.parametrize("q_dtype", [torch.float16]) -@pytest.mark.parametrize("kv_dtype", [torch.float16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("q_dtype", ["float16"]) +@pytest.mark.parametrize("kv_dtype", ["float16", paddle.float8_e4m3fn]) @pytest.mark.parametrize("contiguous_kv", [True]) def test_batch_decode_with_paged_kv_cache( batch_size, @@ -80,7 +79,7 @@ def test_batch_decode_with_paged_kv_cache( kv_dtype, contiguous_kv, ): - q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype) + q = paddle.randn(shape=[batch_size, num_qo_heads, head_dim], dtype=q_dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size if kv_layout == "HND": @@ -93,28 +92,27 @@ def test_batch_decode_with_paged_kv_cache( tmp.append(2) tmp.append(v) kv_shape = tmp - kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") + kv_data_fp32 = paddle.randn(shape=kv_shape, dtype="float32") kv_data = kv_data_fp32.to(kv_dtype) kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] - # actual data is stored in non-contiguous memory assert ( - kv_data.stride(-4) - != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] + kv_data.get_strides()[-4] + != tuple(kv_data.shape)[-3] + * tuple(kv_data.shape)[-2] + * tuple(kv_data.shape)[-1] ) else: - kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") + kv_data_fp32 = paddle.randn(shape=kv_shape, dtype="float32") kv_data = kv_data_fp32.to(kv_dtype) kv_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) - * num_pages_per_seq + paddle.arange(start=0, end=batch_size + 1, dtype="int32") * num_pages_per_seq ) - kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) - kv_last_page_len = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" + kv_indices = paddle.arange(start=0, end=total_num_pages, dtype="int32") + kv_last_page_len = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" ) - - workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + workspace_buffer = paddle.empty(shape=32 * 1024 * 1024, dtype="int8") wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) @@ -135,15 +133,14 @@ def test_batch_decode_with_paged_kv_cache( o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) - for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] qi = q[i] - ki = torch.cat( - [ + ki = paddle.concat( + x=[ kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 0] - .permute(*perm_dims) + .transpose(perm=perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data_fp32[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]] @@ -153,12 +150,12 @@ def test_batch_decode_with_paged_kv_cache( .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], - dim=0, + axis=0, ).to(kv_dtype) - vi = torch.cat( - [ + vi = paddle.concat( + x=[ kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] - .permute(*perm_dims) + .transpose(perm=perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data_fp32[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]] @@ -168,7 +165,7 @@ def test_batch_decode_with_paged_kv_cache( .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], - dim=0, + axis=0, ).to(kv_dtype) o_ref_i = flashinfer.decode.single_decode_with_kv_cache( qi, @@ -177,12 +174,10 @@ def test_batch_decode_with_paged_kv_cache( pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) - torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) - - # test user-allocated output - o_buffer = torch.empty_like(o) + assert paddle.allclose(x=o[i], y=o_ref_i, rtol=0.001, atol=0.001).item(), "" + o_buffer = paddle.empty_like(x=o) wrapper.run(q, kv_data, out=o_buffer) - torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o, y=o_buffer, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [12, 17, 128]) @@ -195,8 +190,8 @@ def test_batch_decode_with_paged_kv_cache( @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) @pytest.mark.parametrize("logits_soft_cap", [0.0]) @pytest.mark.parametrize("return_lse", [True]) -@pytest.mark.parametrize("q_dtype", [torch.float16]) -@pytest.mark.parametrize("kv_dtype", [torch.float16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("q_dtype", ["float16"]) +@pytest.mark.parametrize("kv_dtype", ["float16", paddle.float8_e4m3fn]) @pytest.mark.parametrize("contiguous_kv", [True]) def test_batch_decode_with_tuple_paged_kv_cache( batch_size, @@ -213,7 +208,7 @@ def test_batch_decode_with_tuple_paged_kv_cache( kv_dtype, contiguous_kv, ): - q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype) + q = paddle.randn(shape=[batch_size, num_qo_heads, head_dim], dtype=q_dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size if kv_layout == "HND": @@ -226,36 +221,29 @@ def test_batch_decode_with_tuple_paged_kv_cache( tmp.append(2) tmp.append(v) kv_shape = tmp - kv_data_fp32 = [ - torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") - for _ in range(2) - ] + kv_data_fp32 = [paddle.randn(shape=kv_shape, dtype="float32") for _ in range(2)] kv_data = [kv_data_fp32[i].to(kv_dtype) for i in range(2)] for i in range(2): kv_data_fp32[i] = kv_data_fp32[i][:, 1, :, 1, :, 1, :] kv_data[i] = kv_data[i][:, 1, :, 1, :, 1, :] - # actual data is stored in non-contiguous memory assert ( - kv_data[i].stride(-4) - != kv_data[i].shape[-3] * kv_data[i].shape[-2] * kv_data[i].shape[-1] + kv_data[i].get_strides()[-4] + != tuple(kv_data[i].shape)[-3] + * tuple(kv_data[i].shape)[-2] + * tuple(kv_data[i].shape)[-1] ) else: - kv_data_fp32 = [ - torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") - for _ in range(2) - ] + kv_data_fp32 = [paddle.randn(shape=kv_shape, dtype="float32") for _ in range(2)] kv_data = [kv_data_fp32[i].to(kv_dtype) for i in range(2)] kv_data = tuple(kv_data) kv_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) - * num_pages_per_seq + paddle.arange(start=0, end=batch_size + 1, dtype="int32") * num_pages_per_seq ) - kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) - kv_last_page_len = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" + kv_indices = paddle.arange(start=0, end=total_num_pages, dtype="int32") + kv_last_page_len = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" ) - - workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + workspace_buffer = paddle.empty(shape=32 * 1024 * 1024, dtype="int8") wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) @@ -276,16 +264,15 @@ def test_batch_decode_with_tuple_paged_kv_cache( o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) - k_cache, v_cache = kv_data_fp32 for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] qi = q[i] - ki = torch.cat( - [ + ki = paddle.concat( + x=[ k_cache[kv_indptr[i] : kv_indptr[i + 1] - 1] - .permute(*perm_dims) + .transpose(perm=perm_dims) .reshape(-1, num_kv_heads, head_dim), ( k_cache[kv_indptr[i + 1] - 1, :, : kv_last_page_len[i]] @@ -295,13 +282,13 @@ def test_batch_decode_with_tuple_paged_kv_cache( .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], - dim=0, + axis=0, ).to(kv_dtype) - vi = torch.cat( - [ + vi = paddle.concat( + x=[ v_cache[kv_indptr[i] : kv_indptr[i + 1] - 1] - .to(torch.float32) # torch.cat does not support some fp8 types - .permute(*perm_dims) + .to("float32") + .transpose(perm=perm_dims) .reshape(-1, num_kv_heads, head_dim), ( v_cache[kv_indptr[i + 1] - 1, :, : kv_last_page_len[i]] @@ -311,7 +298,7 @@ def test_batch_decode_with_tuple_paged_kv_cache( .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], - dim=0, + axis=0, ).to(kv_dtype) o_ref_i = flashinfer.decode.single_decode_with_kv_cache( qi, @@ -320,7 +307,7 @@ def test_batch_decode_with_tuple_paged_kv_cache( pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) - torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o[i], y=o_ref_i, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [12, 17, 128]) @@ -331,8 +318,8 @@ def test_batch_decode_with_tuple_paged_kv_cache( @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) -@pytest.mark.parametrize("q_dtype", [torch.float16]) -@pytest.mark.parametrize("kv_dtype", [torch.float16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("q_dtype", ["float16"]) +@pytest.mark.parametrize("kv_dtype", ["float16", paddle.float8_e4m3fn]) @pytest.mark.parametrize("contiguous_kv", [True]) def test_cuda_graph_batch_decode_with_paged_kv_cache( batch_size, @@ -347,7 +334,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( kv_dtype, contiguous_kv, ): - q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype) + q = paddle.randn(shape=[batch_size, num_qo_heads, head_dim], dtype=q_dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size if kv_layout == "HND": @@ -360,40 +347,28 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( tmp.append(2) tmp.append(v) kv_shape = tmp - kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") + kv_data_fp32 = paddle.randn(shape=kv_shape, dtype="float32") kv_data = kv_data_fp32.to(kv_dtype) kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] - # actual data is stored in non-contiguous memory assert ( - kv_data.stride(-4) - != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] + kv_data.get_strides()[-4] + != tuple(kv_data.shape)[-3] + * tuple(kv_data.shape)[-2] + * tuple(kv_data.shape)[-1] ) else: - kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") + kv_data_fp32 = paddle.randn(shape=kv_shape, dtype="float32") kv_data = kv_data_fp32.to(kv_dtype) - kv_indptr_host_warmup = torch.arange( - 0, batch_size + 1, device="cuda:0", dtype=torch.int32 + kv_indptr_host_warmup = paddle.arange(start=0, end=batch_size + 1, dtype="int32") + kv_indices_host_warmup = paddle.arange(start=0, end=batch_size, dtype="int32") + kv_last_page_len_host_warmup = paddle.full( + shape=(batch_size,), fill_value=page_size, dtype="int32" ) - kv_indices_host_warmup = torch.arange( - 0, batch_size, device="cuda:0", dtype=torch.int32 - ) - kv_last_page_len_host_warmup = torch.full( - (batch_size,), page_size, dtype=torch.int32 - ) - - # NOTE(Zihao): allocate more space than needed for testing - kv_indptr_device_buffer = torch.empty( - batch_size + 1, device="cuda:0", dtype=torch.int32 - ) - kv_indices_device_buffer = torch.empty( - total_num_pages, device="cuda:0", dtype=torch.int32 - ) - kv_last_page_device_buffer = torch.empty( - batch_size, device="cuda:0", dtype=torch.int32 - ) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + kv_indptr_device_buffer = paddle.empty(shape=batch_size + 1, dtype="int32") + kv_indices_device_buffer = paddle.empty(shape=total_num_pages, dtype="int32") + kv_last_page_device_buffer = paddle.empty(shape=batch_size, dtype="int32") + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") wrapper = flashinfer.decode.CUDAGraphBatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_indptr_device_buffer, @@ -413,25 +388,25 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( pos_encoding_mode=pos_encoding_mode, q_data_type=q_dtype, ) - # warmup - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): + s = paddle.device.Stream() + s.wait_stream(paddle.device.current_stream()) + with paddle.device.stream_guard(stream=s): for _ in range(3): o = wrapper.run(q, kv_data) - torch.cuda.current_stream().wait_stream(s) - - # capture - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): + paddle.device.current_stream().wait_stream(s) +>>>>>> g = torch.cuda.CUDAGraph() +>>>>>> with torch.cuda.graph(g): o = wrapper.run(q, kv_data) - - # replay multiple times for i in range(1, min(4, num_pages_per_seq)): - kv_indptr_host = torch.arange(0, batch_size + 1).int() * i - kv_indices_host = torch.arange(0, i * batch_size).int() - kv_last_page_len_host = torch.full((batch_size,), page_size, dtype=torch.int32) - + kv_indptr_host = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") * i + ) + kv_indices_host = paddle.arange(start=0, end=i * batch_size).astype( + dtype="int32" + ) + kv_last_page_len_host = paddle.full( + shape=(batch_size,), fill_value=page_size, dtype="int32" + ) wrapper.plan( kv_indptr_host, kv_indices_host, @@ -445,14 +420,14 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( q_data_type=q_dtype, ) g.replay() - - # replay again - kv_indptr_host = torch.arange(0, batch_size + 1).int() * num_pages_per_seq - kv_indices_host = torch.arange(0, total_num_pages).int() - kv_last_page_len_host = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + kv_indptr_host = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") + * num_pages_per_seq + ) + kv_indices_host = paddle.arange(start=0, end=total_num_pages).astype(dtype="int32") + kv_last_page_len_host = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" ) - wrapper.plan( kv_indptr_host, kv_indices_host, @@ -466,18 +441,16 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( q_data_type=q_dtype, ) g.replay() - - # compute ground truth and compare kv_indptr = kv_indptr_host.to(0) kv_last_page_len = kv_last_page_len_host.to(0) for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] qi = q[i] - ki = torch.cat( - [ + ki = paddle.concat( + x=[ kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 0] - .permute(*perm_dims) + .transpose(perm=perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data_fp32[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]] @@ -487,12 +460,12 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], - dim=0, + axis=0, ).to(kv_dtype) - vi = torch.cat( - [ + vi = paddle.concat( + x=[ kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] - .permute(*perm_dims) + .transpose(perm=perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data_fp32[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]] @@ -502,59 +475,23 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], - dim=0, + axis=0, ).to(kv_dtype) o_ref_i = flashinfer.decode.single_decode_with_kv_cache( qi, ki, vi, pos_encoding_mode=pos_encoding_mode ) - torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o[i], y=o_ref_i, rtol=0.001, atol=0.001).item(), "" if __name__ == "__main__": test_batch_decode_with_paged_kv_cache( - 256, - 54, - 8, - 8, - 8, - 128, - "NHD", - "NONE", - 0.0, - False, - torch.float16, - torch.float16, - True, + 256, 54, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, "float16", "float16", True ) test_batch_decode_with_tuple_paged_kv_cache( - 256, - 54, - 8, - 8, - 8, - 128, - "NHD", - "NONE", - 0.0, - False, - torch.float16, - torch.float16, - True, + 256, 54, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, "float16", "float16", True ) test_batch_decode_with_paged_kv_cache( - 12, - 2048, - 8, - 8, - 8, - 128, - "NHD", - "NONE", - 0.0, - False, - torch.float16, - torch.float16, - True, + 12, 2048, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, "float16", "float16", True ) test_batch_decode_with_paged_kv_cache( 12, @@ -567,15 +504,15 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( "NONE", 0.0, True, - torch.float16, - torch.float8_e5m2, + "float16", +>>>>>> paddle.float8_e5m2, True, ) test_cuda_graph_batch_decode_with_paged_kv_cache( - 12, 2048, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16, True + 12, 2048, 8, 8, 8, 128, "NHD", "NONE", "float16", "float16", True ) test_cuda_graph_batch_decode_with_paged_kv_cache( - 128, 54, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16, True + 128, 54, 8, 8, 8, 128, "NHD", "NONE", "float16", "float16", True ) test_batch_decode_with_paged_kv_cache( 12, @@ -588,10 +525,10 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( "NONE", 0.0, True, - torch.float16, - torch.float8_e5m2, + "float16", +>>>>>> paddle.float8_e5m2, True, ) test_cuda_graph_batch_decode_with_paged_kv_cache( - 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float16, torch.float8_e5m2, True +>>>>>> 12, 54, 8, 8, 8, 128, "HND", "NONE", "float16", paddle.float8_e5m2, True ) diff --git a/tests/test_batch_prefill.py b/tests/test_batch_prefill.py index b107a25d2a..3cd2030d3d 100644 --- a/tests/test_batch_prefill.py +++ b/tests/test_batch_prefill.py @@ -1,37 +1,32 @@ +import paddle import pytest -import torch from flashinfer import BatchPrefillWithPagedKVCacheWrapper -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) def test_kv_scale_forwarding_effect(dtype): - torch.manual_seed(42) - + paddle.seed(seed=42) H_QO, H_KV, N_CTX, HEAD_DIM, PAGE_SIZE = 1, 1, 8, 64, 16 max_num_pages = (N_CTX + PAGE_SIZE - 1) // PAGE_SIZE - - # Create paged KV cache - k_cache = torch.randn( - max_num_pages, PAGE_SIZE, H_KV, HEAD_DIM, dtype=dtype, device="cuda" + k_cache = paddle.randn( + shape=[max_num_pages, PAGE_SIZE, H_KV, HEAD_DIM], dtype=dtype ) - v_cache = torch.randn( - max_num_pages, PAGE_SIZE, H_KV, HEAD_DIM, dtype=dtype, device="cuda" + v_cache = paddle.randn( + shape=[max_num_pages, PAGE_SIZE, H_KV, HEAD_DIM], dtype=dtype ) - paged_kv_cache = (k_cache, v_cache) - - # Create query tensor and indptrs - q = torch.randn(N_CTX, H_QO, HEAD_DIM, dtype=dtype, device="cuda") - qo_indptr = torch.tensor([0, N_CTX], dtype=torch.int32, device="cuda") - paged_kv_indptr = torch.tensor([0, max_num_pages], dtype=torch.int32, device="cuda") - paged_kv_indices = torch.arange(max_num_pages, dtype=torch.int32, device="cuda") - paged_kv_last_page_len = torch.tensor( - [N_CTX % PAGE_SIZE or PAGE_SIZE], dtype=torch.int32, device="cuda" + paged_kv_cache = k_cache, v_cache + q = paddle.randn(shape=[N_CTX, H_QO, HEAD_DIM], dtype=dtype) + qo_indptr = paddle.to_tensor(data=[0, N_CTX], dtype="int32", place="gpu") + paged_kv_indptr = paddle.to_tensor( + data=[0, max_num_pages], dtype="int32", place="gpu" ) - - workspace_buffer = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device="cuda") + paged_kv_indices = paddle.arange(dtype="int32", end=max_num_pages) + paged_kv_last_page_len = paddle.to_tensor( + data=[N_CTX % PAGE_SIZE or PAGE_SIZE], dtype="int32", place="gpu" + ) + workspace_buffer = paddle.empty(shape=16 * 1024 * 1024, dtype="uint8") wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer) - wrapper.plan( qo_indptr, paged_kv_indptr, @@ -45,45 +40,35 @@ def test_kv_scale_forwarding_effect(dtype): q_data_type=dtype, kv_data_type=dtype, ) - out1, _ = wrapper.forward_return_lse(q, paged_kv_cache, k_scale=0.1, v_scale=0.1) out2, _ = wrapper.forward_return_lse(q, paged_kv_cache, k_scale=2.0, v_scale=2.0) + assert not paddle.allclose( + x=out1, y=out2, atol=0.001 + ).item(), "Output should change when k_scale/v_scale values are different. This may indicate that the arguments are not passed correctly." - assert not torch.allclose(out1, out2, atol=1e-3), ( - "Output should change when k_scale/v_scale values are different. " - "This may indicate that the arguments are not passed correctly." - ) - - -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_kv_scale_forwarding_math_property(dtype: torch.dtype): - torch.manual_seed(0) - # ---------------- parameters ---------------- +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) +def test_kv_scale_forwarding_math_property(dtype: paddle.dtype): + paddle.seed(seed=0) N_CTX, PAGE_SIZE = 128, 16 - H_QO, H_KV, HEAD_DIM = 1, 1, 64 # Explicitly specify H_QO + H_QO, H_KV, HEAD_DIM = 1, 1, 64 max_num_pages = (N_CTX + PAGE_SIZE - 1) // PAGE_SIZE - - # ---------------- paged KV cache ---------------- - k_cache = torch.randn( - max_num_pages, PAGE_SIZE, H_KV, HEAD_DIM, dtype=dtype, device="cuda" + k_cache = paddle.randn( + shape=[max_num_pages, PAGE_SIZE, H_KV, HEAD_DIM], dtype=dtype ) - v_cache = torch.randn_like(k_cache) - paged_kv_cache = (k_cache, v_cache) - - # ---------------- query and indptr ---------------- - q = torch.randn(N_CTX, H_QO, HEAD_DIM, dtype=dtype, device="cuda") - qo_indptr = torch.tensor([0, N_CTX], dtype=torch.int32, device="cuda") - paged_kv_indptr = torch.tensor([0, max_num_pages], dtype=torch.int32, device="cuda") - paged_kv_indices = torch.arange(max_num_pages, dtype=torch.int32, device="cuda") - paged_kv_last_page_len = torch.tensor( - [N_CTX % PAGE_SIZE or PAGE_SIZE], dtype=torch.int32, device="cuda" + v_cache = paddle.randn(shape=k_cache.shape, dtype=k_cache.dtype) + paged_kv_cache = k_cache, v_cache + q = paddle.randn(shape=[N_CTX, H_QO, HEAD_DIM], dtype=dtype) + qo_indptr = paddle.to_tensor(data=[0, N_CTX], dtype="int32", place="gpu") + paged_kv_indptr = paddle.to_tensor( + data=[0, max_num_pages], dtype="int32", place="gpu" ) - - # ---------------- wrapper ---------------- - workspace = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device="cuda") + paged_kv_indices = paddle.arange(dtype="int32", end=max_num_pages) + paged_kv_last_page_len = paddle.to_tensor( + data=[N_CTX % PAGE_SIZE or PAGE_SIZE], dtype="int32", place="gpu" + ) + workspace = paddle.empty(shape=16 * 1024 * 1024, dtype="uint8") wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace) - wrapper.plan( qo_indptr, paged_kv_indptr, @@ -97,24 +82,20 @@ def test_kv_scale_forwarding_math_property(dtype: torch.dtype): q_data_type=dtype, kv_data_type=dtype, ) - - # ---------------- scale factors ---------------- - k_scale = torch.tensor(0.5, dtype=torch.float32, device="cuda") - v_scale = torch.tensor(2.0, dtype=torch.float32, device="cuda") - - # -------- case 1: k_scale only ---------- + k_scale = paddle.to_tensor(data=0.5, dtype="float32", place="gpu") + v_scale = paddle.to_tensor(data=2.0, dtype="float32", place="gpu") out1, _ = wrapper.forward_return_lse(q, paged_kv_cache, k_scale=k_scale) out1_ref, _ = wrapper.forward_return_lse(q * k_scale, paged_kv_cache) - torch.testing.assert_close(out1, out1_ref, rtol=1e-2, atol=1e-3) - - # -------- case 2: v_scale only ---------- + assert paddle.allclose(x=out1, y=out1_ref, rtol=0.01, atol=0.001).item(), "" out2, _ = wrapper.forward_return_lse(q, paged_kv_cache, v_scale=v_scale) out2_ref, _ = wrapper.forward_return_lse(q, paged_kv_cache) - torch.testing.assert_close(out2, out2_ref * v_scale, rtol=1e-2, atol=1e-3) - - # -------- case 3: both k_scale and v_scale ---------- + assert paddle.allclose( + x=out2, y=out2_ref * v_scale, rtol=0.01, atol=0.001 + ).item(), "" out3, _ = wrapper.forward_return_lse( q, paged_kv_cache, k_scale=k_scale, v_scale=v_scale ) out3_ref, _ = wrapper.forward_return_lse(q * k_scale, paged_kv_cache) - torch.testing.assert_close(out3, out3_ref * v_scale, rtol=1e-2, atol=1e-3) + assert paddle.allclose( + x=out3, y=out3_ref * v_scale, rtol=0.01, atol=0.001 + ).item(), "" diff --git a/tests/test_batch_prefill_kernels.py b/tests/test_batch_prefill_kernels.py index dcd636976e..9e94592b28 100644 --- a/tests/test_batch_prefill_kernels.py +++ b/tests/test_batch_prefill_kernels.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2023 by FlashInfer team. @@ -13,10 +19,8 @@ See the License for the specific language governing permissions and limitations under the License. """ - import numpy import pytest -import torch from jit_utils import gen_prefill_attention_modules import flashinfer @@ -26,17 +30,13 @@ def warmup_jit(): flashinfer.jit.build_jit_specs( gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [ - torch.float16, - torch.float8_e4m3fn, - torch.float8_e5m2, - ], # kv_dtypes - [128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions + ["float16"], + ["float16", paddle.float8_e4m3fn, paddle.float8_e5m2], + [128, 256], + [0, 1], + [False], + [False], + [False], ), verbose=False, ) @@ -75,14 +75,12 @@ def test_batch_prefill_with_paged_kv_cache( ): if qo_len > kv_len and causal: pytest.skip("qo_len > kv_len and causal is not supported") - q = torch.randn( - batch_size * qo_len, - num_qo_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, + q = paddle.randn( + shape=[batch_size * qo_len, num_qo_heads, head_dim], dtype="float16" + ) + q_indptr_cpu = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") * qo_len ) - q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size if kv_layout == "HND": @@ -95,25 +93,28 @@ def test_batch_prefill_with_paged_kv_cache( tmp.append(2) tmp.append(v) kv_shape = tmp - kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") - kv_data = kv_data_fp32.half() + kv_data_fp32 = paddle.randn(shape=kv_shape, dtype="float32") + kv_data = kv_data_fp32.astype(dtype="float16") kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] - # actual data is stored in non-contiguous memory assert ( - kv_data.stride(-4) - != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] + kv_data.get_strides()[-4] + != tuple(kv_data.shape)[-3] + * tuple(kv_data.shape)[-2] + * tuple(kv_data.shape)[-1] ) else: - kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") - kv_data = kv_data_fp32.half() - kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq - kv_indices_cpu = torch.arange(0, total_num_pages).int() - kv_last_page_len_cpu = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + kv_data_fp32 = paddle.randn(shape=kv_shape, dtype="float32") + kv_data = kv_data_fp32.astype(dtype="float16") + kv_indptr_cpu = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") + * num_pages_per_seq ) - - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + kv_indices_cpu = paddle.arange(start=0, end=total_num_pages).astype(dtype="int32") + kv_last_page_len_cpu = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" + ) + workspace_buffer = paddle.empty(shape=256 * 1024 * 1024, dtype="int8") if not use_cuda_graph: q_indptr_gpu = q_indptr_cpu.to(0) kv_indptr_gpu = kv_indptr_cpu.to(0) @@ -139,24 +140,14 @@ def test_batch_prefill_with_paged_kv_cache( o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) - - # test with pre-allocated output - o_buffer = torch.empty_like(o) + o_buffer = paddle.empty_like(x=o) wrapper.run(q, kv_data, out=o_buffer) - torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o, y=o_buffer, rtol=0.001, atol=0.001).item(), "" else: - q_indptr_buffer = torch.empty( - batch_size + 1, device="cuda:0", dtype=torch.int32 - ) - kv_indptr_buffer = torch.empty( - batch_size + 1, device="cuda:0", dtype=torch.int32 - ) - kv_indices_buffer = torch.empty( - total_num_pages, device="cuda:0", dtype=torch.int32 - ) - kv_last_page_len_buffer = torch.empty( - batch_size, device="cuda:0", dtype=torch.int32 - ) + q_indptr_buffer = paddle.empty(shape=batch_size + 1, dtype="int32") + kv_indptr_buffer = paddle.empty(shape=batch_size + 1, dtype="int32") + kv_indices_buffer = paddle.empty(shape=total_num_pages, dtype="int32") + kv_last_page_len_buffer = paddle.empty(shape=batch_size, dtype="int32") wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout, @@ -166,13 +157,16 @@ def test_batch_prefill_with_paged_kv_cache( paged_kv_indices_buf=kv_indices_buffer, paged_kv_last_page_len_buf=kv_last_page_len_buffer, ) - q_indptr_warmup = torch.arange(0, batch_size + 1).int() * qo_len - kv_indptr_warmup = torch.arange(0, batch_size + 1).int() - kv_indices_warmup = torch.arange(0, batch_size).int() - kv_last_page_len_warmup = torch.full( - (batch_size,), page_size, dtype=torch.int32 + q_indptr_warmup = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") * qo_len + ) + kv_indptr_warmup = paddle.arange(start=0, end=batch_size + 1).astype( + dtype="int32" + ) + kv_indices_warmup = paddle.arange(start=0, end=batch_size).astype(dtype="int32") + kv_last_page_len_warmup = paddle.full( + shape=(batch_size,), fill_value=page_size, dtype="int32" ) - wrapper.plan( q_indptr_warmup, kv_indptr_warmup, @@ -186,25 +180,21 @@ def test_batch_prefill_with_paged_kv_cache( pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) - - # warmup - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): + s = paddle.device.Stream() + s.wait_stream(paddle.device.current_stream()) + with paddle.device.stream_guard(stream=s): for _ in range(3): if return_lse: o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) - torch.cuda.current_stream().wait_stream(s) - # capture - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): + paddle.device.current_stream().wait_stream(s) +>>>>>> g = torch.cuda.CUDAGraph() +>>>>>> with torch.cuda.graph(g): if return_lse: o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) - wrapper.plan( q_indptr_cpu, kv_indptr_cpu, @@ -218,17 +208,15 @@ def test_batch_prefill_with_paged_kv_cache( pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) - g.replay() - for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] - ki = torch.cat( - [ + ki = paddle.concat( + x=[ kv_data_fp32[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 0] - .permute(*perm_dims) + .transpose(perm=perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data_fp32[ @@ -242,12 +230,12 @@ def test_batch_prefill_with_paged_kv_cache( .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], - dim=0, - ).half() - vi = torch.cat( - [ + axis=0, + ).astype(dtype="float16") + vi = paddle.concat( + x=[ kv_data_fp32[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 1] - .permute(*perm_dims) + .transpose(perm=perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data_fp32[ @@ -261,8 +249,8 @@ def test_batch_prefill_with_paged_kv_cache( .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], - dim=0, - ).half() + axis=0, + ).astype(dtype="float16") o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( qi, ki, @@ -272,7 +260,7 @@ def test_batch_prefill_with_paged_kv_cache( logits_soft_cap=logits_soft_cap, ) o_i = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] - torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o_i, y=o_ref_i, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [12, 17, 128]) @@ -307,14 +295,12 @@ def test_batch_prefill_with_tuple_paged_kv_cache( ): if qo_len > kv_len and causal: pytest.skip("qo_len > kv_len and causal is not supported") - q = torch.randn( - batch_size * qo_len, - num_qo_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, + q = paddle.randn( + shape=[batch_size * qo_len, num_qo_heads, head_dim], dtype="float16" + ) + q_indptr_cpu = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") * qo_len ) - q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size if kv_layout == "HND": @@ -327,33 +313,30 @@ def test_batch_prefill_with_tuple_paged_kv_cache( tmp.append(2) tmp.append(v) kv_shape = tmp - kv_data_fp32 = [ - torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") - for _ in range(2) - ] - kv_data = [kv_data_fp32[i].half() for i in range(2)] + kv_data_fp32 = [paddle.randn(shape=kv_shape, dtype="float32") for _ in range(2)] + kv_data = [kv_data_fp32[i].astype(dtype="float16") for i in range(2)] for i in range(2): kv_data_fp32[i] = kv_data_fp32[i][:, 1, :, 1, :, 1, :] kv_data[i] = kv_data[i][:, 1, :, 1, :, 1, :] - # actual data is stored in non-contiguous memory assert ( - kv_data[i].stride(-4) - != kv_data[i].shape[-3] * kv_data[i].shape[-2] * kv_data[i].shape[-1] + kv_data[i].get_strides()[-4] + != tuple(kv_data[i].shape)[-3] + * tuple(kv_data[i].shape)[-2] + * tuple(kv_data[i].shape)[-1] ) else: - kv_data_fp32 = [ - torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") - for _ in range(2) - ] - kv_data = [kv_data_fp32[i].half() for i in range(2)] + kv_data_fp32 = [paddle.randn(shape=kv_shape, dtype="float32") for _ in range(2)] + kv_data = [kv_data_fp32[i].astype(dtype="float16") for i in range(2)] kv_data = tuple(kv_data) - kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq - kv_indices_cpu = torch.arange(0, total_num_pages).int() - kv_last_page_len_cpu = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + kv_indptr_cpu = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") + * num_pages_per_seq ) - - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + kv_indices_cpu = paddle.arange(start=0, end=total_num_pages).astype(dtype="int32") + kv_last_page_len_cpu = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" + ) + workspace_buffer = paddle.empty(shape=256 * 1024 * 1024, dtype="int8") if not use_cuda_graph: q_indptr_gpu = q_indptr_cpu.to(0) kv_indptr_gpu = kv_indptr_cpu.to(0) @@ -380,18 +363,10 @@ def test_batch_prefill_with_tuple_paged_kv_cache( else: o = wrapper.run(q, kv_data) else: - q_indptr_buffer = torch.empty( - batch_size + 1, device="cuda:0", dtype=torch.int32 - ) - kv_indptr_buffer = torch.empty( - batch_size + 1, device="cuda:0", dtype=torch.int32 - ) - kv_indices_buffer = torch.empty( - total_num_pages, device="cuda:0", dtype=torch.int32 - ) - kv_last_page_len_buffer = torch.empty( - batch_size, device="cuda:0", dtype=torch.int32 - ) + q_indptr_buffer = paddle.empty(shape=batch_size + 1, dtype="int32") + kv_indptr_buffer = paddle.empty(shape=batch_size + 1, dtype="int32") + kv_indices_buffer = paddle.empty(shape=total_num_pages, dtype="int32") + kv_last_page_len_buffer = paddle.empty(shape=batch_size, dtype="int32") wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout, @@ -401,11 +376,15 @@ def test_batch_prefill_with_tuple_paged_kv_cache( paged_kv_indices_buf=kv_indices_buffer, paged_kv_last_page_len_buf=kv_last_page_len_buffer, ) - q_indptr_warmup = torch.arange(0, batch_size + 1).int() * qo_len - kv_indptr_warmup = torch.arange(0, batch_size + 1).int() - kv_indices_warmup = torch.arange(0, batch_size).int() - kv_last_page_len_warmup = torch.full( - (batch_size,), page_size, dtype=torch.int32 + q_indptr_warmup = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") * qo_len + ) + kv_indptr_warmup = paddle.arange(start=0, end=batch_size + 1).astype( + dtype="int32" + ) + kv_indices_warmup = paddle.arange(start=0, end=batch_size).astype(dtype="int32") + kv_last_page_len_warmup = paddle.full( + shape=(batch_size,), fill_value=page_size, dtype="int32" ) wrapper.plan( q_indptr_warmup, @@ -420,25 +399,21 @@ def test_batch_prefill_with_tuple_paged_kv_cache( pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) - - # warmup - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): + s = paddle.device.Stream() + s.wait_stream(paddle.device.current_stream()) + with paddle.device.stream_guard(stream=s): for _ in range(3): if return_lse: o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) - torch.cuda.current_stream().wait_stream(s) - # capture - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): + paddle.device.current_stream().wait_stream(s) +>>>>>> g = torch.cuda.CUDAGraph() +>>>>>> with torch.cuda.graph(g): if return_lse: o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) - wrapper.plan( q_indptr_cpu, kv_indptr_cpu, @@ -452,18 +427,16 @@ def test_batch_prefill_with_tuple_paged_kv_cache( pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) - g.replay() - k_cache, v_cache = kv_data_fp32 for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] - ki = torch.cat( - [ + ki = paddle.concat( + x=[ k_cache[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1] - .permute(*perm_dims) + .transpose(perm=perm_dims) .reshape(-1, num_kv_heads, head_dim), ( k_cache[kv_indptr_cpu[i + 1] - 1, :, : kv_last_page_len_cpu[i]] @@ -473,12 +446,12 @@ def test_batch_prefill_with_tuple_paged_kv_cache( .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], - dim=0, - ).half() - vi = torch.cat( - [ + axis=0, + ).astype(dtype="float16") + vi = paddle.concat( + x=[ v_cache[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1] - .permute(*perm_dims) + .transpose(perm=perm_dims) .reshape(-1, num_kv_heads, head_dim), ( v_cache[kv_indptr_cpu[i + 1] - 1, :, : kv_last_page_len_cpu[i]] @@ -488,8 +461,8 @@ def test_batch_prefill_with_tuple_paged_kv_cache( .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], - dim=0, - ).half() + axis=0, + ).astype(dtype="float16") o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( qi, ki, @@ -499,7 +472,7 @@ def test_batch_prefill_with_tuple_paged_kv_cache( logits_soft_cap=logits_soft_cap, ) o_i = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] - torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o_i, y=o_ref_i, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [12, 17, 128]) @@ -528,16 +501,10 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask( return_lse, contiguous_kv, ): - q = torch.randn( - batch_size * qo_len, - num_qo_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, - ) - q_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len + q = paddle.randn( + shape=[batch_size * qo_len, num_qo_heads, head_dim], dtype="float16" ) + q_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * qo_len num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size if kv_layout == "HND": @@ -550,34 +517,31 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask( tmp.append(2) tmp.append(v) kv_shape = tmp - kv_data = torch.randn(*kv_shape, dtype=torch.float16, device="cuda:0") + kv_data = paddle.randn(shape=kv_shape, dtype="float16") kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] - # actual data is stored in non-contiguous memory assert ( - kv_data.stride(-4) - != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] + kv_data.get_strides()[-4] + != tuple(kv_data.shape)[-3] + * tuple(kv_data.shape)[-2] + * tuple(kv_data.shape)[-1] ) else: - kv_data = torch.randn(*kv_shape, dtype=torch.float16, device="cuda:0") + kv_data = paddle.randn(shape=kv_shape, dtype="float16") kv_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) - * num_pages_per_seq + paddle.arange(start=0, end=batch_size + 1, dtype="int32") * num_pages_per_seq ) - kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) - kv_last_page_len = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" + kv_indices = paddle.arange(start=0, end=total_num_pages, dtype="int32") + kv_last_page_len = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" ) - - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + workspace_buffer = paddle.empty(shape=256 * 1024 * 1024, dtype="int8") wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) - custom_mask = torch.tril( - torch.full((batch_size, qo_len, kv_len), True, device="cuda:0"), - diagonal=(kv_len - qo_len), + custom_mask = paddle.tril( + x=paddle.full(shape=(batch_size, qo_len, kv_len), fill_value=True), + diagonal=kv_len - qo_len, ).reshape(-1) - - # use custom mask wrapper.plan( q_indptr, kv_indptr, @@ -595,8 +559,6 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask( o_custom, _ = wrapper.run(q, kv_data, return_lse=True) else: o_custom = wrapper.run(q, kv_data) - - # use causal wrapper.plan( q_indptr, kv_indptr, @@ -614,7 +576,7 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask( o_causal, _ = wrapper.run(q, kv_data, return_lse=True) else: o_causal = wrapper.run(q, kv_data) - torch.testing.assert_close(o_custom, o_causal, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o_custom, y=o_causal, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [12, 17, 128]) @@ -642,36 +604,18 @@ def test_batch_prefill_with_ragged_kv_cache( if qo_len > kv_len and causal: pytest.skip("qo_len > kv_len and causal is not supported") kv_layout = "NHD" - q = torch.randn( - batch_size * qo_len, - num_qo_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, - ) - q_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len - ) - - k = torch.randn( - batch_size * kv_len, - num_kv_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, + q = paddle.randn( + shape=[batch_size * qo_len, num_qo_heads, head_dim], dtype="float16" ) - v = torch.randn( - batch_size * kv_len, - num_kv_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, + q_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * qo_len + k = paddle.randn( + shape=[batch_size * kv_len, num_kv_heads, head_dim], dtype="float16" ) - kv_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * kv_len + v = paddle.randn( + shape=[batch_size * kv_len, num_kv_heads, head_dim], dtype="float16" ) - - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + kv_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * kv_len + workspace_buffer = paddle.empty(shape=256 * 1024 * 1024, dtype="int8") wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, kv_layout ) @@ -689,7 +633,6 @@ def test_batch_prefill_with_ragged_kv_cache( o, _ = wrapper.run(q, k, v, return_lse=True) else: o = wrapper.run(q, k, v) - for i in range(batch_size): o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( q[q_indptr[i] : q_indptr[i + 1]], @@ -700,7 +643,7 @@ def test_batch_prefill_with_ragged_kv_cache( logits_soft_cap=logits_soft_cap, ) o_i = o[q_indptr[i] : q_indptr[i + 1]] - torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o_i, y=o_ref_i, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [12, 17]) @@ -724,46 +667,25 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( return_lse, ): kv_layout = "NHD" - q = torch.randn( - batch_size * qo_len, - num_qo_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, + q = paddle.randn( + shape=[batch_size * qo_len, num_qo_heads, head_dim], dtype="float16" ) - q_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len + q_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * qo_len + k = paddle.randn( + shape=[batch_size * kv_len, num_kv_heads, head_dim], dtype="float16" ) - - k = torch.randn( - batch_size * kv_len, - num_kv_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, + v = paddle.randn( + shape=[batch_size * kv_len, num_kv_heads, head_dim], dtype="float16" ) - v = torch.randn( - batch_size * kv_len, - num_kv_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, - ) - kv_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * kv_len - ) - - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + kv_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * kv_len + workspace_buffer = paddle.empty(shape=256 * 1024 * 1024, dtype="int8") wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, kv_layout ) - - custom_mask = torch.tril( - torch.full((batch_size, qo_len, kv_len), True, device="cuda:0"), - diagonal=(kv_len - qo_len), + custom_mask = paddle.tril( + x=paddle.full(shape=(batch_size, qo_len, kv_len), fill_value=True), + diagonal=kv_len - qo_len, ).reshape(-1) - - # use custom mask wrapper.plan( q_indptr, kv_indptr, @@ -778,8 +700,6 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( o_custom, _ = wrapper.run(q, k, v, return_lse=True) else: o_custom = wrapper.run(q, k, v) - - # use causal wrapper.plan( q_indptr, kv_indptr, @@ -794,7 +714,7 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( o_causal, _ = wrapper.run(q, k, v, return_lse=True) else: o_causal = wrapper.run(q, k, v) - torch.testing.assert_close(o_custom, o_causal, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o_custom, y=o_causal, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [1]) @@ -832,24 +752,34 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring( logits_soft_cap, return_lse, ): - q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() - q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len + q = ( + paddle.randn(shape=[batch_size * qo_len, num_qo_heads, head_dim]) + .to(0) + .astype(dtype="float16") + ) + q_indptr_cpu = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") * qo_len + ) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( - torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0).half() + paddle.randn(shape=[total_num_pages, 2, num_kv_heads, page_size, head_dim]) + .to(0) + .astype(dtype="float16") if kv_layout == "HND" - else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim) + else paddle.randn(shape=[total_num_pages, 2, page_size, num_kv_heads, head_dim]) .to(0) - .half() + .astype(dtype="float16") ) - kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq - kv_indices_cpu = torch.arange(0, total_num_pages).int() - kv_last_page_len_cpu = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + kv_indptr_cpu = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") + * num_pages_per_seq ) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + kv_indices_cpu = paddle.arange(start=0, end=total_num_pages).astype(dtype="int32") + kv_last_page_len_cpu = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" + ) + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8").to(0) q_indptr_gpu = q_indptr_cpu.to(0) kv_indptr_gpu = kv_indptr_cpu.to(0) kv_indices_gpu = kv_indices_cpu.to(0) @@ -869,28 +799,31 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring( causal=causal, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, - prefix_len_ptr=torch.tensor(prefix_len_ptr).to(dtype=torch.uint32).to(0), - token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) - .to(dtype=torch.uint16) + prefix_len_ptr=paddle.to_tensor(data=prefix_len_ptr) +>>>>>> .to(dtype=torch.uint32) + .to(0), + token_pos_in_items_ptr=paddle.to_tensor(data=token_pos_in_items_ptr) +>>>>>> .to(dtype=torch.uint16) .to(0), - token_pos_in_items_len=torch.tensor(token_pos_in_items_len) - .to(dtype=torch.uint32) + token_pos_in_items_len=paddle.to_tensor(data=token_pos_in_items_len) +>>>>>> .to(dtype=torch.uint32) + .to(0), + max_item_len_ptr=paddle.to_tensor(data=max_item_len_ptr) +>>>>>> .to(dtype=torch.uint16) .to(0), - max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), ) if return_lse: o, _ = wrapper.run_return_lse(q, kv_data) else: o = wrapper.run(q, kv_data) - for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] - ki = torch.cat( - [ + ki = paddle.concat( + x=[ kv_data[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 0] - .permute(*perm_dims) + .transpose(perm=perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data[kv_indptr_cpu[i + 1] - 1, 0, :, : kv_last_page_len_cpu[i]] @@ -902,12 +835,12 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring( .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], - dim=0, + axis=0, ) - vi = torch.cat( - [ + vi = paddle.concat( + x=[ kv_data[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 1] - .permute(*perm_dims) + .transpose(perm=perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data[kv_indptr_cpu[i + 1] - 1, 1, :, : kv_last_page_len_cpu[i]] @@ -919,80 +852,54 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring( .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], - dim=0, + axis=0, ) def create_2D_multi_item_mask_dense( is_delimiter, sliding_window_size=-1, prefix_cache_len=None ): - # Function to create custom_mask for multi-item scoring - # - # Note, sliding window implementation assumes that candidate_i_size < sliding_window_size < prefix_size - # Args: - # is_delimiter: a boolen torch vec to indicate the delimiter position for creating custom attnetion mask in multi-item scoring - # currently assume qo len and kv len are the same and 1D (bsz=1) case - # sliding_window_size: the window size for sliding window attention, -1 means no sliding window attention delimiter_idx = is_delimiter.nonzero(as_tuple=True)[0] if len(delimiter_idx) == 0: return None else: first_delimiter_pos = delimiter_idx[0] seq_len = len(is_delimiter) - pos = torch.arange(seq_len, device=is_delimiter.device) - - group_ids = torch.cumsum(is_delimiter, 0) - # Get mask for within-group causal attention - within_group_causal = (group_ids.unsqueeze(1) == group_ids.unsqueeze(0)) & ( - pos.unsqueeze(0) <= pos.unsqueeze(1) - ) - # Combine all conditions + pos = paddle.arange(end=seq_len) + group_ids = paddle.cumsum(x=is_delimiter, axis=0) + within_group_causal = ( + group_ids.unsqueeze(axis=1) == group_ids.unsqueeze(axis=0) + ) & (pos.unsqueeze(axis=0) <= pos.unsqueeze(axis=1)) attention_mask = ( ( within_group_causal - | ( - (pos >= first_delimiter_pos).unsqueeze(1) - & (pos < first_delimiter_pos).unsqueeze(0) - ) # Prefix attention + | (pos >= first_delimiter_pos).unsqueeze(axis=1) + & (pos < first_delimiter_pos).unsqueeze(axis=0) ) - & ~is_delimiter.unsqueeze(0) - & ~is_delimiter.unsqueeze(1) - ) # No delimiter attention - + & ~is_delimiter.unsqueeze(axis=0) + & ~is_delimiter.unsqueeze(axis=1) + ) if sliding_window_size > 0 and sliding_window_size < len(is_delimiter): - # Calculate how many positions from right of prefix each token can attend to - - group_size = torch.sum( - within_group_causal & ~is_delimiter.unsqueeze(0), dim=1 + group_size = paddle.sum( + x=within_group_causal & ~is_delimiter.unsqueeze(axis=0), axis=1 ) - - # For prefix: after sliding_window_size position, can see window_size tokens - # For candidate items: can see (sliding_window_size - group_size) tokens from prefix end - prefix_window = torch.where( - pos >= first_delimiter_pos, - sliding_window_size - group_size, - torch.where( - pos < sliding_window_size, - first_delimiter_pos, - sliding_window_size, + prefix_window = paddle.where( + condition=pos >= first_delimiter_pos, + x=sliding_window_size - group_size, + y=paddle.where( + condition=pos < sliding_window_size, + x=first_delimiter_pos, + y=sliding_window_size, ), ) - - # Starting index of attention window relative to token position for candidate item/group - prefix_start = first_delimiter_pos - prefix_window.unsqueeze(1) - + prefix_start = first_delimiter_pos - prefix_window.unsqueeze(axis=1) attention_mask = attention_mask & (pos >= prefix_start) if prefix_cache_len: - patch = torch.ones( - seq_len, - prefix_cache_len, - device=is_delimiter.device, - dtype=torch.bool, - ) - attention_mask = torch.concat([patch, attention_mask], dim=1) - return attention_mask.unsqueeze(0).reshape(-1) + patch = paddle.ones(shape=[seq_len, prefix_cache_len], dtype="bool") + attention_mask = paddle.concat(x=[patch, attention_mask], axis=1) + return attention_mask.unsqueeze(axis=0).reshape(-1) custom_mask = create_2D_multi_item_mask_dense( - is_delimiter=torch.tensor(token_pos_in_items_ptr).to(0) == 0, + is_delimiter=paddle.to_tensor(data=token_pos_in_items_ptr).to(0) == 0, sliding_window_size=-1, prefix_cache_len=prefix_len_ptr, ) @@ -1007,7 +914,7 @@ def create_2D_multi_item_mask_dense( ) o_i_np = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]].cpu().numpy() o_ref_i_np = o_ref_i.cpu().numpy() - numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=0.001, atol=0.001) if __name__ == "__main__": diff --git a/tests/test_blackwell_fmha.py b/tests/test_blackwell_fmha.py index efe0391a00..c689e82f19 100644 --- a/tests/test_blackwell_fmha.py +++ b/tests/test_blackwell_fmha.py @@ -1,8 +1,12 @@ +import sys + + import math +import paddle import pytest -import torch from conftest import VARLEN_INDPTR_PARAMS +from flashinfer.paddle_utils import * import flashinfer from flashinfer.utils import is_sm100a_supported @@ -10,64 +14,72 @@ def attention_ref( batch_size, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, causal: bool, sm_scale: float, -) -> torch.Tensor: - qo_len = q.shape[0] // batch_size - kv_len = k.shape[0] // batch_size - num_qo_heads = q.shape[1] - head_dim_qk = q.shape[2] - head_dim_vo = v.shape[2] +) -> paddle.Tensor: + qo_len = tuple(q.shape)[0] // batch_size + kv_len = tuple(k.shape)[0] // batch_size + num_qo_heads = tuple(q.shape)[1] + head_dim_qk = tuple(q.shape)[2] + head_dim_vo = tuple(v.shape)[2] logits = ( - torch.einsum( + paddle.einsum( "bmhd,bnhd->bhmn", - q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), - k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), + q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).astype( + dtype="float32" + ), + k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).astype( + dtype="float32" + ), ) * sm_scale ) - if causal: - mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( - 1 - ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) + mask = paddle.arange(start=kv_len - qo_len, end=kv_len).unsqueeze( + axis=1 + ) >= paddle.arange(start=0, end=kv_len).unsqueeze(axis=0) else: - mask = torch.ones(qo_len, kv_len, device=q.device) - - logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) - lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2) - p = torch.softmax(logits, dim=-1) + mask = paddle.ones(shape=[qo_len, kv_len]) + logits = logits.masked_fill( + mask=mask.unsqueeze(axis=0).unsqueeze(axis=0) == 0, value=float("-inf") + ) + lse_ref = paddle.logsumexp(x=logits, axis=-1).transpose( + perm=dim2perm(paddle.logsumexp(x=logits, axis=-1).ndim, -1, -2) + ) + p = paddle.nn.functional.softmax(x=logits, axis=-1) o_ref = ( - torch.einsum( + paddle.einsum( "bhmn,bnhd->bmhd", p, - v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), + v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).astype( + dtype="float32" + ), ) .contiguous() .view(batch_size * qo_len, num_qo_heads, head_dim_vo) .to(q) ) - return o_ref, lse_ref * math.log2(math.e) def attention_varlen_ref( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qo_indptr: paddle.Tensor, + kv_indptr: paddle.Tensor, causal: bool, sm_scale: float, -) -> torch.Tensor: - batch_size = qo_indptr.shape[0] - 1 +) -> paddle.Tensor: + batch_size = tuple(qo_indptr.shape)[0] - 1 nnz_qo = qo_indptr[-1].item() - o = torch.empty(nnz_qo, *q.shape[1:-1], v.shape[-1], device=q.device, dtype=q.dtype) - lse = torch.empty(nnz_qo, q.shape[1], device=q.device, dtype=torch.float32) - + o = paddle.empty( + shape=[nnz_qo, *tuple(q.shape)[1:-1], tuple(v.shape)[-1]], dtype=q.dtype + ) + lse = paddle.empty(shape=[nnz_qo, tuple(q.shape)[1]], dtype="float32") for i in range(batch_size): o_i, lse_i = attention_ref( 1, @@ -77,11 +89,9 @@ def attention_varlen_ref( causal, sm_scale, ) - - lse_i = lse_i.flatten(0, 1) + lse_i = lse_i.flatten(start_axis=0, stop_axis=1) o[qo_indptr[i] : qo_indptr[i + 1]] = o_i lse[qo_indptr[i] : qo_indptr[i + 1]] = lse_i - return o, lse @@ -94,7 +104,7 @@ def attention_varlen_ref( @pytest.mark.parametrize("head_dim_vo", [128]) @pytest.mark.parametrize("sm_scale", [1.0, 1.0 / math.sqrt(192), 1.0 / math.sqrt(128)]) @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) def test_blackwell_cutlass_fmha( batch_size, qo_len, @@ -109,28 +119,22 @@ def test_blackwell_cutlass_fmha( ): if qo_len > kv_len and causal: pytest.skip("qo_len > kv_len and causal is not supported") - - if not is_sm100a_supported(torch.device("cuda")): + if not is_sm100a_supported(device2str("cuda")): pytest.skip("SM100A is not supported on this device") - torch.manual_seed(42) - q = torch.randn( - batch_size * qo_len, num_qo_heads, head_dim_qk, dtype=dtype, device="cuda" - ) - qo_indptr = ( - torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + paddle.seed(seed=42) + q = paddle.randn( + shape=[batch_size * qo_len, num_qo_heads, head_dim_qk], dtype=dtype ) - k = torch.randn( - batch_size * kv_len, num_kv_heads, head_dim_qk, dtype=dtype, device="cuda" + qo_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * qo_len + k = paddle.randn( + shape=[batch_size * kv_len, num_kv_heads, head_dim_qk], dtype=dtype ) - v = torch.randn( - batch_size * kv_len, num_kv_heads, head_dim_vo, dtype=dtype, device="cuda" + v = paddle.randn( + shape=[batch_size * kv_len, num_kv_heads, head_dim_vo], dtype=dtype ) - kv_indptr = ( - torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len - ) - + kv_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * kv_len wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + paddle.empty(shape=128 * 1024 * 1024, dtype="uint8"), kv_layout="NHD", backend="cutlass", ) @@ -147,21 +151,18 @@ def test_blackwell_cutlass_fmha( kv_data_type=dtype, ) o, lse = wrapper.run(q, k, v, return_lse=True) - gqa_group_ratio = num_qo_heads // num_kv_heads - k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1) - v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1) + k_repeated = paddle.repeat_interleave(x=k, repeats=gqa_group_ratio, axis=1) + v_repeated = paddle.repeat_interleave(x=v, repeats=gqa_group_ratio, axis=1) o_ref, lse_ref = attention_ref( batch_size, q, k_repeated, v_repeated, causal, sm_scale ) - - lse_ref = lse_ref.flatten(0, 1) - if dtype == torch.half: - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + lse_ref = lse_ref.flatten(start_axis=0, stop_axis=1) + if dtype == "float16": + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" else: - torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2) - - torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o, y=o_ref, rtol=0.01, atol=0.01).item(), "" + assert paddle.allclose(x=lse, y=lse_ref, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("indptr", VARLEN_INDPTR_PARAMS) @@ -171,7 +172,7 @@ def test_blackwell_cutlass_fmha( @pytest.mark.parametrize("head_dim_vo", [128]) @pytest.mark.parametrize("sm_scale", [1.0 / math.sqrt(128)]) @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) def test_blackwell_cutlass_varlen( indptr, num_qo_heads, @@ -182,18 +183,17 @@ def test_blackwell_cutlass_varlen( causal, dtype, ): - if not is_sm100a_supported(torch.device("cuda")): + if not is_sm100a_supported(device2str("cuda")): pytest.skip("SM100A is not supported on this device") - torch.manual_seed(42) - qkv = torch.randn( - indptr[-1], - ( + paddle.seed(seed=42) + qkv = paddle.randn( + shape=[ + indptr[-1], num_qo_heads * head_dim_qk + num_kv_heads * head_dim_qk - + num_kv_heads * head_dim_vo - ), + + num_kv_heads * head_dim_vo, + ], dtype=dtype, - device="cuda", ) q = qkv[:, : num_qo_heads * head_dim_qk].view(indptr[-1], num_qo_heads, head_dim_qk) k = qkv[ @@ -204,15 +204,13 @@ def test_blackwell_cutlass_varlen( v = qkv[:, num_qo_heads * head_dim_qk + num_kv_heads * head_dim_qk :].view( indptr[-1], num_kv_heads, head_dim_vo ) - qo_indptr = torch.tensor(indptr, device="cuda", dtype=torch.int32) + qo_indptr = paddle.to_tensor(data=indptr, dtype="int32", place="gpu") kv_indptr = qo_indptr - wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + paddle.empty(shape=128 * 1024 * 1024, dtype="uint8"), kv_layout="NHD", backend="cutlass", ) - wrapper.plan( qo_indptr, kv_indptr, @@ -226,21 +224,17 @@ def test_blackwell_cutlass_varlen( kv_data_type=dtype, ) o, lse = wrapper.run(q, k, v, return_lse=True) - gqa_group_ratio = num_qo_heads // num_kv_heads - k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1) - v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1) - + k_repeated = paddle.repeat_interleave(x=k, repeats=gqa_group_ratio, axis=1) + v_repeated = paddle.repeat_interleave(x=v, repeats=gqa_group_ratio, axis=1) o_ref, lse_ref = attention_varlen_ref( q, k_repeated, v_repeated, qo_indptr, kv_indptr, causal, sm_scale ) - - if dtype == torch.half: - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" else: - torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2) - - torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o, y=o_ref, rtol=0.01, atol=0.01).item(), "" + assert paddle.allclose(x=lse, y=lse_ref, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("qo_indptr_list", [[0, 10, 20, 30, 40, 50, 60, 100]]) @@ -250,7 +244,7 @@ def test_blackwell_cutlass_varlen( @pytest.mark.parametrize("head_dim_qk", [192, 128]) @pytest.mark.parametrize("head_dim_vo", [128]) @pytest.mark.parametrize("sm_scale", [1.0 / math.sqrt(128)]) -@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) def test_blackwell_cutlass_qo_kv_varlen( qo_indptr_list, kv_indptr_list, @@ -262,40 +256,19 @@ def test_blackwell_cutlass_qo_kv_varlen( dtype, ): causal = False - if not is_sm100a_supported(torch.device("cuda")): + if not is_sm100a_supported(device2str("cuda")): pytest.skip("SM100A is not supported on this device") - torch.manual_seed(42) - q = torch.randn( - qo_indptr_list[-1], - num_qo_heads, - head_dim_qk, - dtype=dtype, - device="cuda", - ) - k = torch.randn( - kv_indptr_list[-1], - num_kv_heads, - head_dim_qk, - dtype=dtype, - device="cuda", - ) - v = torch.randn( - kv_indptr_list[-1], - num_kv_heads, - head_dim_vo, - dtype=dtype, - device="cuda", - ) - - qo_indptr = torch.tensor(qo_indptr_list, device="cuda", dtype=torch.int32) - kv_indptr = torch.tensor(kv_indptr_list, device="cuda", dtype=torch.int32) - + paddle.seed(seed=42) + q = paddle.randn(shape=[qo_indptr_list[-1], num_qo_heads, head_dim_qk], dtype=dtype) + k = paddle.randn(shape=[kv_indptr_list[-1], num_kv_heads, head_dim_qk], dtype=dtype) + v = paddle.randn(shape=[kv_indptr_list[-1], num_kv_heads, head_dim_vo], dtype=dtype) + qo_indptr = paddle.to_tensor(data=qo_indptr_list, dtype="int32", place="gpu") + kv_indptr = paddle.to_tensor(data=kv_indptr_list, dtype="int32", place="gpu") wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + paddle.empty(shape=128 * 1024 * 1024, dtype="uint8"), kv_layout="NHD", backend="cutlass", ) - wrapper.plan( qo_indptr, kv_indptr, @@ -309,37 +282,25 @@ def test_blackwell_cutlass_qo_kv_varlen( kv_data_type=dtype, ) o, lse = wrapper.run(q, k, v, return_lse=True) - gqa_group_ratio = num_qo_heads // num_kv_heads - k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1) - v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1) - + k_repeated = paddle.repeat_interleave(x=k, repeats=gqa_group_ratio, axis=1) + v_repeated = paddle.repeat_interleave(x=v, repeats=gqa_group_ratio, axis=1) o_ref, lse_ref = attention_varlen_ref( q, k_repeated, v_repeated, qo_indptr, kv_indptr, causal, sm_scale ) - - if dtype == torch.half: - torch.testing.assert_close(o[10:60], o_ref[10:60], rtol=1e-3, atol=1e-3) + if dtype == "float16": + assert paddle.allclose( + x=o[10:60], y=o_ref[10:60], rtol=0.001, atol=0.001 + ).item(), "" else: - torch.testing.assert_close(o[10:60], o_ref[10:60], rtol=1e-2, atol=1e-2) - - torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) + assert paddle.allclose( + x=o[10:60], y=o_ref[10:60], rtol=0.01, atol=0.01 + ).item(), "" + assert paddle.allclose(x=lse, y=lse_ref, rtol=0.001, atol=0.001).item(), "" if __name__ == "__main__": - test_blackwell_cutlass_fmha( - 9, - 377, - 977, - 1, - 1, - 192, - 128, - 1, - False, - torch.bfloat16, - ) - + test_blackwell_cutlass_fmha(9, 377, 977, 1, 1, 192, 128, 1, False, "bfloat16") test_blackwell_cutlass_varlen( [0, 1274, 2568, 3915, 5194, 6498, 7839, 8192], 32, @@ -348,9 +309,8 @@ def test_blackwell_cutlass_qo_kv_varlen( 128, 1, True, - torch.bfloat16, + "bfloat16", ) - test_blackwell_cutlass_qo_kv_varlen( [0, 10, 20, 30, 40, 50, 60, 100], [0, 50, 50, 50, 50, 50, 50, 50], @@ -359,5 +319,5 @@ def test_blackwell_cutlass_qo_kv_varlen( 128, 128, 1, - torch.bfloat16, + "bfloat16", ) diff --git a/tests/test_block_sparse.py b/tests/test_block_sparse.py index 08cc2afda9..c97bc7ef6b 100644 --- a/tests/test_block_sparse.py +++ b/tests/test_block_sparse.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2024 by FlashInfer team. @@ -13,12 +19,11 @@ See the License for the specific language governing permissions and limitations under the License. """ - import numpy as np import pytest import scipy as sp -import torch -from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules +from jit_utils import (gen_decode_attention_modules, + gen_prefill_attention_modules) import flashinfer @@ -27,49 +32,31 @@ def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps + ["float16"], ["float16"], [128, 256], [0], [False], [False] ) + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions + ["float16"], ["float16"], [128, 256], [0], [False], [False], [False] ), verbose=False, ) yield -def bsr_attention_ref( - q, - k, - v, - indptr, - indices, - mask_data, -): - M = q.shape[0] - N = k.shape[0] +def bsr_attention_ref(q, k, v, indptr, indices, mask_data): + M = tuple(q.shape)[0] + N = tuple(k.shape)[0] bsr = sp.sparse.bsr_matrix( (mask_data.cpu().numpy(), indices.cpu().numpy(), indptr.cpu().numpy()), shape=(M, N), ) - dense_mask = torch.tensor(bsr.toarray(), dtype=bool, device=q.device) + dense_mask = paddle.to_tensor(data=bsr.toarray(), dtype=bool, place=q.place) o = flashinfer.prefill.single_prefill_with_kv_cache(q, k, v, custom_mask=dense_mask) return o def set_seed(seed: int = 42): - torch.cuda.manual_seed(seed) - torch.manual_seed(seed) + paddle.seed(seed=seed) + paddle.seed(seed=seed) np.random.seed(seed) @@ -86,30 +73,26 @@ def test_block_sparse_attention( ): if num_qo_heads % num_kv_heads != 0: pytest.skip("num_qo_heads must be divisible by num_kv_heads") - set_seed(33) rng = np.random.default_rng() - MB = M // R NB = N // C S = sp.sparse.random(MB, NB, density=0.25, random_state=rng).tocsr() - indptr = torch.from_numpy(S.indptr).to(0) - indices = torch.from_numpy(S.indices).to(0) + indptr = paddle.to_tensor(data=S.indptr).to(0) + indices = paddle.to_tensor(data=S.indices).to(0) nnz = S.nnz if mask_inside_block: - data_mask = (torch.rand((nnz, R, C)) > 0.5).to(0) + data_mask = (paddle.rand(shape=(nnz, R, C)) > 0.5).to(0) else: - data_mask = torch.full((nnz, R, C), True, dtype=bool, device=0) - q = torch.randn((M, num_qo_heads, head_dim), dtype=torch.float16, device=0) - k = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device=0) - v = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device=0) - + data_mask = paddle.full(shape=(nnz, R, C), fill_value=True, dtype=bool) + q = paddle.randn(shape=(M, num_qo_heads, head_dim), dtype="float16") + k = paddle.randn(shape=(N, num_kv_heads, head_dim), dtype="float16") + v = paddle.randn(shape=(N, num_kv_heads, head_dim), dtype="float16") o_ref = bsr_attention_ref(q, k, v, indptr, indices, data_mask) - workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device=0) + workspace_buffer = paddle.zeros(shape=128 * 1024 * 1024, dtype="uint8") sparse_attention_wrapper = flashinfer.sparse.BlockSparseAttentionWrapper( workspace_buffer ) - sparse_attention_wrapper.plan( indptr, indices, @@ -122,49 +105,44 @@ def test_block_sparse_attention( head_dim, mask=data_mask if mask_inside_block else None, ) - o = sparse_attention_wrapper.run(q, k, v) - torch.testing.assert_close(o_ref, o, atol=1e-2, rtol=1e-3) - - # test with pre-allocated output - o_buffer = torch.empty_like(o) + assert paddle.allclose(x=o_ref, y=o, atol=0.01, rtol=0.001).item(), "" + o_buffer = paddle.empty_like(x=o) sparse_attention_wrapper.run(q, k, v, out=o_buffer) - torch.testing.assert_close(o_ref, o_buffer, atol=1e-2, rtol=1e-3) + assert paddle.allclose(x=o_ref, y=o_buffer, atol=0.01, rtol=0.001).item(), "" def _ref_attention( - q: torch.Tensor, # [gqa_group_size, qo_len, head_dim] - k: torch.Tensor, # [1, kv_len, head_dim] - v: torch.Tensor, # [1, kv_len, head_dim] - block_mask_map: torch.Tensor, # [MB, NB] - block_row_sz: torch.Tensor, # [MB] - block_col_sz: torch.Tensor, # [NB] -) -> torch.Tensor: - # convert block mask map to element mask + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + block_mask_map: paddle.Tensor, + block_row_sz: paddle.Tensor, + block_col_sz: paddle.Tensor, +) -> paddle.Tensor: def _block_mask_to_element_mask( - block_mask_map: torch.Tensor, # [MB, NB] – bool - block_row_sz: torch.Tensor, # [MB] – int (rows per block-row) - block_col_sz: torch.Tensor, # [NB] – int (cols per block-col) - ) -> torch.Tensor: - block_row_sz = block_row_sz.to(block_mask_map.device, dtype=torch.long) - block_col_sz = block_col_sz.to(block_mask_map.device, dtype=torch.long) - expanded_rows = torch.repeat_interleave(block_mask_map, block_row_sz, dim=0) - element_mask = torch.repeat_interleave(expanded_rows, block_col_sz, dim=1) - + block_mask_map: paddle.Tensor, + block_row_sz: paddle.Tensor, + block_col_sz: paddle.Tensor, + ) -> paddle.Tensor: + block_row_sz = block_row_sz.to(block_mask_map.place, dtype="int64") + block_col_sz = block_col_sz.to(block_mask_map.place, dtype="int64") + expanded_rows = paddle.repeat_interleave( + x=block_mask_map, repeats=block_row_sz, axis=0 + ) + element_mask = paddle.repeat_interleave( + x=expanded_rows, repeats=block_col_sz, axis=1 + ) return element_mask dense_mask = _block_mask_to_element_mask( block_mask_map, block_row_sz, block_col_sz - ).to(dtype=torch.bool, device=q.device) - - q = q.transpose(0, 1).contiguous() - k = k.transpose(0, 1).contiguous() - v = v.transpose(0, 1).contiguous() - o = flashinfer.prefill.single_prefill_with_kv_cache( - q, k, v, custom_mask=dense_mask - ) # [qo_len, gqa_group_size, head_dim] - o = o.transpose(0, 1).contiguous() - + ).to(dtype="bool", device=q.place) + q = q.transpose(perm=dim2perm(q.ndim, 0, 1)).contiguous() + k = k.transpose(perm=dim2perm(k.ndim, 0, 1)).contiguous() + v = v.transpose(perm=dim2perm(v.ndim, 0, 1)).contiguous() + o = flashinfer.prefill.single_prefill_with_kv_cache(q, k, v, custom_mask=dense_mask) + o = o.transpose(perm=dim2perm(o.ndim, 0, 1)).contiguous() return o @@ -190,63 +168,56 @@ def test_variable_block_sparse_attention_wrapper( pytest.skip("seq_len must be greater than num_blocks_row") if seq_len // num_blocks_col < 1: pytest.skip("seq_len must be greater than num_blocks_col") - set_seed(330) def random_partition_batch( seq_len: int, num_blocks: int, bsz: int, - device: torch.device | str = "cpu", - dtype: torch.dtype = torch.int32, - ) -> torch.Tensor: + device: (str | str) = "cpu", + dtype: paddle.dtype = "int32", + ) -> paddle.Tensor: assert seq_len >= num_blocks - sizes = torch.empty((bsz, num_blocks), dtype=dtype, device=device) + sizes = paddle.empty(shape=(bsz, num_blocks), dtype=dtype) for i in range(bsz): - cut_pts = torch.randperm(seq_len - 1, device=device)[: num_blocks - 1] + 1 - cut_pts, _ = torch.sort(cut_pts) - row_sizes = torch.diff( - torch.cat( - ( - torch.tensor([0], device=device), + cut_pts = paddle.randperm(n=seq_len - 1)[: num_blocks - 1] + 1 + cut_pts, _ = paddle.sort(x=cut_pts), paddle.argsort(x=cut_pts) + row_sizes = paddle.diff( + x=paddle.concat( + x=( + paddle.to_tensor(data=[0], place=device), cut_pts, - torch.tensor([seq_len], device=device), + paddle.to_tensor(data=[seq_len], place=device), ) ) ) sizes[i] = row_sizes - - assert sizes.min() >= 1 - assert sizes.max() <= seq_len - assert torch.all(sizes.sum(dim=-1) == seq_len) - + assert sizes._min() >= 1 + assert sizes._max() <= seq_len + assert paddle.all(x=sizes.sum(axis=-1) == seq_len) return sizes.to(device=device) def _test_variable_block_sparse_attention( num_qo_heads: int, num_kv_heads: int, head_dim: int, - block_mask_map: torch.Tensor, - block_row_sz: torch.Tensor, - block_col_sz: torch.Tensor, + block_mask_map: paddle.Tensor, + block_row_sz: paddle.Tensor, + block_col_sz: paddle.Tensor, device: str = "cuda:0", - dtype: torch.dtype = torch.float16, + dtype: paddle.dtype = "float16", ): - # qkv: HND - qo_len = block_row_sz.sum(dim=1)[0].item() - kv_len = block_col_sz.sum(dim=1)[0].item() - assert torch.all(block_col_sz.sum(dim=1) == block_col_sz.sum(dim=1)[0]) - assert torch.all(block_row_sz.sum(dim=1) == block_row_sz.sum(dim=1)[0]) - - q = torch.randn(num_qo_heads, qo_len, head_dim, device=device, dtype=dtype) - k = torch.randn(num_kv_heads, kv_len, head_dim, device=device, dtype=dtype) - v = torch.randn(num_kv_heads, kv_len, head_dim, device=device, dtype=dtype) - - float_workspace_buffer = torch.empty(128 * 1024 * 1024, device=device) + qo_len = block_row_sz.sum(axis=1)[0].item() + kv_len = block_col_sz.sum(axis=1)[0].item() + assert paddle.all(x=block_col_sz.sum(axis=1) == block_col_sz.sum(axis=1)[0]) + assert paddle.all(x=block_row_sz.sum(axis=1) == block_row_sz.sum(axis=1)[0]) + q = paddle.randn(shape=[num_qo_heads, qo_len, head_dim], dtype=dtype) + k = paddle.randn(shape=[num_kv_heads, kv_len, head_dim], dtype=dtype) + v = paddle.randn(shape=[num_kv_heads, kv_len, head_dim], dtype=dtype) + float_workspace_buffer = paddle.empty(shape=128 * 1024 * 1024) wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper( float_workspace_buffer, backend="auto" ) - wrapper.plan( block_mask_map=block_mask_map, block_row_sz=block_row_sz, @@ -256,10 +227,9 @@ def _test_variable_block_sparse_attention( head_dim=head_dim, q_data_type=dtype, ) - - o: torch.Tensor = wrapper.run(q, k, v) # [num_qo_heads, qo_len, head_dim] - o = o.reshape(num_kv_heads, -1, *o.shape[-2:]) - q = q.reshape(num_kv_heads, -1, *q.shape[-2:]) + o: paddle.Tensor = wrapper.run(q, k, v) + o = o.reshape(num_kv_heads, -1, *tuple(o.shape)[-2:]) + q = q.reshape(num_kv_heads, -1, *tuple(q.shape)[-2:]) for kv_head_idx in range(num_kv_heads): o_ref = _ref_attention( q[kv_head_idx], @@ -269,7 +239,9 @@ def _test_variable_block_sparse_attention( block_row_sz[kv_head_idx], block_col_sz[kv_head_idx], ) - torch.testing.assert_close(o[kv_head_idx], o_ref, atol=1e-2, rtol=1e-2) + assert paddle.allclose( + x=o[kv_head_idx], y=o_ref, atol=0.01, rtol=0.01 + ).item(), "" block_row_sz = random_partition_batch( seq_len, num_blocks_row, num_kv_heads, device="cuda:0" @@ -278,20 +250,14 @@ def _test_variable_block_sparse_attention( seq_len, num_blocks_col, num_kv_heads, device="cuda:0" ) block_mask_map = ( - torch.rand(num_kv_heads, num_blocks_row, num_blocks_col) > block_density - ).to(device="cuda:0") - + paddle.rand(shape=[num_kv_heads, num_blocks_row, num_blocks_col]) + > block_density + ).to(device="gpu:0") _test_variable_block_sparse_attention( - num_qo_heads, - num_kv_heads, - head_dim, - block_mask_map, - block_row_sz, - block_col_sz, + num_qo_heads, num_kv_heads, head_dim, block_mask_map, block_row_sz, block_col_sz ) if __name__ == "__main__": - # This test verifies the INT32_T overflow issue. for seq_len in [16 * 1024, 32 * 1024, 40 * 1024, 48 * 1024, 64 * 1024]: test_block_sparse_attention(128, 128, seq_len, seq_len, 1, 1, 128, False) diff --git a/tests/test_block_sparse_indices_to_vector_sparse_offsets.py b/tests/test_block_sparse_indices_to_vector_sparse_offsets.py index cf2ef003cc..72cd8913cd 100644 --- a/tests/test_block_sparse_indices_to_vector_sparse_offsets.py +++ b/tests/test_block_sparse_indices_to_vector_sparse_offsets.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2023 by FlashInfer team. @@ -13,9 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch import flashinfer.page @@ -31,25 +31,20 @@ def test_block_sparse_indices_to_vector_sparse_offsets( if batch_size * kv_len > 1048576: pytest.skip("skip large test") num_blocks_per_row = (kv_len + block_size - 1) // block_size - - block_sparse_indices = torch.arange( - batch_size * num_blocks_per_row, device="cuda", dtype=torch.int32 - ) - block_sparse_indptr = torch.arange( - 0, - batch_size * num_blocks_per_row + 1, - num_blocks_per_row, - device="cuda", - dtype=torch.int32, + block_sparse_indices = paddle.arange( + dtype="int32", end=batch_size * num_blocks_per_row ) - vector_sparse_offsets_buf = torch.zeros( - batch_size * kv_len, device="cuda", dtype=torch.int32 + block_sparse_indptr = paddle.arange( + start=0, + end=batch_size * num_blocks_per_row + 1, + step=num_blocks_per_row, + dtype="int32", ) - vector_sparse_indptr = torch.arange( - 0, batch_size * kv_len + 1, kv_len, device="cuda", dtype=torch.int32 + vector_sparse_offsets_buf = paddle.zeros(shape=batch_size * kv_len, dtype="int32") + vector_sparse_indptr = paddle.arange( + start=0, end=batch_size * kv_len + 1, step=kv_len, dtype="int32" ) - kv_lens = torch.full((batch_size,), kv_len, device="cuda", dtype=torch.int32) - + kv_lens = paddle.full(shape=(batch_size,), fill_value=kv_len, dtype="int32") vector_sparse_offsets = ( flashinfer.page.block_sparse_indices_to_vector_sparse_offsets( block_sparse_indices, @@ -62,8 +57,6 @@ def test_block_sparse_indices_to_vector_sparse_offsets( block_size, ) ) - - # Check that the output is correct for i in range(batch_size): indices_i = block_sparse_indices[ i * num_blocks_per_row : (i + 1) * num_blocks_per_row @@ -71,13 +64,12 @@ def test_block_sparse_indices_to_vector_sparse_offsets( output_i = vector_sparse_offsets[ vector_sparse_indptr[i] : vector_sparse_indptr[i + 1] ].cpu() - output_ref_i = ( - indices_i[torch.arange(0, kv_len, dtype=torch.int32) // block_size] + indices_i[paddle.arange(start=0, end=kv_len, dtype="int32") // block_size] * stride_block - + (torch.arange(0, kv_len, dtype=torch.int32) % block_size) * stride_n + + paddle.arange(start=0, end=kv_len, dtype="int32") % block_size * stride_n ) - torch.testing.assert_close(output_i, output_ref_i) + assert paddle.allclose(x=output_i, y=output_ref_i).item(), "" if __name__ == "__main__": diff --git a/tests/test_bmm_fp8.py b/tests/test_bmm_fp8.py index 35d45150bb..58e55fadfd 100644 --- a/tests/test_bmm_fp8.py +++ b/tests/test_bmm_fp8.py @@ -1,47 +1,52 @@ +import sys + + +import paddle import pytest -import torch -import torch.nn.functional as F +from flashinfer.paddle_utils import * from flashinfer import autotune, bmm_fp8 -def to_float8(x, dtype=torch.float8_e4m3fn): - finfo = torch.finfo(dtype) - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) +def to_float8(x, dtype=paddle.float8_e4m3fn): + finfo = paddle.finfo(dtype=dtype) + min_val, max_val = tuple( + [ + paddle.amin(x, axis=None, keepdim=False), + paddle.max(x, axis=None, keepdim=False), + ] + ) + amax = paddle.maximum(x=min_val.abs(), y=max_val.abs()).clip(min=1e-12) scale = finfo.max / amax - x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) - return x_scl_sat.to(dtype), scale.float().reciprocal() + x_scl_sat = (x * scale).clip(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.astype(dtype="float32").reciprocal() @pytest.mark.parametrize("b", [1, 16]) @pytest.mark.parametrize("m", [48, 128]) @pytest.mark.parametrize("n", [80, 64]) @pytest.mark.parametrize("k", [64, 256]) -@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) -@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) -@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("input_dtype", [paddle.float8_e4m3fn, paddle.float8_e5m2]) +@pytest.mark.parametrize("mat2_dtype", [paddle.float8_e4m3fn, paddle.float8_e5m2]) +@pytest.mark.parametrize("res_dtype", ["bfloat16", "float16"]) @pytest.mark.parametrize("backend", ["cudnn", "cublas", "cutlass", "auto"]) @pytest.mark.parametrize("auto_tuning", [True, False]) def test_bmm_fp8(b, m, n, k, input_dtype, mat2_dtype, res_dtype, backend, auto_tuning): - if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2: +>>>>>> if input_dtype == paddle.float8_e5m2 and mat2_dtype == paddle.float8_e5m2: pytest.skip("Invalid combination: both input and mat2 are e5m2") - if input_dtype == torch.float8_e5m2 or mat2_dtype == torch.float8_e5m2: +>>>>>> if input_dtype == paddle.float8_e5m2 or mat2_dtype == paddle.float8_e5m2: if backend == "cutlass": pytest.skip("Invalid combination: cutlass does not support e5m2") if auto_tuning and backend != "cutlass": pytest.skip("Invalid combination: auto_tuning only supported for cutlass") - - input = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16) + input = paddle.randn(shape=[b, m, k], dtype="bfloat16") input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) - - # mat2 row major -> column major - mat2 = torch.randn([b, n, k], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) + mat2 = paddle.randn(shape=[b, n, k], dtype="bfloat16").transpose( + perm=dim2perm(paddle.randn(shape=[b, n, k], dtype="bfloat16").ndim, -2, -1) + ) mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype) - reference = torch.bmm(input, mat2) - - res = torch.empty([b, m, n], device="cuda", dtype=res_dtype) - + reference = paddle.bmm(x=input, y=mat2) + res = paddle.empty(shape=[b, m, n], dtype=res_dtype) with autotune(auto_tuning): bmm_fp8( input_fp8, @@ -52,8 +57,9 @@ def test_bmm_fp8(b, m, n, k, input_dtype, mat2_dtype, res_dtype, backend, auto_t res, backend=backend, ) - - cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) + cos_sim = paddle.nn.functional.cosine_similarity( + x1=reference.reshape(-1), x2=res.reshape(-1), axis=0 + ) assert cos_sim > 0.99 diff --git a/tests/test_create_ipc_buffer.py b/tests/test_create_ipc_buffer.py index bdf4cf2419..ca58de6df6 100644 --- a/tests/test_create_ipc_buffer.py +++ b/tests/test_create_ipc_buffer.py @@ -1,45 +1,33 @@ -# adapted from vllm - import ctypes import os import subprocess import sys +import paddle import pytest -import torch -import torch.distributed as dist import flashinfer.comm as comm from flashinfer.comm import CudaRTLibrary def _run_ipc_test(): - # Only run if we're inside a distributed context - if not dist.is_initialized(): - dist.init_process_group(backend="nccl") - - rank = dist.get_rank() - world_size = dist.get_world_size() - + if not paddle.distributed.is_initialized(): + paddle.distributed.init_parallel_env() + rank = paddle.distributed.get_rank() + world_size = paddle.distributed.get_world_size() cudart = CudaRTLibrary() cudart.cudaSetDevice(rank) - buffer_size_in_bytes = 1024 byte_value = rank - pointers = comm.create_shared_buffer(buffer_size_in_bytes) print(f"Rank {rank} init ipc buffer {pointers}", flush=True) - - dist.barrier() - torch.cuda.synchronize() - + paddle.distributed.barrier() + paddle.device.synchronize() for p in pointers: pointer = ctypes.c_void_p(p + rank * (buffer_size_in_bytes // world_size)) cudart.cudaMemset(pointer, byte_value, buffer_size_in_bytes // world_size) - - dist.barrier() - torch.cuda.synchronize() - + paddle.distributed.barrier() + paddle.device.synchronize() host_data = (ctypes.c_char * buffer_size_in_bytes)() for p in pointers: for cur_rank in range(world_size): @@ -50,24 +38,17 @@ def _run_ipc_test(): host_data, offset_pointer, buffer_size_in_bytes // world_size ) for i in range(buffer_size_in_bytes // world_size): - assert ord(host_data[i]) == cur_rank, ( - f"Rank {rank} failed to verify buffer {p}. " - f"Expected {cur_rank}, got {ord(host_data[i])}" - ) - + assert ( + ord(host_data[i]) == cur_rank + ), f"Rank {rank} failed to verify buffer {p}. Expected {cur_rank}, got {ord(host_data[i])}" print(f"Rank {rank} verified all buffers.\n", flush=True) - - dist.barrier() - torch.cuda.synchronize() + paddle.distributed.barrier() + paddle.device.synchronize() comm.free_shared_buffer(pointers) -# ------------------------------- -# Pytest Entrypoint (main test) -# ------------------------------- @pytest.mark.parametrize("world_size", [2, 4]) def test_ipc_distributed(world_size): - # Spawn self with torchrun script = os.path.abspath(__file__) result = subprocess.run( ["torchrun", f"--nproc_per_node={world_size}", script, "--run_ipc_test"], @@ -77,9 +58,6 @@ def test_ipc_distributed(world_size): assert result.returncode == 0 -# ------------------------------- -# Actual Test Logic (called by subprocess) -# ------------------------------- if __name__ == "__main__": if "--run_ipc_test" in sys.argv: _run_ipc_test() diff --git a/tests/test_cudnn_decode.py b/tests/test_cudnn_decode.py index ddf8798cec..1582ea586f 100644 --- a/tests/test_cudnn_decode.py +++ b/tests/test_cudnn_decode.py @@ -1,7 +1,7 @@ import math +import paddle import pytest -import torch import flashinfer @@ -13,36 +13,21 @@ @pytest.mark.parametrize("num_qo_heads", [32]) @pytest.mark.parametrize("is_cuda_graph_compatible", [True, False]) def test_cudnn_decode( - batch_size, - s_kv, - page_size, - num_kv_heads, - num_qo_heads, - is_cuda_graph_compatible, + batch_size, s_kv, page_size, num_kv_heads, num_qo_heads, is_cuda_graph_compatible ): - # test set up basics seed = 0 - torch.manual_seed(seed) + paddle.seed(seed=seed) device = "cuda:0" - s_qo = 1 head_dim = 128 - - # Initialize Q tensor - # Since the number of tokens is 1, batch size is the token count - q = torch.randn( - batch_size, num_qo_heads, head_dim, device=device, dtype=torch.bfloat16 - ) - - # Initialize KV Cache + q = paddle.randn(shape=[batch_size, num_qo_heads, head_dim], dtype="bfloat16") num_pages_per_seq = (s_kv + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size - - kv_cache_shape = (total_num_pages, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, dtype=torch.bfloat16).to(device) + kv_cache_shape = total_num_pages, 2, num_kv_heads, page_size, head_dim + kv_cache = paddle.randn(shape=kv_cache_shape, dtype="bfloat16").to(device) kv_cache = kv_cache.as_strided( - kv_cache.shape, - ( + shape=tuple(kv_cache.shape), + stride=( 2 * page_size * num_kv_heads * head_dim, page_size * num_kv_heads * head_dim, head_dim, @@ -52,38 +37,37 @@ def test_cudnn_decode( ) k_cache_view = kv_cache[:, 0, :, :, :] v_cache_view = kv_cache[:, 1, :, :, :] - v_cache = v_cache_view.as_strided( - v_cache_view.shape, - (2 * page_size * num_kv_heads * head_dim, head_dim, num_kv_heads * head_dim, 1), + shape=tuple(v_cache_view.shape), + stride=( + 2 * page_size * num_kv_heads * head_dim, + head_dim, + num_kv_heads * head_dim, + 1, + ), ) k_cache = k_cache_view.as_strided( - k_cache_view.shape, - (2 * page_size * num_kv_heads * head_dim, head_dim, num_kv_heads * head_dim, 1), + shape=tuple(k_cache_view.shape), + stride=( + 2 * page_size * num_kv_heads * head_dim, + head_dim, + num_kv_heads * head_dim, + 1, + ), ) - - # Now initialize the page tables - block_tables = torch.tensor( - [ - [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + block_tables = paddle.to_tensor( + data=[ + [(k + i * num_pages_per_seq) for k in range(num_pages_per_seq)] for i in range(batch_size) ], - dtype=torch.int, - device=device, + dtype="int32", + place=device, ) - - # Initialize scale - scale = float(1.0 / (head_dim**0.5)) - - # Actual sequence lengths (should be randomized across batches. ) - actual_seq_lens_kv = torch.randint( - 0, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + scale = float(1.0 / head_dim**0.5) + actual_seq_lens_kv = paddle.randint( + low=0, high=s_kv + 1, shape=(batch_size, 1, 1, 1), dtype="int32" ) - - ragged_q = torch.arange(0, batch_size + 1, device=device) * ( - num_qo_heads * head_dim - ) - + ragged_q = paddle.arange(start=0, end=batch_size + 1) * (num_qo_heads * head_dim) workspace_buffer_size = math.ceil( ( batch_size * s_qo * num_qo_heads * head_dim * 4 @@ -91,13 +75,8 @@ def test_cudnn_decode( ) / (1024 * 1024) ) * (1024 * 1024) - workspace_buffer_size = max(workspace_buffer_size, 128 * 1024 * 1024) - - workspace_buffer = torch.empty( - workspace_buffer_size, dtype=torch.int8, device=device - ) - + workspace_buffer = paddle.empty(shape=workspace_buffer_size, dtype="int8") output = flashinfer.decode.cudnn_batch_decode_with_kv_cache( q, k_cache, @@ -111,50 +90,39 @@ def test_cudnn_decode( batch_offsets_q=ragged_q, batch_offsets_o=ragged_q, ) - actual_seq_lens_kv_device = actual_seq_lens_kv.to(device) - kv_indptr = ( - torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum( - (actual_seq_lens_kv_device.flatten() + page_size - 1) // page_size, - dim=0, + paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum( + x=(actual_seq_lens_kv_device.flatten() + page_size - 1) + // page_size, + axis=0, ), ] ) - .int() + .astype(dtype="int32") .to(device) ) - - # kv_indices - kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32) + kv_indices = paddle.zeros(shape=kv_indptr[-1], dtype="int32") for i in range(len(kv_indptr) - 1): start_idx = kv_indptr[i] end_idx = kv_indptr[i + 1] - kv_indices[start_idx:end_idx] = torch.arange( - i * num_pages_per_seq, - i * num_pages_per_seq + (end_idx - start_idx), - device=device, + kv_indices[start_idx:end_idx] = paddle.arange( + start=i * num_pages_per_seq, + end=i * num_pages_per_seq + (end_idx - start_idx), ) - - # kv_last_page_len kv_last_page_len = ( - torch.where( - actual_seq_lens_kv_device.flatten() % page_size == 0, - torch.full((batch_size,), page_size, device=device), - actual_seq_lens_kv_device.flatten() % page_size, + paddle.where( + condition=actual_seq_lens_kv_device.flatten() % page_size == 0, + x=paddle.full(shape=(batch_size,), fill_value=page_size), + y=actual_seq_lens_kv_device.flatten() % page_size, ) - .int() + .astype(dtype="int32") .to(device) ) - - # Workspace buffer - workspace_buffer_ref = torch.empty( - 128 * 1024 * 1024, dtype=torch.int8, device=device - ) - + workspace_buffer_ref = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer_ref, "HND") wrapper.plan( kv_indptr, @@ -164,9 +132,7 @@ def test_cudnn_decode( num_kv_heads, head_dim, page_size, - q_data_type=torch.bfloat16, + q_data_type="bfloat16", ) - output_ref = wrapper.run(q, kv_cache) - - torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-2) + assert paddle.allclose(x=output, y=output_ref, rtol=0.01, atol=0.01).item(), "" diff --git a/tests/test_cudnn_prefill.py b/tests/test_cudnn_prefill.py index a4db634098..d8ba8a6ce2 100644 --- a/tests/test_cudnn_prefill.py +++ b/tests/test_cudnn_prefill.py @@ -1,5 +1,9 @@ +import sys + + +import paddle import pytest -import torch +from flashinfer.paddle_utils import * import flashinfer @@ -27,40 +31,32 @@ def test_cudnn_prefill( head_dim = 128 if s_qo > s_kv: pytest.skip("s_qo > s_kv, skipping test") - - # test set up basics seed = 1 - torch.manual_seed(seed) + paddle.seed(seed=seed) device = "cuda:0" - - actual_seq_lens_q = torch.randint( - 1, s_qo + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + actual_seq_lens_q = paddle.randint( + low=1, high=s_qo + 1, shape=(batch_size, 1, 1, 1), dtype="int32" ) - actual_seq_lens_kv = torch.randint( - s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + actual_seq_lens_kv = paddle.randint( + low=s_qo, high=s_kv + 1, shape=(batch_size, 1, 1, 1), dtype="int32" ) - - cumsum_s_qo = torch.sum(actual_seq_lens_q) - q = torch.randn( - cumsum_s_qo, num_qo_heads, head_dim, device=device, dtype=torch.bfloat16 - ) - - q_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q.view(-1), dim=0) * head_dim * num_qo_heads, + cumsum_s_qo = paddle.sum(x=actual_seq_lens_q) + q = paddle.randn(shape=[cumsum_s_qo, num_qo_heads, head_dim], dtype="bfloat16") + q_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_q.view(-1), axis=0) + * head_dim + * num_qo_heads, ] - ).int() - - # Initialize KV Cache + ).astype(dtype="int32") num_pages_per_seq = (s_kv + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size - - kv_cache_shape = (total_num_pages, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, dtype=torch.bfloat16).to(device) + kv_cache_shape = total_num_pages, 2, num_kv_heads, page_size, head_dim + kv_cache = paddle.randn(shape=kv_cache_shape, dtype="bfloat16").to(device) kv_cache = kv_cache.as_strided( - kv_cache.shape, - ( + shape=tuple(kv_cache.shape), + stride=( 2 * page_size * num_kv_heads * head_dim, page_size * num_kv_heads * head_dim, head_dim, @@ -70,59 +66,55 @@ def test_cudnn_prefill( ) k_cache_view = kv_cache[:, 0, :, :, :] v_cache_view = kv_cache[:, 1, :, :, :] - v_cache = v_cache_view.as_strided( - v_cache_view.shape, - (2 * page_size * num_kv_heads * head_dim, head_dim, num_kv_heads * head_dim, 1), + shape=tuple(v_cache_view.shape), + stride=( + 2 * page_size * num_kv_heads * head_dim, + head_dim, + num_kv_heads * head_dim, + 1, + ), ) k_cache = k_cache_view.as_strided( - k_cache_view.shape, - (2 * page_size * num_kv_heads * head_dim, head_dim, num_kv_heads * head_dim, 1), + shape=tuple(k_cache_view.shape), + stride=( + 2 * page_size * num_kv_heads * head_dim, + head_dim, + num_kv_heads * head_dim, + 1, + ), ) - - kv_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum( - (actual_seq_lens_kv.flatten() + page_size - 1) // page_size, - dim=0, + kv_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum( + x=(actual_seq_lens_kv.flatten() + page_size - 1) // page_size, axis=0 ), ] - ).int() - - # kv_indices - kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32) + ).astype(dtype="int32") + kv_indices = paddle.zeros(shape=kv_indptr[-1], dtype="int32") for i in range(len(kv_indptr) - 1): start_idx = kv_indptr[i] end_idx = kv_indptr[i + 1] - kv_indices[start_idx:end_idx] = torch.arange( - i * num_pages_per_seq, - i * num_pages_per_seq + (end_idx - start_idx), - device=device, + kv_indices[start_idx:end_idx] = paddle.arange( + start=i * num_pages_per_seq, + end=i * num_pages_per_seq + (end_idx - start_idx), ) - - # kv_last_page_len - kv_last_page_len = torch.where( - actual_seq_lens_kv.flatten() % page_size == 0, - torch.full((batch_size,), page_size, device=device), - actual_seq_lens_kv.flatten() % page_size, - ).int() - - # Now initialize the page tables - block_tables = torch.tensor( - [ - [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + kv_last_page_len = paddle.where( + condition=actual_seq_lens_kv.flatten() % page_size == 0, + x=paddle.full(shape=(batch_size,), fill_value=page_size), + y=actual_seq_lens_kv.flatten() % page_size, + ).astype(dtype="int32") + block_tables = paddle.to_tensor( + data=[ + [(k + i * num_pages_per_seq) for k in range(num_pages_per_seq)] for i in range(batch_size) ], - dtype=torch.int, - device=device, + dtype="int32", + place=device, ) - - # Initialize scale - scale = float(1.0 / (head_dim**0.5)) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - + scale = float(1.0 / head_dim**0.5) + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") wrapper_cudnn = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, "NHD", backend="cudnn" ) @@ -137,7 +129,7 @@ def test_cudnn_prefill( page_size, pos_encoding_mode="NONE", causal=causal, - q_data_type=torch.bfloat16, + q_data_type="bfloat16", seq_lens=actual_seq_lens_kv, seq_lens_q=actual_seq_lens_q, sm_scale=scale, @@ -145,21 +137,14 @@ def test_cudnn_prefill( max_sequence_kv=s_kv, block_tables=block_tables, ) - output = wrapper_cudnn.run(q, (k_cache, v_cache)) - - qo_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q.view(-1), dim=0), + qo_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_q.view(-1), axis=0), ] - ).int() - - # Workspace buffer - workspace_buffer_ref = torch.empty( - 128 * 1024 * 1024, dtype=torch.int8, device=device - ) - + ).astype(dtype="int32") + workspace_buffer_ref = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer_ref, "HND" ) @@ -174,9 +159,7 @@ def test_cudnn_prefill( page_size, pos_encoding_mode="NONE", causal=causal, - q_data_type=torch.bfloat16, + q_data_type="bfloat16", ) - output_ref = wrapper.run(q, kv_cache) - - torch.testing.assert_close(output, output_ref, atol=2e-3, rtol=1e-2) + assert paddle.allclose(x=output, y=output_ref, atol=0.002, rtol=0.01).item(), "" diff --git a/tests/test_cudnn_prefill_deepseek.py b/tests/test_cudnn_prefill_deepseek.py index 8362934ece..e4056f06d1 100644 --- a/tests/test_cudnn_prefill_deepseek.py +++ b/tests/test_cudnn_prefill_deepseek.py @@ -1,5 +1,9 @@ +import sys + + +import paddle import pytest -import torch +from flashinfer.paddle_utils import * import flashinfer @@ -15,97 +19,66 @@ def test_cudnn_prefill_deepseek( ): if s_qo > s_kv: pytest.skip("s_qo > s_kv, skipping test as causal") - head_dim_qk = 192 head_dim_vo = 128 - return_lse = True - - # test set up basics seed = 0 - torch.manual_seed(seed) + paddle.seed(seed=seed) device = "cuda:0" - - actual_seq_lens_q = torch.randint( - 1, s_qo + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device - ) - - actual_seq_lens_kv = torch.randint( - s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + actual_seq_lens_q = paddle.randint( + low=1, high=s_qo + 1, shape=(batch_size, 1, 1, 1), dtype="int32" ) - - cumsum_s_qo = torch.sum(actual_seq_lens_q) - - q = torch.randn( - cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=torch.bfloat16 + actual_seq_lens_kv = paddle.randint( + low=s_qo, high=s_kv + 1, shape=(batch_size, 1, 1, 1), dtype="int32" ) - - q_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q.view(-1), dim=0) + cumsum_s_qo = paddle.sum(x=actual_seq_lens_q) + q = paddle.randn(shape=[cumsum_s_qo, num_qo_heads, head_dim_qk], dtype="bfloat16") + q_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_q.view(-1), axis=0) * head_dim_qk * num_qo_heads, ] - ).int() - - k_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_kv.view(-1), dim=0) + ).astype(dtype="int32") + k_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_kv.view(-1), axis=0) * head_dim_qk * num_kv_heads, ] - ).int() - - v_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_kv.view(-1), dim=0) + ).astype(dtype="int32") + v_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_kv.view(-1), axis=0) * head_dim_vo * num_kv_heads, ] - ).int() - - o_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q.view(-1), dim=0) + ).astype(dtype="int32") + o_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_q.view(-1), axis=0) * head_dim_vo * num_qo_heads, ] - ).int() - - batch_offsets_stats = torch.cat( - [ - torch.zeros( - 1, device=actual_seq_lens_q.device, dtype=actual_seq_lens_q.dtype - ), - torch.cumsum(actual_seq_lens_q.flatten(), dim=0) * num_qo_heads, + ).astype(dtype="int32") + batch_offsets_stats = paddle.concat( + x=[ + paddle.zeros(shape=[1], dtype=actual_seq_lens_q.dtype), + paddle.cumsum(x=actual_seq_lens_q.flatten(), axis=0) * num_qo_heads, ] ).cuda() - - k_cache = torch.randn( - batch_size * s_kv, - num_kv_heads, - head_dim_qk, - device=device, - dtype=torch.bfloat16, + k_cache = paddle.randn( + shape=[batch_size * s_kv, num_kv_heads, head_dim_qk], dtype="bfloat16" ) - v_cache = torch.randn( - batch_size * s_kv, - num_kv_heads, - head_dim_vo, - device=device, - dtype=torch.bfloat16, + v_cache = paddle.randn( + shape=[batch_size * s_kv, num_kv_heads, head_dim_vo], dtype="bfloat16" ) - - # Initialize scale - scale = float(1.0 / (head_dim_qk**0.5)) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - - # output = torch.zeros_like(q) + scale = float(1.0 / head_dim_qk**0.5) + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") output, lse = flashinfer.prefill.cudnn_batch_prefill_with_kv_cache( q, k_cache, @@ -125,30 +98,20 @@ def test_cudnn_prefill_deepseek( batch_offsets_stats=batch_offsets_stats, is_cuda_graph_compatible=True, ) - - qo_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q.view(-1), dim=0), + qo_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_q.view(-1), axis=0), ] - ).int() - - # kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * s_kv - - # Create kv_indptr as cumulative sum of actual_seq_lens_kv - kv_indptr = torch.cat( - [ - torch.tensor( - [0], - device=device, - ), - torch.cumsum(actual_seq_lens_kv.view(-1), dim=0), + ).astype(dtype="int32") + kv_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_kv.view(-1), axis=0), ] - ).int() - + ).astype(dtype="int32") wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), - kv_layout="NHD", + paddle.empty(shape=128 * 1024 * 1024, dtype="uint8"), kv_layout="NHD" ) wrapper.plan( qo_indptr, @@ -159,14 +122,8 @@ def test_cudnn_prefill_deepseek( head_dim_vo=head_dim_vo, causal=causal, sm_scale=scale, - q_data_type=torch.bfloat16, - kv_data_type=torch.bfloat16, + q_data_type="bfloat16", + kv_data_type="bfloat16", ) output_ref, lse_ref = wrapper.run(q, k_cache, v_cache, return_lse=True) - - torch.testing.assert_close( - output, - output_ref, - atol=1e-2, - rtol=1e-2, - ) + assert paddle.allclose(x=output, y=output_ref, atol=0.01, rtol=0.01).item(), "" diff --git a/tests/test_cute_dsl_blockscaled_gemm.py b/tests/test_cute_dsl_blockscaled_gemm.py index acd59bbd09..dd918210ee 100644 --- a/tests/test_cute_dsl_blockscaled_gemm.py +++ b/tests/test_cute_dsl_blockscaled_gemm.py @@ -1,26 +1,25 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ This is the test file for MaskedBatchedMatmulCuteDSL kernel. `test_blockscaled_gemm_python_interface` is the python interface test. For pytorch DLFW, refer to this. """ - from typing import Tuple import cutlass import cutlass.cute as cute import cutlass.torch as cutlass_torch import pytest -import torch from cutlass.cute.runtime import from_dlpack from flashinfer.cute_dsl.blockscaled_gemm import ( - Sm100BlockScaledPersistentDenseGemmKernel, # not used in python interface - grouped_gemm_nt_masked, # deepgemm-like python interface for DLFW integration - create_scale_factor_tensor, -) -from flashinfer.cute_dsl.utils import ( - get_cutlass_dtype, - is_cute_dsl_available, -) + Sm100BlockScaledPersistentDenseGemmKernel, create_scale_factor_tensor, + grouped_gemm_nt_masked) +from flashinfer.cute_dsl.utils import get_cutlass_dtype, is_cute_dsl_available @pytest.mark.skipif( @@ -57,7 +56,7 @@ @pytest.mark.parametrize("mma_tiler_mn", [(128, 128)]) @pytest.mark.parametrize("cluster_shape_mn", [(1, 1)]) @pytest.mark.parametrize("sm_count", [132, None]) -@pytest.mark.parametrize("tolerance", [1e-01]) +@pytest.mark.parametrize("tolerance", [0.1]) @pytest.mark.parametrize("iterations", [3]) def test_blockscaled_gemm_python_interface( lm: Tuple[int, int], @@ -77,8 +76,8 @@ def test_blockscaled_gemm_python_interface( tolerance: float, iterations: int, ): - torch.manual_seed(42) - device = torch.device("cuda:0") + paddle.seed(seed=42) + device = device2str("cuda:0") l, m = lm k, n = kn if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( @@ -99,13 +98,10 @@ def test_blockscaled_gemm_python_interface( pytest.skip( f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" ) - if not (a_major == "k" and b_major == "k" and c_major == "n"): - # not supported since we try to align deepgemm for now pytest.skip( f"Skip non deepgemm-like cases {a_major}, {b_major}, {c_major}. Might be added later" ) - a_ref = cutlass_torch.matrix( l, m, k, a_major == "m", cutlass.Float32, device=device ) @@ -115,59 +111,41 @@ def test_blockscaled_gemm_python_interface( c_ref = cutlass_torch.matrix( l, m, n, c_major == "m", cutlass.Float32, device=device ) - a_tensor, a_torch = cutlass_torch.cute_tensor_like( - a_ref, - get_cutlass_dtype(ab_dtype), - is_dynamic_layout=True, - assumed_align=16, + a_ref, get_cutlass_dtype(ab_dtype), is_dynamic_layout=True, assumed_align=16 ) b_tensor, b_torch = cutlass_torch.cute_tensor_like( - b_ref, - get_cutlass_dtype(ab_dtype), - is_dynamic_layout=True, - assumed_align=16, + b_ref, get_cutlass_dtype(ab_dtype), is_dynamic_layout=True, assumed_align=16 ) c_tensor, c_torch = cutlass_torch.cute_tensor_like( - c_ref, - get_cutlass_dtype(c_dtype), - is_dynamic_layout=True, - assumed_align=16, - ) - alpha_tensor = ( - torch.randn(l, dtype=torch.float32, device=device) if fuse_alpha else None + c_ref, get_cutlass_dtype(c_dtype), is_dynamic_layout=True, assumed_align=16 ) - - # for deepgemm-like python interface + alpha_tensor = paddle.randn(shape=l, dtype="float32") if fuse_alpha else None if ab_dtype == "float4_e2m1fn": - m, k, l = a_torch.shape - n, k, l = b_torch.shape - # slice into half after flatten - half_len_a = a_torch.numel() // 2 - half_len_b = b_torch.numel() // 2 + m, k, l = tuple(a_torch.shape) + n, k, l = tuple(b_torch.shape) + half_len_a = a_torch.size // 2 + half_len_b = b_torch.size // 2 a_torch = ( - a_torch.permute(2, 0, 1) + a_torch.transpose(perm=[2, 0, 1]) .flatten()[:half_len_a] .reshape(l, m, k // 2) - .permute(1, 2, 0) + .transpose(perm=[1, 2, 0]) ) b_torch = ( - b_torch.permute(2, 0, 1) + b_torch.transpose(perm=[2, 0, 1]) .flatten()[:half_len_b] .reshape(l, n, k // 2) - .permute(1, 2, 0) + .transpose(perm=[1, 2, 0]) ) - sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor( l, m, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device ) sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor( l, n, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device ) - masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device) - + masked_m_tensor = paddle.randint(low=0, high=m, shape=(l,), dtype="int32") for _ in range(iterations): - # deepgemm-like python interface: fp4 packed, for DLFW integration grouped_gemm_nt_masked( (a_torch, sfa_torch), (b_torch, sfb_torch), @@ -183,55 +161,45 @@ def test_blockscaled_gemm_python_interface( alpha_dtype=alpha_dtype, sm_count=sm_count, ) - - # compute ref output if not fuse_alpha: - alpha_tensor = torch.ones(l, dtype=torch.float32, device=device) - res_a = torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref) - res_b = torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref) - ref = torch.einsum("mkl,nkl->mnl", res_a, res_b) - ref = torch.einsum("mnl,l->mnl", ref, alpha_tensor) - - # Convert c back to f32 for comparison. + alpha_tensor = paddle.ones(shape=l, dtype="float32") + res_a = paddle.einsum("mkl,mkl->mkl", a_ref, sfa_ref) + res_b = paddle.einsum("nkl,nkl->nkl", b_ref, sfb_ref) + ref = paddle.einsum("mkl,nkl->mnl", res_a, res_b) + ref = paddle.einsum("mnl,l->mnl", ref, alpha_tensor) cute.testing.convert( c_tensor, from_dlpack(c_ref, assumed_align=16).mark_layout_dynamic( - leading_dim=(1 if c_major == "n" else 0) + leading_dim=1 if c_major == "n" else 0 ), ) - if c_dtype in ("float32", "float16", "bfloat16"): for i in range(l): - # skip testing c_ref & ref - torch.testing.assert_close( - c_ref[: masked_m_tensor[i].item(), :, i], - ref[: masked_m_tensor[i].item(), :, i], + assert paddle.allclose( + x=c_ref[: masked_m_tensor[i].item(), :, i], + y=ref[: masked_m_tensor[i].item(), :, i], atol=tolerance, - rtol=1e-02, - ) + rtol=0.01, + ).item(), "" elif c_dtype in ("float8_e5m2", "float8_e4m3fn"): - # Convert ref : f32 -> f8 -> f32 - ref_f8_ = torch.empty(*(l, m, n), dtype=torch.uint8, device=device).permute( - 1, 2, 0 - ) + ref_f8_ = paddle.empty(shape=(l, m, n), dtype="uint8").transpose(perm=[1, 2, 0]) ref_f8 = from_dlpack(ref_f8_, assumed_align=16).mark_layout_dynamic( leading_dim=1 ) ref_f8.element_type = get_cutlass_dtype(c_dtype) - ref = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0) + ref = ref.transpose(perm=[2, 0, 1]).contiguous().transpose(perm=[1, 2, 0]) ref_tensor = from_dlpack(ref, assumed_align=16).mark_layout_dynamic( leading_dim=1 ) cute.testing.convert(ref_tensor, ref_f8) cute.testing.convert(ref_f8, ref_tensor) for i in range(l): - # skip testing c_ref & ref - torch.testing.assert_close( - c_ref[: masked_m_tensor[i].item(), :, i], - ref[: masked_m_tensor[i].item(), :, i], + assert paddle.allclose( + x=c_ref[: masked_m_tensor[i].item(), :, i], + y=ref[: masked_m_tensor[i].item(), :, i], atol=tolerance, - rtol=1e-02, - ) + rtol=0.01, + ).item(), "" if __name__ == "__main__": @@ -249,7 +217,7 @@ def test_blockscaled_gemm_python_interface( alpha_dtype="float32", mma_tiler_mn=(128, 128), cluster_shape_mn=(2, 1), - tolerance=1e-01, + tolerance=0.1, iterations=3, sm_count=132, ) diff --git a/tests/test_decode_fp8_calibration_scale.py b/tests/test_decode_fp8_calibration_scale.py index 0408391e16..0d00c00cbc 100644 --- a/tests/test_decode_fp8_calibration_scale.py +++ b/tests/test_decode_fp8_calibration_scale.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2023 by FlashInfer team. @@ -13,9 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch import flashinfer @@ -23,10 +23,10 @@ @pytest.mark.parametrize("kv_len", [7, 19, 39, 1170, 39275]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) -@pytest.mark.parametrize("head_dim", [128]) # [64, 128, 256]) -@pytest.mark.parametrize("kv_layout", ["NHD"]) # ["HND", "NHD"]) -@pytest.mark.parametrize("pos_encoding_mode", ["NONE"]) # , "ROPE_LLAMA", "ALIBI"]) -@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("kv_layout", ["NHD"]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE"]) +@pytest.mark.parametrize("fp8_dtype", [paddle.float8_e4m3fn]) def test_single_decode_fp8_calibration_scale( kv_len, num_kv_heads, @@ -36,28 +36,26 @@ def test_single_decode_fp8_calibration_scale( pos_encoding_mode, fp8_dtype, ): - torch.manual_seed(42) - q = torch.randn(num_qo_heads, head_dim, dtype=torch.float16).to(0) + paddle.seed(seed=42) + q = paddle.randn(shape=[num_qo_heads, head_dim], dtype="float16").to(0) k = ( - torch.randn(kv_len, num_kv_heads, head_dim, dtype=torch.float16).to(0) + paddle.randn(shape=[kv_len, num_kv_heads, head_dim], dtype="float16").to(0) if kv_layout == "NHD" - else torch.randn(num_kv_heads, kv_len, head_dim).to(0) + else paddle.randn(shape=[num_kv_heads, kv_len, head_dim]).to(0) ) v = ( - 0.1 * torch.randn(kv_len, num_kv_heads, head_dim, dtype=torch.float16).to(0) + 0.1 + * paddle.randn(shape=[kv_len, num_kv_heads, head_dim], dtype="float16").to(0) if kv_layout == "NHD" - else 0.1 * torch.randn(num_kv_heads, kv_len, head_dim).to(0) + else 0.1 * paddle.randn(shape=[num_kv_heads, kv_len, head_dim]).to(0) ) - o_fp16 = flashinfer.single_decode_with_kv_cache( q, k, v, kv_layout=kv_layout, pos_encoding_mode=pos_encoding_mode ) - k_scale = k.amax().item() / 256 v_scale = v.amax().item() / 256 k_fp8 = (k / k_scale).to(fp8_dtype) v_fp8 = (v / v_scale).to(fp8_dtype) - o_fp8 = flashinfer.single_decode_with_kv_cache( q, k_fp8, @@ -67,8 +65,7 @@ def test_single_decode_fp8_calibration_scale( k_scale=k_scale, v_scale=v_scale, ) - - torch.testing.assert_close(o_fp16, o_fp8, atol=1e-2, rtol=2e-2) + assert paddle.allclose(x=o_fp16, y=o_fp8, atol=0.01, rtol=0.02).item(), "" @pytest.mark.parametrize("batch_size", [12, 17]) @@ -79,7 +76,7 @@ def test_single_decode_fp8_calibration_scale( @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("dtype", [paddle.float8_e4m3fn, paddle.float8_e5m2]) def test_batch_decode_with_paged_kv_cache_fp8_calibration_scale( batch_size, kv_len, @@ -91,28 +88,32 @@ def test_batch_decode_with_paged_kv_cache_fp8_calibration_scale( pos_encoding_mode, dtype, ): - torch.manual_seed(42) - q = torch.randn(batch_size, num_qo_heads, head_dim, dtype=torch.float16).to(0) + paddle.seed(seed=42) + q = paddle.randn(shape=[batch_size, num_qo_heads, head_dim], dtype="float16").to(0) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( 0.1 - * torch.randn( - total_num_pages, 2, num_kv_heads, page_size, head_dim, dtype=torch.float16 + * paddle.randn( + shape=[total_num_pages, 2, num_kv_heads, page_size, head_dim], + dtype="float16", ).to(0) if kv_layout == "HND" else 0.1 - * torch.randn( - total_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16 + * paddle.randn( + shape=[total_num_pages, 2, page_size, num_kv_heads, head_dim], + dtype="float16", ).to(0) ) - kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq - kv_indices = torch.arange(0, total_num_pages).to(0).int() - kv_last_page_len = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + kv_indptr = ( + paddle.arange(start=0, end=batch_size + 1).to(0).astype(dtype="int32") + * num_pages_per_seq + ) + kv_indices = paddle.arange(start=0, end=total_num_pages).to(0).astype(dtype="int32") + kv_last_page_len = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" ).to(0) - - workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + workspace_buffer = paddle.empty(shape=32 * 1024 * 1024, dtype="int8").to(0) wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout) wrapper.plan( kv_indptr, @@ -123,19 +124,16 @@ def test_batch_decode_with_paged_kv_cache_fp8_calibration_scale( head_dim, page_size, pos_encoding_mode=pos_encoding_mode, - data_type=torch.float16, - q_data_type=torch.float16, + data_type="float16", + q_data_type="float16", ) o_fp16 = wrapper.run(q, kv_data) - - k_data, v_data = torch.chunk(kv_data, 2, dim=1) + k_data, v_data = paddle.chunk(x=kv_data, chunks=2, axis=1) k_scale = k_data.amax().item() / 256 v_scale = v_data.amax().item() / 256 - k_fp8 = (k_data / k_scale).to(dtype) v_fp8 = (v_data / v_scale).to(dtype) - kv_data_fp8 = torch.cat([k_fp8, v_fp8], dim=1) - + kv_data_fp8 = paddle.concat(x=[k_fp8, v_fp8], axis=1) wrapper.plan( kv_indptr, kv_indices, @@ -146,17 +144,16 @@ def test_batch_decode_with_paged_kv_cache_fp8_calibration_scale( page_size, pos_encoding_mode=pos_encoding_mode, data_type=dtype, - q_data_type=torch.float16, + q_data_type="float16", ) o_fp8 = wrapper.run(q, kv_data_fp8.to(dtype), k_scale=k_scale, v_scale=v_scale) - - torch.testing.assert_close(o_fp16, o_fp8, atol=1e-2, rtol=2e-1) + assert paddle.allclose(x=o_fp16, y=o_fp8, atol=0.01, rtol=0.2).item(), "" if __name__ == "__main__": test_single_decode_fp8_calibration_scale( - 1170, 4, 32, 128, "NHD", "NONE", torch.float8_e4m3fn + 1170, 4, 32, 128, "NHD", "NONE", paddle.float8_e4m3fn ) test_batch_decode_with_paged_kv_cache_fp8_calibration_scale( - 12, 54, 1, 4, 4, 128, "NHD", "NONE", torch.float8_e5m2 +>>>>>> 12, 54, 1, 4, 4, 128, "NHD", "NONE", paddle.float8_e5m2 ) diff --git a/tests/test_decode_prefill_lse.py b/tests/test_decode_prefill_lse.py index e6a238d8a9..b074b892ef 100644 --- a/tests/test_decode_prefill_lse.py +++ b/tests/test_decode_prefill_lse.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,25 +15,27 @@ See the License for the specific language governing permissions and limitations under the License. """ - -import torch - import flashinfer def test_mlc_failed_case(): kv_layout = "HND" - kv_indptr_1 = torch.tensor([0, 0, 9]).int().to(0) - kv_indices_1 = torch.tensor([3, 4, 5, 6, 7, 8, 9, 10, 11]).int().to(0) - kv_last_page_len_1 = torch.tensor([0, 1]).int().to(0) + kv_indptr_1 = paddle.to_tensor(data=[0, 0, 9]).astype(dtype="int32").to(0) + kv_indices_1 = ( + paddle.to_tensor(data=[3, 4, 5, 6, 7, 8, 9, 10, 11]).astype(dtype="int32").to(0) + ) + kv_last_page_len_1 = paddle.to_tensor(data=[0, 1]).astype(dtype="int32").to(0) num_qo_heads = 32 num_kv_heads = 32 page_size = 16 head_dim = 128 - q = torch.randn(2, num_qo_heads, head_dim).to(0).half() - kv_data = torch.randn(12, 2, num_kv_heads, page_size, head_dim).to(0).half() - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + q = paddle.randn(shape=[2, num_qo_heads, head_dim]).to(0).astype(dtype="float16") + kv_data = ( + paddle.randn(shape=[12, 2, num_kv_heads, page_size, head_dim]) + .to(0) + .astype(dtype="float16") + ) + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8").to(0) wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout) wrapper.plan( kv_indptr_1, @@ -42,11 +46,10 @@ def test_mlc_failed_case(): head_dim, page_size, pos_encoding_mode="NONE", - data_type=torch.float16, - q_data_type=torch.float16, + data_type="float16", + q_data_type="float16", ) o_1, lse_1 = wrapper.run_return_lse(q, kv_data) - wrapper_tensor_cores = flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, use_tensor_cores=True ) @@ -59,16 +62,14 @@ def test_mlc_failed_case(): head_dim, page_size, pos_encoding_mode="NONE", - data_type=torch.float16, - q_data_type=torch.float16, + data_type="float16", + q_data_type="float16", ) o_1_tc, lse_1_tc = wrapper_tensor_cores.run_return_lse(q, kv_data) - print(lse_1, lse_1_tc) print(o_1, o_1_tc) - - torch.testing.assert_close(lse_1, lse_1_tc, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(o_1, o_1_tc, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=lse_1, y=lse_1_tc, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=o_1, y=o_1_tc, rtol=0.001, atol=0.001).item(), "" if __name__ == "__main__": diff --git a/tests/test_deepseek_mla.py b/tests/test_deepseek_mla.py index e48ac259ee..39f0b7bf7e 100644 --- a/tests/test_deepseek_mla.py +++ b/tests/test_deepseek_mla.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2023 by FlashInfer team. @@ -13,20 +19,16 @@ See the License for the specific language governing permissions and limitations under the License. """ - import math import pytest -import torch from conftest import clear_cuda_cache import flashinfer from flashinfer.jit import build_jit_specs -from flashinfer.jit.attention import ( - gen_batch_mla_module, - gen_batch_prefill_module, - gen_single_prefill_module, -) +from flashinfer.jit.attention import (gen_batch_mla_module, + gen_batch_prefill_module, + gen_single_prefill_module) from flashinfer.utils import is_sm90a_supported, is_sm100a_supported @@ -35,15 +37,14 @@ def warmup_jit(): try: modules = [] for backend in ["fa2", "fa3"]: - if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): + if backend == "fa3" and not is_sm90a_supported(device2str("cuda")): continue - modules.append( gen_single_prefill_module( backend, - torch.float16, - torch.float16, - torch.float16, + "float16", + "float16", + "float16", 192, 128, 0, @@ -52,18 +53,16 @@ def warmup_jit(): False, ) ) - for backend in ["fa2", "fa3"]: - if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): + if backend == "fa3" and not is_sm90a_supported(device2str("cuda")): continue - modules.append( gen_batch_prefill_module( backend, - torch.float16, - torch.float16, - torch.float16, - torch.int32, + "float16", + "float16", + "float16", + "int32", 192, 128, 0, @@ -72,27 +71,16 @@ def warmup_jit(): False, ) ) - for backend in ["fa2", "fa3"]: - if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): + if backend == "fa3" and not is_sm90a_supported(device2str("cuda")): continue - modules.append( gen_batch_mla_module( - backend, - torch.float16, - torch.float16, - torch.float16, - torch.int32, - 512, - 64, - False, + backend, "float16", "float16", "float16", "int32", 512, 64, False ) ) - build_jit_specs(modules, verbose=False) except Exception as e: - # abort the test session if warmup fails pytest.exit(str(e)) finally: yield @@ -100,47 +88,54 @@ def warmup_jit(): def attention_ref( batch_size, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, causal: bool, sm_scale: float, -) -> torch.Tensor: - qo_len = q.shape[0] // batch_size - kv_len = k.shape[0] // batch_size - num_qo_heads = q.shape[1] - head_dim_qk = q.shape[2] - head_dim_vo = v.shape[2] +) -> paddle.Tensor: + qo_len = tuple(q.shape)[0] // batch_size + kv_len = tuple(k.shape)[0] // batch_size + num_qo_heads = tuple(q.shape)[1] + head_dim_qk = tuple(q.shape)[2] + head_dim_vo = tuple(v.shape)[2] logits = ( - torch.einsum( + paddle.einsum( "bmhd,bnhd->bhmn", - q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), - k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), + q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).astype( + dtype="float32" + ), + k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).astype( + dtype="float32" + ), ) * sm_scale ) - if causal: - mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( - 1 - ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) + mask = paddle.arange(start=kv_len - qo_len, end=kv_len).unsqueeze( + axis=1 + ) >= paddle.arange(start=0, end=kv_len).unsqueeze(axis=0) else: - mask = torch.ones(qo_len, kv_len, device=q.device) - - logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) - lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2) - p = torch.softmax(logits, dim=-1) + mask = paddle.ones(shape=[qo_len, kv_len]) + logits = logits.masked_fill( + mask=mask.unsqueeze(axis=0).unsqueeze(axis=0) == 0, value=float("-inf") + ) + lse_ref = paddle.logsumexp(x=logits, axis=-1).transpose( + perm=dim2perm(paddle.logsumexp(x=logits, axis=-1).ndim, -1, -2) + ) + p = paddle.nn.functional.softmax(x=logits, axis=-1) o_ref = ( - torch.einsum( + paddle.einsum( "bhmn,bnhd->bmhd", p, - v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), + v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).astype( + dtype="float32" + ), ) .contiguous() .view(batch_size * qo_len, num_qo_heads, head_dim_vo) .to(q) ) - return o_ref, lse_ref * math.log2(math.e) @@ -149,33 +144,29 @@ def attention_ref( @pytest.mark.parametrize("num_heads", [4, 32, 128]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("backend", ["fa2", "fa3"]) -@pytest.mark.parametrize("dtype", [torch.half]) +@pytest.mark.parametrize("dtype", ["float16"]) def test_single_prefill_with_kv_cache( - kv_len, - qo_len, - num_heads, - causal, - backend, - dtype, + kv_len, qo_len, num_heads, causal, backend, dtype ): - device = torch.device("cuda:0") + device = device2str("cuda:0") clear_cuda_cache(device) if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") - torch.manual_seed(42) + paddle.seed(seed=42) head_dim_qk = 192 head_dim_vo = 128 - q = torch.randn(qo_len, num_heads, head_dim_qk, dtype=dtype, device=device) - k = torch.randn(kv_len, num_heads, head_dim_qk, dtype=dtype, device=device) - v = torch.randn(kv_len, num_heads, head_dim_vo, dtype=dtype, device=device) + q = paddle.randn(shape=[qo_len, num_heads, head_dim_qk], dtype=dtype) + k = paddle.randn(shape=[kv_len, num_heads, head_dim_qk], dtype=dtype) + v = paddle.randn(shape=[kv_len, num_heads, head_dim_vo], dtype=dtype) o, lse = flashinfer.single_prefill_with_kv_cache( q, k, v, causal=causal, backend=backend, return_lse=True ) - sm_scale = 1.0 / (head_dim_qk**0.5) - + sm_scale = 1.0 / head_dim_qk**0.5 o_ref, lse_ref = attention_ref(1, q, k, v, causal, sm_scale) - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(lse, lse_ref.squeeze(0), rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose( + x=lse, y=lse_ref.squeeze(axis=0), rtol=0.001, atol=0.001 + ).item(), "" @pytest.mark.parametrize("batch_size", [12, 17]) @@ -184,42 +175,24 @@ def test_single_prefill_with_kv_cache( @pytest.mark.parametrize("num_heads", [4, 32, 128]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("backend", ["fa2", "fa3"]) -@pytest.mark.parametrize("dtype", [torch.half]) +@pytest.mark.parametrize("dtype", ["float16"]) def test_batch_prefill_with_ragged_kv_cache( - batch_size, - kv_len, - qo_len, - num_heads, - causal, - backend, - dtype, + batch_size, kv_len, qo_len, num_heads, causal, backend, dtype ): - device = torch.device("cuda:0") + device = device2str("cuda:0") clear_cuda_cache(device) if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") - torch.manual_seed(42) + paddle.seed(seed=42) kv_layout = "NHD" head_dim_qk = 192 head_dim_vo = 128 - q = torch.randn( - batch_size * qo_len, num_heads, head_dim_qk, dtype=dtype, device=device - ) - q_indptr = ( - torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * qo_len - ) - - k = torch.zeros( - batch_size * kv_len, num_heads, head_dim_qk, dtype=dtype, device=device - ) - v = torch.zeros( - batch_size * kv_len, num_heads, head_dim_vo, dtype=dtype, device=device - ) - kv_indptr = ( - torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * kv_len - ) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + q = paddle.randn(shape=[batch_size * qo_len, num_heads, head_dim_qk], dtype=dtype) + q_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * qo_len + k = paddle.zeros(shape=[batch_size * kv_len, num_heads, head_dim_qk], dtype=dtype) + v = paddle.zeros(shape=[batch_size * kv_len, num_heads, head_dim_vo], dtype=dtype) + kv_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * kv_len + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, kv_layout, backend=backend ) @@ -233,37 +206,32 @@ def test_batch_prefill_with_ragged_kv_cache( causal=causal, ) o, lse = wrapper.run_return_lse(q, k, v) - - sm_scale = 1.0 / (head_dim_qk**0.5) + sm_scale = 1.0 / head_dim_qk**0.5 o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale) - - lse_ref = lse_ref.flatten(0, 1) - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) - - # test with pre-allocated output - o_buffer = torch.empty_like(o) - lse_buffer = torch.empty_like(lse) + lse_ref = lse_ref.flatten(start_axis=0, stop_axis=1) + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=lse, y=lse_ref, rtol=0.001, atol=0.001).item(), "" + o_buffer = paddle.empty_like(x=o) + lse_buffer = paddle.empty_like(x=lse) wrapper.run(q, k, v, out=o_buffer, lse=lse_buffer) - torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(lse, lse_buffer, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o, y=o_buffer, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=lse, y=lse_buffer, rtol=0.001, atol=0.001).item(), "" def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads): - bs_page_num, page_size, ckv_dim = ckv.shape + bs_page_num, page_size, ckv_dim = tuple(ckv.shape) page_num = bs_page_num // batch_size - _, _, kpe_dim = kpe.shape + _, _, kpe_dim = tuple(kpe.shape) ckv = ckv.view(batch_size, page_num * page_size, ckv_dim) kpe = kpe.view(batch_size, page_num * page_size, kpe_dim) ckv = ckv[:, :kv_len, :] kpe = kpe[:, :kv_len, :] k = ( - torch.cat([ckv, kpe], dim=-1) + paddle.concat(x=[ckv, kpe], axis=-1) .view(-1, 1, ckv_dim + kpe_dim) - .repeat_interleave(num_heads, dim=1) + .repeat_interleave(repeats=num_heads, axis=1) ) - v = ckv.repeat_interleave(num_heads, dim=1) - + v = ckv.repeat_interleave(repeats=num_heads, axis=1) return k, v @@ -276,7 +244,7 @@ def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads): @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("page_size", [1]) @pytest.mark.parametrize("backend", ["fa2", "fa3"]) -@pytest.mark.parametrize("dtype", [torch.half]) +@pytest.mark.parametrize("dtype", ["float16"]) def test_batch_mla_varlen_page_attention( batch_size, kv_len_0, @@ -289,75 +257,63 @@ def test_batch_mla_varlen_page_attention( backend, dtype, ): - device = torch.device("cuda:0") + device = device2str("cuda:0") clear_cuda_cache(device) if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") if causal and qo_len > min(kv_len_0, kv_len_1, kv_len_2): pytest.skip("qo_len > kv_len not supported for causal attention") num_different_kv_len = 3 - kv_lens = torch.tensor([kv_len_0, kv_len_1, kv_len_2], dtype=torch.int32) - torch.manual_seed(42) + kv_lens = paddle.to_tensor(data=[kv_len_0, kv_len_1, kv_len_2], dtype="int32") + paddle.seed(seed=42) head_dim_ckv = 512 head_dim_kpe = 64 - q_nope = torch.randn( - num_different_kv_len * batch_size * qo_len, - num_heads, - head_dim_ckv, + q_nope = paddle.randn( + shape=[num_different_kv_len * batch_size * qo_len, num_heads, head_dim_ckv], dtype=dtype, - device=device, ) - q_pe = torch.randn( - num_different_kv_len * batch_size * qo_len, - num_heads, - head_dim_kpe, + q_pe = paddle.randn( + shape=[num_different_kv_len * batch_size * qo_len, num_heads, head_dim_kpe], dtype=dtype, - device=device, ) - pages_nums = torch.tensor( - [math.ceil(kv_len / page_size) for kv_len in kv_lens], - dtype=torch.int32, + pages_nums = paddle.to_tensor( + data=[math.ceil(kv_len / page_size) for kv_len in kv_lens], dtype="int32" ) - pages_nums_indptr = torch.zeros(num_different_kv_len + 1, dtype=torch.int32) - pages_nums_indptr[1:] = pages_nums.cumsum(0) + pages_nums_indptr = paddle.zeros(shape=num_different_kv_len + 1, dtype="int32") + pages_nums_indptr[1:] = pages_nums.cumsum(axis=0) pages_nums_sum = pages_nums_indptr[-1] - ckv = torch.randn( - batch_size * pages_nums_sum, - page_size, - head_dim_ckv, - dtype=dtype, - device=device, + ckv = paddle.randn( + shape=[batch_size * pages_nums_sum, page_size, head_dim_ckv], dtype=dtype ) - kpe = torch.randn( - batch_size * pages_nums_sum, - page_size, - head_dim_kpe, - dtype=dtype, - device=device, + kpe = paddle.randn( + shape=[batch_size * pages_nums_sum, page_size, head_dim_kpe], dtype=dtype ) - sm_scale = 1.0 / ((128 + 64) ** 0.5) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + sm_scale = 1.0 / (128 + 64) ** 0.5 + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( workspace_buffer, backend=backend ) q_indptr = ( - torch.arange( - 0, num_different_kv_len * batch_size + 1, device=device, dtype=torch.int32 - ) + paddle.arange(start=0, end=num_different_kv_len * batch_size + 1, dtype="int32") * qo_len ) - kv_indptr = torch.cat( - [ - torch.arange(0, batch_size + 1).unsqueeze(-1).int() * pages_nums_sum - + pages_nums_indptr[i] + kv_indptr = paddle.concat( + x=[ + ( + paddle.arange(start=0, end=batch_size + 1) + .unsqueeze(axis=-1) + .astype(dtype="int32") + * pages_nums_sum + + pages_nums_indptr[i] + ) for i in range(num_different_kv_len) ], - dim=-1, + axis=-1, ).flatten() - kv_indices = torch.arange( - 0, batch_size * pages_nums_sum, device=device, dtype=torch.int32 + kv_indices = paddle.arange(start=0, end=batch_size * pages_nums_sum, dtype="int32") + kv_lens = paddle.to_tensor(data=kv_lens, dtype="int32", place=device).tile( + repeat_times=batch_size ) - kv_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device).repeat(batch_size) wrapper.plan( q_indptr, kv_indptr, @@ -373,15 +329,16 @@ def test_batch_mla_varlen_page_attention( ckv.dtype, ) o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) - q_rows = ( - torch.arange(0, num_different_kv_len * qo_len)[None, :] - + torch.arange(0, batch_size)[:, None] * num_different_kv_len * qo_len - ).int() + paddle.arange(start=0, end=num_different_kv_len * qo_len)[None, :] + + paddle.arange(start=0, end=batch_size)[:, None] + * num_different_kv_len + * qo_len + ).astype(dtype="int32") kv_rows = ( - torch.arange(0, pages_nums_sum)[None, :] - + torch.arange(0, batch_size)[:, None] * pages_nums_sum - ).int() + paddle.arange(start=0, end=pages_nums_sum)[None, :] + + paddle.arange(start=0, end=batch_size)[:, None] * pages_nums_sum + ).astype(dtype="int32") q_rows_arr = [ q_rows[:, i * qo_len : (i + 1) * qo_len].flatten() for i in range(num_different_kv_len) @@ -394,13 +351,11 @@ def test_batch_mla_varlen_page_attention( k, v = generate_kv_from_cache( ckv[kv_rows_arr[i]], kpe[kv_rows_arr[i]], kv_lens[i], batch_size, num_heads ) - q = torch.cat([q_nope, q_pe], dim=-1)[q_rows_arr[i]] + q = paddle.concat(x=[q_nope, q_pe], axis=-1)[q_rows_arr[i]] o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale) - lse_ref = lse_ref.flatten(0, 1) + lse_ref = lse_ref.flatten(start_axis=0, stop_axis=1) o_i = o[q_rows_arr[i]] - torch.testing.assert_close(o_i, o_ref, rtol=1e-3, atol=1e-3) - # if kv_lens[i] != 0: - # torch.testing.assert_close(lse_i, lse_ref, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o_i, y=o_ref, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7, 157]) @@ -410,55 +365,45 @@ def test_batch_mla_varlen_page_attention( @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("page_size", [16, 32]) @pytest.mark.parametrize("backend", ["fa2", "fa3"]) -@pytest.mark.parametrize("dtype", [torch.half]) +@pytest.mark.parametrize("dtype", ["float16"]) def test_batch_mla_oob_kv_nan( batch_size, kv_len, qo_len, num_heads, causal, page_size, backend, dtype ): - device = torch.device("cuda:0") + device = device2str("cuda:0") clear_cuda_cache(device) if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") if causal and qo_len > kv_len: pytest.skip("qo_len > kv_len not supported for causal attention") - torch.manual_seed(42) + paddle.seed(seed=42) head_dim_ckv = 512 head_dim_kpe = 64 - q_nope = torch.randn( - batch_size * qo_len, num_heads, head_dim_ckv, dtype=dtype, device=device + q_nope = paddle.randn( + shape=[batch_size * qo_len, num_heads, head_dim_ckv], dtype=dtype ) - q_pe = torch.randn( - batch_size * qo_len, num_heads, head_dim_kpe, dtype=dtype, device=device + q_pe = paddle.randn( + shape=[batch_size * qo_len, num_heads, head_dim_kpe], dtype=dtype ) pages_num = math.ceil(kv_len / page_size) - ckv = torch.randn( - batch_size * pages_num, page_size, head_dim_ckv, dtype=dtype, device=device + ckv = paddle.randn( + shape=[batch_size * pages_num, page_size, head_dim_ckv], dtype=dtype ) - kpe = torch.randn( - batch_size * pages_num, page_size, head_dim_kpe, dtype=dtype, device=device + kpe = paddle.randn( + shape=[batch_size * pages_num, page_size, head_dim_kpe], dtype=dtype ) - - # Fill oob positions with nan for i in range(batch_size): last_page_len = kv_len - (pages_num - 1) * page_size ckv[(i + 1) * pages_num - 1, last_page_len:, :] = float("nan") kpe[(i + 1) * pages_num - 1, last_page_len:, :] = float("nan") - - sm_scale = 1.0 / ((128 + 64) ** 0.5) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + sm_scale = 1.0 / (128 + 64) ** 0.5 + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( workspace_buffer, backend=backend ) - q_indptr = ( - torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * qo_len - ) - kv_indptr = ( - torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * pages_num - ) - kv_indices = torch.arange( - 0, batch_size * pages_num, device=device, dtype=torch.int32 - ) - kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32, device=device) - + q_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * qo_len + kv_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * pages_num + kv_indices = paddle.arange(start=0, end=batch_size * pages_num, dtype="int32") + kv_lens = paddle.full(shape=(batch_size,), fill_value=kv_len, dtype="int32") wrapper.plan( q_indptr, kv_indptr, @@ -474,15 +419,13 @@ def test_batch_mla_oob_kv_nan( ckv.dtype, ) o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) - k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads) - - q = torch.cat([q_nope, q_pe], dim=-1) + q = paddle.concat(x=[q_nope, q_pe], axis=-1) o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale) - lse_ref = lse_ref.flatten(0, 1) - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + lse_ref = lse_ref.flatten(start_axis=0, stop_axis=1) + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" if kv_len != 0: - torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=lse, y=lse_ref, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [1, 3, 5, 7, 157]) @@ -493,7 +436,7 @@ def test_batch_mla_oob_kv_nan( @pytest.mark.parametrize("page_size", [1, 16]) @pytest.mark.parametrize("backend", ["fa2", "fa3"]) @pytest.mark.parametrize("use_cuda_graph", [False]) -@pytest.mark.parametrize("dtype", [torch.half]) +@pytest.mark.parametrize("dtype", ["float16"]) def test_batch_mla_page_attention( batch_size, kv_len, @@ -505,64 +448,47 @@ def test_batch_mla_page_attention( use_cuda_graph, dtype, ): - device = torch.device("cuda:0") + device = device2str("cuda:0") clear_cuda_cache(device) if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") if causal and qo_len > kv_len: pytest.skip("qo_len > kv_len not supported for causal attention") - torch.manual_seed(42) + paddle.seed(seed=42) head_dim_ckv = 512 head_dim_kpe = 64 - q_nope = torch.randn( - batch_size * qo_len, num_heads, head_dim_ckv, dtype=dtype, device=device + q_nope = paddle.randn( + shape=[batch_size * qo_len, num_heads, head_dim_ckv], dtype=dtype ) - q_pe = torch.randn( - batch_size * qo_len, num_heads, head_dim_kpe, dtype=dtype, device=device + q_pe = paddle.randn( + shape=[batch_size * qo_len, num_heads, head_dim_kpe], dtype=dtype ) pages_num = math.ceil(kv_len / page_size) - ckv = torch.randn( - batch_size * pages_num, - page_size, - head_dim_ckv, - dtype=dtype, - device=device, + ckv = paddle.randn( + shape=[batch_size * pages_num, page_size, head_dim_ckv], dtype=dtype ) - kpe = torch.randn( - batch_size * pages_num, - page_size, - head_dim_kpe, - dtype=dtype, - device=device, + kpe = paddle.randn( + shape=[batch_size * pages_num, page_size, head_dim_kpe], dtype=dtype ) - sm_scale = 1.0 / ((128 + 64) ** 0.5) # use head dimension before matrix absorption - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + sm_scale = 1.0 / (128 + 64) ** 0.5 + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( workspace_buffer, backend=backend, use_cuda_graph=True, - qo_indptr=torch.empty(batch_size + 1, dtype=torch.int32, device=device), - kv_indptr=torch.empty(batch_size + 1, dtype=torch.int32, device=device), - kv_indices=torch.empty(1048576, dtype=torch.int32, device=device), - kv_len_arr=torch.empty(batch_size, dtype=torch.int32, device=device), - ) - q_indptr = ( - torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * qo_len - ) - kv_indptr = ( - torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * pages_num - ) - kv_indices = torch.arange( - 0, batch_size * pages_num, device=device, dtype=torch.int32 - ) - kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32, device=device) - + qo_indptr=paddle.empty(shape=batch_size + 1, dtype="int32"), + kv_indptr=paddle.empty(shape=batch_size + 1, dtype="int32"), + kv_indices=paddle.empty(shape=[1048576], dtype="int32"), + kv_len_arr=paddle.empty(shape=batch_size, dtype="int32"), + ) + q_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * qo_len + kv_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * pages_num + kv_indices = paddle.arange(start=0, end=batch_size * pages_num, dtype="int32") + kv_lens = paddle.full(shape=(batch_size,), fill_value=kv_len, dtype="int32") if use_cuda_graph: - kv_indptr_warmup = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) - kv_indices_warmup = torch.arange( - 0, batch_size, device=device, dtype=torch.int32 - ) - kv_lens_warmup = torch.full((batch_size,), 0, dtype=torch.int32, device=device) + kv_indptr_warmup = paddle.zeros(shape=batch_size + 1, dtype="int32") + kv_indices_warmup = paddle.arange(start=0, end=batch_size, dtype="int32") + kv_lens_warmup = paddle.full(shape=(batch_size,), fill_value=0, dtype="int32") wrapper.plan( q_indptr, kv_indptr_warmup, @@ -577,20 +503,15 @@ def test_batch_mla_page_attention( q_nope.dtype, ckv.dtype, ) - - # warmup - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): + s = paddle.device.Stream() + s.wait_stream(paddle.device.current_stream()) + with paddle.device.stream_guard(stream=s): for _ in range(3): o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) - torch.cuda.current_stream().wait_stream(s) - - # capture - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): + paddle.device.current_stream().wait_stream(s) +>>>>>> g = torch.cuda.CUDAGraph() +>>>>>> with torch.cuda.graph(g): o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) - wrapper.plan( q_indptr, kv_indptr, @@ -606,97 +527,72 @@ def test_batch_mla_page_attention( ckv.dtype, ) if use_cuda_graph: - o.fill_(0) - lse.fill_(0) + o.fill_(value=0) + lse.fill_(value=0) g.replay() else: o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) - k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads) - - q = torch.cat([q_nope, q_pe], dim=-1) + q = paddle.concat(x=[q_nope, q_pe], axis=-1) o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale) - lse_ref = lse_ref.flatten(0, 1) - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + lse_ref = lse_ref.flatten(start_axis=0, stop_axis=1) + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" if kv_len != 0: - torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) - - # test with pre-allocated output - o_buffer = torch.empty_like(o) - lse_buffer = torch.empty_like(lse) + assert paddle.allclose(x=lse, y=lse_ref, rtol=0.001, atol=0.001).item(), "" + o_buffer = paddle.empty_like(x=o) + lse_buffer = paddle.empty_like(x=lse) wrapper.run(q_nope, q_pe, ckv, kpe, out=o_buffer, lse=lse_buffer) - torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(lse, lse_buffer, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o, y=o_buffer, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=lse, y=lse_buffer, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [1, 2, 4]) @pytest.mark.parametrize("max_seq_len", [128, 1024, 4096]) @pytest.mark.parametrize("page_size", [1, 16, 128]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half]) +@pytest.mark.parametrize("dtype", ["bfloat16", "float16"]) def test_cutlass_mla(batch_size, max_seq_len, page_size, dtype): - device = torch.device("cuda:0") + device = device2str("cuda:0") clear_cuda_cache(device) if not is_sm100a_supported(device): pytest.skip("Cutlass MLA is not supported on this device") - - torch.manual_seed(42) - + paddle.seed(seed=42) num_local_heads = 128 head_dim_ckv = 512 head_dim_kpe = 64 total_page_num = 8192 - - # NOTE(Zihao): use larger scale to detect bugs such as - # https://github.com/flashinfer-ai/flashinfer/pull/1055 q_nope_pe = ( - torch.randn( - batch_size, - num_local_heads, - head_dim_ckv + head_dim_kpe, + paddle.randn( + shape=[batch_size, num_local_heads, head_dim_ckv + head_dim_kpe], dtype=dtype, - device=device, ) * 100 ) - ckv_kpe = torch.randn( - total_page_num, - page_size, - head_dim_ckv + head_dim_kpe, - dtype=dtype, - device=device, + ckv_kpe = paddle.randn( + shape=[total_page_num, page_size, head_dim_ckv + head_dim_kpe], dtype=dtype ) - kv_lens = torch.full((batch_size,), max_seq_len, dtype=torch.int32, device=device) + kv_lens = paddle.full(shape=(batch_size,), fill_value=max_seq_len, dtype="int32") page_num_per_batch = (max_seq_len + page_size - 1) // page_size - # Cutlass MLA requires small pages (< 128) are packed into a 128 page. assert page_num_per_batch % (128 // page_size) == 0 - page_table = torch.randint( - 0, - total_page_num, - (batch_size, page_num_per_batch), - dtype=torch.int32, - device=device, + page_table = paddle.randint( + low=0, + high=total_page_num, + shape=(batch_size, page_num_per_batch), + dtype="int32", ) - mla_ref = flashinfer.mla.BatchMLAPagedAttentionWrapper( - torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device), backend="fa2" + paddle.empty(shape=128 * 1024 * 1024, dtype="int8"), backend="fa2" ) - - # for decode, each query length is 1 - q_indptr = torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) - kv_lens = torch.full((batch_size,), max_seq_len, dtype=torch.int32, device=device) + q_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") + kv_lens = paddle.full(shape=(batch_size,), fill_value=max_seq_len, dtype="int32") kv_indptr = ( - torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) - * page_num_per_batch + paddle.arange(start=0, end=batch_size + 1, dtype="int32") * page_num_per_batch ) kv_indices = page_table.flatten() - q_nope = q_nope_pe[..., :head_dim_ckv] q_pe = q_nope_pe[..., head_dim_ckv:] ckv = ckv_kpe[..., :head_dim_ckv] kpe = ckv_kpe[..., head_dim_ckv:] - - # use head dimension before matrix absorption - sm_scale = 1.0 / ((128 + 64) ** 0.5) + sm_scale = 1.0 / (128 + 64) ** 0.5 mla_ref.plan( q_indptr, kv_indptr, @@ -706,27 +602,20 @@ def test_cutlass_mla(batch_size, max_seq_len, page_size, dtype): head_dim_ckv, head_dim_kpe, page_size, - False, # causal + False, sm_scale, q_nope.dtype, ckv.dtype, ) - o_ref = mla_ref.run(q_nope, q_pe, ckv, kpe, return_lse=False) - mla_ans = flashinfer.mla.BatchMLAPagedAttentionWrapper( - torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device), - backend="cutlass", + paddle.empty(shape=128 * 1024 * 1024, dtype="int8"), backend="cutlass" ) o_ans = mla_ans.run(q_nope, q_pe, ckv, kpe, kv_len=kv_lens, page_table=page_table) - torch.testing.assert_close(o_ans, o_ref, rtol=1e-2, atol=1e-2) + assert paddle.allclose(x=o_ans, y=o_ref, rtol=0.01, atol=0.01).item(), "" if __name__ == "__main__": test_batch_mla_varlen_page_attention( - 1, 65, 65, 65, 1, 128, True, 64, "fa2", torch.half + 1, 65, 65, 65, 1, 128, True, 64, "fa2", "float16" ) - # test_batch_mla_varlen_page_attention( - # 155, 1024, 8, 128, 128, 16, False, 1, "fa3", torch.half - # ) - # test_batch_mla_page_attention(1, 1024, 128, 128, False, 1, "fa2", True, torch.half) diff --git a/tests/test_fp4_quantize.py b/tests/test_fp4_quantize.py index 77bc6470d8..f95e890540 100644 --- a/tests/test_fp4_quantize.py +++ b/tests/test_fp4_quantize.py @@ -1,36 +1,32 @@ +import sys + + import functools +import paddle import pytest -import torch +from flashinfer.paddle_utils import * from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_fp4_quant -from flashinfer import ( - block_scale_interleave, - e2m1_and_ufp8sf_scale_to_float, - fp4_quantize, - mxfp4_quantize, - mxfp4_dequantize, -) +from flashinfer import (block_scale_interleave, e2m1_and_ufp8sf_scale_to_float, + fp4_quantize, mxfp4_dequantize, mxfp4_quantize) from flashinfer.utils import is_sm100a_supported -DTYPES = [torch.float16, torch.bfloat16] -# The batch dimension doesn't need to be multiple of 128 +DTYPES = ["float16", "bfloat16"] SHAPES = [(128, 64), (256, 128), (120, 64), (200, 256)] SEEDS = [42] CUDA_DEVICES = ["cuda:0"] - FLOAT4_E2M1_MAX = 6.0 -FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max - +FLOAT8_E4M3_MAX = paddle.finfo(dtype=paddle.float8_e4m3fn).max BLOCK_SIZE = 16 def swizzle_sf( - unswizzled_sf: torch.Tensor, + unswizzled_sf: paddle.Tensor, original_row: int, original_col: int, scaling_vector_size: int = 16, -) -> torch.Tensor: +) -> paddle.Tensor: """ Inverse of `unswizzle_sf`. Converts an unswizzled tensor back to swizzled form. @@ -45,44 +41,35 @@ def swizzle_sf( """ unswizzled_sf = unswizzled_sf.contiguous() factor = scaling_vector_size * 4 - padded_row = ((original_row + 128 - 1) // 128) * 128 # Next multiple of 128 - padded_col = ((original_col + factor - 1) // factor) * factor # Next multiple of 64 - - # Pad the input tensor to [padded_row, padded_col // scaling_vector_size] + padded_row = (original_row + 128 - 1) // 128 * 128 + padded_col = (original_col + factor - 1) // factor * factor pad_rows = padded_row - original_row pad_cols = (padded_col - original_col) // scaling_vector_size - padded_sf = torch.nn.functional.pad( - unswizzled_sf, - (0, pad_cols, 0, pad_rows), + padded_sf = paddle.nn.functional.pad( + x=unswizzled_sf, + pad=(0, pad_cols, 0, pad_rows), mode="constant", value=0, + pad_from_left_axis=False, ).contiguous() - - # Reshape and transpose to reverse unswizzle_sf num_m_tiles = padded_row // 128 num_k_tiles = padded_col // factor - sf_reshaped = padded_sf.view(num_m_tiles, 4, 32, num_k_tiles, 4) # Reverse reshape - sf_swizzled = sf_reshaped.transpose( - 1, 3 - ) # Reverse transpose [num_m_tiles, num_k_tiles, 32, 4, 4] - sf_swizzled = sf_swizzled.reshape( - padded_row, padded_col // scaling_vector_size - ) # Flatten to [128, 64] - + sf_reshaped = padded_sf.view(num_m_tiles, 4, 32, num_k_tiles, 4) + sf_swizzled = sf_reshaped.transpose(perm=dim2perm(sf_reshaped.ndim, 1, 3)) + sf_swizzled = sf_swizzled.reshape(padded_row, padded_col // scaling_vector_size) return sf_swizzled.contiguous() def unswizzle_sf( - sf: torch.Tensor, row: int, col: int, scaling_vector_size: int = 16 -) -> torch.Tensor: + sf: paddle.Tensor, row: int, col: int, scaling_vector_size: int = 16 +) -> paddle.Tensor: factor = scaling_vector_size * 4 num_m_tiles = (row + 128 - 1) // 128 num_k_tiles = (col + factor - 1) // factor - # SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)] sf_reshaped = sf.view(num_m_tiles, num_k_tiles, 32, 4, 4) - sf_unswizzle = sf_reshaped.transpose(1, 3) + sf_unswizzle = sf_reshaped.transpose(perm=dim2perm(sf_reshaped.ndim, 1, 3)) sf_unswizzle = sf_unswizzle.reshape(num_m_tiles * 32 * 4, num_k_tiles * 4) - sf_unswizzle_sliced = sf_unswizzle[:row, : (col // scaling_vector_size)] + sf_unswizzle_sliced = sf_unswizzle[:row, : col // scaling_vector_size] return sf_unswizzle_sliced.contiguous() @@ -92,25 +79,25 @@ def unswizzle_sf( @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("sf_use_ue8m0", [False, True]) @pytest.mark.parametrize("is_swizzled", [False, True]) -@torch.inference_mode() +@paddle.no_grad() def test_fp4_quantization( - dtype: torch.dtype, + dtype: paddle.dtype, shape: tuple[int, int], seed: int, device: str, sf_use_ue8m0: bool, is_swizzled: bool, ) -> None: - if not is_sm100a_supported(torch.device(device)): + if not is_sm100a_supported(device2str(device)): pytest.skip("Nvfp4 Requires compute capability of 10 or above") - torch.set_default_device(device) - torch.manual_seed(seed) + paddle.device.set_device(device=device2str(device)) + paddle.seed(seed=seed) m, n = shape sf_vec_size = 32 if sf_use_ue8m0 else 16 - x = torch.randn((m, n), dtype=dtype) - tensor_amax = torch.abs(x).max().to(torch.float32) + x = paddle.randn(shape=(m, n), dtype=dtype) + tensor_amax = paddle.abs(x=x)._max().to("float32") if sf_use_ue8m0: - global_scale = torch.tensor(1.0, dtype=torch.float32) + global_scale = paddle.to_tensor(data=1.0, dtype="float32") else: global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax out_ref, scale_ref = ref_fp4_quant(x, global_scale, sf_vec_size, sf_use_ue8m0) @@ -119,56 +106,42 @@ def test_fp4_quantization( ) assert n % sf_vec_size == 0, f"cols needs to be {sf_vec_size} divisible" if sf_use_ue8m0: - out_scale = (out_scale.to(torch.int32) << 23).view(torch.float32) + out_scale = (out_scale.to("int32") << 23).view("float32") else: - out_scale = out_scale.view(torch.float8_e4m3fn).to(torch.float32) + out_scale = out_scale.view(paddle.float8_e4m3fn).to("float32") if is_swizzled: scale_ans = recover_swizzled_scales( - out_scale.reshape(-1, n // sf_vec_size), - m, - n, - sf_vec_size, + out_scale.reshape(-1, n // sf_vec_size), m, n, sf_vec_size ) else: scale_ans = out_scale out_ans = cast_from_fp4(out).reshape(m, n) - torch.testing.assert_close(out_ans, out_ref, rtol=1e0, atol=1e-1) - torch.testing.assert_close(scale_ans, scale_ref, rtol=1e-1, atol=1e-1) + assert paddle.allclose(x=out_ans, y=out_ref, rtol=1.0, atol=0.1).item(), "" + assert paddle.allclose(x=scale_ans, y=scale_ref, rtol=0.1, atol=0.1).item(), "" @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() +@paddle.no_grad() def test_scale_swizzling( - dtype: torch.dtype, - shape: tuple[int, int], - seed: int, - device: str, + dtype: paddle.dtype, shape: tuple[int, int], seed: int, device: str ) -> None: - if not is_sm100a_supported(torch.device("cuda")): + if not is_sm100a_supported(device2str("cuda")): pytest.skip("Nvfp4 Requires compute capability of 10 or above") - torch.set_default_device(device) - torch.manual_seed(seed) + paddle.device.set_device(device=device2str(device)) + paddle.seed(seed=seed) m, n = shape - x = torch.randn((m, n), dtype=dtype) - tensor_amax = torch.abs(x).max().to(torch.float32) + x = paddle.randn(shape=(m, n), dtype=dtype) + tensor_amax = paddle.abs(x=x)._max().to("float32") global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax - _, unswizzled_scale = fp4_quantize(x, global_scale, BLOCK_SIZE, False, False) _, swizzled_scale = fp4_quantize(x, global_scale, BLOCK_SIZE, False, True) assert n % BLOCK_SIZE == 0, f"cols needs to be {BLOCK_SIZE} divisible" - recovered_unswizzled_scale = unswizzle_sf( - swizzle_sf(unswizzled_scale, m, n), - m, - n, - ) - - # We don't expect the following since padding: - # swizzle_sf(unswizzled_scale) == swizzled_scale + recovered_unswizzled_scale = unswizzle_sf(swizzle_sf(unswizzled_scale, m, n), m, n) ref_unswizzled_scale = unswizzle_sf(swizzled_scale, m, n) - assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) + assert_equal = functools.partial(paddle.allclose, rtol=0, atol=0) assert_equal(recovered_unswizzled_scale, unswizzled_scale) assert_equal(ref_unswizzled_scale, unswizzled_scale) @@ -176,47 +149,30 @@ def test_scale_swizzling( @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_block_scale_interleave( - shape: tuple[int, int], - seed: int, - device: str, -) -> None: +@paddle.no_grad() +def test_block_scale_interleave(shape: tuple[int, int], seed: int, device: str) -> None: """Test the block_scale_interleave function directly.""" - if not is_sm100a_supported(torch.device("cuda")): + if not is_sm100a_supported(device2str("cuda")): pytest.skip("Nvfp4 Requires compute capability of 10 or above") - torch.set_default_device(device) - torch.manual_seed(seed) - + paddle.device.set_device(device=device2str(device)) + paddle.seed(seed=seed) m, n = shape sf_vec_size = BLOCK_SIZE - - # Create a test scale factors tensor with uint8 dtype - # The shape should be [m, n // sf_vec_size] for scale factors - scale_shape = (m, n // sf_vec_size) - unswizzled_sf = torch.randint(0, 256, scale_shape, dtype=torch.uint8, device=device) - - # Test the swizzling function + scale_shape = m, n // sf_vec_size + unswizzled_sf = paddle.randint(low=0, high=256, shape=scale_shape, dtype="uint8") swizzled_sf = block_scale_interleave(unswizzled_sf) - - # Compare against the reference implementation ref_swizzled_sf = swizzle_sf(unswizzled_sf, m, n, sf_vec_size) - - # Basic checks - assert swizzled_sf.dtype == torch.uint8, f"Expected uint8, got {swizzled_sf.dtype}" - assert swizzled_sf.device == unswizzled_sf.device, "Device mismatch" - - # Check that the output has the expected padded shape + assert swizzled_sf.dtype == "uint8", f"Expected uint8, got {swizzled_sf.dtype}" + assert swizzled_sf.place == unswizzled_sf.place, "Device mismatch" factor = sf_vec_size * 4 - padded_row = ((m + 128 - 1) // 128) * 128 # Next multiple of 128 - padded_col = ((n + factor - 1) // factor) * factor # Next multiple of 64 - expected_shape = (padded_row, padded_col // sf_vec_size) + padded_row = (m + 128 - 1) // 128 * 128 + padded_col = (n + factor - 1) // factor * factor + expected_shape = padded_row, padded_col // sf_vec_size expected_size = expected_shape[0] * expected_shape[1] - - assert expected_size == swizzled_sf.shape[0], ( - f"Expected size {expected_size}, got {swizzled_sf.shape[0]}" - ) - assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) + assert ( + expected_size == tuple(swizzled_sf.shape)[0] + ), f"Expected size {expected_size}, got {tuple(swizzled_sf.shape)[0]}" + assert_equal = functools.partial(paddle.allclose, rtol=0, atol=0) assert_equal(swizzled_sf.reshape(expected_shape), ref_swizzled_sf) @@ -224,37 +180,24 @@ def test_block_scale_interleave( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("sf_use_ue8m0", [True, False]) -@torch.inference_mode() +@paddle.no_grad() def test_e2m1_dequantization( - shape: tuple[int, int], - seed: int, - device: str, - sf_use_ue8m0: bool, + shape: tuple[int, int], seed: int, device: str, sf_use_ue8m0: bool ) -> None: """Test roundtrip: fp4_quantize -> e2m1_and_ufp8sf_scale_to_float.""" - if not is_sm100a_supported(torch.device("cuda")): + if not is_sm100a_supported(device2str("cuda")): pytest.skip("Nvfp4 Requires compute capability of 10 or above") - torch.set_default_device(device) - torch.manual_seed(seed) - - # Create a reasonable test tensor + paddle.device.set_device(device=device2str(device)) + paddle.seed(seed=seed) m, n = shape - x = torch.randn((m, n), dtype=torch.float16) - - # Calculate global scale as in the other tests - tensor_amax = torch.abs(x).max().to(torch.float32) + x = paddle.randn(shape=(m, n), dtype="float16") + tensor_amax = paddle.abs(x=x)._max().to("float32") global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax - - # Test with default common settings is_sf_swizzled_layout = True block_size = 32 if sf_use_ue8m0 else 16 - - # Step 1: Quantize with fp4_quantize quantized_tensor, scale_factors = fp4_quantize( x, global_scale, block_size, sf_use_ue8m0, is_sf_swizzled_layout ) - - # Step 2: Dequantize with e2m1_and_ufp8sf_scale_to_float ufp8_type = 0 if sf_use_ue8m0 else 1 dequantized_tensor = e2m1_and_ufp8sf_scale_to_float( quantized_tensor, @@ -264,53 +207,35 @@ def test_e2m1_dequantization( ufp8_type=ufp8_type, is_sf_swizzled_layout=is_sf_swizzled_layout, ) - - # Move back to device for comparison dequantized_tensor = dequantized_tensor.to(device) - x_float32 = x.to(torch.float32) - - # Step 3: Compare results - assert dequantized_tensor.shape == x.shape, ( - f"Shape mismatch: expected {x.shape}, got {dequantized_tensor.shape}" - ) - assert dequantized_tensor.dtype == torch.float32, ( - f"Expected float32, got {dequantized_tensor.dtype}" - ) - - # Check for invalid values - assert not torch.isnan(dequantized_tensor).any(), ( - "Dequantized tensor contains NaN values" - ) - assert not torch.isinf(dequantized_tensor).any(), ( - "Dequantized tensor contains Inf values" - ) - - # Compare with original - should be reasonably close since FP4 is designed to preserve important values - torch.testing.assert_close( - dequantized_tensor, - x_float32, - rtol=0.3, - atol=0.5, # Reasonable tolerance for FP4 quantization - msg="Quantize -> dequantize roundtrip failed", - ) + x_float32 = x.to("float32") + assert tuple(dequantized_tensor.shape) == tuple( + x.shape + ), f"Shape mismatch: expected {tuple(x.shape)}, got {tuple(dequantized_tensor.shape)}" + assert ( + dequantized_tensor.dtype == "float32" + ), f"Expected float32, got {dequantized_tensor.dtype}" + assert ( + not paddle.isnan(x=dequantized_tensor).astype("bool").any() + ), "Dequantized tensor contains NaN values" + assert ( + not paddle.isinf(x=dequantized_tensor).astype("bool").any() + ), "Dequantized tensor contains Inf values" + assert paddle.allclose( + x=dequantized_tensor, y=x_float32, rtol=0.3, atol=0.5 + ).item(), "Quantize -> dequantize roundtrip failed" @pytest.mark.parametrize("device", CUDA_DEVICES) def test_mxfp4_quantize_roundtrip(device: str): - if not is_sm100a_supported(torch.device(device)): + if not is_sm100a_supported(device2str(device)): pytest.skip("Nvfp4 Requires compute capability of 10 or above") - x = torch.randn((128, 64), device="cuda", dtype=torch.bfloat16) / 10 - + x = paddle.randn(shape=(128, 64), dtype="bfloat16") / 10 quant_a, sfs = mxfp4_quantize(x) dq_a = mxfp4_dequantize(quant_a, sfs) - - torch.testing.assert_close( - dq_a.cpu().to(torch.float32), - x.cpu().to(torch.float32), - rtol=0.3, - atol=0.5, - msg="Quantize -> dequantize mxfp4 roundtrip failed", - ) + assert paddle.allclose( + x=dq_a.cpu().to("float32"), y=x.cpu().to("float32"), rtol=0.3, atol=0.5 + ).item(), "Quantize -> dequantize mxfp4 roundtrip failed" if __name__ == "__main__": diff --git a/tests/test_fp4_tensor_torch_cute.py b/tests/test_fp4_tensor_torch_cute.py index ba53480ae9..468b53f849 100644 --- a/tests/test_fp4_tensor_torch_cute.py +++ b/tests/test_fp4_tensor_torch_cute.py @@ -1,8 +1,12 @@ -import pytest +import sys + + import cutlass import cutlass.cute as cute -import torch +import paddle +import pytest from cutlass.cute.runtime import make_ptr +from flashinfer.paddle_utils import * from flashinfer.cute_dsl.utils import is_cute_dsl_available @@ -25,16 +29,12 @@ def copy_torch_fp4_tensor(a_ptr: cute.Pointer, b_ptr: cute.Pointer): def test_fp4_tensor_torch_cute(): if not is_cute_dsl_available(): pytest.skip("cute-dsl is not available") - - a = torch.randint( - 0, 128, size=(3, 4), dtype=torch.uint8, device=torch.device("cuda:0") - ) - b = torch.zeros_like(a) - a_view = a.view(torch.float4_e2m1fn_x2) - b_view = b.view(torch.float4_e2m1fn_x2) + a = paddle.randint(low=0, high=128, shape=(3, 4), dtype="uint8") + b = paddle.zeros_like(x=a) +>>>>>> a_view = a.view(torch.float4_e2m1fn_x2) +>>>>>> b_view = b.view(torch.float4_e2m1fn_x2) print(f"a_view: \n{a_view}") print("") - a_ptr = make_ptr( cutlass.Float4E2M1FN, a_view.data_ptr(), @@ -48,6 +48,6 @@ def test_fp4_tensor_torch_cute(): assumed_align=16, ) copy_torch_fp4_tensor(a_ptr, b_ptr) - torch.testing.assert_close(a, b) + assert paddle.allclose(x=a, y=b).item(), "" print("Results verified successfully!") print(f"Result: \n{b_view}") diff --git a/tests/test_fp8_prefill.py b/tests/test_fp8_prefill.py index 414173f452..0d46ecea20 100644 --- a/tests/test_fp8_prefill.py +++ b/tests/test_fp8_prefill.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,9 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch import flashinfer @@ -28,7 +28,7 @@ @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("dtype", [paddle.float8_e4m3fn, paddle.float8_e5m2]) def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( batch_size, qo_len, @@ -40,31 +40,37 @@ def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( kv_layout, dtype, ): - torch.manual_seed(42) - q = torch.randn( - batch_size * qo_len, num_qo_heads, head_dim, dtype=torch.float16 + paddle.seed(seed=42) + q = paddle.randn( + shape=[batch_size * qo_len, num_qo_heads, head_dim], dtype="float16" ).to(0) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( 0.05 - * torch.randn( - total_num_pages, 2, num_kv_heads, page_size, head_dim, dtype=torch.float16 + * paddle.randn( + shape=[total_num_pages, 2, num_kv_heads, page_size, head_dim], + dtype="float16", ).to(0) if kv_layout == "HND" else 0.05 - * torch.randn( - total_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16 + * paddle.randn( + shape=[total_num_pages, 2, page_size, num_kv_heads, head_dim], + dtype="float16", ).to(0) ) - qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len - kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq - kv_indices = torch.arange(0, total_num_pages).to(0).int() - kv_last_page_len = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + qo_indptr = ( + paddle.arange(start=0, end=batch_size + 1).to(0).astype(dtype="int32") * qo_len + ) + kv_indptr = ( + paddle.arange(start=0, end=batch_size + 1).to(0).astype(dtype="int32") + * num_pages_per_seq + ) + kv_indices = paddle.arange(start=0, end=total_num_pages).to(0).astype(dtype="int32") + kv_last_page_len = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" ).to(0) - - workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + workspace_buffer = paddle.empty(shape=32 * 1024 * 1024, dtype="int8").to(0) wrapper_f16 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) @@ -77,18 +83,16 @@ def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( num_kv_heads, head_dim, page_size, - q_data_type=torch.float16, - kv_data_type=torch.float16, + q_data_type="float16", + kv_data_type="float16", ) o_fp16 = wrapper_f16.run(q, kv_data) - k_data, v_data = torch.chunk(kv_data, 2, dim=1) + k_data, v_data = paddle.chunk(x=kv_data, chunks=2, axis=1) k_scale = k_data.amax().item() / 256 v_scale = v_data.amax().item() / 256 - k_fp8 = (k_data / k_scale).to(dtype) v_fp8 = (v_data / v_scale).to(dtype) - kv_data_fp8 = torch.cat([k_fp8, v_fp8], dim=1) - + kv_data_fp8 = paddle.concat(x=[k_fp8, v_fp8], axis=1) wrapper_f8 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) @@ -101,17 +105,11 @@ def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( num_kv_heads, head_dim, page_size, - q_data_type=torch.float16, + q_data_type="float16", kv_data_type=dtype, ) - o_fp8 = wrapper_f8.run( - q, - kv_data_fp8.to(dtype), - k_scale=k_scale, - v_scale=v_scale, - ) - - torch.testing.assert_close(o_fp16, o_fp8, atol=1e-2, rtol=2e-1) + o_fp8 = wrapper_f8.run(q, kv_data_fp8.to(dtype), k_scale=k_scale, v_scale=v_scale) + assert paddle.allclose(x=o_fp16, y=o_fp8, atol=0.01, rtol=0.2).item(), "" @pytest.mark.parametrize("batch_size", [12, 17]) @@ -121,7 +119,7 @@ def test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("dtype", [paddle.float8_e4m3fn, paddle.float8_e5m2]) def test_batch_decode_with_prefill_with_paged_kv_cache( batch_size, kv_len, @@ -132,29 +130,33 @@ def test_batch_decode_with_prefill_with_paged_kv_cache( kv_layout, dtype, ): - torch.manual_seed(42) - q = torch.randn(batch_size, num_qo_heads, head_dim, dtype=torch.float16).to(0) + paddle.seed(seed=42) + q = paddle.randn(shape=[batch_size, num_qo_heads, head_dim], dtype="float16").to(0) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( 0.1 - * torch.randn( - total_num_pages, 2, num_kv_heads, page_size, head_dim, dtype=torch.float16 + * paddle.randn( + shape=[total_num_pages, 2, num_kv_heads, page_size, head_dim], + dtype="float16", ).to(0) if kv_layout == "HND" else 0.1 - * torch.randn( - total_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16 + * paddle.randn( + shape=[total_num_pages, 2, page_size, num_kv_heads, head_dim], + dtype="float16", ).to(0) ).to(dtype) - qo_indptr = torch.arange(0, batch_size + 1).to(0).int() - kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq - kv_indices = torch.arange(0, total_num_pages).to(0).int() - kv_last_page_len = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + qo_indptr = paddle.arange(start=0, end=batch_size + 1).to(0).astype(dtype="int32") + kv_indptr = ( + paddle.arange(start=0, end=batch_size + 1).to(0).astype(dtype="int32") + * num_pages_per_seq + ) + kv_indices = paddle.arange(start=0, end=total_num_pages).to(0).astype(dtype="int32") + kv_last_page_len = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" ).to(0) - - workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + workspace_buffer = paddle.empty(shape=32 * 1024 * 1024, dtype="int8").to(0) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) @@ -167,11 +169,10 @@ def test_batch_decode_with_prefill_with_paged_kv_cache( num_kv_heads, head_dim, page_size, - q_data_type=torch.float16, + q_data_type="float16", kv_data_type=dtype, ) o_fp8 = wrapper.run(q, kv_data) - decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) @@ -183,18 +184,17 @@ def test_batch_decode_with_prefill_with_paged_kv_cache( num_kv_heads, head_dim, page_size, - q_data_type=torch.float16, + q_data_type="float16", kv_data_type=dtype, ) o_decode_fp8 = decode_wrapper.run(q, kv_data) - - torch.testing.assert_close(o_decode_fp8, o_fp8, atol=1e-2, rtol=1e-2) + assert paddle.allclose(x=o_decode_fp8, y=o_fp8, atol=0.01, rtol=0.01).item(), "" if __name__ == "__main__": test_batch_prefill_with_paged_kv_cache_fp8_calibration_scale( - 12, 7, 54, 1, 4, 4, 128, "NHD", torch.float8_e5m2 +>>>>>> 12, 7, 54, 1, 4, 4, 128, "NHD", paddle.float8_e5m2 ) test_batch_decode_with_prefill_with_paged_kv_cache( - 12, 54, 1, 4, 4, 128, "NHD", torch.float8_e5m2 +>>>>>> 12, 54, 1, 4, 4, 128, "NHD", paddle.float8_e5m2 ) diff --git a/tests/test_fp8_quantize.py b/tests/test_fp8_quantize.py index 50352eacc1..5000a9fca8 100644 --- a/tests/test_fp8_quantize.py +++ b/tests/test_fp8_quantize.py @@ -1,47 +1,43 @@ +import sys + + +import paddle import pytest -import torch +from flashinfer.paddle_utils import * from flashinfer import mxfp8_dequantize_host, mxfp8_quantize @pytest.mark.parametrize("m", [1, 1024]) @pytest.mark.parametrize("k", [1024]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) @pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_mxfp8_quantize_torch(m, k, dtype, is_sf_swizzled_layout, device): - a = 16 * torch.randn([m, k], dtype=dtype).to(device).contiguous() - + a = 16 * paddle.randn(shape=[m, k], dtype=dtype).to(device).contiguous() if device == "cpu": - a = a.float() - + a = a.astype(dtype="float32") a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout) - if device == "cuda": a_fp8 = a_fp8.cpu() a_sf = a_sf.cpu() - a_pt = mxfp8_dequantize_host( - a_fp8.view(torch.uint8), - a_sf.view(torch.uint8).reshape(-1), - is_sf_swizzled_layout, + a_fp8.view("uint8"), a_sf.view("uint8").reshape(-1), is_sf_swizzled_layout ) - if device == "cuda": a_pt = a_pt.cuda() - - torch.cuda.synchronize() + paddle.device.synchronize() def check_accuracy(a, b, atol, rtol, percent): - if torch.any(torch.isnan(a)): + if paddle.any(x=paddle.isnan(x=a)): raise Exception("NaN in a") - if torch.any(torch.isnan(b)): + if paddle.any(x=paddle.isnan(x=b)): raise Exception("NaN in b") - assert a.shape == b.shape - left = torch.abs(a - b) - right = atol + rtol * torch.abs(b) - count = torch.sum(left > right) - mismatch_percent = count / a.numel() + assert tuple(a.shape) == tuple(b.shape) + left = paddle.abs(x=a - b) + right = atol + rtol * paddle.abs(x=b) + count = paddle.sum(x=left > right) + mismatch_percent = count / a.size if mismatch_percent > 1 - percent: raise Exception( "Mismatch percentage is %f for rtol %f" % (mismatch_percent, rtol) @@ -51,15 +47,15 @@ def check_accuracy(a, b, atol, rtol, percent): def mxfp8_quantize_check_accuracy(a, b, atol, rtol, percent): - if torch.any(torch.isnan(a)): + if paddle.any(x=paddle.isnan(x=a)): raise Exception("NaN in a") - if torch.any(torch.isnan(b)): + if paddle.any(x=paddle.isnan(x=b)): raise Exception("NaN in b") - assert a.shape == b.shape - left = torch.abs(a - b) - right = atol + rtol * torch.abs(b) - count = torch.sum(left > right) - mismatch_percent = count / a.numel() + assert tuple(a.shape) == tuple(b.shape) + left = paddle.abs(x=a - b) + right = atol + rtol * paddle.abs(x=b) + count = paddle.sum(x=left > right) + mismatch_percent = count / a.size if mismatch_percent > 1 - percent: raise Exception( "Mismatch percentage is %f for rtol %f" % (mismatch_percent, rtol) @@ -68,75 +64,57 @@ def mxfp8_quantize_check_accuracy(a, b, atol, rtol, percent): @pytest.mark.parametrize("m", [1, 2, 16, 1024]) @pytest.mark.parametrize("k", [512, 1024]) -@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) def test_mxfp8_quantize_torch_host(m, k, dtype, is_sf_swizzled_layout): - torch.random.manual_seed(0) - a = (torch.randn([m, k], dtype=torch.float) * 16).cpu().contiguous() - + paddle.seed(seed=0) + a = (paddle.randn(shape=[m, k], dtype="float32") * 16).cpu().contiguous() a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout) - a_pt = mxfp8_dequantize_host( - a_fp8.view(torch.uint8), a_sf.view(torch.uint8), is_sf_swizzled_layout + a_fp8.view("uint8"), a_sf.view("uint8"), is_sf_swizzled_layout ) - - torch.cuda.synchronize() - + paddle.device.synchronize() mxfp8_quantize_check_accuracy(a_pt, a, 8, 0, 0.999) @pytest.mark.parametrize("m", [1, 2, 16, 1024]) @pytest.mark.parametrize("k", [512, 1024]) -@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) def test_mxfp8_quantize_torch_device(m, k, dtype, is_sf_swizzled_layout): - torch.random.manual_seed(0) - a = (torch.randn([m, k], dtype=torch.float) * 16).to(dtype).cuda().contiguous() - + paddle.seed(seed=0) + a = (paddle.randn(shape=[m, k], dtype="float32") * 16).to(dtype).cuda().contiguous() a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, 32) a_pt = mxfp8_dequantize_host( - a_fp8.cpu().view(torch.uint8), - a_sf.cpu().view(torch.uint8), - is_sf_swizzled_layout, + a_fp8.cpu().view("uint8"), a_sf.cpu().view("uint8"), is_sf_swizzled_layout ) - - torch.cuda.synchronize() + paddle.device.synchronize() mxfp8_quantize_check_accuracy( - a_pt.cpu().to(torch.float32), a.cpu().to(torch.float32), 8, 0, 0.999 + a_pt.cpu().to("float32"), a.cpu().to("float32"), 8, 0, 0.999 ) @pytest.mark.parametrize("m", [1, 2, 16, 1024]) @pytest.mark.parametrize("k", [1568]) -@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) @pytest.mark.parametrize("alignment", [64, 128]) def test_mxfp8_quantize_alignment_torch_device( m, k, dtype, is_sf_swizzled_layout, alignment ): - torch.random.manual_seed(0) - a = (torch.randn([m, k], dtype=torch.float) * 16).to(dtype).cuda().contiguous() - padded_k = ((k + alignment - 1) // alignment) * alignment - - # Quantize it on device. + paddle.seed(seed=0) + a = (paddle.randn(shape=[m, k], dtype="float32") * 16).to(dtype).cuda().contiguous() + padded_k = (k + alignment - 1) // alignment * alignment a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, alignment) - assert a_fp8.shape[1] == padded_k - - # Dequantize it on host. + assert tuple(a_fp8.shape)[1] == padded_k a_pt = mxfp8_dequantize_host( - a_fp8.cpu().view(torch.uint8), - a_sf.cpu().view(torch.uint8), - is_sf_swizzled_layout, + a_fp8.cpu().view("uint8"), a_sf.cpu().view("uint8"), is_sf_swizzled_layout ) - - # Check if the bits of paddings are zero. - paddings = a_fp8.view(torch.int8)[:, k:] - assert torch.all(paddings == 0), "Paddings should be zero" - - torch.cuda.synchronize() - + paddings = a_fp8.view("int8")[:, k:] + assert paddle.all(x=paddings == 0), "Paddings should be zero" + paddle.device.synchronize() mxfp8_quantize_check_accuracy( - a_pt[:, :k].cpu().to(torch.float32), a.cpu().to(torch.float32), 8, 0, 0.999 + a_pt[:, :k].cpu().to("float32"), a.cpu().to("float32"), 8, 0, 0.999 ) diff --git a/tests/test_green_ctx.py b/tests/test_green_ctx.py index 4863dd5c51..42bcc65f3a 100644 --- a/tests/test_green_ctx.py +++ b/tests/test_green_ctx.py @@ -1,5 +1,9 @@ +import sys + + +import paddle import pytest -import torch +from flashinfer.paddle_utils import * import flashinfer.green_ctx as green_ctx @@ -7,15 +11,10 @@ @pytest.mark.parametrize("device", ["cuda:0"]) @pytest.mark.parametrize("num_groups", [1, 2, 3]) @pytest.mark.parametrize("min_count", [16, 32]) -def test_green_ctx_creation( - device: str, - num_groups: int, - min_count: int, -): +def test_green_ctx_creation(device: str, num_groups: int, min_count: int): streams, resources = green_ctx.split_device_green_ctx( - torch.device(device), num_groups, min_count + device2str(device), num_groups, min_count ) - assert len(resources) == num_groups + 1 for resource in resources[:-1]: sm_count = resource.sm.smCount @@ -25,104 +24,65 @@ def test_green_ctx_creation( @pytest.mark.parametrize("device", ["cuda:0"]) @pytest.mark.parametrize("num_groups", [1, 2, 3]) @pytest.mark.parametrize("min_count", [16, 32]) -def test_green_ctx_kernel_execution( - device: str, - num_groups: int, - min_count: int, -): +def test_green_ctx_kernel_execution(device: str, num_groups: int, min_count: int): streams, resources = green_ctx.split_device_green_ctx( - torch.device(device), num_groups, min_count + device2str(device), num_groups, min_count ) num_partitions = num_groups + 1 assert len(streams) == num_partitions assert len(resources) == num_partitions - for stream in streams: - with torch.cuda.stream(stream): - x = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) - y = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16) + with paddle.device.stream_guard(stream=stream): + x = paddle.randn(shape=[8192, 8192], dtype="bfloat16") + y = paddle.randn(shape=[8192, 8192], dtype="bfloat16") z = x @ y - print(z.shape) + print(tuple(z.shape)) @pytest.mark.parametrize("device", ["cuda:0"]) -@pytest.mark.parametrize( - "sm_counts", - [ - [16, 16, 16], - [8, 16, 24], - [32], - [8, 8, 8, 8], - ], -) -def test_split_device_green_ctx_by_sm_count_creation( - device: str, - sm_counts: list, -): +@pytest.mark.parametrize("sm_counts", [[16, 16, 16], [8, 16, 24], [32], [8, 8, 8, 8]]) +def test_split_device_green_ctx_by_sm_count_creation(device: str, sm_counts: list): streams, resources = green_ctx.split_device_green_ctx_by_sm_count( - torch.device(device), sm_counts + device2str(device), sm_counts ) num_partitions = len(sm_counts) + 1 assert len(resources) == num_partitions assert len(streams) == num_partitions - - # Check that each partition has the expected SM count for i, expected_sm_count in enumerate(sm_counts): actual_sm_count = resources[i].sm.smCount assert actual_sm_count >= expected_sm_count @pytest.mark.parametrize("device", ["cuda:0"]) -@pytest.mark.parametrize( - "sm_counts", - [ - [16, 16, 16], - [8, 16, 24], - [32], - ], -) +@pytest.mark.parametrize("sm_counts", [[16, 16, 16], [8, 16, 24], [32]]) def test_split_device_green_ctx_by_sm_count_kernel_execution( - device: str, - sm_counts: list, + device: str, sm_counts: list ): streams, resources = green_ctx.split_device_green_ctx_by_sm_count( - torch.device(device), sm_counts + device2str(device), sm_counts ) num_partitions = len(sm_counts) + 1 assert len(streams) == num_partitions assert len(resources) == num_partitions - for i, stream in enumerate(streams): - with torch.cuda.stream(stream): - x = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) - y = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) + with paddle.device.stream_guard(stream=stream): + x = paddle.randn(shape=[4096, 4096], dtype="bfloat16") + y = paddle.randn(shape=[4096, 4096], dtype="bfloat16") z = x @ y - print(f"Partition {i}: {z.shape}") + print(f"Partition {i}: {tuple(z.shape)}") @pytest.mark.parametrize("device", ["cuda:0"]) -@pytest.mark.parametrize( - "sm_counts", - [ - [1, 2, 3, 4], # Should be aligned to minimum requirements - [7, 8, 9, 10], # Should be aligned to 8 for compute capability 9+ - [15, 16, 17, 18], # Should be aligned to 8 for compute capability 9+ - ], -) -def test_split_device_green_ctx_by_sm_count_alignment( - device: str, - sm_counts: list, -): +@pytest.mark.parametrize("sm_counts", [[1, 2, 3, 4], [7, 8, 9, 10], [15, 16, 17, 18]]) +def test_split_device_green_ctx_by_sm_count_alignment(device: str, sm_counts: list): _, resources = green_ctx.split_device_green_ctx_by_sm_count( - torch.device(device), sm_counts + device2str(device), sm_counts ) - - for resource in resources[:-1]: # Exclude remaining SMs + for resource in resources[:-1]: sm_count = resource.sm.smCount assert sm_count > 0 - min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint( - *green_ctx.get_compute_capability(torch.device(device)) + *green_ctx.get_compute_capability(device2str(device)) ) assert sm_count >= min_sm_count assert sm_count % sm_alignment == 0 diff --git a/tests/test_group_gemm.py b/tests/test_group_gemm.py index 7e0a02c610..2ae09cafef 100644 --- a/tests/test_group_gemm.py +++ b/tests/test_group_gemm.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2024 by FlashInfer team. @@ -13,21 +19,19 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch import flashinfer from flashinfer.utils import determine_gemm_backend, is_sm90a_supported -DTYPES = [torch.float16] +DTYPES = ["float16"] CUDA_DEVICES = ["cuda:0"] @pytest.fixture(autouse=True, scope="module") def warmup_jit(): jit_specs = [flashinfer.gemm.gen_gemm_module()] - if is_sm90a_supported(torch.device("cuda:0")): + if is_sm90a_supported(device2str("cuda:0")): jit_specs.append(flashinfer.gemm.gen_gemm_sm90_module()) flashinfer.jit.build_jit_specs(jit_specs, verbose=False) yield @@ -53,71 +57,73 @@ def test_segment_gemm( device, backend, ): - torch.manual_seed(42) + paddle.seed(seed=42) if batch_size * num_rows_per_batch > 8192: pytest.skip("batch_size * num_rows_per_batch too large for test.") - latest_supported_backend = determine_gemm_backend(torch.device(device)) + latest_supported_backend = determine_gemm_backend(device2str(device)) if backend == "sm90" and latest_supported_backend == "sm80": pytest.skip("sm90 backend not supported on this device.") - torch.manual_seed(42) - workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device=device) + paddle.seed(seed=42) + workspace_buffer = paddle.empty(shape=32 * 1024 * 1024, dtype="int8") segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer, backend=backend) - x = torch.randn(batch_size * num_rows_per_batch, d_in, dtype=dtype, device=device) + x = paddle.randn(shape=[batch_size * num_rows_per_batch, d_in], dtype=dtype) if use_weight_indices: num_weights = 1024 if column_major: - weight = torch.randn(num_weights, d_out, d_in, dtype=dtype, device=device) + weight = paddle.randn(shape=[num_weights, d_out, d_in], dtype=dtype) else: - weight = torch.randn(num_weights, d_in, d_out, dtype=dtype, device=device) + weight = paddle.randn(shape=[num_weights, d_in, d_out], dtype=dtype) + elif column_major: + weight = paddle.randn(shape=[batch_size, d_out, d_in], dtype=dtype) else: - if column_major: - weight = torch.randn(batch_size, d_out, d_in, dtype=dtype, device=device) - else: - weight = torch.randn(batch_size, d_in, d_out, dtype=dtype, device=device) + weight = paddle.randn(shape=[batch_size, d_in, d_out], dtype=dtype) y = segment_gemm.run( x, weight, batch_size, weight_column_major=column_major, - seg_lens=torch.full((batch_size,), num_rows_per_batch, dtype=torch.int64), - weight_indices=( - (torch.arange(0, batch_size, device=device) % num_weights) - if use_weight_indices - else None + seg_lens=paddle.full( + shape=(batch_size,), fill_value=num_rows_per_batch, dtype="int64" ), + weight_indices=paddle.arange(start=0, end=batch_size) % num_weights + if use_weight_indices + else None, ) - if use_weight_indices: for i in range(batch_size): - torch.testing.assert_close( - y[i * num_rows_per_batch : (i + 1) * num_rows_per_batch], - torch.matmul( - x[i * num_rows_per_batch : (i + 1) * num_rows_per_batch].float(), - ( - weight[i % num_weights].float().T - if column_major - else weight[i % num_weights].float() + assert paddle.allclose( + x=y[i * num_rows_per_batch : (i + 1) * num_rows_per_batch], + y=paddle.matmul( + x=x[i * num_rows_per_batch : (i + 1) * num_rows_per_batch].astype( + dtype="float32" ), + y=weight[i % num_weights].astype(dtype="float32").T + if column_major + else weight[i % num_weights].astype(dtype="float32"), ).to(dtype), - rtol=1e-3, - atol=1e-3, - ) + rtol=0.001, + atol=0.001, + ).item(), "" else: - torch.testing.assert_close( - y, - torch.matmul( - x.view(batch_size, num_rows_per_batch, d_in).float(), - weight.float().transpose(-1, -2) if column_major else weight.float(), + assert paddle.allclose( + x=y, + y=paddle.matmul( + x=x.view(batch_size, num_rows_per_batch, d_in).astype(dtype="float32"), + y=weight.astype(dtype="float32").transpose( + perm=dim2perm(weight.astype(dtype="float32").ndim, -1, -2) + ) + if column_major + else weight.astype(dtype="float32"), ) .view(batch_size * num_rows_per_batch, d_out) .to(dtype), - rtol=1e-3, - atol=2e-3, - ) + rtol=0.001, + atol=0.002, + ).item(), "" if __name__ == "__main__": - test_segment_gemm(199, 17, 128, 1024, False, False, torch.float16, "cuda:0", "auto") - test_segment_gemm(199, 17, 128, 1024, False, True, torch.float16, "cuda:0", "auto") - test_segment_gemm(199, 17, 128, 1024, True, False, torch.float16, "cuda:0", "auto") - test_segment_gemm(199, 17, 128, 1024, True, True, torch.float16, "cuda:0", "auto") + test_segment_gemm(199, 17, 128, 1024, False, False, "float16", "cuda:0", "auto") + test_segment_gemm(199, 17, 128, 1024, False, True, "float16", "cuda:0", "auto") + test_segment_gemm(199, 17, 128, 1024, True, False, "float16", "cuda:0", "auto") + test_segment_gemm(199, 17, 128, 1024, True, True, "float16", "cuda:0", "auto") diff --git a/tests/test_groupwise_scaled_gemm_fp8.py b/tests/test_groupwise_scaled_gemm_fp8.py index 32ffa0c573..923c7a9b4f 100755 --- a/tests/test_groupwise_scaled_gemm_fp8.py +++ b/tests/test_groupwise_scaled_gemm_fp8.py @@ -1,3 +1,6 @@ +import einops +import paddle + """ Copyright (c) 2025 by FlashInfer team. @@ -13,20 +16,14 @@ See the License for the specific language governing permissions and limitations under the License. """ - import math import pytest -import torch -from einops import einsum -from flashinfer.gemm import ( - batch_deepgemm_fp8_nt_groupwise, - gemm_fp8_nt_blockscaled, - gemm_fp8_nt_groupwise, - group_deepgemm_fp8_nt_groupwise, - group_gemm_fp8_nt_groupwise, -) +from flashinfer.gemm import (batch_deepgemm_fp8_nt_groupwise, + gemm_fp8_nt_blockscaled, gemm_fp8_nt_groupwise, + group_deepgemm_fp8_nt_groupwise, + group_gemm_fp8_nt_groupwise) from flashinfer.testing.utils import dequantize_fp8, quantize_fp8 @@ -34,40 +31,29 @@ @pytest.mark.parametrize("n", [128, 256, 512, 4096, 8192]) @pytest.mark.parametrize("k", [128, 256, 512, 4096, 8192]) @pytest.mark.parametrize("scale_major_mode", ["MN", "K"]) -@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) -def test_fp8_blockscale_gemm( - m, - n, - k, - scale_major_mode, - out_dtype, -): - torch.random.manual_seed(0) +@pytest.mark.parametrize("out_dtype", ["bfloat16"]) +def test_fp8_blockscale_gemm(m, n, k, scale_major_mode, out_dtype): + paddle.seed(seed=0) tile_size = 128 - - a_val = torch.randn((m, k), dtype=torch.float, device="cuda") - b_val = torch.randn((n, k), dtype=torch.float, device="cuda") / math.sqrt(k) - + a_val = paddle.randn(shape=(m, k), dtype="float32") + b_val = paddle.randn(shape=(n, k), dtype="float32") / math.sqrt(k) if scale_major_mode == "K": - a_scale_shape = (m // tile_size, k // tile_size) - b_scale_shape = (n // tile_size, k // tile_size) + a_scale_shape = m // tile_size, k // tile_size + b_scale_shape = n // tile_size, k // tile_size else: - a_scale_shape = (k // tile_size, m // tile_size) - b_scale_shape = (k // tile_size, n // tile_size) - a_tile_shape = (tile_size, tile_size) - b_tile_shape = (tile_size, tile_size) - + a_scale_shape = k // tile_size, m // tile_size + b_scale_shape = k // tile_size, n // tile_size + a_tile_shape = tile_size, tile_size + b_tile_shape = tile_size, tile_size a_fp8, a_scale = quantize_fp8(a_val, a_scale_shape, a_tile_shape, scale_major_mode) b_fp8, b_scale = quantize_fp8(b_val, b_scale_shape, b_tile_shape, scale_major_mode) - a_dequant = dequantize_fp8(a_fp8, a_scale, scale_major_mode) b_dequant = dequantize_fp8(b_fp8, b_scale, scale_major_mode) - ref_c = einsum(a_dequant, b_dequant, "m k, n k -> m n").to(out_dtype) - + ref_c = einops.einsum(a_dequant, b_dequant, "m k, n k -> m n").to(out_dtype) c = gemm_fp8_nt_blockscaled( a_fp8, b_fp8, a_scale, b_scale, scale_major_mode, out_dtype=out_dtype ) - torch.testing.assert_close(c, ref_c, atol=1e-2, rtol=1e-2) + assert paddle.allclose(x=c, y=ref_c, atol=0.01, rtol=0.01).item(), "" @pytest.mark.parametrize("m", [128, 256, 512, 4096, 8192]) @@ -75,45 +61,32 @@ def test_fp8_blockscale_gemm( @pytest.mark.parametrize("k", [128, 256, 512, 4096, 8192]) @pytest.mark.parametrize("scale_major_mode", ["MN", "K"]) @pytest.mark.parametrize("backend", ["cutlass", "trtllm"]) -def test_fp8_groupwise_gemm( - m, - n, - k, - scale_major_mode, - backend, -): +def test_fp8_groupwise_gemm(m, n, k, scale_major_mode, backend): if backend == "trtllm": if scale_major_mode != "MN": pytest.skip("trtllm only supports MN scale_major_mode") if k < 256: pytest.skip("k < 256") - - torch.random.manual_seed(0) + paddle.seed(seed=0) tile_size = 128 - out_dtype = torch.bfloat16 - - a_val = torch.randn((m, k), dtype=torch.float, device="cuda") - b_val = torch.randn((n, k), dtype=torch.float, device="cuda") / math.sqrt(k) - + out_dtype = "bfloat16" + a_val = paddle.randn(shape=(m, k), dtype="float32") + b_val = paddle.randn(shape=(n, k), dtype="float32") / math.sqrt(k) if scale_major_mode == "K": - a_scale_shape = (m, k // tile_size) - b_scale_shape = (n // tile_size, k // tile_size) + a_scale_shape = m, k // tile_size + b_scale_shape = n // tile_size, k // tile_size else: - a_scale_shape = (k // tile_size, m) - b_scale_shape = (k // tile_size, n // tile_size) - a_tile_shape = (1, tile_size) - b_tile_shape = (tile_size, tile_size) - + a_scale_shape = k // tile_size, m + b_scale_shape = k // tile_size, n // tile_size + a_tile_shape = 1, tile_size + b_tile_shape = tile_size, tile_size a_fp8, a_scale = quantize_fp8(a_val, a_scale_shape, a_tile_shape, scale_major_mode) b_fp8, b_scale = quantize_fp8(b_val, b_scale_shape, b_tile_shape, scale_major_mode) - a_dequant = dequantize_fp8(a_fp8, a_scale, scale_major_mode) b_dequant = dequantize_fp8(b_fp8, b_scale, scale_major_mode) - ref_c = einsum(a_dequant, b_dequant, "m k, n k -> m n").to(out_dtype) - + ref_c = einops.einsum(a_dequant, b_dequant, "m k, n k -> m n").to(out_dtype) if backend == "trtllm": b_scale = b_scale.t().contiguous() - c = gemm_fp8_nt_groupwise( a_fp8, b_fp8, @@ -123,7 +96,7 @@ def test_fp8_groupwise_gemm( out_dtype=out_dtype, backend=backend, ) - torch.testing.assert_close(c, ref_c, atol=1e-2, rtol=1e-2) + assert paddle.allclose(x=c, y=ref_c, atol=0.01, rtol=0.01).item(), "" @pytest.mark.parametrize("m", [4, 128, 256, 512, 4096, 8192]) @@ -131,40 +104,25 @@ def test_fp8_groupwise_gemm( @pytest.mark.parametrize("k", [128, 256, 512, 4096, 8192]) @pytest.mark.parametrize("group_size", [1, 2, 4, 8]) @pytest.mark.parametrize("scale_major_mode", ["MN", "K"]) -@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) -def test_fp8_groupwise_group_gemm( - m, - n, - k, - group_size, - scale_major_mode, - out_dtype, -): - torch.random.manual_seed(0) +@pytest.mark.parametrize("out_dtype", ["bfloat16"]) +def test_fp8_groupwise_group_gemm(m, n, k, group_size, scale_major_mode, out_dtype): + paddle.seed(seed=0) tile_size = 128 - - a_val = torch.randn((group_size * m, k), dtype=torch.float, device="cuda") - b_val = torch.randn( - (group_size, n, k), dtype=torch.float, device="cuda" - ) / math.sqrt(k) - + a_val = paddle.randn(shape=(group_size * m, k), dtype="float32") + b_val = paddle.randn(shape=(group_size, n, k), dtype="float32") / math.sqrt(k) if scale_major_mode == "K": - a_scale_shape = (group_size * m, k // tile_size) - b_scale_shape = (group_size, n // tile_size, k // tile_size) + a_scale_shape = group_size * m, k // tile_size + b_scale_shape = group_size, n // tile_size, k // tile_size else: - a_scale_shape = (k // tile_size, m * group_size) - b_scale_shape = (group_size, k // tile_size, n // tile_size) - a_tile_shape = (1, tile_size) - b_tile_shape = (1, tile_size, tile_size) - + a_scale_shape = k // tile_size, m * group_size + b_scale_shape = group_size, k // tile_size, n // tile_size + a_tile_shape = 1, tile_size + b_tile_shape = 1, tile_size, tile_size a_fp8, a_scale = quantize_fp8(a_val, a_scale_shape, a_tile_shape, scale_major_mode) b_fp8, b_scale = quantize_fp8(b_val, b_scale_shape, b_tile_shape, scale_major_mode) - a_dequant = dequantize_fp8(a_fp8, a_scale, scale_major_mode) b_dequant = dequantize_fp8(b_fp8, b_scale, scale_major_mode) - - m_indptr = torch.arange(0, group_size + 1, dtype=torch.int32, device="cuda") * m - + m_indptr = paddle.arange(start=0, end=group_size + 1, dtype="int32") * m out = group_gemm_fp8_nt_groupwise( a_fp8, b_fp8, @@ -175,105 +133,75 @@ def test_fp8_groupwise_group_gemm( out_dtype=out_dtype, ) ref_c = ( - einsum( - a_dequant.view((group_size, m, k)), - b_dequant, - "b m k, b n k -> b m n", + einops.einsum( + a_dequant.view((group_size, m, k)), b_dequant, "b m k, b n k -> b m n" ) .view((group_size * m, n)) .to(out_dtype) ) - torch.testing.assert_close(out, ref_c, atol=1e-2, rtol=1e-2) + assert paddle.allclose(x=out, y=ref_c, atol=0.01, rtol=0.01).item(), "" @pytest.mark.parametrize("m", [128, 256, 512, 1024]) @pytest.mark.parametrize("nk", [(128, 512), (512, 128), (4096, 7168), (7168, 2048)]) @pytest.mark.parametrize("group_size", [1, 4, 8, 64, 128, 256]) -@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) -def test_fp8_groupwise_group_deepgemm( - m, - nk, - group_size, - out_dtype, -): - torch.random.manual_seed(0) +@pytest.mark.parametrize("out_dtype", ["bfloat16"]) +def test_fp8_groupwise_group_deepgemm(m, nk, group_size, out_dtype): + paddle.seed(seed=0) m_per_group = m // group_size if m_per_group < 128: return n, k = nk - a = torch.randn((m, k), device="cuda", dtype=torch.float32) - b = torch.randn((group_size, n, k), device="cuda", dtype=torch.float32) - m_indptr = torch.empty((m,), device="cuda", dtype=torch.int32) + a = paddle.randn(shape=(m, k), dtype="float32") + b = paddle.randn(shape=(group_size, n, k), dtype="float32") + m_indptr = paddle.empty(shape=(m,), dtype="int32") a_fp8, a_scale = quantize_fp8(a, (m, k // 128), (1, 128), "K") b_fp8, b_scale = quantize_fp8( b, (group_size, n // 128, k // 128), (1, 128, 128), "K" ) a_dequant = dequantize_fp8(a_fp8, a_scale, "K") b_dequant = dequantize_fp8(b_fp8, b_scale, "K") - - ref = torch.empty((m, n), device="cuda", dtype=out_dtype) - + ref = paddle.empty(shape=(m, n), dtype=out_dtype) for i in range(group_size): r = slice(i * m_per_group, (i + 1) * m_per_group) m_indptr[r] = i ref[r] = a_dequant[r] @ b_dequant[i].t() - out = group_deepgemm_fp8_nt_groupwise( - a_fp8, - b_fp8, - a_scale, - b_scale, - m_indptr, - out_dtype=out_dtype, + a_fp8, b_fp8, a_scale, b_scale, m_indptr, out_dtype=out_dtype ) - torch.testing.assert_close(out, ref, atol=3e-2, rtol=3e-2) + assert paddle.allclose(x=out, y=ref, atol=0.03, rtol=0.03).item(), "" @pytest.mark.parametrize("m", [128, 256, 512, 1024]) @pytest.mark.parametrize("nk", [(128, 512), (512, 128), (4096, 7168), (7168, 2048)]) @pytest.mark.parametrize("group_size", [1, 4, 8, 64, 128, 256]) -@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) -def test_fp8_groupwise_batch_deepgemm_masked( - m, - nk, - group_size, - out_dtype, -): - torch.random.manual_seed(0) +@pytest.mark.parametrize("out_dtype", ["bfloat16"]) +def test_fp8_groupwise_batch_deepgemm_masked(m, nk, group_size, out_dtype): + paddle.seed(seed=0) n, k = nk - a = torch.randn((group_size, m, k), device="cuda", dtype=torch.float32) - b = torch.randn((group_size, n, k), device="cuda", dtype=torch.float32) - masked_m = torch.randint(0, m, (group_size,), device="cuda", dtype=torch.int32) - + a = paddle.randn(shape=(group_size, m, k), dtype="float32") + b = paddle.randn(shape=(group_size, n, k), dtype="float32") + masked_m = paddle.randint(low=0, high=m, shape=(group_size,), dtype="int32") a_fp8, a_scale = quantize_fp8(a, (group_size, m, k // 128), (1, 1, 128), "K") b_fp8, b_scale = quantize_fp8( b, (group_size, n // 128, k // 128), (1, 128, 128), "K" ) - a_dequant = dequantize_fp8(a_fp8, a_scale, "K") b_dequant = dequantize_fp8(b_fp8, b_scale, "K") - ref = torch.einsum("bmk,bnk->bmn", a_dequant, b_dequant).to(out_dtype) - - expected_m = min(int(masked_m.float().mean()) + 1, m) - + ref = paddle.einsum("bmk,bnk->bmn", a_dequant, b_dequant).to(out_dtype) + expected_m = min(int(masked_m.astype(dtype="float32").mean()) + 1, m) out = batch_deepgemm_fp8_nt_groupwise( - a_fp8, - b_fp8, - a_scale, - b_scale, - masked_m, - expected_m, - out_dtype=out_dtype, + a_fp8, b_fp8, a_scale, b_scale, masked_m, expected_m, out_dtype=out_dtype ) for i in range(group_size): - torch.testing.assert_close( - out[i][: masked_m[i]], ref[i][: masked_m[i]], atol=3e-2, rtol=3e-2 - ) + assert paddle.allclose( + x=out[i][: masked_m[i]], y=ref[i][: masked_m[i]], atol=0.03, rtol=0.03 + ).item(), "" if __name__ == "__main__": - test_fp8_blockscale_gemm(8192, 8192, 8192, "MN", torch.bfloat16) - test_fp8_groupwise_gemm(8192, 8192, 8192, "K", torch.bfloat16) - test_fp8_groupwise_group_gemm(4, 128, 256, 2, "MN", torch.bfloat16) - test_fp8_groupwise_group_deepgemm(256, (128, 512), 4, torch.bfloat16) - test_fp8_groupwise_batch_deepgemm_masked(256, (128, 512), 8, torch.bfloat16) + test_fp8_blockscale_gemm(8192, 8192, 8192, "MN", "bfloat16") + test_fp8_groupwise_gemm(8192, 8192, 8192, "K", "bfloat16") + test_fp8_groupwise_group_gemm(4, 128, 256, 2, "MN", "bfloat16") + test_fp8_groupwise_group_deepgemm(256, (128, 512), 4, "bfloat16") + test_fp8_groupwise_batch_deepgemm_masked(256, (128, 512), 8, "bfloat16") diff --git a/tests/test_groupwise_scaled_gemm_mxfp4.py b/tests/test_groupwise_scaled_gemm_mxfp4.py index e34d7feea7..c45049c122 100644 --- a/tests/test_groupwise_scaled_gemm_mxfp4.py +++ b/tests/test_groupwise_scaled_gemm_mxfp4.py @@ -1,3 +1,10 @@ +import sys + + +import einops +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,19 +20,14 @@ See the License for the specific language governing permissions and limitations under the License. """ - import math from enum import Enum, auto from itertools import product import pytest -import torch -from einops import einsum, rearrange -from flashinfer.fp4_quantization import ( - _pad_scale_factors, - get_fp4_quantization_module, -) +from flashinfer.fp4_quantization import (_pad_scale_factors, + get_fp4_quantization_module) from flashinfer.gemm import group_gemm_mxfp4_nt_groupwise @@ -36,9 +38,9 @@ class QuantMode(Enum): def swizzle_blockscale( - unswizzled_sf: torch.Tensor, b: int, m: int, n: int, sf_vec_size: int = 32 -) -> torch.Tensor: - r"""Swizzle block scale tensor for MXFP4/MXFP8 format. + unswizzled_sf: paddle.Tensor, b: int, m: int, n: int, sf_vec_size: int = 32 +) -> paddle.Tensor: + """Swizzle block scale tensor for MXFP4/MXFP8 format. This function swizzles the block scale tensor to optimize memory access patterns for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128. @@ -53,27 +55,24 @@ def swizzle_blockscale( Returns: torch.Tensor: Swizzled tensor with the same shape as input. """ - assert unswizzled_sf.dtype == torch.uint8, ( - f"Input dtype must be uint8, got {unswizzled_sf.dtype}" - ) + assert ( + unswizzled_sf.dtype == "uint8" + ), f"Input dtype must be uint8, got {unswizzled_sf.dtype}" assert unswizzled_sf.ndim == 3, f"Input must be 3D, got {unswizzled_sf.ndim}" - assert unswizzled_sf.shape[0] == b, ( - f"Batch dimension must equal b, got {unswizzled_sf.shape[0]} != {b}" - ) + assert ( + tuple(unswizzled_sf.shape)[0] == b + ), f"Batch dimension must equal b, got {tuple(unswizzled_sf.shape)[0]} != {b}" padded_input_sf_chunked = [ _pad_scale_factors(unswizzled_sf[i], m, n, sf_vec_size) for i in range(b) ] - padded_input_sf = torch.stack(padded_input_sf_chunked) - out = get_fp4_quantization_module().nvfp4_block_scale_interleave_sm100( - padded_input_sf - ) - out = out.view(padded_input_sf.shape) + padded_input_sf = paddle.stack(x=padded_input_sf_chunked) + out = get_fp4_quantization_module().block_scale_interleave_sm100(padded_input_sf) + out = out.view(tuple(padded_input_sf.shape)) return out -# Vanilla implementation only for unit test def quantize_e2m1(x): - r""" + """ Quantizes a tensor to FP4. Args: @@ -82,27 +81,26 @@ def quantize_e2m1(x): Returns: torch.Tensor: The quantized tensor. """ - assert x.shape[-1] % 2 == 0 - x = x.clamp(-6, 6) - x_sign_bit = torch.lt(x, 0) - x_abs = torch.abs(x) - log_x_quant = torch.floor(torch.log2(x_abs)).clamp(0, 2) - x_quant_e_fp32 = torch.exp2(log_x_quant) + assert tuple(x.shape)[-1] % 2 == 0 + x = x.clip(min=-6, max=6) + x_sign_bit = paddle.less_than(x=x, y=paddle.to_tensor(0)) + x_abs = paddle.abs(x=x) + log_x_quant = paddle.floor(x=paddle.log2(x=x_abs)).clip(min=0, max=2) + x_quant_e_fp32 = 2.0**log_x_quant m_scale = 2 - x_quant_m_scaled_fp32 = torch.round(x_abs * m_scale / x_quant_e_fp32) - mask = torch.ge(x_quant_m_scaled_fp32, m_scale) + x_quant_m_scaled_fp32 = paddle.round(x_abs * m_scale / x_quant_e_fp32) + mask = paddle.greater_equal(x=x_quant_m_scaled_fp32, y=paddle.to_tensor(m_scale)) x_quant_data_raw_e = log_x_quant + mask x_quant_data_raw_m = x_quant_m_scaled_fp32 - mask * m_scale x_quant_data_raw = ( x_sign_bit * 8 + x_quant_data_raw_e * m_scale + x_quant_data_raw_m - ).to(torch.uint8) + ).to("uint8") x_quant_data = x_quant_data_raw[..., ::2] + x_quant_data_raw[..., 1::2] * 16 return x_quant_data -# Vanilla implementation only for unit test def dequantize_e2m1(x): - r""" + """ Dequantizes a tensor from FP4. Args: @@ -113,49 +111,51 @@ def dequantize_e2m1(x): """ x_quant_data_raw_1 = x % 16 x_quant_data_raw_2 = x // 16 - x_quant_data_raw = torch.stack( - [x_quant_data_raw_1, x_quant_data_raw_2], dim=-1 - ).flatten(start_dim=-2) + x_quant_data_raw = paddle.stack( + x=[x_quant_data_raw_1, x_quant_data_raw_2], axis=-1 + ).flatten(start_axis=-2) x_sign_bit = x_quant_data_raw // 8 x = x_quant_data_raw % 8 m_scale = 2 x_quant_data_raw_e = x // m_scale x_quant_data_raw_m = x % m_scale - mask = torch.gt(x_quant_data_raw_e, 0).to(torch.float32) + mask = paddle.greater_than(x=x_quant_data_raw_e, y=paddle.to_tensor(0)).to( + "float32" + ) log_x_quant = x_quant_data_raw_e - mask x_quant_m_scaled_fp32 = x_quant_data_raw_m + mask * m_scale - x_dequant_abs = x_quant_m_scaled_fp32 / m_scale * torch.exp2(log_x_quant) + x_dequant_abs = x_quant_m_scaled_fp32 / m_scale * 2.0**log_x_quant x_dequant = (0.5 - x_sign_bit) * 2 * x_dequant_abs return x_dequant def gemm_mxfp8_mxfp4_nt_groupwise_ref( - A, B, As, Bs, tile_size, n, k, output_dtype=torch.bfloat16 + A, B, As, Bs, tile_size, n, k, output_dtype="bfloat16" ): - r""" - A: (m, k), torch.float8_e4m3fn or torch.float8_e5m2 + """ + A: (m, k), paddle.float8_e4m3fn or paddle.float8_e5m2 B: (n // 2, k), e2m1 packed as torch.uint8 A_scale: (m, k // tile_size), ue8m0 saved as torch.uint8 B_scale: (n, k // tile_size), ue8m0 saved as torch.uint8 """ ue8m0_bias = 127 - A_f32 = A.to(torch.float32) + A_f32 = A.to("float32") B_f32 = dequantize_e2m1(B) - A_f32_reshape = rearrange(A_f32, "m (k b) -> m k b", b=tile_size) - A_f32_scale_reshape = A_f32_reshape * rearrange( - torch.exp2(As.to(torch.float32) - ue8m0_bias), "m k -> m k 1" + A_f32_reshape = einops.rearrange(A_f32, "m (k b) -> m k b", b=tile_size) + A_f32_scale_reshape = A_f32_reshape * einops.rearrange( + 2.0 ** (As.to("float32") - ue8m0_bias), "m k -> m k 1" ) - A_f32_scale = rearrange(A_f32_scale_reshape, "m k b -> m (k b)")[:, :k] - B_f32_reshape = rearrange(B_f32, "n (k b) -> n k b", b=tile_size) - B_f32_scale_reshape = B_f32_reshape * rearrange( - torch.exp2(Bs.to(torch.float32) - ue8m0_bias), "n k -> n k 1" + A_f32_scale = einops.rearrange(A_f32_scale_reshape, "m k b -> m (k b)")[:, :k] + B_f32_reshape = einops.rearrange(B_f32, "n (k b) -> n k b", b=tile_size) + B_f32_scale_reshape = B_f32_reshape * einops.rearrange( + 2.0 ** (Bs.to("float32") - ue8m0_bias), "n k -> n k 1" ) - B_f32_scale = rearrange(B_f32_scale_reshape, "n k b -> n (k b)")[:n, :k] - return einsum(A_f32_scale, B_f32_scale, "m k, n k -> m n").to(output_dtype) + B_f32_scale = einops.rearrange(B_f32_scale_reshape, "n k b -> n (k b)")[:n, :k] + return einops.einsum(A_f32_scale, B_f32_scale, "m k, n k -> m n").to(output_dtype) def quantize_tensor(x, tile_size, n_padded, k_padded, quant_mode): - r""" + """ Quantizes a tensor to MXFP4 or MXFP8. Args: @@ -169,72 +169,62 @@ def quantize_tensor(x, tile_size, n_padded, k_padded, quant_mode): tuple: A tuple containing the quantized tensor and the calculated scales. """ - # 1. Initial Setup ue8m0_bias = 127 if quant_mode == QuantMode.MXFP8_E4M3: - fp8_info = torch.finfo(torch.float8_e4m3fn) - quant_amax = torch.tensor(fp8_info.max, dtype=torch.float32, device=x.device) + fp8_info = paddle.finfo(dtype=paddle.float8_e4m3fn) + quant_amax = paddle.to_tensor(data=fp8_info.max, dtype="float32", place=x.place) elif quant_mode == QuantMode.MXFP8_E5M2: - fp8_info = torch.finfo(torch.float8_e5m2) - quant_amax = torch.tensor(fp8_info.max, dtype=torch.float32, device=x.device) +>>>>>> fp8_info = paddle.finfo(dtype=paddle.float8_e5m2) + quant_amax = paddle.to_tensor(data=fp8_info.max, dtype="float32", place=x.place) elif quant_mode == QuantMode.MXFP4: - quant_amax = torch.tensor(6, dtype=torch.float32, device=x.device) + quant_amax = paddle.to_tensor(data=6, dtype="float32", place=x.place) else: raise ValueError(f"Unsupported quantization mode: {quant_mode}") - if n_padded is not None and x.shape[-2] != n_padded: - x = torch.cat( - [ + if n_padded is not None and tuple(x.shape)[-2] != n_padded: + x = paddle.concat( + x=[ x, - torch.zeros( - (*x.shape[:-2], n_padded - x.shape[-2], x.shape[-1]), + paddle.zeros( + shape=( + *tuple(x.shape)[:-2], + n_padded - tuple(x.shape)[-2], + tuple(x.shape)[-1], + ), dtype=x.dtype, - device=x.device, ), ], - dim=-2, + axis=-2, ) - if x.shape[-1] != k_padded: - x = torch.cat( - [ + if tuple(x.shape)[-1] != k_padded: + x = paddle.concat( + x=[ x, - torch.zeros( - (*x.shape[:-1], k_padded - x.shape[-1]), + paddle.zeros( + shape=(*tuple(x.shape)[:-1], k_padded - tuple(x.shape)[-1]), dtype=x.dtype, - device=x.device, ), ], - dim=-1, + axis=-1, ) - - # 2. Tiling and Scale Calculation - x_tiled = x.unflatten(-1, (-1, tile_size)) + x_tiled = x.unflatten(axis=-1, shape=(-1, tile_size)) x_tiled_abs = x_tiled.abs() log2_x_scale = ( - torch.floor(torch.log2(x_tiled_abs.amax(dim=-1))) - - torch.floor(torch.log2(quant_amax)) - ).clamp(-ue8m0_bias, ue8m0_bias) - - # 3. Final Quantization - # Divide the original tensor by the broadcasted scales - x_tiled_quant = ( - torch.exp2(torch.log2(x_tiled_abs) - log2_x_scale[..., None]).clamp( - 0, quant_amax - ) - * x_tiled.sign() - ) - x_quant = x_tiled_quant.flatten(-2, -1) - - # Convert the result to the target format + paddle.floor(x=paddle.log2(x=x_tiled_abs.amax(axis=-1))) + - paddle.floor(x=paddle.log2(x=quant_amax)) + ).clip(min=-ue8m0_bias, max=ue8m0_bias) + x_tiled_quant = (2.0 ** paddle.log2(x=x_tiled_abs) - log2_x_scale[..., None]).clip( + min=0, max=quant_amax + ) * x_tiled.sign() + x_quant = x_tiled_quant.flatten(start_axis=-2, stop_axis=-1) if quant_mode == QuantMode.MXFP8_E4M3: - x_quant_data = x_quant.to(torch.float8_e4m3fn) + x_quant_data = x_quant.to(paddle.float8_e4m3fn) elif quant_mode == QuantMode.MXFP8_E5M2: - x_quant_data = x_quant.to(torch.float8_e5m2) +>>>>>> x_quant_data = x_quant.to(paddle.float8_e5m2) elif quant_mode == QuantMode.MXFP4: x_quant_data = quantize_e2m1(x_quant) else: raise ValueError(f"Unsupported quantization mode: {quant_mode}") - x_scale_data = (log2_x_scale + ue8m0_bias).to(torch.uint8) - + x_scale_data = (log2_x_scale + ue8m0_bias).to("uint8") return x_quant_data, x_scale_data @@ -242,31 +232,20 @@ def quantize_tensor(x, tile_size, n_padded, k_padded, quant_mode): @pytest.mark.parametrize("n", [128, 256, 512, 2879, 4096, 8192]) @pytest.mark.parametrize("k", [128, 256, 512, 2880, 4096, 8192]) @pytest.mark.parametrize("group_size", [1, 2, 4, 8]) -@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) -@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) -def test_mxfp8_mxfp4_groupwise_group_gemm( - m, - n, - k, - group_size, - fp8_dtype, - out_dtype, -): - torch.random.manual_seed(0) +@pytest.mark.parametrize("fp8_dtype", [paddle.float8_e4m3fn, paddle.float8_e5m2]) +@pytest.mark.parametrize("out_dtype", ["bfloat16", "float16"]) +def test_mxfp8_mxfp4_groupwise_group_gemm(m, n, k, group_size, fp8_dtype, out_dtype): + paddle.seed(seed=0) tile_size = 32 alignment_n = 8 alignment_k = 128 - - a_val = torch.randn((group_size * m, k), dtype=torch.float32, device="cuda") - b_val = torch.randn( - (group_size, n, k), dtype=torch.float32, device="cuda" - ) / math.sqrt(k) + a_val = paddle.randn(shape=(group_size * m, k), dtype="float32") + b_val = paddle.randn(shape=(group_size, n, k), dtype="float32") / math.sqrt(k) n_padded = (n + alignment_n - 1) // alignment_n * alignment_n k_padded = (k + alignment_k - 1) // alignment_k * alignment_k - - if fp8_dtype == torch.float8_e4m3fn: + if fp8_dtype == paddle.float8_e4m3fn: a_quant_mode = QuantMode.MXFP8_E4M3 - elif fp8_dtype == torch.float8_e5m2: +>>>>>> elif fp8_dtype == paddle.float8_e5m2: a_quant_mode = QuantMode.MXFP8_E5M2 else: raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}") @@ -274,19 +253,18 @@ def test_mxfp8_mxfp4_groupwise_group_gemm( b_fp4, b_scale = quantize_tensor( b_val, tile_size, n_padded, k_padded, QuantMode.MXFP4 ) - a_scale_swizzled = swizzle_blockscale( - a_scale.unflatten(0, (group_size, m)), group_size, m, k_padded, tile_size - ).flatten(0, 1) + a_scale.unflatten(axis=0, shape=(group_size, m)), + group_size, + m, + k_padded, + tile_size, + ).flatten(start_axis=0, stop_axis=1) b_scale_swizzled = swizzle_blockscale( b_scale, group_size, n_padded, k_padded, tile_size ) - - group_arange = torch.arange(0, group_size + 1, dtype=torch.int32, device="cuda") + group_arange = paddle.arange(start=0, end=group_size + 1, dtype="int32") m_indptr = group_arange * m - - # Pad a_scale_swizzled according to the function compute_sm100_cutlass_group_gemm_args - # in group_gemm_mxfp4_groupwise_sm100.cuh alignment_m_sf = 128 m_indptr_padded = ( (m_indptr + group_arange * (alignment_m_sf - 1)) @@ -294,21 +272,21 @@ def test_mxfp8_mxfp4_groupwise_group_gemm( * alignment_m_sf ) m_sf = m_indptr_padded[1:] - m_indptr_padded[:-1] - a_scale_chunked = a_scale_swizzled.chunk(group_size, dim=0) + a_scale_chunked = a_scale_swizzled.chunk(chunks=group_size, axis=0) a_scale_chunked = [ - torch.cat( - [ + paddle.concat( + x=[ x, - torch.zeros( - m_sf[i] - x.shape[0], *x.shape[1:], dtype=x.dtype, device=x.device + paddle.zeros( + shape=[m_sf[i] - tuple(x.shape)[0], *tuple(x.shape)[1:]], + dtype=x.dtype, ), ] ) for i, x in enumerate(a_scale_chunked) ] - a_scale_swizzled = torch.cat(a_scale_chunked) - - out_ref = torch.empty((group_size * m, n), dtype=out_dtype, device="cuda") + a_scale_swizzled = paddle.concat(x=a_scale_chunked) + out_ref = paddle.empty(shape=(group_size * m, n), dtype=out_dtype) for i in range(group_size): out_ref[m * i : m * (i + 1)] = gemm_mxfp8_mxfp4_nt_groupwise_ref( a_fp8[m * i : m * (i + 1)], @@ -320,7 +298,6 @@ def test_mxfp8_mxfp4_groupwise_group_gemm( k, out_dtype, ) - mma_sm_list = [1, 2] tile_m_list = [128] tile_n_list = [64, 128, 192, 256] @@ -342,12 +319,12 @@ def test_mxfp8_mxfp4_groupwise_group_gemm( swap_ab=swap_ab, out_dtype=out_dtype, )[:, :n] - torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2) + assert paddle.allclose(x=out, y=out_ref, atol=0.01, rtol=0.01).item(), "" if __name__ == "__main__": - for fp8_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - for out_dtype in [torch.bfloat16, torch.float16]: + for fp8_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: + for out_dtype in ["bfloat16", "float16"]: test_mxfp8_mxfp4_groupwise_group_gemm( 4, 2879, 2880, 2, fp8_dtype, out_dtype ) diff --git a/tests/test_hopper.py b/tests/test_hopper.py index 916bdfad92..1e44634bb1 100644 --- a/tests/test_hopper.py +++ b/tests/test_hopper.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2024 by FlashInfer team. @@ -13,9 +19,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch import flashinfer from flashinfer.utils import is_sm90a_supported @@ -30,30 +34,22 @@ def test_single_prefill( seq_len, num_qo_heads, num_kv_heads, causal, head_dim, logits_soft_cap ): - if not is_sm90a_supported(torch.device("cuda")): + if not is_sm90a_supported(device2str("cuda")): pytest.skip("SM90A is not supported") - if num_qo_heads % num_kv_heads != 0: pytest.skip("num_qo_heads must be divisible by num_kv_heads") - torch.random.manual_seed(123) - q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") - k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") - v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") - + paddle.seed(seed=123) + q = paddle.randn(shape=[seq_len, num_qo_heads, head_dim], dtype="float16") + k = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") + v = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") o_sm80, lse_sm80 = flashinfer.single_prefill_with_kv_cache_return_lse( - q, - k, - v, - causal=causal, - logits_soft_cap=logits_soft_cap, - backend="fa2", + q, k, v, causal=causal, logits_soft_cap=logits_soft_cap, backend="fa2" ) - o_sm90, lse_sm90 = flashinfer.single_prefill_with_kv_cache_return_lse( q, k, v, causal=causal, logits_soft_cap=logits_soft_cap, backend="fa3" ) - torch.testing.assert_close(lse_sm80, lse_sm90, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(o_sm80, o_sm90, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=lse_sm80, y=lse_sm90, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=o_sm80, y=o_sm90, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [1, 4, 8, 16]) @@ -61,42 +57,38 @@ def test_single_prefill( @pytest.mark.parametrize("num_qo_heads", [1, 4, 8]) @pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("head_dim", [128]) # [64, 128, 256]) +@pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) def test_batch_ragged_prefill( batch_size, seq_len, num_qo_heads, num_kv_heads, causal, head_dim, logits_soft_cap ): - if not is_sm90a_supported(torch.device("cuda")): + if not is_sm90a_supported(device2str("cuda")): pytest.skip("SM90A is not supported") - if num_qo_heads % num_kv_heads != 0: pytest.skip("num_qo_heads must be divisible by num_kv_heads") - torch.random.manual_seed(42) - q = torch.randn( - batch_size * seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" - ) - k = torch.randn( - batch_size * seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + paddle.seed(seed=42) + q = paddle.randn( + shape=[batch_size * seq_len, num_qo_heads, head_dim], dtype="float16" ) - v = torch.randn( - batch_size * seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda" + k = paddle.randn( + shape=[batch_size * seq_len, num_kv_heads, head_dim], dtype="float16" ) - - workspace_buffer = torch.empty( - 256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0" + v = paddle.randn( + shape=[batch_size * seq_len, num_kv_heads, head_dim], dtype="float16" ) - + workspace_buffer = paddle.empty(shape=256 * 1024 * 1024, dtype="uint8") wrapper_sm80 = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, backend="fa2" ) - wrapper_sm90 = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, backend="fa3" ) - - qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() - kv_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() - + qo_indptr = paddle.arange( + start=0, end=batch_size * seq_len + 1, step=seq_len + ).astype(dtype="int32") + kv_indptr = paddle.arange( + start=0, end=batch_size * seq_len + 1, step=seq_len + ).astype(dtype="int32") wrapper_sm80.plan( qo_indptr, kv_indptr, @@ -107,7 +99,6 @@ def test_batch_ragged_prefill( logits_soft_cap=logits_soft_cap, ) o_sm80, lse_sm80 = wrapper_sm80.run_return_lse(q, k, v) - wrapper_sm90.plan( qo_indptr, kv_indptr, @@ -118,56 +109,39 @@ def test_batch_ragged_prefill( logits_soft_cap=logits_soft_cap, ) o_sm90, lse_sm90 = wrapper_sm90.run_return_lse(q, k, v) - - torch.testing.assert_close(lse_sm80, lse_sm90, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(o_sm80, o_sm90, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=lse_sm80, y=lse_sm90, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=o_sm80, y=o_sm90, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [1, 4, 8, 16]) @pytest.mark.parametrize("seq_len", [11, 99, 1763, 9999, 32767]) @pytest.mark.parametrize("num_heads", [4, 32, 128]) @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) -def test_deepseek_prefill( - batch_size, - seq_len, - num_heads, - causal, - dtype, -): - if not is_sm90a_supported(torch.device("cuda")): +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) +def test_deepseek_prefill(batch_size, seq_len, num_heads, causal, dtype): + if not is_sm90a_supported(device2str("cuda")): pytest.skip("SM90A is not supported") - if batch_size * seq_len > 131072: pytest.skip() head_dim_qk = 192 head_dim_vo = 128 - torch.random.manual_seed(42) - q = torch.randn( - batch_size * seq_len, num_heads, head_dim_qk, dtype=dtype, device="cuda" - ) - k = torch.randn( - batch_size * seq_len, num_heads, head_dim_qk, dtype=dtype, device="cuda" - ) - v = torch.randn( - batch_size * seq_len, num_heads, head_dim_vo, dtype=dtype, device="cuda" - ) - - workspace_buffer = torch.empty( - 256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0" - ) - + paddle.seed(seed=42) + q = paddle.randn(shape=[batch_size * seq_len, num_heads, head_dim_qk], dtype=dtype) + k = paddle.randn(shape=[batch_size * seq_len, num_heads, head_dim_qk], dtype=dtype) + v = paddle.randn(shape=[batch_size * seq_len, num_heads, head_dim_vo], dtype=dtype) + workspace_buffer = paddle.empty(shape=256 * 1024 * 1024, dtype="uint8") wrapper_sm80 = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, backend="fa2" ) - wrapper_sm90 = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, backend="fa3" ) - - qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() - kv_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() - + qo_indptr = paddle.arange( + start=0, end=batch_size * seq_len + 1, step=seq_len + ).astype(dtype="int32") + kv_indptr = paddle.arange( + start=0, end=batch_size * seq_len + 1, step=seq_len + ).astype(dtype="int32") wrapper_sm80.plan( qo_indptr, kv_indptr, @@ -180,7 +154,6 @@ def test_deepseek_prefill( kv_data_type=dtype, ) o_sm80, lse_sm80 = wrapper_sm80.run_return_lse(q, k, v) - wrapper_sm90.plan( qo_indptr, kv_indptr, @@ -193,14 +166,13 @@ def test_deepseek_prefill( kv_data_type=dtype, ) o_sm90, lse_sm90 = wrapper_sm90.run_return_lse(q, k, v) - - torch.testing.assert_close(lse_sm80, lse_sm90, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(o_sm80, o_sm90, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=lse_sm80, y=lse_sm90, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=o_sm80, y=o_sm90, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [1, 4, 8, 16]) @pytest.mark.parametrize("seq_len", [11, 12, 99, 1763, 9999, 32767]) -@pytest.mark.parametrize("page_size", [1]) # [1, 16]) +@pytest.mark.parametrize("page_size", [1]) @pytest.mark.parametrize("num_qo_heads", [1, 4, 8]) @pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) @pytest.mark.parametrize("causal", [False, True]) @@ -216,54 +188,43 @@ def test_batch_paged_prefill( head_dim, logits_soft_cap, ): - if not is_sm90a_supported(torch.device("cuda")): + if not is_sm90a_supported(device2str("cuda")): pytest.skip("SM90A is not supported") - if num_qo_heads % num_kv_heads != 0: pytest.skip("num_qo_heads must be divisible by num_kv_heads") - torch.random.manual_seed(42) - q = torch.randn( - batch_size * seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda" + paddle.seed(seed=42) + q = paddle.randn( + shape=[batch_size * seq_len, num_qo_heads, head_dim], dtype="float16" ) num_pages_per_request = (seq_len + page_size - 1) // page_size - k = torch.randn( - batch_size * num_pages_per_request, - page_size, - num_kv_heads, - head_dim, - dtype=torch.half, - device="cuda", + k = paddle.randn( + shape=[batch_size * num_pages_per_request, page_size, num_kv_heads, head_dim], + dtype="float16", ) - v = torch.randn( - batch_size * num_pages_per_request, - page_size, - num_kv_heads, - head_dim, - dtype=torch.half, - device="cuda", + v = paddle.randn( + shape=[batch_size * num_pages_per_request, page_size, num_kv_heads, head_dim], + dtype="float16", ) - - workspace_buffer = torch.empty( - 256 * 1024 * 1024, dtype=torch.uint8, device="cuda:0" - ) - + workspace_buffer = paddle.empty(shape=256 * 1024 * 1024, dtype="uint8") wrapper_sm80 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, backend="fa2" ) - wrapper_sm90 = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, backend="fa3" ) - last_page_len = seq_len - (num_pages_per_request - 1) * page_size - qo_indptr = torch.arange(0, batch_size * seq_len + 1, seq_len).int() - kv_indptr = torch.arange( - 0, batch_size * num_pages_per_request + 1, num_pages_per_request - ).int() - # NOTE(Zihao): pad 256 elements to avoid out-of-bound because we didn't check the boundary in the kernel - kv_indices = torch.arange(0, batch_size * num_pages_per_request + 256).int() - last_page_len = torch.full((batch_size,), last_page_len, dtype=torch.int32) - + qo_indptr = paddle.arange( + start=0, end=batch_size * seq_len + 1, step=seq_len + ).astype(dtype="int32") + kv_indptr = paddle.arange( + start=0, end=batch_size * num_pages_per_request + 1, step=num_pages_per_request + ).astype(dtype="int32") + kv_indices = paddle.arange( + start=0, end=batch_size * num_pages_per_request + 256 + ).astype(dtype="int32") + last_page_len = paddle.full( + shape=(batch_size,), fill_value=last_page_len, dtype="int32" + ) wrapper_sm80.plan( qo_indptr, kv_indptr, @@ -277,7 +238,6 @@ def test_batch_paged_prefill( logits_soft_cap=logits_soft_cap, ) o_sm80, lse_sm80 = wrapper_sm80.run_return_lse(q, (k, v)) - wrapper_sm90.plan( qo_indptr, kv_indptr, @@ -291,9 +251,8 @@ def test_batch_paged_prefill( logits_soft_cap=logits_soft_cap, ) o_sm90, lse_sm90 = wrapper_sm90.run_return_lse(q, (k, v)) - - torch.testing.assert_close(lse_sm80, lse_sm90, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(o_sm80, o_sm90, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=lse_sm80, y=lse_sm90, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=o_sm80, y=o_sm90, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [1]) @@ -329,29 +288,38 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring_fa3( logits_soft_cap, return_lse, ): - q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() - q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len + q = ( + paddle.randn(shape=[batch_size * qo_len, num_qo_heads, head_dim]) + .to(0) + .astype(dtype="float16") + ) + q_indptr_cpu = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") * qo_len + ) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( - torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0).half() + paddle.randn(shape=[total_num_pages, 2, num_kv_heads, page_size, head_dim]) + .to(0) + .astype(dtype="float16") if kv_layout == "HND" - else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim) + else paddle.randn(shape=[total_num_pages, 2, page_size, num_kv_heads, head_dim]) .to(0) - .half() + .astype(dtype="float16") ) - kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq - kv_indices_cpu = torch.arange(0, total_num_pages).int() - kv_last_page_len_cpu = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + kv_indptr_cpu = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") + * num_pages_per_seq ) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + kv_indices_cpu = paddle.arange(start=0, end=total_num_pages).astype(dtype="int32") + kv_last_page_len_cpu = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" + ) + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8").to(0) q_indptr_gpu = q_indptr_cpu.to(0) kv_indptr_gpu = kv_indptr_cpu.to(0) kv_indices_gpu = kv_indices_cpu.to(0) kv_last_page_len_gpu = kv_last_page_len_cpu.to(0) - wrapper_fa2 = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout, backend="fa2" ) @@ -366,17 +334,20 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring_fa3( page_size, causal=causal, logits_soft_cap=logits_soft_cap, - prefix_len_ptr=torch.tensor(prefix_len_ptr).to(dtype=torch.uint32).to(0), - token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) - .to(dtype=torch.uint16) + prefix_len_ptr=paddle.to_tensor(data=prefix_len_ptr) +>>>>>> .to(dtype=torch.uint32) + .to(0), + token_pos_in_items_ptr=paddle.to_tensor(data=token_pos_in_items_ptr) +>>>>>> .to(dtype=torch.uint16) + .to(0), + token_pos_in_items_len=paddle.to_tensor(data=token_pos_in_items_len) +>>>>>> .to(dtype=torch.uint32) .to(0), - token_pos_in_items_len=torch.tensor(token_pos_in_items_len) - .to(dtype=torch.uint32) + max_item_len_ptr=paddle.to_tensor(data=max_item_len_ptr) +>>>>>> .to(dtype=torch.uint16) .to(0), - max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), ) o_fa2, lse_fa2 = wrapper_fa2.run_return_lse(q, kv_data) - wrapper_fa3 = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout, backend="fa3" ) @@ -391,20 +362,22 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring_fa3( page_size, causal=causal, logits_soft_cap=logits_soft_cap, - prefix_len_ptr=torch.tensor(prefix_len_ptr).to(dtype=torch.uint32).to(0), - token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) - .to(dtype=torch.uint16) + prefix_len_ptr=paddle.to_tensor(data=prefix_len_ptr) +>>>>>> .to(dtype=torch.uint32) + .to(0), + token_pos_in_items_ptr=paddle.to_tensor(data=token_pos_in_items_ptr) +>>>>>> .to(dtype=torch.uint16) + .to(0), + token_pos_in_items_len=paddle.to_tensor(data=token_pos_in_items_len) +>>>>>> .to(dtype=torch.uint32) .to(0), - token_pos_in_items_len=torch.tensor(token_pos_in_items_len) - .to(dtype=torch.uint32) + max_item_len_ptr=paddle.to_tensor(data=max_item_len_ptr) +>>>>>> .to(dtype=torch.uint16) .to(0), - max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), ) - o_fa3, lse_fa3 = wrapper_fa3.run_return_lse(q, kv_data) - - torch.testing.assert_close(lse_fa2, lse_fa3, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(o_fa2, o_fa3, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=lse_fa2, y=lse_fa3, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=o_fa2, y=o_fa3, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [2]) @@ -460,29 +433,38 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring_fa3_bsz2( logits_soft_cap, return_lse, ): - q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() - q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len + q = ( + paddle.randn(shape=[batch_size * qo_len, num_qo_heads, head_dim]) + .to(0) + .astype(dtype="float16") + ) + q_indptr_cpu = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") * qo_len + ) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( - torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0).half() + paddle.randn(shape=[total_num_pages, 2, num_kv_heads, page_size, head_dim]) + .to(0) + .astype(dtype="float16") if kv_layout == "HND" - else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim) + else paddle.randn(shape=[total_num_pages, 2, page_size, num_kv_heads, head_dim]) .to(0) - .half() + .astype(dtype="float16") ) - kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq - kv_indices_cpu = torch.arange(0, total_num_pages).int() - kv_last_page_len_cpu = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + kv_indptr_cpu = ( + paddle.arange(start=0, end=batch_size + 1).astype(dtype="int32") + * num_pages_per_seq ) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + kv_indices_cpu = paddle.arange(start=0, end=total_num_pages).astype(dtype="int32") + kv_last_page_len_cpu = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" + ) + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8").to(0) q_indptr_gpu = q_indptr_cpu.to(0) kv_indptr_gpu = kv_indptr_cpu.to(0) kv_indices_gpu = kv_indices_cpu.to(0) kv_last_page_len_gpu = kv_last_page_len_cpu.to(0) - wrapper_fa2 = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout, backend="fa2" ) @@ -497,17 +479,20 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring_fa3_bsz2( page_size, causal=causal, logits_soft_cap=logits_soft_cap, - prefix_len_ptr=torch.tensor(prefix_len_ptr).to(dtype=torch.uint32).to(0), - token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) - .to(dtype=torch.uint16) + prefix_len_ptr=paddle.to_tensor(data=prefix_len_ptr) +>>>>>> .to(dtype=torch.uint32) + .to(0), + token_pos_in_items_ptr=paddle.to_tensor(data=token_pos_in_items_ptr) +>>>>>> .to(dtype=torch.uint16) .to(0), - token_pos_in_items_len=torch.tensor(token_pos_in_items_len) - .to(dtype=torch.uint32) + token_pos_in_items_len=paddle.to_tensor(data=token_pos_in_items_len) +>>>>>> .to(dtype=torch.uint32) + .to(0), + max_item_len_ptr=paddle.to_tensor(data=max_item_len_ptr) +>>>>>> .to(dtype=torch.uint16) .to(0), - max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), ) o_fa2, lse_fa2 = wrapper_fa2.run_return_lse(q, kv_data) - wrapper_fa3 = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout, backend="fa3" ) @@ -522,25 +507,23 @@ def test_batch_prefill_with_paged_kv_cache_multi_item_scoring_fa3_bsz2( page_size, causal=causal, logits_soft_cap=logits_soft_cap, - prefix_len_ptr=torch.tensor(prefix_len_ptr).to(dtype=torch.uint32).to(0), - token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) - .to(dtype=torch.uint16) + prefix_len_ptr=paddle.to_tensor(data=prefix_len_ptr) +>>>>>> .to(dtype=torch.uint32) .to(0), - token_pos_in_items_len=torch.tensor(token_pos_in_items_len) - .to(dtype=torch.uint32) + token_pos_in_items_ptr=paddle.to_tensor(data=token_pos_in_items_ptr) +>>>>>> .to(dtype=torch.uint16) + .to(0), + token_pos_in_items_len=paddle.to_tensor(data=token_pos_in_items_len) +>>>>>> .to(dtype=torch.uint32) + .to(0), + max_item_len_ptr=paddle.to_tensor(data=max_item_len_ptr) +>>>>>> .to(dtype=torch.uint16) .to(0), - max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), ) - o_fa3, lse_fa3 = wrapper_fa3.run_return_lse(q, kv_data) - - torch.testing.assert_close(lse_fa2, lse_fa3, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(o_fa2, o_fa3, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=lse_fa2, y=lse_fa3, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=o_fa2, y=o_fa3, rtol=0.001, atol=0.001).item(), "" if __name__ == "__main__": - # test_batch_prefill(14, 64, 32, 32, False, 128) - # test_batch_prefill(1, 32767, 8, 8, True, 128) - # test_single_prefill(64, 1, 1, False, 256) - # test_batch_paged_prefill(2, 32768, 1, 1, 1, False, 128) test_batch_paged_prefill(16, 32767, 1, 8, 8, True, 128, 0) diff --git a/tests/test_hopper_fp8_attention.py b/tests/test_hopper_fp8_attention.py index 08ecf26dfd..db2d9e1923 100644 --- a/tests/test_hopper_fp8_attention.py +++ b/tests/test_hopper_fp8_attention.py @@ -1,91 +1,73 @@ +import sys + + from typing import Tuple import numpy as np +import paddle import pytest import scipy as sp -import torch +from flashinfer.paddle_utils import * import flashinfer def per_head_symmetric_quant( - x: torch.Tensor, quant_dtype: torch.dtype -) -> Tuple[torch.Tensor, torch.Tensor]: - # x: [seq_len, num_heads, head_dim] - assert quant_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + x: paddle.Tensor, quant_dtype: paddle.dtype +) -> Tuple[paddle.Tensor, paddle.Tensor]: + assert quant_dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2] - def get_dtype_minmax(dtype: torch.dtype) -> Tuple[float, float]: - if dtype == torch.float8_e4m3fn: + def get_dtype_minmax(dtype: paddle.dtype) -> Tuple[float, float]: + if dtype == paddle.float8_e4m3fn: return -448.0, 448.0 - elif dtype == torch.float8_e5m2: +>>>>>> elif dtype == paddle.float8_e5m2: return -57344, 57344 else: raise ValueError(f"Unsupported quantization dtype: {dtype}") o_min_val, o_max_val = get_dtype_minmax(quant_dtype) - x_max_val = x.abs().amax(dim=(0, 2)).to(dtype=torch.float32) - - s_out = torch.clamp(x_max_val / o_max_val, min=1e-6) + x_max_val = x.abs().amax(axis=(0, 2)).to(dtype="float32") + s_out = paddle.clip(x=x_max_val / o_max_val, min=1e-06) s_out_broadcast = s_out.view(1, -1, 1) - - q_x_out = torch.clamp( - x / s_out_broadcast, - min=o_min_val, - max=o_max_val, - ).to(dtype=quant_dtype) - - assert not torch.any(torch.isnan(q_x_out)) - assert not torch.any(torch.isnan(s_out)) - + q_x_out = paddle.clip(x=x / s_out_broadcast, min=o_min_val, max=o_max_val).to( + dtype=quant_dtype + ) + assert not paddle.any(x=paddle.isnan(x=q_x_out)) + assert not paddle.any(x=paddle.isnan(x=s_out)) return q_x_out, s_out -def bsr_attention_ref( - q, - k, - v, - indptr, - indices, - mask_data, -): - M = q.shape[0] - N = k.shape[0] +def bsr_attention_ref(q, k, v, indptr, indices, mask_data): + M = tuple(q.shape)[0] + N = tuple(k.shape)[0] bsr = sp.sparse.bsr_matrix( (mask_data.cpu().numpy(), indices.cpu().numpy(), indptr.cpu().numpy()), shape=(M, N), ) - dense_mask = torch.tensor(bsr.toarray(), dtype=bool, device=q.device) + dense_mask = paddle.to_tensor(data=bsr.toarray(), dtype=bool, place=q.place) o = flashinfer.prefill.single_prefill_with_kv_cache( q, k, v, custom_mask=dense_mask, backend="fa2" ) return o -# Test single_prefill correctness: MSE should be below threshold @pytest.mark.parametrize("seq_len", [117, 509, 1011, 2372, 7777, 12315]) @pytest.mark.parametrize("num_heads", [24, 32]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("dtype", [paddle.float8_e4m3fn, paddle.float8_e5m2]) def test_single_prefill(seq_len, num_heads, causal, head_dim, dtype): - # Prepare inputs - o_dtype = torch.half + o_dtype = "float16" num_qo_heads = num_kv_heads = num_heads - q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda") - k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") - v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda") - - # Reference output + q = paddle.randn(shape=[seq_len, num_qo_heads, head_dim], dtype="float16") + k = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") + v = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") o_ref = flashinfer.single_prefill_with_kv_cache( q, k, v, causal=causal, backend="fa3" ) - - # Quantize q_fp8, s_q = per_head_symmetric_quant(q, quant_dtype=dtype) k_fp8, s_k = per_head_symmetric_quant(k, quant_dtype=dtype) v_fp8, s_v = per_head_symmetric_quant(v, quant_dtype=dtype) - - # FP8 output o_fp8 = flashinfer.single_prefill_with_kv_cache( q_fp8, k_fp8, @@ -97,14 +79,12 @@ def test_single_prefill(seq_len, num_heads, causal, head_dim, dtype): backend="fa3", o_dtype=o_dtype, ) - - # Compute MSE and assert - # NOTE: This is not a strict correctness guarantee - mse = torch.mean((o_ref.float() - o_fp8.float()) ** 2) + mse = paddle.mean( + x=(o_ref.astype(dtype="float32") - o_fp8.astype(dtype="float32")) ** 2 + ) assert mse < 1.0, f"MSE too high: {mse.item()}" -# Test block sparse attention correctness: MSE should be below threshold @pytest.mark.parametrize("R", [1, 4, 16]) @pytest.mark.parametrize("C", [1, 4, 16]) @pytest.mark.parametrize("M", [256, 512, 1024, 4096]) @@ -112,41 +92,31 @@ def test_single_prefill(seq_len, num_heads, causal, head_dim, dtype): @pytest.mark.parametrize("num_heads", [1, 8, 24, 32]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) @pytest.mark.parametrize("mask_inside_block", [False]) -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("dtype", [paddle.float8_e4m3fn, paddle.float8_e5m2]) def test_block_sparse_attention( R, C, M, N, num_heads, head_dim, mask_inside_block, dtype ): - # print args print( - f"Testing block sparse attention with R={R}, C={C}, M={M}, N={N}, num_heads={num_heads}, " - f"head_dim={head_dim}, mask_inside_block={mask_inside_block}, dtype={dtype}" + f"Testing block sparse attention with R={R}, C={C}, M={M}, N={N}, num_heads={num_heads}, head_dim={head_dim}, mask_inside_block={mask_inside_block}, dtype={dtype}" ) - # setup random seed for reproducibility - torch.manual_seed(0) + paddle.seed(seed=0) np.random.seed(0) - # Build sparse mask MB = M // R NB = N // C rng = np.random.default_rng(seed=0) S = sp.sparse.random(MB, NB, density=0.25, random_state=rng).tocsr() - indptr = torch.from_numpy(S.indptr).cuda() - indices = torch.from_numpy(S.indices).cuda() + indptr = paddle.to_tensor(data=S.indptr).cuda() + indices = paddle.to_tensor(data=S.indices).cuda() nnz = S.nnz if mask_inside_block: - data_mask = (torch.rand((nnz, R, C)) > 0.5).to(torch.bool).cuda() + data_mask = (paddle.rand(shape=(nnz, R, C)) > 0.5).to("bool").cuda() else: - data_mask = torch.ones((nnz, R, C), dtype=torch.bool, device="cuda") - - # Random inputs - q = torch.randn((M, num_heads, head_dim), dtype=torch.float16, device="cuda") - k = torch.randn((N, num_heads, head_dim), dtype=torch.float16, device="cuda") - v = torch.randn((N, num_heads, head_dim), dtype=torch.float16, device="cuda") - - # Reference output via dense mask + data_mask = paddle.ones(shape=(nnz, R, C), dtype="bool") + q = paddle.randn(shape=(M, num_heads, head_dim), dtype="float16") + k = paddle.randn(shape=(N, num_heads, head_dim), dtype="float16") + v = paddle.randn(shape=(N, num_heads, head_dim), dtype="float16") o_ref = bsr_attention_ref(q, k, v, indptr, indices, data_mask) - - # Plan and run BlockSparseAttention - workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device="cuda") + workspace_buffer = paddle.zeros(shape=128 * 1024 * 1024, dtype="uint8") sparse_wrapper = flashinfer.sparse.BlockSparseAttentionWrapper( workspace_buffer, backend="fa3" ) @@ -163,16 +133,15 @@ def test_block_sparse_attention( mask=data_mask if mask_inside_block else None, q_data_type=dtype, kv_data_type=dtype, - o_data_type=torch.float16, + o_data_type="float16", ) q_fp8, s_q = per_head_symmetric_quant(q, quant_dtype=dtype) k_fp8, s_k = per_head_symmetric_quant(k, quant_dtype=dtype) v_fp8, s_v = per_head_symmetric_quant(v, quant_dtype=dtype) o = sparse_wrapper.run(q_fp8, k_fp8, v_fp8, s_q, s_k, s_v) - - # Compute MSE and assert - # NOTE: This is not a strict correctness guarantee - mse = torch.mean((o_ref.float() - o.float()) ** 2) + mse = paddle.mean( + x=(o_ref.astype(dtype="float32") - o.astype(dtype="float32")) ** 2 + ) assert mse < 1.0, f"Block sparse MSE too high: {mse.item()}" @@ -184,7 +153,7 @@ def test_block_sparse_attention( for num_heads in [8]: for head_dim in [256]: for mask_inside_block in [False]: - for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + for dtype in [paddle.float8_e4m3fn, paddle.float8_e5m2]: test_block_sparse_attention( R, C, diff --git a/tests/test_jit_example.py b/tests/test_jit_example.py index cf54a38a5e..1cbb74350d 100644 --- a/tests/test_jit_example.py +++ b/tests/test_jit_example.py @@ -1,22 +1,24 @@ +import sys + + import functools import math +import paddle import pytest -import torch +from flashinfer.paddle_utils import * import flashinfer from flashinfer.decode import single_decode_with_kv_cache_with_jit_module -from flashinfer.jit.attention import ( - gen_customize_single_decode_module, - gen_customize_single_prefill_module, -) +from flashinfer.jit.attention import (gen_customize_single_decode_module, + gen_customize_single_prefill_module) from flashinfer.prefill import single_prefill_with_kv_cache_with_jit_module from flashinfer.utils import MaskMode, is_sm90a_supported def test_single_decode_mask(): - torch.manual_seed(42) - variant_decl = r""" + paddle.seed(seed=42) + variant_decl = """ struct SingleDecodeWithCustomMask : AttentionVariantBase { static constexpr bool use_softmax = true; @@ -46,39 +48,43 @@ def test_single_decode_mask(): }; """ jit_module = gen_customize_single_decode_module( - "single_decode_custom_mask", # uri - torch.float16, # dtype_q - torch.float16, # dtype_kv - torch.float16, # dtype_o - 128, # head_dim_qk - 128, # head_dim_vo - ["custom_mask"], # additional_tensor_names - ["uint8_t"], # additional_tensor_dtypes - ["sm_scale"], # # additional_scalar_names - ["double"], # additional_scalar_dtypes + "single_decode_custom_mask", + "float16", + "float16", + "float16", + 128, + 128, + ["custom_mask"], + ["uint8_t"], + ["sm_scale"], + ["double"], "SingleDecodeWithCustomMask", variant_decl, ).build_and_load() - f = functools.partial(single_decode_with_kv_cache_with_jit_module, jit_module) - - q = torch.randn(32, 128, dtype=torch.float16, device="cuda") - k = torch.randn(254, 32, 128, dtype=torch.float16, device="cuda") - v = torch.randn(254, 32, 128, dtype=torch.float16, device="cuda") + q = paddle.randn(shape=[32, 128], dtype="float16") + k = paddle.randn(shape=[254, 32, 128], dtype="float16") + v = paddle.randn(shape=[254, 32, 128], dtype="float16") sm_scale = 1.0 / math.sqrt(128) - - custom_mask = torch.randint(0, 2, (254,), dtype=torch.uint8, device="cuda") + custom_mask = paddle.randint(low=0, high=2, shape=(254,), dtype="uint8") packed_custom_mask = flashinfer.packbits(custom_mask, bitorder="little") - o = f(q, k, v, packed_custom_mask, sm_scale) - - p = torch.einsum("hd,nhd->hn", q.float(), k.float()) * sm_scale - p[:, torch.nonzero(torch.logical_not(custom_mask)).squeeze()] = -float("inf") - o_ref = torch.einsum("hn,nhd->hd", torch.softmax(p, dim=-1), v.float()).half() - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + p = ( + paddle.einsum( + "hd,nhd->hn", q.astype(dtype="float32"), k.astype(dtype="float32") + ) + * sm_scale + ) + p[:, paddle.nonzero(x=paddle.logical_not(x=custom_mask)).squeeze()] = -float("inf") + o_ref = paddle.einsum( + "hn,nhd->hd", + paddle.nn.functional.softmax(x=p, axis=-1), + v.astype(dtype="float32"), + ).astype(dtype="float16") + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" -flash_sigmoid_sm80_decl = r""" +flash_sigmoid_sm80_decl = """ struct FlashSigmoid : AttentionVariantBase { static constexpr bool use_softmax = false; @@ -106,8 +112,7 @@ def test_single_decode_mask(): }) }; """ - -flash_sigmoid_sm90_decl = r""" +flash_sigmoid_sm90_decl = """ struct FlashSigmoid : AttentionVariantBase { float logits_scale_log2, sigmoid_bias_log2e; // Init @@ -131,43 +136,46 @@ def test_single_decode_mask(): def test_flash_sigmoid(): - torch.manual_seed(42) + paddle.seed(seed=42) variant_decl = flash_sigmoid_sm80_decl jit_module = gen_customize_single_prefill_module( - "fa2", # backend - "single_prefill_flash_sigmoid", # uri - torch.float16, # dtype_q - torch.float16, # dtype_kv - torch.float16, # dtype_o - 128, # head_dim_qk - 128, # head_dim_vo - [], # additional_tensor_names - [], # additional_tensor_dtypes - ["logits_scale", "sigmoid_bias"], # additional_scalar_names - ["double", "double"], # additional_scalar_dtypes + "fa2", + "single_prefill_flash_sigmoid", + "float16", + "float16", + "float16", + 128, + 128, + [], + [], + ["logits_scale", "sigmoid_bias"], + ["double", "double"], "FlashSigmoid", variant_decl, ).build_and_load() - f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module) - - q = torch.randn(128, 8, 128, dtype=torch.float16, device="cuda") - k = torch.randn(1027, 8, 128, dtype=torch.float16, device="cuda") - v = torch.randn(1027, 8, 128, dtype=torch.float16, device="cuda") + q = paddle.randn(shape=[128, 8, 128], dtype="float16") + k = paddle.randn(shape=[1027, 8, 128], dtype="float16") + v = paddle.randn(shape=[1027, 8, 128], dtype="float16") logits_scale = 1.0 / math.sqrt(128) sigmoid_bias = 0.25 o = f(q, k, v, logits_scale, sigmoid_bias, mask_mode=MaskMode.NON_CAUSAL.value) - - p = torch.sigmoid( - torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * logits_scale + sigmoid_bias + p = paddle.nn.functional.sigmoid( + x=paddle.einsum( + "mhd,nhd->hmn", q.astype(dtype="float32"), k.astype(dtype="float32") + ) + * logits_scale + + sigmoid_bias ) - o_ref = torch.einsum("hmn,nhd->mhd", p, v.float()).half() - torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2) + o_ref = paddle.einsum("hmn,nhd->mhd", p, v.astype(dtype="float32")).astype( + dtype="float16" + ) + assert paddle.allclose(x=o, y=o_ref, rtol=0.02, atol=0.02).item(), "" def test_dump_logits(): - torch.manual_seed(42) - variant_decl = r""" + paddle.seed(seed=42) + variant_decl = """ struct DumpLogits : AttentionVariantBase { static constexpr bool use_softmax = true; @@ -193,80 +201,84 @@ def test_dump_logits(): }; """ jit_module = gen_customize_single_prefill_module( - "fa2", # backend - "single_prefill_dump_logits", # uri - torch.float16, # dtype_q - torch.float16, # dtype_kv - torch.float16, # dtype_o - 128, # head_dim_qk - 128, # head_dim_vo - ["output_logits"], # additional_tensor_names - ["float"], # additional_tensor_dtypes - ["sm_scale"], # additional_scalar_names - ["double"], # additional_scalar_dtypes + "fa2", + "single_prefill_dump_logits", + "float16", + "float16", + "float16", + 128, + 128, + ["output_logits"], + ["float"], + ["sm_scale"], + ["double"], "DumpLogits", variant_decl, ).build_and_load() - f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module) - - q = torch.randn(128, 32, 128, dtype=torch.float16, device="cuda") - k = torch.randn(1023, 32, 128, dtype=torch.float16, device="cuda") - v = torch.randn(1023, 32, 128, dtype=torch.float16, device="cuda") - logits = torch.empty(32, 128, 1023, dtype=torch.float32, device="cuda") + q = paddle.randn(shape=[128, 32, 128], dtype="float16") + k = paddle.randn(shape=[1023, 32, 128], dtype="float16") + v = paddle.randn(shape=[1023, 32, 128], dtype="float16") + logits = paddle.empty(shape=[32, 128, 1023], dtype="float32") sm_scale = 1.0 / math.sqrt(128) o = f(q, k, v, logits, sm_scale, mask_mode=MaskMode.NON_CAUSAL.value) - - p = torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * sm_scale - o_ref = torch.einsum("hmn,nhd->mhd", torch.softmax(p, dim=-1), v.float()).half() - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(logits, p, rtol=2e-2, atol=2e-2) + p = ( + paddle.einsum( + "mhd,nhd->hmn", q.astype(dtype="float32"), k.astype(dtype="float32") + ) + * sm_scale + ) + o_ref = paddle.einsum( + "hmn,nhd->mhd", + paddle.nn.functional.softmax(x=p, axis=-1), + v.astype(dtype="float32"), + ).astype(dtype="float16") + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=logits, y=p, rtol=0.02, atol=0.02).item(), "" @pytest.mark.parametrize("use_tensor_cores", [False, True]) def test_batch_decode_flash_sigmoid(use_tensor_cores): - torch.manual_seed(42) + paddle.seed(seed=42) variant_decl = flash_sigmoid_sm80_decl jit_args = ( - f"batch_decode_flash_sigmoid_sm80_{use_tensor_cores}", # uri - torch.float16, # dtype_q - torch.float16, # dtype_kv - torch.float16, # dtype_o - torch.int32, # idtype - 128, # hidden_dim_qk - 128, # hidden_dim_vo - [], # additional_tensor_names - [], # additional_tensor_dtypes - ["logits_scale", "sigmoid_bias"], # additional_scalar_names - ["double", "double"], # additional_scalar_dtypes + f"batch_decode_flash_sigmoid_sm80_{use_tensor_cores}", + "float16", + "float16", + "float16", + "int32", + 128, + 128, + [], + [], + ["logits_scale", "sigmoid_bias"], + ["double", "double"], "FlashSigmoid", variant_decl, ) - - float_workspace_buffer = torch.empty( - 128 * 1024 * 1024, dtype=torch.uint8, device="cuda" - ) + float_workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="uint8") wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", use_tensor_cores=use_tensor_cores, jit_args=jit_args, ) - batch_size = 128 seq_len_per_request = 1024 - kv_indptr_host = torch.arange( - 0, batch_size * seq_len_per_request + 1, seq_len_per_request, dtype=torch.int32 + kv_indptr_host = paddle.arange( + start=0, + end=batch_size * seq_len_per_request + 1, + step=seq_len_per_request, + dtype="int32", ) page_size = 1 - kv_indices_host = torch.arange( - 0, batch_size * seq_len_per_request, dtype=torch.int32 + kv_indices_host = paddle.arange( + start=0, end=batch_size * seq_len_per_request, dtype="int32" ) - last_page_len_host = torch.full((batch_size,), 1, dtype=torch.int32) + last_page_len_host = paddle.full(shape=(batch_size,), fill_value=1, dtype="int32") num_qo_heads = 32 num_kv_heads = 32 head_dim = 128 - wrapper.plan( kv_indptr_host, kv_indices_host, @@ -275,101 +287,85 @@ def test_batch_decode_flash_sigmoid(use_tensor_cores): num_kv_heads, head_dim, page_size, - q_data_type=torch.float16, - kv_data_type=torch.float16, + q_data_type="float16", + kv_data_type="float16", ) - - q = torch.randn( - batch_size, - num_qo_heads, - head_dim, - dtype=torch.float16, - device="cuda", - ) - k_cache = torch.randn( - batch_size * seq_len_per_request, - num_kv_heads, - head_dim, - dtype=torch.float16, - device="cuda", + q = paddle.randn(shape=[batch_size, num_qo_heads, head_dim], dtype="float16") + k_cache = paddle.randn( + shape=[batch_size * seq_len_per_request, num_kv_heads, head_dim], + dtype="float16", ) - v_cache = torch.randn( - batch_size * seq_len_per_request, - num_kv_heads, - head_dim, - dtype=torch.float16, - device="cuda", + v_cache = paddle.randn( + shape=[batch_size * seq_len_per_request, num_kv_heads, head_dim], + dtype="float16", ) - logits_scale = 1.0 / math.sqrt(128) sigmoid_bias = 0.25 - o = wrapper.run(q, (k_cache, v_cache), logits_scale, sigmoid_bias) - p = torch.sigmoid( - torch.einsum( + p = paddle.nn.functional.sigmoid( + x=paddle.einsum( "bhd,bnhd->bhn", - q.view(batch_size, num_qo_heads, head_dim).float(), + q.view(batch_size, num_qo_heads, head_dim).astype(dtype="float32"), k_cache.view( batch_size, seq_len_per_request, num_kv_heads, head_dim - ).float(), + ).astype(dtype="float32"), ) * logits_scale + sigmoid_bias ) o_ref = ( - torch.einsum( + paddle.einsum( "bhn,bnhd->bhd", p, v_cache.view( batch_size, seq_len_per_request, num_kv_heads, head_dim - ).float(), + ).astype(dtype="float32"), ) - .half() + .astype(dtype="float16") .reshape(batch_size, num_qo_heads, head_dim) ) - - torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2) + assert paddle.allclose(x=o, y=o_ref, rtol=0.02, atol=0.02).item(), "" def test_batch_prefill_flash_sigmoid(): - torch.manual_seed(42) + paddle.seed(seed=42) variant_decl = flash_sigmoid_sm80_decl jit_args = ( - "batch_prefill_flash_sigmoid_sm80", # uri - torch.float16, # dtype_q - torch.float16, # dtype_kv - torch.float16, # dtype_o - torch.int32, # idtype - 128, # hidden_dim_qk - 128, # hidden_dim_vo - [], # additional_tensor_names - [], # additional_tensor_dtypes - ["logits_scale", "sigmoid_bias"], # additional_scalar_names - ["double", "double"], # additional_scalar_dtypes + "batch_prefill_flash_sigmoid_sm80", + "float16", + "float16", + "float16", + "int32", + 128, + 128, + [], + [], + ["logits_scale", "sigmoid_bias"], + ["double", "double"], "FlashSigmoid", variant_decl, ) - - float_workspace_buffer = torch.empty( - 128 * 1024 * 1024, dtype=torch.uint8, device="cuda" - ) + float_workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="uint8") wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", backend="fa2", jit_args=jit_args ) - batch_size = 128 seq_len_per_request = 1024 - qo_indptr_host = torch.arange( - 0, batch_size * seq_len_per_request + 1, seq_len_per_request, dtype=torch.int32 - ) - kv_indptr_host = torch.arange( - 0, batch_size * seq_len_per_request + 1, seq_len_per_request, dtype=torch.int32 + qo_indptr_host = paddle.arange( + start=0, + end=batch_size * seq_len_per_request + 1, + step=seq_len_per_request, + dtype="int32", + ) + kv_indptr_host = paddle.arange( + start=0, + end=batch_size * seq_len_per_request + 1, + step=seq_len_per_request, + dtype="int32", ) - num_qo_heads = 32 num_kv_heads = 32 head_dim = 128 - wrapper.plan( qo_indptr_host, kv_indptr_host, @@ -377,45 +373,33 @@ def test_batch_prefill_flash_sigmoid(): num_kv_heads, head_dim, causal=False, - q_data_type=torch.float16, - kv_data_type=torch.float16, + q_data_type="float16", + kv_data_type="float16", ) - - q = torch.randn( - batch_size * seq_len_per_request, - num_qo_heads, - head_dim, - dtype=torch.float16, - device="cuda", + q = paddle.randn( + shape=[batch_size * seq_len_per_request, num_qo_heads, head_dim], + dtype="float16", ) - k = torch.randn( - batch_size * seq_len_per_request, - num_kv_heads, - head_dim, - dtype=torch.float16, - device="cuda", + k = paddle.randn( + shape=[batch_size * seq_len_per_request, num_kv_heads, head_dim], + dtype="float16", ) - v = torch.randn( - batch_size * seq_len_per_request, - num_kv_heads, - head_dim, - dtype=torch.float16, - device="cuda", + v = paddle.randn( + shape=[batch_size * seq_len_per_request, num_kv_heads, head_dim], + dtype="float16", ) logits_scale = 1.0 / math.sqrt(128) sigmoid_bias = 0.25 - o = wrapper.run(q, k, v, logits_scale, sigmoid_bias) - wrapper_paged = flashinfer.BatchPrefillWithPagedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", backend="fa2", jit_args=jit_args ) - kv_indices_host = torch.arange( - 0, - batch_size * seq_len_per_request, - dtype=torch.int32, + kv_indices_host = paddle.arange( + start=0, end=batch_size * seq_len_per_request, dtype="int32" + ) + paged_kv_last_page_len_host = paddle.full( + shape=(batch_size,), fill_value=1, dtype="int32" ) - paged_kv_last_page_len_host = torch.full((batch_size,), 1, dtype=torch.int32) wrapper_paged.plan( qo_indptr_host, kv_indptr_host, @@ -427,71 +411,75 @@ def test_batch_prefill_flash_sigmoid(): 1, ) o_paged = wrapper_paged.run(q, (k, v), logits_scale, sigmoid_bias) - - p = torch.sigmoid( - torch.einsum( + p = paddle.nn.functional.sigmoid( + x=paddle.einsum( "bmhd,bnhd->bhmn", - q.view(batch_size, seq_len_per_request, num_qo_heads, head_dim).float(), - k.view(batch_size, seq_len_per_request, num_kv_heads, head_dim).float(), + q.view(batch_size, seq_len_per_request, num_qo_heads, head_dim).astype( + dtype="float32" + ), + k.view(batch_size, seq_len_per_request, num_kv_heads, head_dim).astype( + dtype="float32" + ), ) * logits_scale + sigmoid_bias ) o_ref = ( - torch.einsum( + paddle.einsum( "bhmn,bnhd->bmhd", p, - v.view(batch_size, seq_len_per_request, num_kv_heads, head_dim).float(), + v.view(batch_size, seq_len_per_request, num_kv_heads, head_dim).astype( + dtype="float32" + ), ) - .half() + .astype(dtype="float16") .reshape(batch_size * seq_len_per_request, num_qo_heads, head_dim) ) - torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2) - torch.testing.assert_close(o_paged, o_ref, rtol=2e-2, atol=2e-2) + assert paddle.allclose(x=o, y=o_ref, rtol=0.02, atol=0.02).item(), "" + assert paddle.allclose(x=o_paged, y=o_ref, rtol=0.02, atol=0.02).item(), "" def test_batch_prefill_sm90_flash_sigmoid(): - if not is_sm90a_supported(torch.device("cuda")): + if not is_sm90a_supported(device2str("cuda")): pytest.skip("SM90A is not supported") - - torch.manual_seed(42) + paddle.seed(seed=42) variant_decl = flash_sigmoid_sm90_decl jit_args = ( - "batch_prefill_flash_sigmoid", # uri - torch.float16, # dtype_q - torch.float16, # dtype_kv - torch.float16, # dtype_o - torch.int32, # idtype - 128, # hidden_dim_qk - 128, # hidden_dim_vo - [], # additional_tensor_names - [], # additional_tensor_dtypes - ["logits_scale", "sigmoid_bias"], # additional_scalar_names - ["double", "double"], # additional_scalar_dtypes + "batch_prefill_flash_sigmoid", + "float16", + "float16", + "float16", + "int32", + 128, + 128, + [], + [], + ["logits_scale", "sigmoid_bias"], + ["double", "double"], "FlashSigmoid", variant_decl, ) - - float_workspace_buffer = torch.empty( - 128 * 1024 * 1024, dtype=torch.uint8, device="cuda" - ) + float_workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="uint8") wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", backend="fa3", jit_args=jit_args ) - batch_size = 128 seq_len_per_request = 1024 - qo_indptr_host = torch.arange( - 0, batch_size * seq_len_per_request + 1, seq_len_per_request, dtype=torch.int32 - ) - kv_indptr_host = torch.arange( - 0, batch_size * seq_len_per_request + 1, seq_len_per_request, dtype=torch.int32 + qo_indptr_host = paddle.arange( + start=0, + end=batch_size * seq_len_per_request + 1, + step=seq_len_per_request, + dtype="int32", + ) + kv_indptr_host = paddle.arange( + start=0, + end=batch_size * seq_len_per_request + 1, + step=seq_len_per_request, + dtype="int32", ) - num_qo_heads = 32 num_kv_heads = 32 head_dim = 128 - wrapper.plan( qo_indptr_host, kv_indptr_host, @@ -499,44 +487,33 @@ def test_batch_prefill_sm90_flash_sigmoid(): num_kv_heads, head_dim, causal=False, - q_data_type=torch.float16, - kv_data_type=torch.float16, + q_data_type="float16", + kv_data_type="float16", ) - - q = torch.randn( - batch_size * seq_len_per_request, - num_qo_heads, - head_dim, - dtype=torch.float16, - device="cuda", + q = paddle.randn( + shape=[batch_size * seq_len_per_request, num_qo_heads, head_dim], + dtype="float16", ) - k = torch.randn( - batch_size * seq_len_per_request, - num_kv_heads, - head_dim, - dtype=torch.float16, - device="cuda", + k = paddle.randn( + shape=[batch_size * seq_len_per_request, num_kv_heads, head_dim], + dtype="float16", ) - v = torch.randn( - batch_size * seq_len_per_request, - num_kv_heads, - head_dim, - dtype=torch.float16, - device="cuda", + v = paddle.randn( + shape=[batch_size * seq_len_per_request, num_kv_heads, head_dim], + dtype="float16", ) logits_scale = 1.0 / math.sqrt(128) sigmoid_bias = 0.25 - o = wrapper.run(q, k, v, logits_scale, sigmoid_bias) wrapper_paged = flashinfer.BatchPrefillWithPagedKVCacheWrapper( float_workspace_buffer, kv_layout="NHD", backend="fa3", jit_args=jit_args ) - kv_indices_host = torch.arange( - 0, - batch_size * seq_len_per_request, - dtype=torch.int32, + kv_indices_host = paddle.arange( + start=0, end=batch_size * seq_len_per_request, dtype="int32" + ) + paged_kv_last_page_len_host = paddle.full( + shape=(batch_size,), fill_value=1, dtype="int32" ) - paged_kv_last_page_len_host = torch.full((batch_size,), 1, dtype=torch.int32) wrapper_paged.plan( qo_indptr_host, kv_indptr_host, @@ -548,32 +525,37 @@ def test_batch_prefill_sm90_flash_sigmoid(): 1, ) o_paged = wrapper_paged.run(q, (k, v), logits_scale, sigmoid_bias) - - p = torch.sigmoid( - torch.einsum( + p = paddle.nn.functional.sigmoid( + x=paddle.einsum( "bmhd,bnhd->bhmn", - q.view(batch_size, seq_len_per_request, num_qo_heads, head_dim).float(), - k.view(batch_size, seq_len_per_request, num_kv_heads, head_dim).float(), + q.view(batch_size, seq_len_per_request, num_qo_heads, head_dim).astype( + dtype="float32" + ), + k.view(batch_size, seq_len_per_request, num_kv_heads, head_dim).astype( + dtype="float32" + ), ) * logits_scale + sigmoid_bias ) o_ref = ( - torch.einsum( + paddle.einsum( "bhmn,bnhd->bmhd", p, - v.view(batch_size, seq_len_per_request, num_kv_heads, head_dim).float(), + v.view(batch_size, seq_len_per_request, num_kv_heads, head_dim).astype( + dtype="float32" + ), ) - .half() + .astype(dtype="float16") .reshape(batch_size * seq_len_per_request, num_qo_heads, head_dim) ) - torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2) - torch.testing.assert_close(o_paged, o_ref, rtol=2e-2, atol=2e-2) + assert paddle.allclose(x=o, y=o_ref, rtol=0.02, atol=0.02).item(), "" + assert paddle.allclose(x=o_paged, y=o_ref, rtol=0.02, atol=0.02).item(), "" def test_debug_print_logits(): - torch.manual_seed(42) - variant_decl = r""" + paddle.seed(seed=42) + variant_decl = """ struct DebugPrintLogits : AttentionVariantBase { static constexpr bool use_softmax = true; @@ -592,7 +574,7 @@ def test_debug_print_logits(): REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { if (logits >= 5) { - printf("Large logits at qo_idx=%d, kv_idx=%d, qo_head_idx=%d, kv_head_idx=%d: %.3f\n", + printf("Large logits at qo_idx=%d, kv_idx=%d, qo_head_idx=%d, kv_head_idx=%d: %.3f\\n", qo_idx, kv_idx, qo_head_idx, kv_head_idx, float(logits)); } return logits; @@ -600,40 +582,45 @@ def test_debug_print_logits(): }; """ jit_module = gen_customize_single_prefill_module( - "fa2", # backend - "batch_prefill_debug_print_logits", # uri - torch.float16, # dtype_q - torch.float16, # dtype_kv - torch.float16, # dtype_o - 128, # hidden_dim_qk - 128, # hidden_dim_vo - [], # additional_tensor_names - [], # additional_tensor_dtypes - ["sm_scale"], # additional_scalar_names - ["double"], # additional_scalar_dtypes + "fa2", + "batch_prefill_debug_print_logits", + "float16", + "float16", + "float16", + 128, + 128, + [], + [], + ["sm_scale"], + ["double"], "DebugPrintLogits", variant_decl, ).build_and_load() - f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module) - - q = torch.randn(128, 32, 128, dtype=torch.float16, device="cuda") - k = torch.randn(1023, 32, 128, dtype=torch.float16, device="cuda") - v = torch.randn(1023, 32, 128, dtype=torch.float16, device="cuda") + q = paddle.randn(shape=[128, 32, 128], dtype="float16") + k = paddle.randn(shape=[1023, 32, 128], dtype="float16") + v = paddle.randn(shape=[1023, 32, 128], dtype="float16") sm_scale = 1.0 / math.sqrt(128) o = f(q, k, v, sm_scale, mask_mode=MaskMode.NON_CAUSAL.value) - - p = torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * sm_scale - o_ref = torch.einsum("hmn,nhd->mhd", torch.softmax(p, dim=-1), v.float()).half() - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + p = ( + paddle.einsum( + "mhd,nhd->hmn", q.astype(dtype="float32"), k.astype(dtype="float32") + ) + * sm_scale + ) + o_ref = paddle.einsum( + "hmn,nhd->mhd", + paddle.nn.functional.softmax(x=p, axis=-1), + v.astype(dtype="float32"), + ).astype(dtype="float16") + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" def test_sm90_debug_print_logits(): - if not is_sm90a_supported(torch.device("cuda")): + if not is_sm90a_supported(device2str("cuda")): pytest.skip("SM90A is not supported") - - torch.manual_seed(42) - variant_decl = r""" + paddle.seed(seed=42) + variant_decl = """ struct DebugPrintLogits : AttentionVariantBase { float sm_scale_log2; int qo_len, kv_len; @@ -664,7 +651,7 @@ def test_sm90_debug_print_logits(): "kv_idx=%-5d " "sm_scale_log2=%-12.5f " "logits=%-12.5f " - "\n", + "\\n", qo_idx, kv_idx, sm_scale_log2, @@ -676,32 +663,38 @@ def test_sm90_debug_print_logits(): }; """ jit_module = gen_customize_single_prefill_module( - "fa3", # backend - "debug_print_logits", # uri - torch.float16, # dtype_q - torch.float16, # dtype_kv - torch.float16, # dtype_o - 128, # hidden_dim_qk - 128, # hidden_dim_vo - [], # additional_tensor_names - [], # additional_tensor_dtypes - ["sm_scale"], # additional_scalar_names - ["double"], # additional_scalar_dtypes + "fa3", + "debug_print_logits", + "float16", + "float16", + "float16", + 128, + 128, + [], + [], + ["sm_scale"], + ["double"], "DebugPrintLogits", variant_decl, ).build_and_load() - f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module) - - q = torch.randn(16, 2, 128, dtype=torch.float16, device="cuda") - k = torch.randn(16, 1, 128, dtype=torch.float16, device="cuda") - v = torch.randn(16, 1, 128, dtype=torch.float16, device="cuda") + q = paddle.randn(shape=[16, 2, 128], dtype="float16") + k = paddle.randn(shape=[16, 1, 128], dtype="float16") + v = paddle.randn(shape=[16, 1, 128], dtype="float16") sm_scale = 1.0 / math.sqrt(128) o = f(q, k, v, sm_scale, mask_mode=MaskMode.NON_CAUSAL.value) - - p = torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * sm_scale - o_ref = torch.einsum("hmn,nhd->mhd", torch.softmax(p, dim=-1), v.float()).half() - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + p = ( + paddle.einsum( + "mhd,nhd->hmn", q.astype(dtype="float32"), k.astype(dtype="float32") + ) + * sm_scale + ) + o_ref = paddle.einsum( + "hmn,nhd->mhd", + paddle.nn.functional.softmax(x=p, axis=-1), + v.astype(dtype="float32"), + ).astype(dtype="float16") + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" if __name__ == "__main__": diff --git a/tests/test_jit_warmup.py b/tests/test_jit_warmup.py index cd5c430664..eca412539d 100644 --- a/tests/test_jit_warmup.py +++ b/tests/test_jit_warmup.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,9 +15,6 @@ See the License for the specific language governing permissions and limitations under the License. """ - -import torch - import flashinfer from flashinfer.utils import PosEncodingMode @@ -29,28 +28,28 @@ def test_warmpup_llama(): flashinfer.quantization.gen_quantization_module(), flashinfer.page.gen_page_module(), flashinfer.decode.gen_batch_decode_module( - torch.float16, - torch.float16, - torch.float16, - torch.int32, - 128, # head_dim_qk - 128, # head_dim_vo + "float16", + "float16", + "float16", + "int32", + 128, + 128, PosEncodingMode.NONE.value, - False, # use_sliding_window - False, # use_logits_soft_cap + False, + False, ), flashinfer.prefill.gen_batch_prefill_module( - "fa2", # backend - torch.float16, - torch.float16, - torch.float16, - torch.int32, - 128, # head_dim_qk - 128, # head_dim_vo + "fa2", + "float16", + "float16", + "float16", + "int32", + 128, + 128, PosEncodingMode.NONE.value, - False, # use_sliding_window - False, # use_logits_soft_cap - False, # use_fp16_qk_reduction + False, + False, + False, ), ], verbose=False, @@ -66,41 +65,41 @@ def test_warmpup_llama_sm90(): flashinfer.quantization.gen_quantization_module(), flashinfer.page.gen_page_module(), flashinfer.decode.gen_batch_decode_module( - torch.float16, - torch.float16, - torch.float16, - torch.int32, - 128, # head_dim_qk - 128, # head_dim_vo + "float16", + "float16", + "float16", + "int32", + 128, + 128, PosEncodingMode.NONE.value, - False, # use_sliding_window - False, # use_logits_soft_cap + False, + False, ), flashinfer.prefill.gen_batch_prefill_module( - "fa2", # backend - torch.float16, - torch.float16, - torch.float16, - torch.int32, - 128, # head_dim_qk - 128, # head_dim_vo + "fa2", + "float16", + "float16", + "float16", + "int32", + 128, + 128, PosEncodingMode.NONE.value, - False, # use_sliding_window - False, # use_logits_soft_cap - False, # use_fp16_qk_reduction + False, + False, + False, ), flashinfer.prefill.gen_batch_prefill_module( - "fa3", # backend - torch.float16, - torch.float16, - torch.float16, - torch.int32, - 128, # head_dim_qk - 128, # head_dim_vo + "fa3", + "float16", + "float16", + "float16", + "int32", + 128, + 128, PosEncodingMode.NONE.value, - False, # use_sliding_window - False, # use_logits_soft_cap - False, # use_fp16_qk_reduction + False, + False, + False, ), ], verbose=False, diff --git a/tests/test_logits_cap.py b/tests/test_logits_cap.py index 8dc723adc7..6668d3bf8c 100644 --- a/tests/test_logits_cap.py +++ b/tests/test_logits_cap.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,12 +15,11 @@ See the License for the specific language governing permissions and limitations under the License. """ - import math import pytest -import torch -from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules +from jit_utils import (gen_decode_attention_modules, + gen_prefill_attention_modules) import flashinfer @@ -27,21 +28,10 @@ def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False, True], # use_logits_soft_caps + ["float16"], ["float16"], [128, 256], [0], [False], [False, True] ) + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False, True], # use_logits_soft_caps - [False], # use_fp16_qk_reductions + ["float16"], ["float16"], [128, 256], [0], [False], [False, True], [False] ), verbose=False, ) @@ -49,31 +39,29 @@ def warmup_jit(): def attention_logits_soft_cap_torch(q, k, v, soft_cap): - q_len, num_heads, head_dim = q.shape - scores = torch.einsum("qhd,khd->qkh", q.float(), k.float()) + q_len, num_heads, head_dim = tuple(q.shape) + scores = paddle.einsum( + "qhd,khd->qkh", q.astype(dtype="float32"), k.astype(dtype="float32") + ) scores *= 1.0 / math.sqrt(head_dim) - scores = soft_cap * torch.tanh(scores / soft_cap) - attn = torch.softmax(scores, dim=1) - return torch.einsum("ovh,vhd->ohd", attn, v.float()).to(q) + scores = soft_cap * paddle.nn.functional.tanh(x=scores / soft_cap) + attn = paddle.nn.functional.softmax(x=scores, axis=1) + return paddle.einsum("ovh,vhd->ohd", attn, v.astype(dtype="float32")).to(q) @pytest.mark.parametrize("seq_len", [1, 9, 81, 729, 33001]) @pytest.mark.parametrize("num_heads", [4, 8, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("soft_cap", [1.0, 30.0, 50.0]) -def test_single_decode_logits_soft_cap( - seq_len, - num_heads, - head_dim, - soft_cap, -): - q = torch.randn(num_heads, head_dim, device="cuda:0", dtype=torch.float16) - k = torch.randn(seq_len, num_heads, head_dim, device="cuda:0", dtype=torch.float16) - v = torch.randn(seq_len, num_heads, head_dim, device="cuda:0", dtype=torch.float16) - +def test_single_decode_logits_soft_cap(seq_len, num_heads, head_dim, soft_cap): + q = paddle.randn(shape=[num_heads, head_dim], dtype="float16") + k = paddle.randn(shape=[seq_len, num_heads, head_dim], dtype="float16") + v = paddle.randn(shape=[seq_len, num_heads, head_dim], dtype="float16") o = flashinfer.single_decode_with_kv_cache(q, k, v, logits_soft_cap=soft_cap) - o_ref = attention_logits_soft_cap_torch(q.unsqueeze(0), k, v, soft_cap).squeeze(0) - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + o_ref = attention_logits_soft_cap_torch( + q.unsqueeze(axis=0), k, v, soft_cap + ).squeeze(axis=0) + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("q_len", [1, 17, 81, 987]) @@ -81,20 +69,13 @@ def test_single_decode_logits_soft_cap( @pytest.mark.parametrize("num_heads", [4, 8, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("soft_cap", [1.0, 30.0, 50.0]) -def test_single_prefill_logits_soft_cap( - q_len, - kv_len, - num_heads, - head_dim, - soft_cap, -): - q = torch.randn(q_len, num_heads, head_dim, device="cuda:0", dtype=torch.float16) - k = torch.randn(kv_len, num_heads, head_dim, device="cuda:0", dtype=torch.float16) - v = torch.randn(kv_len, num_heads, head_dim, device="cuda:0", dtype=torch.float16) - +def test_single_prefill_logits_soft_cap(q_len, kv_len, num_heads, head_dim, soft_cap): + q = paddle.randn(shape=[q_len, num_heads, head_dim], dtype="float16") + k = paddle.randn(shape=[kv_len, num_heads, head_dim], dtype="float16") + v = paddle.randn(shape=[kv_len, num_heads, head_dim], dtype="float16") o = flashinfer.single_prefill_with_kv_cache(q, k, v, logits_soft_cap=soft_cap) o_ref = attention_logits_soft_cap_torch(q, k, v, soft_cap) - torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2) + assert paddle.allclose(x=o, y=o_ref, rtol=0.01, atol=0.01).item(), "" if __name__ == "__main__": diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 35565e90da..969ed13142 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -1,23 +1,15 @@ import numpy as np +import paddle import pytest -import torch import flashinfer -from flashinfer.logits_processor import ( - LogitsPipe, - MinP, - Sample, - Softmax, - Temperature, - TensorType, - TopK, - TopP, -) +from flashinfer.logits_processor import (LogitsPipe, MinP, Sample, Softmax, + Temperature, TensorType, TopK, TopP) def normal_distribution(std): def normal_noise(shape, device): - return torch.randn(shape, device=device) * std + return paddle.randn(shape=shape) * std normal_noise.__name__ = f"normal_distribution(std={std})" return normal_noise @@ -25,21 +17,21 @@ def normal_noise(shape, device): def gumbel_distribution(beta): def gumbel_noise(shape, device): - U = torch.rand(shape, device=device) + U = paddle.rand(shape=shape) eps = 1e-20 - return torch.log(-torch.log(U + eps) + eps) / beta + return paddle.log(x=-paddle.log(x=U + eps) + eps) / beta gumbel_noise.__name__ = f"gumbel_distribution(beta={beta})" return gumbel_noise def set_random_seed(seed=42): - torch.manual_seed(seed) + paddle.seed(seed=seed) np.random.seed(seed) def get_generators(): - gen1 = torch.Generator("cuda:0") + gen1 = paddle.framework.core.default_cpu_generator() gen1.manual_seed(42) gen2 = gen1.clone_state() return gen1, gen2 @@ -52,11 +44,7 @@ class TestLogitsPipeCompilation: @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", - [ - normal_distribution(1), - normal_distribution(5), - gumbel_distribution(0.1), - ], + [normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1)], ) @pytest.mark.parametrize("temperature", [1.0, 0.5, 0.1]) def test_temperature_softmax( @@ -64,522 +52,477 @@ def test_temperature_softmax( ): set_random_seed(42) logits = distribution((batch_size, vocab_size), "cuda:0") - pipe_compiled = LogitsPipe([Temperature(), Softmax()], compile=True) pipe_no_compile = LogitsPipe([Temperature(), Softmax()], compile=False) - probs_compiled = pipe_compiled(logits, temperature=temperature) probs_no_compile = pipe_no_compile(logits, temperature=temperature) - - assert torch.allclose(probs_compiled, probs_no_compile, atol=1e-5) + assert paddle.allclose(x=probs_compiled, y=probs_no_compile, atol=1e-05).item() @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", - [ - normal_distribution(1), - normal_distribution(5), - gumbel_distribution(0.1), - ], + [normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1)], ) @pytest.mark.parametrize("zero_ratio", [0.0, 0.5, 0.9]) def test_probs_sample_freq(self, vocab_size, distribution, zero_ratio): set_random_seed(42) num_trials = 5000000 - logits = distribution((1, vocab_size), "cuda:0") - zero_indices = torch.randperm(vocab_size)[: int(vocab_size * zero_ratio)] + zero_indices = paddle.randperm(n=vocab_size)[: int(vocab_size * zero_ratio)] logits[:, zero_indices] = -float("inf") - probs = torch.softmax(logits, dim=-1) - + probs = paddle.nn.functional.softmax(x=logits, axis=-1) pipe_compiled = LogitsPipe( [Sample()], compile=True, input_type=TensorType.PROBS ) - counter_compiled = torch.zeros(vocab_size, dtype=torch.int32, device="cuda:0") - + counter_compiled = paddle.zeros(shape=vocab_size, dtype="int32") samples_compiled = pipe_compiled( - probs, indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0") + probs, indices=paddle.zeros(shape=num_trials, dtype="int32") ) - counter_compiled.scatter_add_( - 0, samples_compiled.long(), torch.ones_like(samples_compiled) + counter_compiled.put_along_axis_( + axis=0, + indices=samples_compiled.astype(dtype="int64"), + values=paddle.ones_like(x=samples_compiled), + reduce="add", ) - freq_compiled = counter_compiled.float() / num_trials - + freq_compiled = counter_compiled.astype(dtype="float32") / num_trials pipe_no_compile = LogitsPipe( [Sample()], compile=False, input_type=TensorType.PROBS ) - counter_no_compile = torch.zeros(vocab_size, dtype=torch.int32, device="cuda:0") + counter_no_compile = paddle.zeros(shape=vocab_size, dtype="int32") samples_no_compile = pipe_no_compile( - probs, indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0") + probs, indices=paddle.zeros(shape=num_trials, dtype="int32") ) - counter_no_compile.scatter_add_( - 0, samples_no_compile.long(), torch.ones_like(samples_no_compile) + counter_no_compile.put_along_axis_( + axis=0, + indices=samples_no_compile.astype(dtype="int64"), + values=paddle.ones_like(x=samples_no_compile), + reduce="add", ) - freq_no_compile = counter_no_compile.float() / num_trials - - # check if the zero indices are never sampled - assert torch.all(counter_compiled[zero_indices] == 0) and torch.all( - counter_no_compile[zero_indices] == 0 + freq_no_compile = counter_no_compile.astype(dtype="float32") / num_trials + assert paddle.all(x=counter_compiled[zero_indices] == 0) and paddle.all( + x=counter_no_compile[zero_indices] == 0 ) - - # check if sampled results follow given distribution - similarity_compiled = torch.cosine_similarity(freq_compiled, probs) - similarity_no_compile = torch.cosine_similarity(freq_no_compile, probs) - assert similarity_compiled > 0.99, f"Compiled similarity: {similarity_compiled}" - assert similarity_no_compile > 0.99, ( - f"Non-compiled similarity: {similarity_no_compile}" + similarity_compiled = paddle.nn.functional.cosine_similarity( + x1=freq_compiled, x2=probs ) - - # check if compiled and non-compiled results are similar - freq_similarity = torch.cosine_similarity(freq_compiled, freq_no_compile, dim=0) - assert freq_similarity > 0.99, ( - f"Compiled vs non-compiled similarity: {freq_similarity}" + similarity_no_compile = paddle.nn.functional.cosine_similarity( + x1=freq_no_compile, x2=probs ) + assert similarity_compiled > 0.99, f"Compiled similarity: {similarity_compiled}" + assert ( + similarity_no_compile > 0.99 + ), f"Non-compiled similarity: {similarity_no_compile}" + freq_similarity = paddle.nn.functional.cosine_similarity( + x1=freq_compiled, x2=freq_no_compile, axis=0 + ) + assert ( + freq_similarity > 0.99 + ), f"Compiled vs non-compiled similarity: {freq_similarity}" @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", - [ - normal_distribution(1), - normal_distribution(5), - gumbel_distribution(0.1), - ], + [normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1)], ) def test_logits_sample_freq(self, vocab_size, distribution): set_random_seed(42) num_trials = 5000000 - logits = distribution((1, vocab_size), "cuda:0") - probs = torch.softmax(logits, dim=-1) - + probs = paddle.nn.functional.softmax(x=logits, axis=-1) pipe_compiled = LogitsPipe( [Sample()], compile=True, input_type=TensorType.LOGITS ) - counter_compiled = torch.zeros(vocab_size, dtype=torch.int32, device="cuda:0") - + counter_compiled = paddle.zeros(shape=vocab_size, dtype="int32") samples_compiled = pipe_compiled( - logits, indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0") + logits, indices=paddle.zeros(shape=num_trials, dtype="int32") ) - counter_compiled.scatter_add_( - 0, samples_compiled.long(), torch.ones_like(samples_compiled) + counter_compiled.put_along_axis_( + axis=0, + indices=samples_compiled.astype(dtype="int64"), + values=paddle.ones_like(x=samples_compiled), + reduce="add", ) - freq_compiled = counter_compiled.float() / num_trials - + freq_compiled = counter_compiled.astype(dtype="float32") / num_trials pipe_no_compile = LogitsPipe( [Sample()], compile=False, input_type=TensorType.LOGITS ) - counter_no_compile = torch.zeros(vocab_size, dtype=torch.int32, device="cuda:0") + counter_no_compile = paddle.zeros(shape=vocab_size, dtype="int32") samples_no_compile = pipe_no_compile( - logits, indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0") + logits, indices=paddle.zeros(shape=num_trials, dtype="int32") ) - counter_no_compile.scatter_add_( - 0, samples_no_compile.long(), torch.ones_like(samples_no_compile) + counter_no_compile.put_along_axis_( + axis=0, + indices=samples_no_compile.astype(dtype="int64"), + values=paddle.ones_like(x=samples_no_compile), + reduce="add", ) - freq_no_compile = counter_no_compile.float() / num_trials - - # check if sampled results follow given distribution - similarity_compiled = torch.cosine_similarity(freq_compiled, probs) - similarity_no_compile = torch.cosine_similarity(freq_no_compile, probs) - assert similarity_compiled > 0.99, f"Compiled similarity: {similarity_compiled}" - assert similarity_no_compile > 0.99, ( - f"Non-compiled similarity: {similarity_no_compile}" + freq_no_compile = counter_no_compile.astype(dtype="float32") / num_trials + similarity_compiled = paddle.nn.functional.cosine_similarity( + x1=freq_compiled, x2=probs ) - - # check if compiled and non-compiled results are similar - freq_similarity = torch.cosine_similarity(freq_compiled, freq_no_compile, dim=0) - assert freq_similarity > 0.99, ( - f"Compiled vs non-compiled similarity: {freq_similarity}" + similarity_no_compile = paddle.nn.functional.cosine_similarity( + x1=freq_no_compile, x2=probs + ) + assert similarity_compiled > 0.99, f"Compiled similarity: {similarity_compiled}" + assert ( + similarity_no_compile > 0.99 + ), f"Non-compiled similarity: {similarity_no_compile}" + freq_similarity = paddle.nn.functional.cosine_similarity( + x1=freq_compiled, x2=freq_no_compile, axis=0 ) + assert ( + freq_similarity > 0.99 + ), f"Compiled vs non-compiled similarity: {freq_similarity}" @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", - [ - normal_distribution(1), - normal_distribution(5), - gumbel_distribution(0.1), - ], + [normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1)], ) @pytest.mark.parametrize("k", [10, 100, 500]) def test_probs_top_k_sample_freq(self, vocab_size, distribution, k): if k > vocab_size: pytest.skip("k should be less than vocab_size") - set_random_seed(42) num_trials = 5000000 - logits = distribution((1, vocab_size), "cuda:0") - probs = torch.softmax(logits, dim=-1) - - sorted_prob, _ = torch.sort(probs, descending=True) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) + sorted_prob, _ = paddle.sort(x=probs, descending=True), paddle.argsort( + x=probs, descending=True + ) pivot = sorted_prob[:, k - 1] - mask = (probs >= pivot.unsqueeze(-1)).int() + mask = (probs >= pivot.unsqueeze(axis=-1)).astype(dtype="int32") masked_probs = probs.clone() masked_probs[mask == 0] = 0 - pipe_compiled = LogitsPipe( [TopK(), Sample()], compile=True, input_type=TensorType.PROBS ) - counter_compiled = torch.zeros(vocab_size, dtype=torch.int32, device="cuda:0") - + counter_compiled = paddle.zeros(shape=vocab_size, dtype="int32") samples_compiled = pipe_compiled( - probs, - indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0"), - top_k=k, + probs, indices=paddle.zeros(shape=num_trials, dtype="int32"), top_k=k ) - counter_compiled.scatter_add_( - 0, samples_compiled.long(), torch.ones_like(samples_compiled) + counter_compiled.put_along_axis_( + axis=0, + indices=samples_compiled.astype(dtype="int64"), + values=paddle.ones_like(x=samples_compiled), + reduce="add", ) - freq_compiled = counter_compiled.float() / num_trials - + freq_compiled = counter_compiled.astype(dtype="float32") / num_trials pipe_no_compile = LogitsPipe( [TopK(), Sample()], compile=False, input_type=TensorType.PROBS ) - counter_no_compile = torch.zeros(vocab_size, dtype=torch.int32, device="cuda:0") + counter_no_compile = paddle.zeros(shape=vocab_size, dtype="int32") samples_no_compile = pipe_no_compile( - probs, - indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0"), - top_k=k, + probs, indices=paddle.zeros(shape=num_trials, dtype="int32"), top_k=k ) - counter_no_compile.scatter_add_( - 0, samples_no_compile.long(), torch.ones_like(samples_no_compile) + counter_no_compile.put_along_axis_( + axis=0, + indices=samples_no_compile.astype(dtype="int64"), + values=paddle.ones_like(x=samples_no_compile), + reduce="add", ) - freq_no_compile = counter_no_compile.float() / num_trials - - # check if the top-k thresholding is properly applied - assert torch.all(mask[torch.arange(1), samples_compiled] == 1) - assert torch.all(mask[torch.arange(1), samples_no_compile] == 1) - - similarity_compiled = torch.cosine_similarity(freq_compiled, masked_probs) - similarity_no_compile = torch.cosine_similarity(freq_no_compile, masked_probs) - # check if the sampled results follow given distribution - assert similarity_compiled > 0.99, f"Compiled similarity: {similarity_compiled}" - assert similarity_no_compile > 0.99, ( - f"Non-compiled similarity: {similarity_no_compile}" + freq_no_compile = counter_no_compile.astype(dtype="float32") / num_trials + assert paddle.all(x=mask[paddle.arange(end=1), samples_compiled] == 1) + assert paddle.all(x=mask[paddle.arange(end=1), samples_no_compile] == 1) + similarity_compiled = paddle.nn.functional.cosine_similarity( + x1=freq_compiled, x2=masked_probs ) - - # check if compiled and non-compiled results are similar - freq_similarity = torch.cosine_similarity(freq_compiled, freq_no_compile, dim=0) - assert freq_similarity > 0.99, ( - f"Compiled vs non-compiled similarity: {freq_similarity}" + similarity_no_compile = paddle.nn.functional.cosine_similarity( + x1=freq_no_compile, x2=masked_probs + ) + assert similarity_compiled > 0.99, f"Compiled similarity: {similarity_compiled}" + assert ( + similarity_no_compile > 0.99 + ), f"Non-compiled similarity: {similarity_no_compile}" + freq_similarity = paddle.nn.functional.cosine_similarity( + x1=freq_compiled, x2=freq_no_compile, axis=0 ) + assert ( + freq_similarity > 0.99 + ), f"Compiled vs non-compiled similarity: {freq_similarity}" @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", - [ - normal_distribution(1), - normal_distribution(5), - gumbel_distribution(0.1), - ], + [normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1)], ) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) def test_probs_top_p_sample_freq(self, vocab_size, distribution, p): set_random_seed(42) num_trials = 5000000 - eps = 1e-4 - + eps = 0.0001 logits = distribution((1, vocab_size), "cuda:0") - probs = torch.softmax(logits, dim=-1) - - sorted_prob, indices = torch.sort(probs, descending=False) - cdf = torch.cumsum(sorted_prob, dim=-1) - mask = torch.zeros(1, vocab_size, dtype=torch.int32, device="cuda:0") - mask.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) + sorted_prob, indices = paddle.sort(x=probs, descending=False), paddle.argsort( + x=probs, descending=False + ) + cdf = paddle.cumsum(x=sorted_prob, axis=-1) + mask = paddle.zeros(shape=[1, vocab_size], dtype="int32") + mask.put_along_axis_( + axis=1, + indices=indices, + values=(cdf > 1 - p - eps).astype(dtype="int32"), + reduce="add", + ) masked_probs = probs.clone() masked_probs[mask == 0] = 0 - - pipe_compiled = LogitsPipe( - [TopP(), Sample()], - compile=True, - ) - counter_compiled = torch.zeros(vocab_size, dtype=torch.int32, device="cuda:0") + pipe_compiled = LogitsPipe([TopP(), Sample()], compile=True) + counter_compiled = paddle.zeros(shape=vocab_size, dtype="int32") samples_compiled = pipe_compiled( - probs, - indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0"), - top_p=p, + probs, indices=paddle.zeros(shape=num_trials, dtype="int32"), top_p=p ) - counter_compiled.scatter_add_( - 0, samples_compiled.long(), torch.ones_like(samples_compiled) + counter_compiled.put_along_axis_( + axis=0, + indices=samples_compiled.astype(dtype="int64"), + values=paddle.ones_like(x=samples_compiled), + reduce="add", ) - freq_compiled = counter_compiled.float() / num_trials - + freq_compiled = counter_compiled.astype(dtype="float32") / num_trials pipe_no_compile = LogitsPipe( [TopP(), Sample()], compile=False, input_type=TensorType.PROBS ) - counter_no_compile = torch.zeros(vocab_size, dtype=torch.int32, device="cuda:0") + counter_no_compile = paddle.zeros(shape=vocab_size, dtype="int32") samples_no_compile = pipe_no_compile( - probs, - indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0"), - top_p=p, + probs, indices=paddle.zeros(shape=num_trials, dtype="int32"), top_p=p ) - counter_no_compile.scatter_add_( - 0, samples_no_compile.long(), torch.ones_like(samples_no_compile) + counter_no_compile.put_along_axis_( + axis=0, + indices=samples_no_compile.astype(dtype="int64"), + values=paddle.ones_like(x=samples_no_compile), + reduce="add", ) - freq_no_compile = counter_no_compile.float() / num_trials - - # check if the top-p thresholding is properly applied - assert torch.all(mask[torch.arange(1), samples_compiled] == 1) - assert torch.all(mask[torch.arange(1), samples_no_compile] == 1) - - # check if the sampled results follow given distribution - similarity_compiled = torch.cosine_similarity(freq_compiled, masked_probs) - similarity_no_compile = torch.cosine_similarity(freq_no_compile, masked_probs) - assert similarity_compiled > 0.99, f"Compiled similarity: {similarity_compiled}" - assert similarity_no_compile > 0.99, ( - f"Non-compiled similarity: {similarity_no_compile}" + freq_no_compile = counter_no_compile.astype(dtype="float32") / num_trials + assert paddle.all(x=mask[paddle.arange(end=1), samples_compiled] == 1) + assert paddle.all(x=mask[paddle.arange(end=1), samples_no_compile] == 1) + similarity_compiled = paddle.nn.functional.cosine_similarity( + x1=freq_compiled, x2=masked_probs ) - - # check if compiled and non-compiled results are similar - freq_similarity = torch.cosine_similarity(freq_compiled, freq_no_compile, dim=0) - assert freq_similarity > 0.99, ( - f"Compiled vs non-compiled similarity: {freq_similarity}" + similarity_no_compile = paddle.nn.functional.cosine_similarity( + x1=freq_no_compile, x2=masked_probs ) + assert similarity_compiled > 0.99, f"Compiled similarity: {similarity_compiled}" + assert ( + similarity_no_compile > 0.99 + ), f"Non-compiled similarity: {similarity_no_compile}" + freq_similarity = paddle.nn.functional.cosine_similarity( + x1=freq_compiled, x2=freq_no_compile, axis=0 + ) + assert ( + freq_similarity > 0.99 + ), f"Compiled vs non-compiled similarity: {freq_similarity}" @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", - [ - normal_distribution(1), - normal_distribution(5), - gumbel_distribution(0.1), - ], + [normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1)], ) @pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1]) def test_probs_min_p_sample_freq(self, vocab_size, distribution, p): set_random_seed(42) num_trials = 5000000 - logits = distribution((1, vocab_size), "cuda:0") - probs = torch.softmax(logits, dim=-1) - - sorted_prob, indices = torch.sort(probs, descending=False) - top_probs = sorted_prob[:, -1].unsqueeze(-1) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) + sorted_prob, indices = paddle.sort(x=probs, descending=False), paddle.argsort( + x=probs, descending=False + ) + top_probs = sorted_prob[:, -1].unsqueeze(axis=-1) scaled_p = p * top_probs - - mask = torch.zeros(1, vocab_size, dtype=torch.int32, device="cuda:0") - mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int()) + mask = paddle.zeros(shape=[1, vocab_size], dtype="int32") + mask.put_along_axis_( + axis=1, + indices=indices, + values=(sorted_prob >= scaled_p).astype(dtype="int32"), + reduce="add", + ) masked_probs = probs.clone() masked_probs[mask == 0] = 0 - - pipe_compiled = LogitsPipe( - [MinP(), Sample()], - compile=True, - ) - counter_compiled = torch.zeros(vocab_size, dtype=torch.int32, device="cuda:0") + pipe_compiled = LogitsPipe([MinP(), Sample()], compile=True) + counter_compiled = paddle.zeros(shape=vocab_size, dtype="int32") samples_compiled = pipe_compiled( - probs, - indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0"), - min_p=p, - ) - counter_compiled.scatter_add_( - 0, samples_compiled.long(), torch.ones_like(samples_compiled) + probs, indices=paddle.zeros(shape=num_trials, dtype="int32"), min_p=p ) - freq_compiled = counter_compiled.float() / num_trials - - pipe_no_compile = LogitsPipe( - [MinP(), Sample()], - compile=False, + counter_compiled.put_along_axis_( + axis=0, + indices=samples_compiled.astype(dtype="int64"), + values=paddle.ones_like(x=samples_compiled), + reduce="add", ) - counter_no_compile = torch.zeros(vocab_size, dtype=torch.int32, device="cuda:0") + freq_compiled = counter_compiled.astype(dtype="float32") / num_trials + pipe_no_compile = LogitsPipe([MinP(), Sample()], compile=False) + counter_no_compile = paddle.zeros(shape=vocab_size, dtype="int32") samples_no_compile = pipe_no_compile( - probs, - indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0"), - min_p=p, + probs, indices=paddle.zeros(shape=num_trials, dtype="int32"), min_p=p ) - counter_no_compile.scatter_add_( - 0, samples_no_compile.long(), torch.ones_like(samples_no_compile) + counter_no_compile.put_along_axis_( + axis=0, + indices=samples_no_compile.astype(dtype="int64"), + values=paddle.ones_like(x=samples_no_compile), + reduce="add", ) - freq_no_compile = counter_no_compile.float() / num_trials - - # check if the min-p thresholding is properly applied - assert torch.all(mask[torch.arange(1), samples_compiled] == 1) - assert torch.all(mask[torch.arange(1), samples_no_compile] == 1) - - # check if the sampled results follow given distribution - similarity_compiled = torch.cosine_similarity(freq_compiled, masked_probs) - similarity_no_compile = torch.cosine_similarity(freq_no_compile, masked_probs) - assert similarity_compiled > 0.99, f"Compiled similarity: {similarity_compiled}" - assert similarity_no_compile > 0.99, ( - f"Non-compiled similarity: {similarity_no_compile}" + freq_no_compile = counter_no_compile.astype(dtype="float32") / num_trials + assert paddle.all(x=mask[paddle.arange(end=1), samples_compiled] == 1) + assert paddle.all(x=mask[paddle.arange(end=1), samples_no_compile] == 1) + similarity_compiled = paddle.nn.functional.cosine_similarity( + x1=freq_compiled, x2=masked_probs ) - - # check if compiled and non-compiled results are similar - freq_similarity = torch.cosine_similarity(freq_compiled, freq_no_compile, dim=0) - assert freq_similarity > 0.99, ( - f"Compiled vs non-compiled similarity: {freq_similarity}" + similarity_no_compile = paddle.nn.functional.cosine_similarity( + x1=freq_no_compile, x2=masked_probs ) + assert similarity_compiled > 0.99, f"Compiled similarity: {similarity_compiled}" + assert ( + similarity_no_compile > 0.99 + ), f"Non-compiled similarity: {similarity_no_compile}" + freq_similarity = paddle.nn.functional.cosine_similarity( + x1=freq_compiled, x2=freq_no_compile, axis=0 + ) + assert ( + freq_similarity > 0.99 + ), f"Compiled vs non-compiled similarity: {freq_similarity}" @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", - [ - normal_distribution(1), - normal_distribution(5), - gumbel_distribution(0.1), - ], + [normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1)], ) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_probs_top_k_top_p_joint_sample_freq(self, vocab_size, distribution, p): set_random_seed(42) num_trials = 5000000 - eps = 1e-4 - + eps = 0.0001 if p == 0.1: k = int(vocab_size * 0.5) elif p == 0.5: k = int(vocab_size * 0.1) else: raise ValueError("p not recognized") - logits = distribution((1, vocab_size), "cuda:0") - probs = torch.softmax(logits, dim=-1) - - sorted_prob_asc, idx_asc = torch.sort(probs, descending=False) - cdf = torch.cumsum(sorted_prob_asc, dim=-1) - mask_top_p = torch.zeros(1, vocab_size, dtype=torch.int32, device="cuda:0") - mask_top_p.scatter_add_(1, idx_asc, (cdf > (1 - p) - eps).int()) - - sorted_prob_desc, _ = torch.sort(probs, descending=True) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) + sorted_prob_asc, idx_asc = paddle.sort( + x=probs, descending=False + ), paddle.argsort(x=probs, descending=False) + cdf = paddle.cumsum(x=sorted_prob_asc, axis=-1) + mask_top_p = paddle.zeros(shape=[1, vocab_size], dtype="int32") + mask_top_p.put_along_axis_( + axis=1, + indices=idx_asc, + values=(cdf > 1 - p - eps).astype(dtype="int32"), + reduce="add", + ) + sorted_prob_desc, _ = paddle.sort(x=probs, descending=True), paddle.argsort( + x=probs, descending=True + ) pivot = sorted_prob_desc[:, k - 1] - mask_top_k = (probs >= pivot.unsqueeze(-1)).int() - - mask = torch.minimum(mask_top_k, mask_top_p) + mask_top_k = (probs >= pivot.unsqueeze(axis=-1)).astype(dtype="int32") + mask = paddle.minimum(x=mask_top_k, y=mask_top_p) masked_probs = probs.clone() masked_probs[mask == 0] = 0 - pipe_compiled = LogitsPipe( - [ - TopK(joint_topk_topp=True), - TopP(), - Sample(), - ], + [TopK(joint_topk_topp=True), TopP(), Sample()], compile=True, input_type=TensorType.PROBS, ) - counter_compiled = torch.zeros(vocab_size, dtype=torch.int32, device="cuda:0") + counter_compiled = paddle.zeros(shape=vocab_size, dtype="int32") samples_compiled = pipe_compiled( probs, - indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0"), + indices=paddle.zeros(shape=num_trials, dtype="int32"), top_k=k, top_p=p, ) - counter_compiled.scatter_add_( - 0, samples_compiled.long(), torch.ones_like(samples_compiled) + counter_compiled.put_along_axis_( + axis=0, + indices=samples_compiled.astype(dtype="int64"), + values=paddle.ones_like(x=samples_compiled), + reduce="add", ) - freq_compiled = counter_compiled.float() / num_trials - + freq_compiled = counter_compiled.astype(dtype="float32") / num_trials pipe_no_compile = LogitsPipe( - [ - TopK(), - TopP(), - Sample(), - ], - compile=False, - input_type=TensorType.PROBS, + [TopK(), TopP(), Sample()], compile=False, input_type=TensorType.PROBS ) samples_no_compile = pipe_no_compile( probs, - indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0"), + indices=paddle.zeros(shape=num_trials, dtype="int32"), top_k=k, top_p=p, ) - - # check if the top-k-top-p thresholding is properly applied - assert torch.all(mask[torch.arange(1), samples_compiled] == 1) - assert torch.all(mask[torch.arange(1), samples_no_compile] == 1) - - # check if the sampled results follow given distribution - # we don't check the non-compiled results because joint topk-topp yeilds different results from topk then topp - # same for the compile-non-compile similarity as well - similarity_compiled = torch.cosine_similarity(freq_compiled, masked_probs) + assert paddle.all(x=mask[paddle.arange(end=1), samples_compiled] == 1) + assert paddle.all(x=mask[paddle.arange(end=1), samples_no_compile] == 1) + similarity_compiled = paddle.nn.functional.cosine_similarity( + x1=freq_compiled, x2=masked_probs + ) assert similarity_compiled > 0.99, f"Compiled similarity: {similarity_compiled}" @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", - [ - normal_distribution(1), - normal_distribution(5), - gumbel_distribution(0.1), - ], + [normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1)], ) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_logits_top_k_top_p_joint_sample_freq(self, vocab_size, distribution, p): set_random_seed(42) num_trials = 5000000 - eps = 1e-4 - + eps = 0.0001 if p == 0.1: k = int(vocab_size * 0.5) elif p == 0.5: k = int(vocab_size * 0.1) else: raise ValueError("p not recognized") - logits = distribution((1, vocab_size), "cuda:0") - probs = torch.softmax(logits, dim=-1) - - sorted_prob_asc, idx_asc = torch.sort(probs, descending=False) - cdf = torch.cumsum(sorted_prob_asc, dim=-1) - mask_top_p = torch.zeros(1, vocab_size, dtype=torch.int32, device="cuda:0") - mask_top_p.scatter_add_(1, idx_asc, (cdf > (1 - p) - eps).int()) - - sorted_prob_desc, _ = torch.sort(probs, descending=True) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) + sorted_prob_asc, idx_asc = paddle.sort( + x=probs, descending=False + ), paddle.argsort(x=probs, descending=False) + cdf = paddle.cumsum(x=sorted_prob_asc, axis=-1) + mask_top_p = paddle.zeros(shape=[1, vocab_size], dtype="int32") + mask_top_p.put_along_axis_( + axis=1, + indices=idx_asc, + values=(cdf > 1 - p - eps).astype(dtype="int32"), + reduce="add", + ) + sorted_prob_desc, _ = paddle.sort(x=probs, descending=True), paddle.argsort( + x=probs, descending=True + ) pivot = sorted_prob_desc[:, k - 1] - mask_top_k = (probs >= pivot.unsqueeze(-1)).int() - - mask = torch.minimum(mask_top_k, mask_top_p) + mask_top_k = (probs >= pivot.unsqueeze(axis=-1)).astype(dtype="int32") + mask = paddle.minimum(x=mask_top_k, y=mask_top_p) masked_probs = probs.clone() masked_probs[mask == 0] = 0 - pipe_compiled = LogitsPipe( - [ - Softmax(), - TopK(joint_topk_topp=True), - TopP(), - Sample(), - ], + [Softmax(), TopK(joint_topk_topp=True), TopP(), Sample()], compile=True, input_type=TensorType.LOGITS, ) - counter_compiled = torch.zeros(vocab_size, dtype=torch.int32, device="cuda:0") + counter_compiled = paddle.zeros(shape=vocab_size, dtype="int32") samples_compiled = pipe_compiled( logits, - indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0"), + indices=paddle.zeros(shape=num_trials, dtype="int32"), top_k=k, top_p=p, ) - counter_compiled.scatter_add_( - 0, samples_compiled.long(), torch.ones_like(samples_compiled) + counter_compiled.put_along_axis_( + axis=0, + indices=samples_compiled.astype(dtype="int64"), + values=paddle.ones_like(x=samples_compiled), + reduce="add", ) - freq_compiled = counter_compiled.float() / num_trials - + freq_compiled = counter_compiled.astype(dtype="float32") / num_trials pipe_no_compile = LogitsPipe( - [ - Softmax(), - TopK(), - TopP(), - Sample(), - ], + [Softmax(), TopK(), TopP(), Sample()], compile=False, input_type=TensorType.LOGITS, ) samples_no_compile = pipe_no_compile( logits, - indices=torch.zeros(num_trials, dtype=torch.int32, device="cuda:0"), + indices=paddle.zeros(shape=num_trials, dtype="int32"), top_k=k, top_p=p, ) - - # check if the top-k-top-p thresholding is properly applied - assert torch.all(mask[torch.arange(1), samples_compiled] == 1) - assert torch.all(mask[torch.arange(1), samples_no_compile] == 1) - - # check if the sampled results follow given distribution - # we don't check the non-compiled results because joint topk-topp yeilds different results from topk then topp - # same for the compile-non-compile similarity as well - similarity_compiled = torch.cosine_similarity(freq_compiled, masked_probs) + assert paddle.all(x=mask[paddle.arange(end=1), samples_compiled] == 1) + assert paddle.all(x=mask[paddle.arange(end=1), samples_no_compile] == 1) + similarity_compiled = paddle.nn.functional.cosine_similarity( + x1=freq_compiled, x2=masked_probs + ) assert similarity_compiled > 0.99, f"Compiled similarity: {similarity_compiled}" @@ -594,52 +537,39 @@ def test_temperature_softmax( self, batch_size, vocab_size, temperature, temperature_arr ): set_random_seed(42) - - logits = torch.randn(batch_size, vocab_size, device="cuda:0") - + logits = paddle.randn(shape=[batch_size, vocab_size]) if temperature_arr: - temperature = torch.rand(batch_size, device="cuda:0") - + temperature = paddle.rand(shape=batch_size) samples_direct = flashinfer.sampling.softmax( logits=logits, temperature=temperature ) - pipe = LogitsPipe([Temperature(), Softmax()]) samples_pipe = pipe(logits, temperature=temperature) - - assert torch.allclose(samples_pipe, samples_direct, atol=1e-5) + assert paddle.allclose(x=samples_pipe, y=samples_direct, atol=1e-05).item() @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) def test_topp(self, batch_size, vocab_size, p): set_random_seed(42) - - probs = torch.rand(batch_size, vocab_size, device="cuda:0") - probs = probs / probs.sum(dim=-1, keepdim=True) - + probs = paddle.rand(shape=[batch_size, vocab_size]) + probs = probs / probs.sum(axis=-1, keepdim=True) samples_direct = flashinfer.sampling.top_p_renorm_probs(probs, p) - pipe = LogitsPipe([TopP()]) samples_pipe = pipe(probs, top_p=p) - - assert torch.all(samples_pipe == samples_direct) + assert paddle.all(x=samples_pipe == samples_direct) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("k", [10, 100, 500]) def test_probs_topk(self, batch_size, vocab_size, k): set_random_seed(42) - - probs = torch.rand(batch_size, vocab_size, device="cuda:0") - probs = probs / probs.sum(dim=-1, keepdim=True) - + probs = paddle.rand(shape=[batch_size, vocab_size]) + probs = probs / probs.sum(axis=-1, keepdim=True) samples_direct = flashinfer.sampling.top_k_renorm_probs(probs, k) - pipe = LogitsPipe([TopK()], input_type=TensorType.PROBS) samples_pipe = pipe(probs, top_k=k) - - assert torch.all(samples_pipe == samples_direct) + assert paddle.all(x=samples_pipe == samples_direct) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @@ -648,57 +578,43 @@ def test_probs_topk(self, batch_size, vocab_size, k): def test_logits_topk(self, batch_size, vocab_size, k, neginf_input): if k > vocab_size: pytest.skip("k should be less than vocab_size") - set_random_seed(42) - - logits = torch.randn(batch_size, vocab_size, device="cuda:0") - + logits = paddle.randn(shape=[batch_size, vocab_size]) if neginf_input: - num_neginf = torch.randint(1, vocab_size * batch_size, (1,)).item() - idxs = torch.randperm(batch_size * vocab_size, device="cuda:0")[:num_neginf] + num_neginf = paddle.randint( + low=1, high=vocab_size * batch_size, shape=(1,) + ).item() + idxs = paddle.randperm(n=batch_size * vocab_size)[:num_neginf] logits[idxs // vocab_size, idxs % vocab_size] = -float("inf") - samples_direct = flashinfer.sampling.top_k_mask_logits(logits, k) - pipe = LogitsPipe([TopK()], input_type=TensorType.LOGITS) samples_pipe = pipe(logits, top_k=k) - - assert torch.all(samples_pipe == samples_direct) + assert paddle.all(x=samples_pipe == samples_direct) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) def test_probs_sample(self, batch_size, vocab_size): set_random_seed(42) - - probs = torch.rand(batch_size, vocab_size, device="cuda:0") - probs = probs / probs.sum(dim=-1, keepdim=True) - + probs = paddle.rand(shape=[batch_size, vocab_size]) + probs = probs / probs.sum(axis=-1, keepdim=True) gen1, gen2 = get_generators() - samples_direct = flashinfer.sampling.sampling_from_probs(probs, generator=gen1) - pipe = LogitsPipe([Sample()], input_type=TensorType.PROBS) samples_pipe = pipe(probs, generator=gen2) - - assert torch.all(samples_pipe == samples_direct) + assert paddle.all(x=samples_pipe == samples_direct) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) def test_logits_sample(self, batch_size, vocab_size): set_random_seed(42) - - logits = torch.randn(batch_size, vocab_size, device="cuda:0") - + logits = paddle.randn(shape=[batch_size, vocab_size]) gen1, gen2 = get_generators() - samples_direct = flashinfer.sampling.sampling_from_logits( logits, generator=gen1 ) - pipe = LogitsPipe([Sample()], input_type=TensorType.LOGITS) samples_pipe = pipe(logits, generator=gen2) - - assert torch.all(samples_pipe == samples_direct) + assert paddle.all(x=samples_pipe == samples_direct) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @@ -706,178 +622,136 @@ def test_logits_sample(self, batch_size, vocab_size): def test_probs_topk_sample(self, batch_size, vocab_size, k): if k > vocab_size: pytest.skip("k should be less than vocab_size") - set_random_seed(42) - - probs = torch.rand(batch_size, vocab_size, device="cuda:0") - probs = probs / probs.sum(dim=-1, keepdim=True) - + probs = paddle.rand(shape=[batch_size, vocab_size]) + probs = probs / probs.sum(axis=-1, keepdim=True) gen1, gen2 = get_generators() - samples_direct = flashinfer.sampling.top_k_sampling_from_probs( probs, k, generator=gen1 ) - pipe = LogitsPipe([TopK(), Sample()], input_type=TensorType.PROBS) samples_pipe = pipe(probs, top_k=k, generator=gen2) - - assert torch.all(samples_pipe == samples_direct) + assert paddle.all(x=samples_pipe == samples_direct) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) def test_probs_topp_sample(self, batch_size, vocab_size, p): set_random_seed(42) - - probs = torch.rand(batch_size, vocab_size, device="cuda:0") - probs = probs / probs.sum(dim=-1, keepdim=True) - + probs = paddle.rand(shape=[batch_size, vocab_size]) + probs = probs / probs.sum(axis=-1, keepdim=True) gen1, gen2 = get_generators() - samples_direct = flashinfer.sampling.top_p_sampling_from_probs( probs, p, generator=gen1 ) - pipe = LogitsPipe([TopP(), Sample()]) samples_pipe = pipe(probs, top_p=p, generator=gen2) - - assert torch.all(samples_pipe == samples_direct) + assert paddle.all(x=samples_pipe == samples_direct) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1]) def test_probs_minp_sample(self, batch_size, vocab_size, p): set_random_seed(42) - - probs = torch.rand(batch_size, vocab_size, device="cuda:0") - probs = probs / probs.sum(dim=-1, keepdim=True) - + probs = paddle.rand(shape=[batch_size, vocab_size]) + probs = probs / probs.sum(axis=-1, keepdim=True) gen1, gen2 = get_generators() - samples_direct = flashinfer.sampling.min_p_sampling_from_probs( probs, p, generator=gen1 ) - pipe = LogitsPipe([MinP(), Sample()]) samples_pipe = pipe(probs, min_p=p, generator=gen2) - - assert torch.all(samples_pipe == samples_direct) + assert paddle.all(x=samples_pipe == samples_direct) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_joint_probs_topk_topp_sample(self, batch_size, vocab_size, p): set_random_seed(42) - if p == 0.1: k = int(vocab_size * 0.5) elif p == 0.5: k = int(vocab_size * 0.1) else: raise ValueError("p not recognized") - - probs = torch.rand(batch_size, vocab_size, device="cuda:0") - probs = probs / probs.sum(dim=-1, keepdim=True) - + probs = paddle.rand(shape=[batch_size, vocab_size]) + probs = probs / probs.sum(axis=-1, keepdim=True) gen1, gen2 = get_generators() - samples_direct = flashinfer.sampling.top_k_top_p_sampling_from_probs( probs, k, p, filter_apply_order="joint", generator=gen1 ) - pipe = LogitsPipe( [TopK(joint_topk_topp=True), TopP(), Sample()], input_type=TensorType.PROBS ) - samples_pipe = pipe(probs, top_k=k, top_p=p, generator=gen2) - - assert torch.all(samples_pipe == samples_direct) + assert paddle.all(x=samples_pipe == samples_direct) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_sequential_probs_topk_topp_sample(self, batch_size, vocab_size, p): set_random_seed(42) - if p == 0.1: k = int(vocab_size * 0.5) elif p == 0.5: k = int(vocab_size * 0.1) else: raise ValueError("p not recognized") - - probs = torch.rand(batch_size, vocab_size, device="cuda:0") - probs = probs / probs.sum(dim=-1, keepdim=True) - + probs = paddle.rand(shape=[batch_size, vocab_size]) + probs = probs / probs.sum(axis=-1, keepdim=True) gen1, gen2 = get_generators() - samples_direct = flashinfer.sampling.top_k_top_p_sampling_from_probs( probs, k, p, filter_apply_order="top_k_first", generator=gen1 ) - pipe = LogitsPipe([TopK(), TopP(), Sample()], input_type=TensorType.PROBS) samples_pipe = pipe(probs, top_k=k, top_p=p, generator=gen2) - - assert torch.all(samples_pipe == samples_direct) + assert paddle.all(x=samples_pipe == samples_direct) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_joint_logits_topk_topp_sample(self, batch_size, vocab_size, p): set_random_seed(42) - if p == 0.1: k = int(vocab_size * 0.5) elif p == 0.5: k = int(vocab_size * 0.1) else: raise ValueError("p not recognized") - - logits = torch.randn(batch_size, vocab_size, device="cuda:0") - + logits = paddle.randn(shape=[batch_size, vocab_size]) gen1, gen2 = get_generators() - samples_direct = flashinfer.sampling.top_k_top_p_sampling_from_logits( logits, k, p, filter_apply_order="joint", generator=gen1 ) - pipe = LogitsPipe( [Softmax(), TopK(joint_topk_topp=True), TopP(), Sample()], input_type=TensorType.LOGITS, ) samples_pipe = pipe(logits, top_k=k, top_p=p, generator=gen2) - - assert torch.all(samples_pipe == samples_direct) + assert paddle.all(x=samples_pipe == samples_direct) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_sequential_logits_topk_topp_sample(self, batch_size, vocab_size, p): set_random_seed(42) - if p == 0.1: k = int(vocab_size * 0.5) elif p == 0.5: k = int(vocab_size * 0.1) else: raise ValueError("p not recognized") - - logits = torch.randn(batch_size, vocab_size, device="cuda:0") - + logits = paddle.randn(shape=[batch_size, vocab_size]) gen1, gen2 = get_generators() - samples_direct = flashinfer.sampling.top_k_top_p_sampling_from_logits( logits, k, p, filter_apply_order="top_k_first", generator=gen1 ) - topk_mask_pipe = LogitsPipe([TopK()], input_type=TensorType.LOGITS) topp_pipe = LogitsPipe([Softmax(), TopP(), Sample()]) - samples_pipe = topp_pipe( topk_mask_pipe(logits, top_k=k), top_p=p, generator=gen2 ) - - assert torch.all(samples_pipe == samples_direct) + assert paddle.all(x=samples_pipe == samples_direct) if __name__ == "__main__": diff --git a/tests/test_mla_decode_kernel.py b/tests/test_mla_decode_kernel.py index 40a61d82a3..2483abe63e 100644 --- a/tests/test_mla_decode_kernel.py +++ b/tests/test_mla_decode_kernel.py @@ -1,13 +1,15 @@ +import sys + + from typing import Optional, Tuple -import torch -import torch.nn.functional as F -from torch import nn +import paddle +from flashinfer.paddle_utils import * import flashinfer -def wmape(target: torch.Tensor, preds: torch.Tensor): +def wmape(target: paddle.Tensor, preds: paddle.Tensor): sum_abs_error = (preds - target).abs().sum().detach().item() sum_scale = target.abs().sum().detach().item() return sum_abs_error / sum_scale @@ -16,297 +18,258 @@ def wmape(target: torch.Tensor, preds: torch.Tensor): from rope_reference import * -class DeepseekV2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): +class DeepseekV2RMSNorm(paddle.nn.Layer): + def __init__(self, hidden_size, eps=1e-06): """ DeepseekV2RMSNorm is equivalent to T5LayerNorm """ super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) + self.weight = paddle.base.framework.EagerParamBase.from_tensor( + tensor=paddle.ones(shape=hidden_size) + ) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states.to("float32") + variance = hidden_states.pow(y=2).mean(axis=-1, keepdim=True) + hidden_states = hidden_states * paddle.rsqrt(x=variance + self.variance_epsilon) return (self.weight * hidden_states).to(input_dtype) -class DeepseekV2AttentionVanilla(nn.Module): +class DeepseekV2AttentionVanilla(paddle.nn.Layer): def __init__(self): super().__init__() - self.hidden_size = 5120 self.num_heads = 128 - self.q_lora_rank = 1536 self.qk_rope_head_dim = 64 self.kv_lora_rank = 512 self.v_head_dim = 128 self.qk_nope_head_dim = 128 - self.q_head_dim = 192 # 192 = config.qk_nope_head_dim + config.qk_rope_head_dim - + self.q_head_dim = 192 self.rope_theta = 10000 - - # W^DQ ~ [5120, 1536] - self.q_a_proj = nn.Linear( - self.hidden_size, - self.q_lora_rank, - bias=False, + self.q_a_proj = paddle.nn.Linear( + in_features=self.hidden_size, out_features=self.q_lora_rank, bias_attr=False ) - torch.nn.init.normal_(self.q_a_proj.weight) - + init_Normal = paddle.nn.initializer.Normal() + init_Normal(self.q_a_proj.weight) self.q_a_layernorm = DeepseekV2RMSNorm(self.q_lora_rank) - - # W^UQ & W^QR = [1536, 128*(128+64)] - self.q_b_proj = nn.Linear( - self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + self.q_b_proj = paddle.nn.Linear( + in_features=self.q_lora_rank, + out_features=self.num_heads * self.q_head_dim, + bias_attr=False, ) - torch.nn.init.normal_(self.q_b_proj.weight) - - # We don't need these modules, since we already have cached k_pe and compressed_kv tensor. - # self.kv_a_proj_with_mqa = nn.Linear( # [,5120]-->[, 512+64] W^DKV & W^KR = [5120, 512+64] - # self.hidden_size, - # self.kv_lora_rank + self.qk_rope_head_dim, - # bias=False, - # ) - # self.kv_a_layernorm = DeepseekV2RMSNorm(self.kv_lora_rank) - - # W^UK & W^UV ~ [512, 128*(128+128)] - self.kv_b_proj = nn.Linear( - self.kv_lora_rank, - self.num_heads + init_Normal = paddle.nn.initializer.Normal() + init_Normal(self.q_b_proj.weight) + self.kv_b_proj = paddle.nn.Linear( + in_features=self.kv_lora_rank, + out_features=self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), - bias=False, + bias_attr=False, ) - torch.nn.init.normal_(self.kv_b_proj.weight) - - # W^O ~ [128*128, 5120] - self.o_proj = nn.Linear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, + init_Normal = paddle.nn.initializer.Normal() + init_Normal(self.kv_b_proj.weight) + self.o_proj = paddle.nn.Linear( + in_features=self.num_heads * self.v_head_dim, + out_features=self.hidden_size, + bias_attr=False, ) - torch.nn.init.normal_(self.o_proj.weight) - - self.softmax_scale = self.q_head_dim ** (-0.5) + init_Normal = paddle.nn.initializer.Normal() + init_Normal(self.o_proj.weight) + self.softmax_scale = self.q_head_dim**-0.5 def run_decode( self, - hidden_states: torch.Tensor, - compressed_kv_normed_cache: torch.Tensor, - k_pe_cache: torch.Tensor, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + hidden_states: paddle.Tensor, + compressed_kv_normed_cache: paddle.Tensor, + k_pe_cache: paddle.Tensor, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + bsz, q_len, _ = tuple(hidden_states.shape) if q_len != 1: raise ValueError( - f"Only support decode, but got hidden_states[{hidden_states.size()}]" + f"Only support decode, but got hidden_states[{tuple(hidden_states.shape)}]" ) - - ckv_bsz, kv_len, ckv_dim = compressed_kv_normed_cache.size() + ckv_bsz, kv_len, ckv_dim = tuple(compressed_kv_normed_cache.shape) if ckv_bsz != bsz or ckv_dim != self.kv_lora_rank: raise ValueError( - f"Unexpected shape: compressed_kv_normed_cache[{compressed_kv_normed_cache.size()}]" + f"Unexpected shape: compressed_kv_normed_cache[{tuple(compressed_kv_normed_cache.shape)}]" ) - - kpe_bsz, kpe_len, kpe_dim = k_pe_cache.size() + kpe_bsz, kpe_len, kpe_dim = tuple(k_pe_cache.shape) if kpe_bsz != bsz or kpe_dim != self.qk_rope_head_dim or kv_len != kpe_len: - raise ValueError(f"Unexpected shape: k_pe_cache[{k_pe_cache.size()}]") - + raise ValueError(f"Unexpected shape: k_pe_cache[{tuple(k_pe_cache.shape)}]") q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - # q_nope ~ [bsz, q_len, 128] q_pe ~ [bsz, q_len, 64] - q_nope, q_pe = torch.split( - q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose( + perm=dim2perm( + q.view(bsz, q_len, self.num_heads, self.q_head_dim).ndim, 1, 2 + ) + ) + q_nope, q_pe = paddle_split( + x=q, num_or_sections=[self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1 + ) + k_pe = k_pe_cache.view(bsz, kv_len, 1, self.qk_rope_head_dim).transpose( + perm=dim2perm( + k_pe_cache.view(bsz, kv_len, 1, self.qk_rope_head_dim).ndim, 1, 2 + ) ) - - k_pe = k_pe_cache.view(bsz, kv_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(compressed_kv_normed_cache) .view(bsz, kv_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) + .transpose( + perm=dim2perm( + self.kv_b_proj(compressed_kv_normed_cache) + .view( + bsz, + kv_len, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + .ndim, + 1, + 2, + ) + ) ) - # k_nope ~ [bsz, num_heads, kv_len, 128] value_states ~ [bsz, num_heads, kv_len, 128] - k_nope, value_states = torch.split( - kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + k_nope, value_states = paddle_split( + x=kv, num_or_sections=[self.qk_nope_head_dim, self.v_head_dim], axis=-1 ) - if k_nope.size() != (bsz, self.num_heads, kv_len, self.qk_nope_head_dim): - raise ValueError(f"k_nope[{k_nope.size()}]") - if value_states.size() != (bsz, self.num_heads, kv_len, self.v_head_dim): - raise ValueError(f"value_states[{value_states.size()}]") - + if tuple(k_nope.shape) != (bsz, self.num_heads, kv_len, self.qk_nope_head_dim): + raise ValueError(f"k_nope[{tuple(k_nope.shape)}]") + if tuple(value_states.shape) != (bsz, self.num_heads, kv_len, self.v_head_dim): + raise ValueError(f"value_states[{tuple(value_states.shape)}]") freqs_cis = precompute_freqs_cis( self.qk_rope_head_dim, kv_len, self.rope_theta, use_scaled=False - ).to(q_pe.device) + ).to(q_pe.place) q_pe, k_pe = apply_rotary_emb( - q_pe.transpose(1, 2).repeat(1, kv_len, 1, 1), - k_pe.transpose(1, 2), + q_pe.transpose(perm=dim2perm(q_pe.ndim, 1, 2)).tile( + repeat_times=[1, kv_len, 1, 1] + ), + k_pe.transpose(perm=dim2perm(k_pe.ndim, 1, 2)), freqs_cis, ) - q_pe = q_pe[:, -1:, :, :].transpose(1, 2) - k_pe = k_pe.transpose(1, 2) - - # Concat q_nope and q_pe to produce a new Q tensor with head_dim = 192 - query_states = q.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + q_pe = q_pe[:, -1:, :, :].transpose( + perm=dim2perm(q_pe[:, -1:, :, :].ndim, 1, 2) + ) + k_pe = k_pe.transpose(perm=dim2perm(k_pe.ndim, 1, 2)) + query_states = paddle.empty( + shape=[bsz, self.num_heads, q_len, self.q_head_dim], dtype=q.dtype + ) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - - # Concat k_nope and k_pe to produce a new K tensor with head_dim = 192 - key_states = k_pe.new_empty(bsz, self.num_heads, kv_len, self.q_head_dim) + key_states = paddle.empty( + shape=[bsz, self.num_heads, kv_len, self.q_head_dim], dtype=k_pe.dtype + ) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - attn_weights = ( - torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + paddle.matmul( + x=query_states, + y=key_states.transpose(perm=dim2perm(key_states.ndim, 2, 3)), + ) + * self.softmax_scale ) - - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 + attn_weights = paddle.nn.functional.softmax( + x=attn_weights, axis=-1, dtype="float32" ).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2).reshape( - bsz, q_len, self.num_heads * self.v_head_dim - ) - + attn_output = paddle.matmul(x=attn_weights, y=value_states) + attn_output = attn_output.transpose( + perm=dim2perm(attn_output.ndim, 1, 2) + ).reshape(bsz, q_len, self.num_heads * self.v_head_dim) output = self.o_proj(attn_output) - return output -class DeepseekV2AttentionMatAbsorbDecode(nn.Module): +class DeepseekV2AttentionMatAbsorbDecode(paddle.nn.Layer): def __init__(self, mla_vanilla: DeepseekV2AttentionVanilla): super().__init__() - - self.hidden_size = mla_vanilla.hidden_size # 5120 - self.num_heads = mla_vanilla.num_heads # 128 - - self.q_lora_rank = mla_vanilla.q_lora_rank # 1536 - self.qk_rope_head_dim = mla_vanilla.qk_rope_head_dim # 64 - self.kv_lora_rank = mla_vanilla.kv_lora_rank # 512 - self.v_head_dim = mla_vanilla.v_head_dim # 128 - self.qk_nope_head_dim = mla_vanilla.qk_nope_head_dim # 128 - self.q_head_dim = ( - mla_vanilla.q_head_dim - ) # qk_nope_head_dim + qk_rope_head_dim # 128+64=192 - + self.hidden_size = mla_vanilla.hidden_size + self.num_heads = mla_vanilla.num_heads + self.q_lora_rank = mla_vanilla.q_lora_rank + self.qk_rope_head_dim = mla_vanilla.qk_rope_head_dim + self.kv_lora_rank = mla_vanilla.kv_lora_rank + self.v_head_dim = mla_vanilla.v_head_dim + self.qk_nope_head_dim = mla_vanilla.qk_nope_head_dim + self.q_head_dim = mla_vanilla.q_head_dim self.softmax_scale = mla_vanilla.softmax_scale - self.rope_theta = mla_vanilla.rope_theta - - # W^DQ ~ [5120, 1536] - self.W_DQ = mla_vanilla.q_a_proj.weight.transpose(0, 1) - + self.W_DQ = mla_vanilla.q_a_proj.weight.transpose( + perm=dim2perm(mla_vanilla.q_a_proj.weight.ndim, 0, 1) + ) self.q_a_layernorm = DeepseekV2RMSNorm(self.q_lora_rank) - - # W_UQ ~ [1536, 128, 128] - W_UQ, W_QR = torch.split( - mla_vanilla.q_b_proj.weight.t().view( + W_UQ, W_QR = paddle_split( + x=mla_vanilla.q_b_proj.weight.t().view( self.q_lora_rank, self.num_heads, self.q_head_dim ), - [self.qk_nope_head_dim, self.qk_rope_head_dim], - -1, + num_or_sections=[self.qk_nope_head_dim, self.qk_rope_head_dim], + axis=-1, ) - # W_UQ ~ [1536, 128*64] self.W_QR = W_QR.reshape( self.q_lora_rank, self.num_heads * self.qk_rope_head_dim ) - - # W_UK ~ [512, 128, 128] W_UV ~ [512, 128, 128] - W_UK, W_UV = torch.split( - mla_vanilla.kv_b_proj.weight.t().view( + W_UK, W_UV = paddle_split( + x=mla_vanilla.kv_b_proj.weight.t().view( self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim + self.v_head_dim, ), - [self.qk_nope_head_dim, self.v_head_dim], - -1, + num_or_sections=[self.qk_nope_head_dim, self.v_head_dim], + axis=-1, + ) + self.W_UQ_UK = paddle.einsum("q n d, l n d -> q n l", W_UQ, W_UK).flatten( + start_axis=1 ) - - # Now we merge W_UQ and W_UK (absorb W_UK into W_UQ) - # q~q_lora_rank n~num_heads d~qk_nope_head_dim l~kv_lora_rank - self.W_UQ_UK = torch.einsum("q n d, l n d -> q n l", W_UQ, W_UK).flatten( - start_dim=1 - ) # [1536, 65536] - W_O = mla_vanilla.o_proj.weight.view( self.hidden_size, self.num_heads, self.v_head_dim ) - - # Merge W_UV and W_O (absorb W_UV into W_O) - # l~kv_lora_rank n~num_heads d~v_head_dim h~hidden_size - self.W_UV_O = torch.einsum("l n d, h n d -> n l h", W_UV, W_O).flatten( - start_dim=0, end_dim=1 - ) # [65536, 5120] + self.W_UV_O = paddle.einsum("l n d, h n d -> n l h", W_UV, W_O).flatten( + start_axis=0, stop_axis=1 + ) def run_proof_of_concept( self, - hidden_states: torch.Tensor, - compressed_kv_normed_cache: torch.Tensor, - k_pe_cache: torch.Tensor, + hidden_states: paddle.Tensor, + compressed_kv_normed_cache: paddle.Tensor, + k_pe_cache: paddle.Tensor, use_flashinfer_kernel: bool, convert_float16: bool, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - c_Q = torch.matmul(hidden_states, self.W_DQ) - # c_Q ~ [bsz, q_lora_rank:1536] + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + c_Q = paddle.matmul(x=hidden_states, y=self.W_DQ) c_Q = self.q_a_layernorm(c_Q) - - q_pe = torch.matmul( - c_Q, - self.W_QR, # c_Q ~ [bsz, q_lora_rank~1536] - ) # W_QR ~ [1536, 128*64] - # q_pe ~ [bsz, 128, 64] + q_pe = paddle.matmul(x=c_Q, y=self.W_QR) q_pe = q_pe.reshape(bsz, self.num_heads, self.qk_rope_head_dim) - - q_nope = torch.matmul(c_Q, self.W_UQ_UK) # W_UQ_UK~[1536, 128*512] - # q_nope ~ [bsz, 128, 512] + q_nope = paddle.matmul(x=c_Q, y=self.W_UQ_UK) q_nope = q_nope.reshape(bsz, self.num_heads, self.kv_lora_rank) - - q_kv_dtype = torch.float16 + q_kv_dtype = "float16" if convert_float16: q_nope = q_nope.to(q_kv_dtype) q_pe = q_pe.to(q_kv_dtype) compressed_kv_normed_cache = compressed_kv_normed_cache.to(q_kv_dtype) k_pe_cache = k_pe_cache.to(q_kv_dtype) - if not use_flashinfer_kernel: freqs_cis = precompute_freqs_cis( self.qk_rope_head_dim, kv_len, self.rope_theta, use_scaled=False - ).to(k_pe_cache.device) + ).to(k_pe_cache.place) q_pe, k_pe_cache = apply_rotary_emb( - q_pe.unsqueeze(1).repeat(1, kv_len, 1, 1), - k_pe_cache.unsqueeze(2), + q_pe.unsqueeze(axis=1).tile(repeat_times=[1, kv_len, 1, 1]), + k_pe_cache.unsqueeze(axis=2), freqs_cis, ) - q_pe = q_pe[:, -1:, :, :].squeeze(1) - k_pe_cache = k_pe_cache.squeeze(2) - - # attn_weights_pe ~ [bsz, 128, kv_len] - attn_weights_pe = torch.matmul( - q_pe, # [bsz, num_heads, qk_rope_head_dim] - k_pe_cache.transpose( - 1, 2 - ), # [bsz, kv_len, 64] view(bsz, kv_len, self.qk_rope_head_dim) + q_pe = q_pe[:, -1:, :, :].squeeze(axis=1) + k_pe_cache = k_pe_cache.squeeze(axis=2) + attn_weights_pe = paddle.matmul( + x=q_pe, y=k_pe_cache.transpose(perm=dim2perm(k_pe_cache.ndim, 1, 2)) ) - # attn_weights_nope ~ [bsz, 128, kv_len] - attn_weights_nope = torch.matmul( - q_nope, # [bsz, 128, 512] - compressed_kv_normed_cache.transpose(1, 2), # view(bsz, kv_len, 512) + attn_weights_nope = paddle.matmul( + x=q_nope, + y=compressed_kv_normed_cache.transpose( + perm=dim2perm(compressed_kv_normed_cache.ndim, 1, 2) + ), ) - attn_weights = (attn_weights_pe + attn_weights_nope) * self.softmax_scale - - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 + attn_weights = paddle.nn.functional.softmax( + x=attn_weights, axis=-1, dtype="float32" ).to(q_nope.dtype) - - # attn_output ~ {attn_output.shape}") # [bsz, 128, 512] - attn_output = torch.matmul( - attn_weights, # [bsz, 128, kv_len] - compressed_kv_normed_cache, # [bsz, kv_len, 512] - ) - + attn_output = paddle.matmul(x=attn_weights, y=compressed_kv_normed_cache) else: print("Now use MLA decode kernel!\n") if kv_len % page_size != 0: @@ -315,31 +278,35 @@ def run_proof_of_concept( ) freqs_cis = precompute_freqs_cis( self.qk_rope_head_dim, kv_len, self.rope_theta, use_scaled=False - ).to(k_pe_cache.device) + ).to(k_pe_cache.place) q_pe, k_pe_cache = apply_rotary_emb( - q_pe.unsqueeze(1).repeat(1, kv_len, 1, 1), - k_pe_cache.unsqueeze(2), + q_pe.unsqueeze(axis=1).tile(repeat_times=[1, kv_len, 1, 1]), + k_pe_cache.unsqueeze(axis=2), freqs_cis, ) - q_pe = q_pe[:, -1:, :, :].squeeze(1).contiguous() - k_pe_cache = k_pe_cache.squeeze(2) + q_pe = q_pe[:, -1:, :, :].squeeze(axis=1).contiguous() + k_pe_cache = k_pe_cache.squeeze(axis=2) num_pages_per_seq = kv_len // page_size total_num_pages = num_pages_per_seq * bsz - - kv_indptr = torch.arange(0, bsz + 1).to(dev_id).int() * num_pages_per_seq - kv_indices = torch.arange(0, total_num_pages).to(dev_id).int() - kv_last_page_len = torch.full((bsz,), page_size, dtype=torch.int32).to( - dev_id + kv_indptr = ( + paddle.arange(start=0, end=bsz + 1).to(dev_id).astype(dtype="int32") + * num_pages_per_seq ) - + kv_indices = ( + paddle.arange(start=0, end=total_num_pages) + .to(dev_id) + .astype(dtype="int32") + ) + kv_last_page_len = paddle.full( + shape=(bsz,), fill_value=page_size, dtype="int32" + ).to(dev_id) paged_ckv_cache = compressed_kv_normed_cache.reshape( total_num_pages, page_size, self.kv_lora_rank ) paged_kpe_cache = k_pe_cache.reshape( total_num_pages, page_size, self.qk_rope_head_dim ) - - workspace_buffer = torch.empty(64 * 1024 * 1024, dtype=torch.int8).to( + workspace_buffer = paddle.empty(shape=64 * 1024 * 1024, dtype="int8").to( dev_id ) wrapper = flashinfer.BatchDecodeMlaWithPagedKVCacheWrapper( @@ -362,130 +329,120 @@ def run_proof_of_concept( data_type=q_kv_dtype, q_data_type=q_kv_dtype, ) - attn_output = wrapper.run(q_nope, q_pe, paged_ckv_cache, paged_kpe_cache) - - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): + s = paddle.device.Stream() + s.wait_stream(paddle.device.current_stream()) + with paddle.device.stream_guard(stream=s): for _ in range(3): o, lse = wrapper.run( q_nope, q_pe, paged_ckv_cache, paged_kpe_cache, return_lse=True ) - torch.cuda.current_stream().wait_stream(s) - - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): + paddle.device.current_stream().wait_stream(s) +>>>>>> g = torch.cuda.CUDAGraph() +>>>>>> with torch.cuda.graph(g): attn_output = wrapper.run( q_nope, q_pe, paged_ckv_cache, paged_kpe_cache ) g.replay() - - # output ~ [bsz, 5120] - output = torch.matmul( - attn_output.to(self.W_UV_O.dtype).reshape( + output = paddle.matmul( + x=attn_output.to(self.W_UV_O.dtype).reshape( bsz, self.num_heads * self.kv_lora_rank ), - self.W_UV_O, - ) # W_UV_O ~ [65536, 5120] - + y=self.W_UV_O, + ) return output if __name__ == "__main__": dev_id = 0 - - torch.manual_seed(666) - torch.set_grad_enabled(False) - - mla_vanilla = DeepseekV2AttentionVanilla().cuda(device=dev_id) - + paddle.seed(seed=666) + paddle.set_grad_enabled(mode=False) + mla_vanilla = DeepseekV2AttentionVanilla().cuda(device_id=device2int(dev_id)) bsz = 6 kv_len = 640 page_size = 16 - - hidden_states = torch.randn([bsz, 1, mla_vanilla.hidden_size]).to(dev_id) - compressed_kv_normed_cache = torch.randn( - [bsz, kv_len, mla_vanilla.kv_lora_rank] + hidden_states = paddle.randn(shape=[bsz, 1, mla_vanilla.hidden_size]).to(dev_id) + compressed_kv_normed_cache = paddle.randn( + shape=[bsz, kv_len, mla_vanilla.kv_lora_rank] ).to(dev_id) - k_pe_cache = torch.randn([bsz, kv_len, mla_vanilla.qk_rope_head_dim]).to(dev_id) - + k_pe_cache = paddle.randn(shape=[bsz, kv_len, mla_vanilla.qk_rope_head_dim]).to( + dev_id + ) output_vanilla = mla_vanilla.run_decode( hidden_states, compressed_kv_normed_cache, k_pe_cache ) - - mla_mat_absorb = DeepseekV2AttentionMatAbsorbDecode(mla_vanilla).cuda(device=dev_id) + mla_mat_absorb = DeepseekV2AttentionMatAbsorbDecode(mla_vanilla).cuda( + device_id=device2int(dev_id) + ) output_mat_absorbed_use_torch_f32 = mla_mat_absorb.run_proof_of_concept( - hidden_states.squeeze(1), + hidden_states.squeeze(axis=1), compressed_kv_normed_cache, k_pe_cache, use_flashinfer_kernel=False, convert_float16=False, ) output_mat_absorbed_use_torch_f16 = mla_mat_absorb.run_proof_of_concept( - hidden_states.squeeze(1), + hidden_states.squeeze(axis=1), compressed_kv_normed_cache, k_pe_cache, use_flashinfer_kernel=False, convert_float16=True, ) output_mat_absorbed_use_flashinfer = mla_mat_absorb.run_proof_of_concept( - hidden_states.squeeze(1), + hidden_states.squeeze(axis=1), compressed_kv_normed_cache, k_pe_cache, use_flashinfer_kernel=True, convert_float16=True, ) - - cos_use_torch_f32 = F.cosine_similarity( - output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f32.reshape(-1), dim=0 + cos_use_torch_f32 = paddle.nn.functional.cosine_similarity( + x1=output_vanilla.reshape(-1), + x2=output_mat_absorbed_use_torch_f32.reshape(-1), + axis=0, ) print(f"cos_use_torch_f32 = {cos_use_torch_f32}") assert cos_use_torch_f32 > 0.99 - wmape_use_torch_f32 = wmape( output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f32.reshape(-1) ) print(f"wmape_use_torch_f32 = {wmape_use_torch_f32}") assert wmape_use_torch_f32 < 0.02 - - mse_use_torch_f32 = F.mse_loss( - output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f32.reshape(-1) + mse_use_torch_f32 = paddle.nn.functional.mse_loss( + input=output_vanilla.reshape(-1), + label=output_mat_absorbed_use_torch_f32.reshape(-1), ) print(f"mse_use_torch_f32={mse_use_torch_f32}\n") - - cos_use_torch_f16 = F.cosine_similarity( - output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1), dim=0 + cos_use_torch_f16 = paddle.nn.functional.cosine_similarity( + x1=output_vanilla.reshape(-1), + x2=output_mat_absorbed_use_torch_f16.reshape(-1), + axis=0, ) print(f"cos_use_torch_f16 = {cos_use_torch_f16}") assert cos_use_torch_f16 > 0.99 - wmape_use_torch_f16 = wmape( output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1) ) print(f"wmape_use_torch_f16 = {wmape_use_torch_f16}") assert wmape_use_torch_f16 < 0.03 - - mse_use_torch_f16 = F.mse_loss( - output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1) + mse_use_torch_f16 = paddle.nn.functional.mse_loss( + input=output_vanilla.reshape(-1), + label=output_mat_absorbed_use_torch_f16.reshape(-1), ) print(f"mse_use_torch_f16 = {mse_use_torch_f16}\n") - - cos_use_flashinfer = F.cosine_similarity( - output_vanilla.reshape(-1), - output_mat_absorbed_use_flashinfer.reshape(-1), - dim=0, + cos_use_flashinfer = paddle.nn.functional.cosine_similarity( + x1=output_vanilla.reshape(-1), + x2=output_mat_absorbed_use_flashinfer.reshape(-1), + axis=0, ) print(f"cos_use_flashinfer = {cos_use_flashinfer}") assert cos_use_flashinfer > 0.99 - wmape_use_flashinfer = wmape( output_vanilla.reshape(-1), output_mat_absorbed_use_flashinfer.reshape(-1) ) print(f"wmape_use_flashinfer = {wmape_use_flashinfer}") assert wmape_use_flashinfer < 0.02 - - mse_use_flashinfer = F.mse_loss( - output_vanilla.reshape(-1), output_mat_absorbed_use_flashinfer.reshape(-1) + mse_use_flashinfer = paddle.nn.functional.mse_loss( + input=output_vanilla.reshape(-1), + label=output_mat_absorbed_use_flashinfer.reshape(-1), ) print(f"mse_use_flashinfer = {mse_use_flashinfer}") diff --git a/tests/test_mla_page.py b/tests/test_mla_page.py index b2734643b2..8c0fcbbdc9 100644 --- a/tests/test_mla_page.py +++ b/tests/test_mla_page.py @@ -1,8 +1,12 @@ +import sys + + import math from typing import List +import paddle import pytest -import torch +from flashinfer.paddle_utils import * import flashinfer @@ -11,7 +15,7 @@ def calculate_last_page_len(kv_len: List[int], page_size: int): - return [len % page_size if len % page_size != 0 else page_size for len in kv_len] + return [(len % page_size if len % page_size != 0 else page_size) for len in kv_len] kv_len_configs = [ @@ -28,33 +32,32 @@ def calculate_last_page_len(kv_len: List[int], page_size: int): @pytest.mark.parametrize("page_size", [1, 16, 64]) def test_append_mla_paged_kv_cache(kv_len, page_size): nnz_kv = sum(kv_len) - ckv_append = torch.randn(nnz_kv, CKV_DIM, dtype=torch.float16, device="cuda:0") - kpe_append = torch.randn(nnz_kv, KPE_DIM, dtype=torch.float16, device="cuda:0") - num_pages_per_req = torch.tensor( - [math.ceil(len / page_size) for len in kv_len], - dtype=torch.int32, - device="cuda:0", + ckv_append = paddle.randn(shape=[nnz_kv, CKV_DIM], dtype="float16") + kpe_append = paddle.randn(shape=[nnz_kv, KPE_DIM], dtype="float16") + num_pages_per_req = paddle.to_tensor( + data=[math.ceil(len / page_size) for len in kv_len], + dtype="int32", + place="gpu:0", ) - kv_append_length = torch.tensor(kv_len, dtype=torch.int32, device="cuda:0") - kv_append_indptr = torch.cat( - [torch.zeros(1).int().to(0), torch.cumsum(kv_append_length, dim=0)] - ).int() - + kv_append_length = paddle.to_tensor(data=kv_len, dtype="int32", place="gpu:0") + kv_append_indptr = paddle.concat( + x=[ + paddle.zeros(shape=[1]).astype(dtype="int32").to(0), + paddle.cumsum(x=kv_append_length, axis=0), + ] + ).astype(dtype="int32") max_num_pages = sum(num_pages_per_req) - ckv_cache = torch.zeros( - max_num_pages, page_size, CKV_DIM, dtype=torch.float16, device="cuda:0" - ) - kpe_cache = torch.zeros( - max_num_pages, page_size, KPE_DIM, dtype=torch.float16, device="cuda:0" - ) - kv_page_indptr = torch.cat( - [torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)] - ).int() - kv_page_indices = torch.arange( - sum(num_pages_per_req), dtype=torch.int32, device="cuda:0" - ) - kv_last_page_len = torch.tensor( - calculate_last_page_len(kv_len, page_size), dtype=torch.int32, device="cuda:0" + ckv_cache = paddle.zeros(shape=[max_num_pages, page_size, CKV_DIM], dtype="float16") + kpe_cache = paddle.zeros(shape=[max_num_pages, page_size, KPE_DIM], dtype="float16") + kv_page_indptr = paddle.concat( + x=[ + paddle.zeros(shape=[1]).astype(dtype="int32").to(0), + paddle.cumsum(x=num_pages_per_req, axis=0), + ] + ).astype(dtype="int32") + kv_page_indices = paddle.arange(dtype="int32", end=sum(num_pages_per_req)) + kv_last_page_len = paddle.to_tensor( + data=calculate_last_page_len(kv_len, page_size), dtype="int32", place="gpu:0" ) batch_indices, positions = flashinfer.get_batch_indices_positions( kv_append_indptr, @@ -72,38 +75,38 @@ def test_append_mla_paged_kv_cache(kv_len, page_size): kv_page_indptr, kv_last_page_len, ) - ckv_cache = ckv_cache.view(-1, CKV_DIM) kpe_cache = kpe_cache.view(-1, KPE_DIM) - acc_kv_len = 0 acc_padding_kv_len = 0 for i in range(len(kv_len)): - assert torch.all( - torch.isclose( - kpe_append[acc_kv_len : acc_kv_len + kv_len[i]], - kpe_cache[acc_padding_kv_len : acc_padding_kv_len + kv_len[i]], + assert paddle.all( + x=paddle.isclose( + x=kpe_append[acc_kv_len : acc_kv_len + kv_len[i]], + y=kpe_cache[acc_padding_kv_len : acc_padding_kv_len + kv_len[i]], ) ) - assert torch.all( - torch.isclose( - ckv_append[acc_kv_len : acc_kv_len + kv_len[i]], - ckv_cache[acc_padding_kv_len : acc_padding_kv_len + kv_len[i]], + assert paddle.all( + x=paddle.isclose( + x=ckv_append[acc_kv_len : acc_kv_len + kv_len[i]], + y=ckv_cache[acc_padding_kv_len : acc_padding_kv_len + kv_len[i]], ) ) assert bool( - torch.all( - ckv_cache[ - acc_padding_kv_len + kv_len[i] : acc_padding_kv_len + paddle.all( + x=ckv_cache[ + acc_padding_kv_len + + kv_len[i] : acc_padding_kv_len + num_pages_per_req[i] * page_size ] == 0 ) ) assert bool( - torch.all( - kpe_cache[ - acc_padding_kv_len + kv_len[i] : acc_padding_kv_len + paddle.all( + x=kpe_cache[ + acc_padding_kv_len + + kv_len[i] : acc_padding_kv_len + num_pages_per_req[i] * page_size ] == 0 diff --git a/tests/test_mm_fp4.py b/tests/test_mm_fp4.py index d53e3143ca..74a47116c9 100644 --- a/tests/test_mm_fp4.py +++ b/tests/test_mm_fp4.py @@ -1,24 +1,22 @@ +import sys + + +import paddle import pytest -import torch -import torch.nn.functional as F +from flashinfer.paddle_utils import * -from flashinfer import ( - SfLayout, - autotune, - mm_fp4, - nvfp4_quantize, -) +from flashinfer import SfLayout, autotune, mm_fp4, nvfp4_quantize @pytest.mark.parametrize("m", [1, 48, 128, 256, 512]) @pytest.mark.parametrize("n", [128, 256, 512]) @pytest.mark.parametrize("k", [128, 256, 512]) -@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("res_dtype", ["bfloat16", "float16"]) @pytest.mark.parametrize("backend", ["trtllm", "cudnn", "cutlass"]) @pytest.mark.parametrize("use_128x4_sf_layout", [False, True]) @pytest.mark.parametrize("auto_tuning", [False, True]) def test_mm_fp4(m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning): - if backend == "trtllm" and res_dtype == torch.float16: + if backend == "trtllm" and res_dtype == "float16": print("Skipping test for trtllm fp4 with float16") return if not use_128x4_sf_layout and backend != "trtllm": @@ -27,29 +25,21 @@ def test_mm_fp4(m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning): if auto_tuning and backend == "cudnn": print("Skipping test for cudnn fp4 with auto_tuning=True") return - - input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) - mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) + input = paddle.randn(shape=[m, k], dtype="bfloat16") + mat2 = paddle.randn(shape=[n, k], dtype="bfloat16") a_sf_layout = SfLayout.layout_128x4 if use_128x4_sf_layout else SfLayout.layout_8x4 - - global_sf_input = (448 * 6) / input.float().abs().nan_to_num().max() - global_sf_mat2 = (448 * 6) / mat2.float().abs().nan_to_num().max() - + global_sf_input = 448 * 6 / input.astype(dtype="float32").abs().nan_to_num()._max() + global_sf_mat2 = 448 * 6 / mat2.astype(dtype="float32").abs().nan_to_num()._max() input_fp4, input_inv_s = nvfp4_quantize( input, global_sf_input, sfLayout=a_sf_layout, do_shuffle=False ) - - # for trtllm, we need to shuffle mat2 because we swap A, B. do_shuffle_b = backend == "trtllm" mat2_fp4, mat2_inv_s = nvfp4_quantize( mat2, global_sf_mat2, sfLayout=SfLayout.layout_128x4, do_shuffle=do_shuffle_b ) - - reference = torch.mm(input, mat2.T) - + reference = paddle.mm(input=input, mat2=mat2.T) alpha = 1.0 / (global_sf_input * global_sf_mat2) - res = torch.empty([m, n], device="cuda", dtype=res_dtype) - + res = paddle.empty(shape=[m, n], dtype=res_dtype) with autotune(auto_tuning): mm_fp4( input_fp4, @@ -62,8 +52,9 @@ def test_mm_fp4(m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning): use_8x4_sf_layout=not use_128x4_sf_layout, backend=backend, ) - - cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) + cos_sim = paddle.nn.functional.cosine_similarity( + x1=reference.reshape(-1), x2=res.reshape(-1), axis=0 + ) assert cos_sim > 0.97 diff --git a/tests/test_mnnvl_memory.py b/tests/test_mnnvl_memory.py index bbda852f06..d85def2679 100644 --- a/tests/test_mnnvl_memory.py +++ b/tests/test_mnnvl_memory.py @@ -1,22 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +import sys + + import socket +import paddle import pynvml import pytest -import torch +from flashinfer.paddle_utils import * from flashinfer.comm.mapping import Mapping from flashinfer.comm.mnnvl import MnnvlMemory, MpiComm @@ -32,7 +22,6 @@ class TestMnnvlMemory: @pytest.fixture(autouse=True) def setup(self): - # get num of task per node hostname = socket.gethostname() self.comm = MpiComm() self.world_size = self.comm.Get_size() @@ -44,11 +33,11 @@ def setup(self): assert uniform_ntasks, "Not all nodes has same ntasks_per_node" self.local_world_size = local_ntasks_per_node self.local_rank = self.rank % self.local_world_size - local_dev_count = torch.cuda.device_count() - assert self.local_world_size <= local_dev_count, ( - "ntasks_per_node should be less than local device count" - ) - torch.cuda.set_device(self.local_rank) + local_dev_count = paddle.device.cuda.device_count() + assert ( + self.local_world_size <= local_dev_count + ), "ntasks_per_node should be less than local device count" + paddle.device.set_device(device=device2str(self.local_rank)) MnnvlMemory.initialize() self.mapping = Mapping( self.world_size, self.rank, self.local_world_size, tp_size=self.world_size @@ -64,24 +53,21 @@ def align_memory(size: int): reason="Mnnvl memory is not supported on this platform", ) def test_mnnvl_memory(self): - # allocate un-aligned memory allocate0_size = 4 * 1024 * 1024 - 3 * 1024 mnnvl_memory0 = MnnvlMemory(self.mapping, allocate0_size) allocate0_size_aligned = TestMnnvlMemory.align_memory(allocate0_size) assert MnnvlMemory.current_mem_offset == allocate0_size_aligned - - tensor0 = mnnvl_memory0.as_torch_strided_tensor(torch.int32) + tensor0 = mnnvl_memory0.as_torch_strided_tensor("int32") numel_per_rank = allocate0_size // 4 - tensor0[(self.rank + 1) % self.world_size] = torch.arange( - start=self.rank, end=self.rank + numel_per_rank, device="cuda" + tensor0[(self.rank + 1) % self.world_size] = paddle.arange( + start=self.rank, end=self.rank + numel_per_rank ) self.comm.Barrier() for r in range(self.world_size): - torch.equal( - tensor0[(r + 1) % self.world_size], - torch.arange(start=r, end=r + numel_per_rank, device="cuda"), - ) - + paddle.equal_all( + x=tensor0[(r + 1) % self.world_size], + y=paddle.arange(start=r, end=r + numel_per_rank), + ).item() allocate1_size = 30 * 1024 * 1024 - 2 * 1024 mnnvl_memory1 = MnnvlMemory(self.mapping, allocate1_size) allocate1_size_aligned = TestMnnvlMemory.align_memory(allocate1_size) @@ -89,32 +75,25 @@ def test_mnnvl_memory(self): MnnvlMemory.current_mem_offset == allocate0_size_aligned + allocate1_size_aligned ) - tensor1 = mnnvl_memory1.as_torch_strided_tensor(torch.float32) + tensor1 = mnnvl_memory1.as_torch_strided_tensor("float32") numel_per_rank = allocate1_size // 4 - tensor1[(self.rank + 5) % self.world_size] = torch.arange( - start=self.rank, - end=self.rank + numel_per_rank, - dtype=torch.float32, - device="cuda", + tensor1[(self.rank + 5) % self.world_size] = paddle.arange( + start=self.rank, end=self.rank + numel_per_rank, dtype="float32" ) self.comm.Barrier() for r in range(self.world_size): - torch.equal( - tensor1[(r + 5) % self.world_size], - torch.arange( - start=r, end=r + numel_per_rank, dtype=torch.float32, device="cuda" - ), - ) + paddle.equal_all( + x=tensor1[(r + 5) % self.world_size], + y=paddle.arange(start=r, end=r + numel_per_rank, dtype="float32"), + ).item() self.comm.Barrier() del tensor0, mnnvl_memory0 self.comm.Barrier() - large_allocation2_size = 768 * 1024 * 1024 large_mnnvl_memory2 = MnnvlMemory(self.mapping, large_allocation2_size) allocate2_size_aligned = TestMnnvlMemory.align_memory(large_allocation2_size) assert MnnvlMemory.current_mem_offset == allocate2_size_aligned - assert large_mnnvl_memory2.rank_stride == (1 << 30) - + assert large_mnnvl_memory2.rank_stride == 1 << 30 del tensor1 @pytest.mark.skipif( @@ -122,70 +101,57 @@ def test_mnnvl_memory(self): reason="Mnnvl memory is not supported on this platform", ) def test_moe_alltoall_multi_rank_single_gpu(self): - torch.cuda.set_device(self.rank) + paddle.device.set_device(device=device2str(self.rank)) max_world_size = 8 - assert self.world_size <= max_world_size, ( - f"should run with world_size at most {max_world_size}" + assert ( + self.world_size <= max_world_size + ), f"should run with world_size at most {max_world_size}" + paddle.seed(seed=self.world_size) + input_entry_per_rank, vector_dim, dtype = 128, 256, "float16" + input_tensor = paddle.randn( + shape=[input_entry_per_rank * self.world_size, vector_dim], dtype=dtype ) - torch.manual_seed(self.world_size) - input_entry_per_rank, vector_dim, dtype = 128, 256, torch.float16 - - # Create a random input tensor - input_tensor = torch.randn( - input_entry_per_rank * self.world_size, - vector_dim, - dtype=dtype, - device=torch.device("cuda"), + ref_output_tensor = paddle.zeros( + shape=[input_entry_per_rank * self.world_size, vector_dim], dtype=dtype ) - ref_output_tensor = torch.zeros( - input_entry_per_rank * self.world_size, - vector_dim, - dtype=dtype, - device=torch.device("cuda"), + target_rank_ids = paddle.randint( + low=0, + high=self.world_size, + shape=(input_entry_per_rank * self.world_size,), + dtype="int32", ) - target_rank_ids = torch.randint( - 0, - self.world_size, - (input_entry_per_rank * self.world_size,), - dtype=torch.int32, - device=torch.device("cuda"), + input_tensors_all_ranks = list( + paddle_split(x=input_tensor, num_or_sections=input_entry_per_rank) ) - - input_tensors_all_ranks = list(torch.split(input_tensor, input_entry_per_rank)) target_rank_ids_all_ranks = list( - torch.split(target_rank_ids, input_entry_per_rank) + paddle_split(x=target_rank_ids, num_or_sections=input_entry_per_rank) ) - send_ids_all_ranks = [] send_counts_all_ranks = [] send_cumsum_all_ranks = [] send_start_end_all_ranks = [] - - # each rank do its own local compute to get how to send data to other ranks. for rank in range(self.world_size): send_start_end = [] local_target_rank_ids = target_rank_ids_all_ranks[rank] - sorted_local_target_rank_ids, local_send_id = torch.sort( - local_target_rank_ids - ) - local_send_id = local_send_id.to(torch.int32) - padded_sorted_local_target_rank_ids = torch.cat( - ( + sorted_local_target_rank_ids, local_send_id = paddle.sort( + x=local_target_rank_ids + ), paddle.argsort(x=local_target_rank_ids) + local_send_id = local_send_id.to("int32") + padded_sorted_local_target_rank_ids = paddle.concat( + x=( sorted_local_target_rank_ids, - torch.arange( - self.world_size, dtype=torch.int32, device=torch.device("cuda") - ), + paddle.arange(dtype="int32", end=self.world_size), ) ) - unique_target_rank_ids, local_send_counts = torch.unique( - padded_sorted_local_target_rank_ids, return_counts=True - ) - local_send_counts = local_send_counts.to(torch.int32) - assert unique_target_rank_ids.numel() == self.world_size, ( - "unique_target_rank_ids must be equal to world_size" + unique_target_rank_ids, local_send_counts = paddle.unique( + x=padded_sorted_local_target_rank_ids, return_counts=True ) - local_send_counts -= 1 # remove padding - local_send_cumsum = torch.cumsum(local_send_counts, dim=0).to(torch.int32) + local_send_counts = local_send_counts.to("int32") + assert ( + unique_target_rank_ids.size == self.world_size + ), "unique_target_rank_ids must be equal to world_size" + local_send_counts -= 1 + local_send_cumsum = paddle.cumsum(x=local_send_counts, axis=0).to("int32") send_ids_all_ranks.append(local_send_id) send_counts_all_ranks.append(local_send_counts) send_cumsum_all_ranks.append(local_send_cumsum) @@ -198,21 +164,14 @@ def test_moe_alltoall_multi_rank_single_gpu(self): ) ) send_start_end_all_ranks.append(send_start_end) - recv_ids_all_ranks = [] recv_cumsum_all_ranks = [] - ref_output_tensors_all_ranks = [] - total_recv_all_ranks_cpu = [] output_indice_offset = 0 - output_start_current_rank = 0 - # each rank do compute based on other ranks' send counts to get how to receive data from other ranks. for rank in range(self.world_size): - local_recv_counts = torch.zeros( - self.world_size, dtype=torch.int32, device=torch.device("cuda") - ) + local_recv_counts = paddle.zeros(shape=self.world_size, dtype="int32") for other_rank in range(self.world_size): local_recv_counts[other_rank] = send_counts_all_ranks[other_rank][rank] local_recv_count_pair = local_recv_counts[other_rank].cpu().item() @@ -225,7 +184,7 @@ def test_moe_alltoall_multi_rank_single_gpu(self): ] ] output_indice_offset += local_recv_count_pair - local_recv_cumsum = torch.cumsum(local_recv_counts, dim=0).to(torch.int32) + local_recv_cumsum = paddle.cumsum(x=local_recv_counts, axis=0).to("int32") recv_cumsum_all_ranks.append(local_recv_cumsum) total_recv_count = local_recv_cumsum[-1].cpu() total_recv_all_ranks_cpu.append(total_recv_count) @@ -236,11 +195,8 @@ def test_moe_alltoall_multi_rank_single_gpu(self): ] ) output_start_current_rank += total_recv_count - local_recv_ids = torch.arange( - total_recv_count, dtype=torch.int32, device=torch.device("cuda") - ) + local_recv_ids = paddle.arange(dtype="int32", end=total_recv_count) recv_ids_all_ranks.append(local_recv_ids) - alltoall_info = MoEAlltoallInfo( None, send_cumsum_all_ranks[self.rank], @@ -248,13 +204,10 @@ def test_moe_alltoall_multi_rank_single_gpu(self): recv_cumsum_all_ranks[self.rank], recv_ids_all_ranks[self.rank], None, - ref_output_tensors_all_ranks[self.rank].shape[0], + tuple(ref_output_tensors_all_ranks[self.rank].shape)[0], ) - alltoall_workspace = MnnvlMoe.get_moe_workspaces(self.mapping) - self.comm.Barrier() - output = MnnvlMoe.mnnvl_moe_alltoallv( input_tensors_all_ranks[self.rank], alltoall_info, @@ -262,9 +215,7 @@ def test_moe_alltoall_multi_rank_single_gpu(self): self.rank, self.world_size, ) - self.comm.Barrier() - - torch.testing.assert_close( - output, ref_output_tensors_all_ranks[self.rank], atol=1e-5, rtol=1e-5 - ) + assert paddle.allclose( + x=output, y=ref_output_tensors_all_ranks[self.rank], atol=1e-05, rtol=1e-05 + ).item(), "" diff --git a/tests/test_non_contiguous_decode.py b/tests/test_non_contiguous_decode.py index 8ffd4d6176..fc87d35d60 100644 --- a/tests/test_non_contiguous_decode.py +++ b/tests/test_non_contiguous_decode.py @@ -1,6 +1,11 @@ +import sys + + +import paddle import pytest -import torch -from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules +from jit_utils import (gen_decode_attention_modules, + gen_prefill_attention_modules) +from flashinfer.paddle_utils import * import flashinfer @@ -9,21 +14,10 @@ def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps + ["float16"], ["float16"], [64, 128, 256], [0], [False], [False] ) + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions + ["float16"], ["float16"], [64, 128, 256], [0], [False], [False], [False] ), verbose=False, ) @@ -37,12 +31,7 @@ def warmup_jit(): @pytest.mark.parametrize("num_qo_heads", [4, 8]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) def test_batch_paged_decode_packed_input( - batch_size, - page_size, - seq_len, - num_kv_heads, - num_qo_heads, - head_dim, + batch_size, page_size, seq_len, num_kv_heads, num_qo_heads, head_dim ): if num_qo_heads % num_kv_heads != 0: pytest.skip("num_qo_heads must be a multiple of num_kv_heads") @@ -50,28 +39,23 @@ def test_batch_paged_decode_packed_input( num_pages_per_req = (seq_len + page_size - 1) // page_size num_pages = batch_size * num_pages_per_req last_page_len = (seq_len - 1) % page_size + 1 - k_cache = torch.randn( - size=(num_pages, page_size, num_kv_heads, head_dim), - dtype=torch.float16, - device="cuda:0", - ) - v_cache = torch.randn_like(k_cache) - paged_kv_cache = (k_cache, v_cache) - workspace_buffer = torch.empty( - (256 * 1024 * 1024,), dtype=torch.uint8, device="cuda:0" + k_cache = paddle.randn( + shape=(num_pages, page_size, num_kv_heads, head_dim), dtype="float16" ) - paged_kv_indptr = torch.tensor( - [i * num_pages_per_req for i in range(batch_size + 1)], - dtype=torch.int32, - device="cuda:0", + v_cache = paddle.randn(shape=k_cache.shape, dtype=k_cache.dtype) + paged_kv_cache = k_cache, v_cache + workspace_buffer = paddle.empty(shape=(256 * 1024 * 1024,), dtype="uint8") + paged_kv_indptr = paddle.to_tensor( + data=[(i * num_pages_per_req) for i in range(batch_size + 1)], + dtype="int32", + place="gpu:0", ) - paged_kv_indices = torch.tensor( - list(range(num_pages)), dtype=torch.int32, device="cuda:0" + paged_kv_indices = paddle.to_tensor( + data=list(range(num_pages)), dtype="int32", place="gpu:0" ) - paged_kv_last_page_len = torch.tensor( - [last_page_len for _ in range(batch_size)], dtype=torch.int32, device="cuda:0" + paged_kv_last_page_len = paddle.to_tensor( + data=[last_page_len for _ in range(batch_size)], dtype="int32", place="gpu:0" ) - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer) wrapper.plan( indptr=paged_kv_indptr, @@ -82,11 +66,8 @@ def test_batch_paged_decode_packed_input( head_dim=head_dim, page_size=page_size, ) - - qkv_packed = torch.randn( - size=(nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim), - dtype=torch.float16, - device="cuda:0", + qkv_packed = paddle.randn( + shape=(nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim), dtype="float16" ) qkv_split_idx = ( num_qo_heads * head_dim, @@ -97,4 +78,6 @@ def test_batch_paged_decode_packed_input( q = q.view(-1, num_qo_heads, head_dim) o_packed = wrapper.run(q, paged_kv_cache) o_contiguous = wrapper.run(q.contiguous(), paged_kv_cache) - torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3) + assert paddle.allclose( + x=o_packed, y=o_contiguous, rtol=0.001, atol=0.001 + ).item(), "" diff --git a/tests/test_non_contiguous_prefill.py b/tests/test_non_contiguous_prefill.py index b095bcf308..07ce6a5f56 100644 --- a/tests/test_non_contiguous_prefill.py +++ b/tests/test_non_contiguous_prefill.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2024 by FlashInfer team. @@ -13,9 +19,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch from jit_utils import gen_prefill_attention_modules import flashinfer @@ -25,13 +29,7 @@ def warmup_jit(): flashinfer.jit.build_jit_specs( gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions + ["float16"], ["float16"], [64, 128, 256], [0], [False], [False], [False] ), verbose=False, ) @@ -48,11 +46,8 @@ def test_single_prefill_packed_input( ): if num_qo_heads % num_kv_heads != 0: pytest.skip("num_qo_heads must be a multiple of num_kv_heads") - qkv_packed = torch.randn( - seq_len, - (num_qo_heads + 2 * num_kv_heads) * head_dim, - dtype=torch.float16, - device="cuda:0", + qkv_packed = paddle.randn( + shape=[seq_len, (num_qo_heads + 2 * num_kv_heads) * head_dim], dtype="float16" ) q = qkv_packed[:, : num_qo_heads * head_dim].reshape( seq_len, num_qo_heads, head_dim @@ -63,13 +58,13 @@ def test_single_prefill_packed_input( v = qkv_packed[:, (num_qo_heads + num_kv_heads) * head_dim :].reshape( seq_len, num_kv_heads, head_dim ) - o_packed = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=causal) o_contiguous = flashinfer.single_prefill_with_kv_cache( q.contiguous(), k.contiguous(), v.contiguous(), causal=causal ) - - torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3) + assert paddle.allclose( + x=o_packed, y=o_contiguous, rtol=0.001, atol=0.001 + ).item(), "" @pytest.mark.parametrize("batch_size", [1, 19, 99]) @@ -84,11 +79,8 @@ def test_batch_ragged_prefill_packed_input( if num_qo_heads % num_kv_heads != 0: pytest.skip("num_qo_heads must be a multiple of num_kv_heads") nnz = batch_size * seq_len - qkv_packed = torch.randn( - nnz, - (num_qo_heads + 2 * num_kv_heads) * head_dim, - dtype=torch.float16, - device="cuda:0", + qkv_packed = paddle.randn( + shape=[nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim], dtype="float16" ) q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) k = qkv_packed[ @@ -97,22 +89,22 @@ def test_batch_ragged_prefill_packed_input( v = qkv_packed[:, (num_qo_heads + num_kv_heads) * head_dim :].reshape( nnz, num_kv_heads, head_dim ) - qo_indptr = torch.tensor( - [i * seq_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" + qo_indptr = paddle.to_tensor( + data=[(i * seq_len) for i in range(batch_size + 1)], + dtype="int32", + place="gpu:0", ) kv_indptr = qo_indptr - - workspace_buffer = torch.empty( - (256 * 1024 * 1024,), dtype=torch.uint8, device="cuda:0" - ) + workspace_buffer = paddle.empty(shape=(256 * 1024 * 1024,), dtype="uint8") wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer) wrapper.plan( qo_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, causal=causal ) o_packed = wrapper.run(q, k, v) o_contiguous = wrapper.run(q.contiguous(), k.contiguous(), v.contiguous()) - - torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3) + assert paddle.allclose( + x=o_packed, y=o_contiguous, rtol=0.001, atol=0.001 + ).item(), "" @pytest.mark.parametrize("batch_size", [1, 19, 99]) @@ -123,44 +115,35 @@ def test_batch_ragged_prefill_packed_input( @pytest.mark.parametrize("head_dim", [64, 128, 256]) @pytest.mark.parametrize("causal", [True, False]) def test_batch_paged_prefill_packed_input( - batch_size, - page_size, - seq_len, - num_kv_heads, - num_qo_heads, - head_dim, - causal, + batch_size, page_size, seq_len, num_kv_heads, num_qo_heads, head_dim, causal ): if num_qo_heads % num_kv_heads != 0: pytest.skip("num_qo_heads must be a multiple of num_kv_heads") - nnz = batch_size * seq_len num_pages_per_req = (seq_len + page_size - 1) // page_size num_pages = batch_size * num_pages_per_req last_page_len = (seq_len - 1) % page_size + 1 - k_cache = torch.randn( - size=(num_pages, page_size, num_kv_heads, head_dim), - dtype=torch.float16, - device="cuda:0", - ) - v_cache = torch.randn_like(k_cache) - paged_kv_cache = (k_cache, v_cache) - workspace_buffer = torch.empty( - (256 * 1024 * 1024,), dtype=torch.uint8, device="cuda:0" + k_cache = paddle.randn( + shape=(num_pages, page_size, num_kv_heads, head_dim), dtype="float16" ) - qo_indptr = torch.tensor( - [i * seq_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" + v_cache = paddle.randn(shape=k_cache.shape, dtype=k_cache.dtype) + paged_kv_cache = k_cache, v_cache + workspace_buffer = paddle.empty(shape=(256 * 1024 * 1024,), dtype="uint8") + qo_indptr = paddle.to_tensor( + data=[(i * seq_len) for i in range(batch_size + 1)], + dtype="int32", + place="gpu:0", ) - paged_kv_indptr = torch.tensor( - [i * num_pages_per_req for i in range(batch_size + 1)], - dtype=torch.int32, - device="cuda:0", + paged_kv_indptr = paddle.to_tensor( + data=[(i * num_pages_per_req) for i in range(batch_size + 1)], + dtype="int32", + place="gpu:0", ) - paged_kv_indices = torch.tensor( - list(range(num_pages)), dtype=torch.int32, device="cuda:0" + paged_kv_indices = paddle.to_tensor( + data=list(range(num_pages)), dtype="int32", place="gpu:0" ) - paged_kv_last_page_len = torch.tensor( - [last_page_len for _ in range(batch_size)], dtype=torch.int32, device="cuda:0" + paged_kv_last_page_len = paddle.to_tensor( + data=[last_page_len for _ in range(batch_size)], dtype="int32", place="gpu:0" ) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer) wrapper.plan( @@ -174,11 +157,8 @@ def test_batch_paged_prefill_packed_input( page_size=page_size, causal=causal, ) - - qkv_packed = torch.randn( - size=(nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim), - dtype=torch.float16, - device="cuda:0", + qkv_packed = paddle.randn( + shape=(nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim), dtype="float16" ) qkv_split_idx = ( num_qo_heads * head_dim, @@ -186,11 +166,12 @@ def test_batch_paged_prefill_packed_input( num_kv_heads * head_dim, ) q, _, _ = qkv_packed.split(qkv_split_idx, dim=-1) - # pretend that we have already appended k/v to paged_kv table q = q.view(-1, num_qo_heads, head_dim) o_packed = wrapper.run(q, paged_kv_cache) o_contiguous = wrapper.run(q.contiguous(), paged_kv_cache) - torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=2e-3) + assert paddle.allclose( + x=o_packed, y=o_contiguous, rtol=0.001, atol=0.002 + ).item(), "" if __name__ == "__main__": diff --git a/tests/test_norm.py b/tests/test_norm.py index 83ad0f412c..d66d29695f 100644 --- a/tests/test_norm.py +++ b/tests/test_norm.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,125 +15,114 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch import flashinfer from flashinfer.utils import get_compute_capability -def llama_rms_norm(x, w, eps=1e-6): +def llama_rms_norm(x, w, eps=1e-06): orig_dtype = x.dtype - x = x.float() - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + eps) - x = x * w.float() + x = x.astype(dtype="float32") + variance = x.pow(y=2).mean(axis=-1, keepdim=True) + x = x * paddle.rsqrt(x=variance + eps) + x = x * w.astype(dtype="float32") x = x.to(orig_dtype) return x -def gemma_rms_norm(x, w, eps=1e-6): +def gemma_rms_norm(x, w, eps=1e-06): orig_dtype = x.dtype - x = x.float() - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + eps) - x = x * (1.0 + w.float()) + x = x.astype(dtype="float32") + variance = x.pow(y=2).mean(axis=-1, keepdim=True) + x = x * paddle.rsqrt(x=variance + eps) + x = x * (1.0 + w.astype(dtype="float32")) x = x.to(orig_dtype) return x -def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6): +def gemma_fused_add_rms_norm(x, residual, w, eps=1e-06): orig_dtype = x.dtype x = x + residual residual = x - x = x.float() - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + eps) - x = x * (1.0 + w.float()) + x = x.astype(dtype="float32") + variance = x.pow(y=2).mean(axis=-1, keepdim=True) + x = x * paddle.rsqrt(x=variance + eps) + x = x * (1.0 + w.astype(dtype="float32")) x = x.to(orig_dtype) return x, residual def fused_add_rms_norm(x, residual, weight, eps): orig_dtype = x.dtype - x = x.to(torch.float32) - x = x + residual.to(torch.float32) + x = x.to("float32") + x = x + residual.to("float32") residual = x.to(orig_dtype) - - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + eps) - x = (x * weight.float()).to(orig_dtype) + variance = x.pow(y=2).mean(axis=-1, keepdim=True) + x = x * paddle.rsqrt(x=variance + eps) + x = (x * weight.astype(dtype="float32")).to(orig_dtype) return x, residual @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) -@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("dtype", ["float16"]) @pytest.mark.parametrize("specify_out", [True, False]) @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("contiguous", [True, False]) def test_norm(batch_size, hidden_size, dtype, specify_out, enable_pdl, contiguous): if contiguous: - x = torch.randn(batch_size, hidden_size).to(0).to(dtype) + x = paddle.randn(shape=[batch_size, hidden_size]).to(0).to(dtype) else: - x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype) + x = paddle.randn(shape=[batch_size, hidden_size * 2]).to(dtype) x = x[:, :hidden_size] - - major, _ = get_compute_capability(x.device) + major, _ = get_compute_capability(x.place) if major < 9 and enable_pdl: pytest.skip("PDL is only available for Hopper and later GPUs") - - w = torch.randn(hidden_size).to(0).to(dtype) - + w = paddle.randn(shape=hidden_size).to(0).to(dtype) y_ref = llama_rms_norm(x, w) if specify_out: - y = torch.empty_like(x) + y = paddle.empty_like(x=x) flashinfer.norm.rmsnorm(x, w, out=y, enable_pdl=enable_pdl) else: y = flashinfer.norm.rmsnorm(x, w, enable_pdl=enable_pdl) - - torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=y_ref, y=y, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) -@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("dtype", ["float16"]) @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("contiguous", [True, False]) def test_fused_add_rmsnorm(batch_size, hidden_size, dtype, enable_pdl, contiguous): - eps = 1e-6 - + eps = 1e-06 if contiguous: - x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") + x = paddle.randn(shape=[batch_size, hidden_size], dtype=dtype) else: - x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype) + x = paddle.randn(shape=[batch_size, hidden_size * 2]).to(dtype) x = x[:, :hidden_size] - - major, _ = get_compute_capability(x.device) + major, _ = get_compute_capability(x.place) if major < 9 and enable_pdl: pytest.skip("PDL is only available for Hopper and later GPUs") - - residual = torch.randn_like(x) - weight = torch.randn(hidden_size, dtype=dtype, device="cuda") - + residual = paddle.randn(shape=x.shape, dtype=x.dtype) + weight = paddle.randn(shape=hidden_size, dtype=dtype) x_native, residual_native = fused_add_rms_norm( x.clone(), residual.clone(), weight, eps ) - x_fused = x.clone() residual_fused = residual.clone() flashinfer.fused_add_rmsnorm( x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl ) - - torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=x_fused, y=x_native, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose( + x=residual_fused, y=residual_native, rtol=0.001, atol=0.001 + ).item(), "" @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) -@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("dtype", ["float16"]) @pytest.mark.parametrize("specify_out", [True, False]) @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("contiguous", [True, False]) @@ -139,64 +130,55 @@ def test_gemma_norm( batch_size, hidden_size, dtype, specify_out, enable_pdl, contiguous ): if contiguous: - x = torch.randn(batch_size, hidden_size).to(0).to(dtype) + x = paddle.randn(shape=[batch_size, hidden_size]).to(0).to(dtype) else: - x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype) + x = paddle.randn(shape=[batch_size, hidden_size * 2]).to(dtype) x = x[:, :hidden_size] - - major, _ = get_compute_capability(x.device) + major, _ = get_compute_capability(x.place) if major < 9 and enable_pdl: pytest.skip("PDL is only available for Hopper and later GPUs") - - w = torch.randn(hidden_size).to(0).to(dtype) - + w = paddle.randn(shape=hidden_size).to(0).to(dtype) y_ref = gemma_rms_norm(x, w) if specify_out: - y = torch.empty_like(x) + y = paddle.empty_like(x=x) flashinfer.norm.gemma_rmsnorm(x, w, out=y, enable_pdl=enable_pdl) else: y = flashinfer.norm.gemma_rmsnorm(x, w, enable_pdl=enable_pdl) - - torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=y_ref, y=y, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) -@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("dtype", ["float16"]) @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("contiguous", [True, False]) def test_gemma_fused_add_rmsnorm( batch_size, hidden_size, dtype, enable_pdl, contiguous ): - eps = 1e-6 - + eps = 1e-06 if contiguous: - x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") + x = paddle.randn(shape=[batch_size, hidden_size], dtype=dtype) else: - x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype) + x = paddle.randn(shape=[batch_size, hidden_size * 2]).to(dtype) x = x[:, :hidden_size] - - major, _ = get_compute_capability(x.device) + major, _ = get_compute_capability(x.place) if major < 9 and enable_pdl: pytest.skip("PDL is only available for Hopper and later GPUs") - - residual = torch.randn_like(x) - weight = torch.randn(hidden_size, dtype=dtype, device="cuda") - + residual = paddle.randn(shape=x.shape, dtype=x.dtype) + weight = paddle.randn(shape=hidden_size, dtype=dtype) x_native, residual_native = gemma_fused_add_rms_norm( x.clone(), residual.clone(), weight, eps ) - x_fused = x.clone() residual_fused = residual.clone() flashinfer.gemma_fused_add_rmsnorm( x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl ) - - torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=x_fused, y=x_native, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose( + x=residual_fused, y=residual_native, rtol=0.001, atol=0.001 + ).item(), "" if __name__ == "__main__": - # test_norm(1, 1024, torch.float16, False, True) - test_fused_add_rmsnorm(1, 16384, torch.float16, True, True) + test_fused_add_rmsnorm(1, 16384, "float16", True, True) diff --git a/tests/test_nvshmem_allreduce.py b/tests/test_nvshmem_allreduce.py index 31f8d73ea1..0dcb52485a 100644 --- a/tests/test_nvshmem_allreduce.py +++ b/tests/test_nvshmem_allreduce.py @@ -1,11 +1,14 @@ +import sys + + import logging import multiprocessing as mp import socket from typing import Any +import paddle import pytest -import torch -import torch.distributed as dist +from flashinfer.paddle_utils import * from flashinfer.comm.nvshmem_allreduce import NVSHMEMAllReduce @@ -14,55 +17,48 @@ def _run_correctness_worker(world_size, rank, distributed_init_port): assert rank >= 0 - torch.cuda.set_device(rank) - device = torch.device("cuda", rank) + paddle.device.set_device(device=device2str(rank)) + device = device2str("cuda", rank) distributed_init_method = f"tcp://localhost:{distributed_init_port}" - dist.init_process_group( +>>>>>> torch.distributed.init_process_group( backend="cpu:gloo,cuda:nccl", rank=rank, world_size=world_size, device_id=device, init_method=distributed_init_method, ) - group = dist.group.WORLD - num_ranks = torch.distributed.get_world_size() - rank_id = torch.distributed.get_rank() - +>>>>>> group = torch.distributed.group.WORLD + num_ranks = paddle.distributed.get_world_size() + rank_id = paddle.distributed.get_rank() batch_sizes = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] max_batch_size = 4096 hidden_dim = 8192 test_loop = 10 - tensor_dtype = torch.bfloat16 + tensor_dtype = "bfloat16" nvshmem_allreduce = NVSHMEMAllReduce( - rank_id, - num_ranks, - max_batch_size * hidden_dim, - tensor_dtype, - device, - group, + rank_id, num_ranks, max_batch_size * hidden_dim, tensor_dtype, device, group ) - try: for batch_size in batch_sizes: for _ in range(test_loop): tensor_size = batch_size * hidden_dim - inp1 = torch.full( - [tensor_size], rank_id, dtype=tensor_dtype, device=device + inp1 = paddle.full( + shape=[tensor_size], fill_value=rank_id, dtype=tensor_dtype ) inp1_ref = inp1.clone() - out1 = torch.empty_like(inp1) + out1 = paddle.empty_like(x=inp1) nvshmem_allreduce.all_reduce(inp1, out1) - torch.distributed.all_reduce(inp1_ref, group=group) - torch.cuda.synchronize() - torch.testing.assert_close(out1, inp1_ref) - torch.distributed.barrier(group) + paddle.distributed.all_reduce(tensor=inp1_ref, group=group) + paddle.device.synchronize() + assert paddle.allclose(x=out1, y=inp1_ref).item(), "" + paddle.distributed.barrier(group=group) except Exception as e: print(f"Rank {rank_id}: Exception during test: {e}") raise finally: - torch.distributed.barrier(group) + paddle.distributed.barrier(group=group) nvshmem_allreduce.shutdown() - torch.distributed.destroy_process_group(group) +>>>>>> torch.distributed.destroy_process_group(group) def get_open_port() -> int: @@ -80,33 +76,28 @@ def multi_process_parallel( world_size: int, test_target: Any, target_args: tuple = () ) -> None: mp.set_start_method("spawn", force=True) - procs = [] distributed_init_port = get_open_port() for i in range(world_size): proc_args = (world_size, i, distributed_init_port) + target_args proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") - proc.start() + """Not Support auto convert *.start, please judge whether it is Pytorch API and convert by yourself""" +>>>>>> proc.start() procs.append(proc) - for i in range(world_size): procs[i].join() - assert procs[i].exitcode == 0, ( - f"Process {i} failed with exit code {procs[i].exitcode}" - ) + assert ( + procs[i].exitcode == 0 + ), f"Process {i} failed with exit code {procs[i].exitcode}" @pytest.mark.parametrize("world_size", [8]) def test_nvshmem_allreduce(world_size): - available_gpus = torch.cuda.device_count() + available_gpus = paddle.device.cuda.device_count() if world_size > available_gpus: raise ValueError( f"world_size {world_size} is greater than available_gpus {available_gpus}" ) print(f"Running test for world_size={world_size}") - multi_process_parallel( - world_size, - _run_correctness_worker, - target_args=(), - ) + multi_process_parallel(world_size, _run_correctness_worker, target_args=()) print(f"NVSHMEM allreduce tp = {world_size}: OK") diff --git a/tests/test_page.py b/tests/test_page.py index 0b75f3423c..ff48b6947b 100644 --- a/tests/test_page.py +++ b/tests/test_page.py @@ -1,5 +1,5 @@ +import paddle import pytest -import torch import flashinfer @@ -9,42 +9,59 @@ def test_append_paged_kv_cache(contiguous): nnz_kv = 100 num_kv_heads = 32 head_dim = 128 - if contiguous: - k_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0) - v_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0) + k_append = ( + paddle.randn(shape=[nnz_kv, num_kv_heads, head_dim]) + .astype(dtype="float16") + .to(0) + ) + v_append = ( + paddle.randn(shape=[nnz_kv, num_kv_heads, head_dim]) + .astype(dtype="float16") + .to(0) + ) else: - kv_append = torch.randn(nnz_kv, 2, num_kv_heads, head_dim).half().to(0) + kv_append = ( + paddle.randn(shape=[nnz_kv, 2, num_kv_heads, head_dim]) + .astype(dtype="float16") + .to(0) + ) k_append = kv_append[:, 0] v_append = kv_append[:, 1] - # 45 + 8 + 25 + 22 = nnz_kv - kv_append_length = torch.tensor([45, 8, 25, 22], dtype=torch.int32, device="cuda:0") - kv_append_indptr = torch.cat( - [torch.zeros(1).int().to(0), torch.cumsum(kv_append_length, dim=0)] - ).int() - + kv_append_length = paddle.to_tensor( + data=[45, 8, 25, 22], dtype="int32", place="gpu:0" + ) + kv_append_indptr = paddle.concat( + x=[ + paddle.zeros(shape=[1]).astype(dtype="int32").to(0), + paddle.cumsum(x=kv_append_length, axis=0), + ] + ).astype(dtype="int32") max_num_pages = 1000 page_size = 16 paged_kv_cache = ( - torch.randn(max_num_pages, 2, page_size, num_kv_heads, head_dim).half().to(0) + paddle.randn(shape=[max_num_pages, 2, page_size, num_kv_heads, head_dim]) + .astype(dtype="float16") + .to(0) + ) + num_pages_per_req = paddle.to_tensor( + data=[3, 1, 2, 2], dtype="int32", place="gpu:0" + ) + kv_page_indptr = paddle.concat( + x=[ + paddle.zeros(shape=[1]).astype(dtype="int32").to(0), + paddle.cumsum(x=num_pages_per_req, axis=0), + ] + ).astype(dtype="int32") + kv_page_indices = paddle.arange(dtype="int32", end=8) + kv_last_page_len = paddle.to_tensor( + data=[13, 8, 9, 6], dtype="int32", place="gpu:0" ) - num_pages_per_req = torch.tensor([3, 1, 2, 2], dtype=torch.int32, device="cuda:0") - kv_page_indptr = torch.cat( - [torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)] - ).int() - # use first 8 pages in the paged-kv - kv_page_indices = torch.arange(8, dtype=torch.int32, device="cuda:0") - # 45 = (3 - 1) * 16 + 13 - # 8 = (1 - 1) * 16 + 8 - # 25 = (2 - 1) * 16 + 9 - # 22 = (2 - 1) * 16 + 6 - kv_last_page_len = torch.tensor([13, 8, 9, 6], dtype=torch.int32, device="cuda:0") batch_indices, positions = flashinfer.get_batch_indices_positions( kv_append_indptr, flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size), nnz_kv, ) - flashinfer.append_paged_kv_cache( k_append, v_append, diff --git a/tests/test_pod_kernels.py b/tests/test_pod_kernels.py index fcbdc0affb..1941263e39 100644 --- a/tests/test_pod_kernels.py +++ b/tests/test_pod_kernels.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2023 by FlashInfer team. @@ -13,10 +15,9 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch -from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules +from jit_utils import (gen_decode_attention_modules, + gen_prefill_attention_modules) import flashinfer from flashinfer.jit.attention.pytorch import gen_pod_module @@ -26,38 +27,25 @@ def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_fp16_qk_reductions + ["float16"], ["float16"], [128], [0], [False], [False] ) + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [ - torch.float16, - ], # kv_dtypes - [128], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_cap - [False], # use_fp16_qk_reductions + ["float16"], ["float16"], [128], [0], [False], [False], [False] ) + [ gen_pod_module( - torch.float16, # dtype_q - torch.float16, # dtype_kv - torch.float16, # dtype_o - 128, # head_dim - 0, # pos_encoding_mode_p - False, # use_sliding_window_p - False, # use_logits_soft_cap_p - False, # use_fp16_qk_reduction - torch.int32, # dtype_idx - 0, # pos_encoding_mode_d - False, # use_sliding_window_d - False, # use_logits_soft_cap_d + "float16", + "float16", + "float16", + 128, + 0, + False, + False, + False, + "int32", + 0, + False, + False, ) ], verbose=False, @@ -76,20 +64,17 @@ def warmup_jit(): @pytest.mark.parametrize("num_qo_heads", [8, 32]) @pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE"]) -@pytest.mark.parametrize("q_dtype", [torch.float16]) -@pytest.mark.parametrize("kv_dtype", [torch.float16]) +@pytest.mark.parametrize("q_dtype", ["float16"]) +@pytest.mark.parametrize("kv_dtype", ["float16"]) @pytest.mark.parametrize("contiguous_kv", [True]) def test_pod_with_paged_kv_cache( - # Prefill params kv_len_p, qo_len_p, causal, - # Decode params batch_size_d, kv_len_d, page_size_d, kv_layout_d, - # Shared params num_kv_heads, num_qo_heads, head_dim, @@ -100,28 +85,13 @@ def test_pod_with_paged_kv_cache( ): if causal and qo_len_p > kv_len_p: pytest.skip("Causal prefill with qo_len_p > kv_len_p is not supported") - # Prefill inputs - q_p = torch.randn( - qo_len_p, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 - ) - k_p = torch.randn( - kv_len_p, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16 - ) - v_p = torch.randn( - kv_len_p, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16 - ) - # Generate prefill reference output + q_p = paddle.randn(shape=[qo_len_p, num_qo_heads, head_dim], dtype="float16") + k_p = paddle.randn(shape=[kv_len_p, num_kv_heads, head_dim], dtype="float16") + v_p = paddle.randn(shape=[kv_len_p, num_kv_heads, head_dim], dtype="float16") o_ref_p = flashinfer.prefill.single_prefill_with_kv_cache( - q_p, - k_p, - v_p, - causal=causal, - pos_encoding_mode=pos_encoding_mode, - ) - # Decode inputs - q_d = torch.randn( - batch_size_d, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 + q_p, k_p, v_p, causal=causal, pos_encoding_mode=pos_encoding_mode ) + q_d = paddle.randn(shape=[batch_size_d, num_qo_heads, head_dim], dtype="float16") num_pages_per_seq = (kv_len_d + page_size_d - 1) // page_size_d total_num_pages = num_pages_per_seq * batch_size_d if kv_layout_d == "HND": @@ -134,34 +104,29 @@ def test_pod_with_paged_kv_cache( tmp.append(2) tmp.append(v_d) kv_shape = tmp - kv_data_fp32 = torch.randn(*kv_shape, device="cuda:0", dtype=torch.float32) + kv_data_fp32 = paddle.randn(shape=kv_shape, dtype="float32") kv_data = kv_data_fp32.to(kv_dtype) kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] - # actual data is stored in non-contiguous memory assert ( - kv_data.stride(-4) - != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] + kv_data.get_strides()[-4] + != tuple(kv_data.shape)[-3] + * tuple(kv_data.shape)[-2] + * tuple(kv_data.shape)[-1] ) else: - kv_data_fp32 = torch.randn(*kv_shape, device="cuda:0", dtype=torch.float32) + kv_data_fp32 = paddle.randn(shape=kv_shape, dtype="float32") kv_data = kv_data_fp32.to(kv_dtype) kv_indptr_d = ( - torch.arange(0, batch_size_d + 1, device="cuda:0", dtype=torch.int32) - * num_pages_per_seq - ) - kv_indices_d = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) - kv_last_page_len = torch.full( - (batch_size_d,), - (kv_len_d - 1) % page_size_d + 1, - device="cuda:0", - dtype=torch.int32, + paddle.arange(start=0, end=batch_size_d + 1, dtype="int32") * num_pages_per_seq ) - - # Generate decode reference output - decode_workspace_buffer = torch.empty( - 32 * 1024 * 1024, device="cuda:0", dtype=torch.int8 + kv_indices_d = paddle.arange(start=0, end=total_num_pages, dtype="int32") + kv_last_page_len = paddle.full( + shape=(batch_size_d,), + fill_value=(kv_len_d - 1) % page_size_d + 1, + dtype="int32", ) + decode_workspace_buffer = paddle.empty(shape=32 * 1024 * 1024, dtype="int8") decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( decode_workspace_buffer, kv_layout_d ) @@ -178,12 +143,8 @@ def test_pod_with_paged_kv_cache( q_data_type=q_dtype, ) o_ref_d = decode_wrapper.run(q_d, kv_data) - - workspace_buffer = torch.empty(32 * 1024 * 1024, device="cuda:0", dtype=torch.int8) - pod_wrapper = flashinfer.PODWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout_d, - ) + workspace_buffer = paddle.empty(shape=32 * 1024 * 1024, dtype="int8") + pod_wrapper = flashinfer.PODWithPagedKVCacheWrapper(workspace_buffer, kv_layout_d) pod_wrapper.plan( kv_indptr_d, kv_indices_d, @@ -196,7 +157,6 @@ def test_pod_with_paged_kv_cache( data_type=kv_dtype, q_data_type=q_dtype, ) - o_p, o_d = pod_wrapper.run( q_p, k_p, @@ -206,72 +166,61 @@ def test_pod_with_paged_kv_cache( pos_encoding_mode_p=pos_encoding_mode, causal_p=causal, ) - # Prefill is run with batch size 1 - torch.testing.assert_close( - o_p, o_ref_p, rtol=1e-3, atol=1e-3, msg="Prefill mismatch" - ) - # Decode uses all batches at once. - torch.testing.assert_close( - o_d, o_ref_d, rtol=1e-3, atol=1e-3, msg="Decode mismatch" - ) + assert paddle.allclose( + x=o_p, y=o_ref_p, rtol=0.001, atol=0.001 + ).item(), "Prefill mismatch" + assert paddle.allclose( + x=o_d, y=o_ref_d, rtol=0.001, atol=0.001 + ).item(), "Decode mismatch" if __name__ == "__main__": test_pod_with_paged_kv_cache( - # Prefill params 128, 128, True, - # Decode params 80, 12288, 16, "NHD", - # Other shared params 8, 8, 128, "NONE", - torch.float16, - torch.float16, + "float16", + "float16", True, ) test_pod_with_paged_kv_cache( - # Prefill params 12288, 12288, True, - # Decode params 220, 12288, 16, "NHD", - # Other shared params 4, 16, 128, "NONE", - torch.float16, - torch.float16, + "float16", + "float16", True, ) test_pod_with_paged_kv_cache( - # Prefill params 16384, 16384, True, - # Decode params 250, 12288, 16, "NHD", - # Other shared params 4, 16, 128, "NONE", - torch.float16, - torch.float16, + "float16", + "float16", True, ) print("POD test(s) passed!") diff --git a/tests/test_quantization.py b/tests/test_quantization.py index b551ef4f40..5264637903 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,51 +15,48 @@ See the License for the specific language governing permissions and limitations under the License. """ - from typing import Literal import numpy import pytest -import torch import flashinfer -def numpy_packbits_ref(x_cpu: torch.Tensor, bitorder: Literal["big", "little"]): +def numpy_packbits_ref(x_cpu: paddle.Tensor, bitorder: Literal["big", "little"]): x_np = x_cpu.numpy() x_packed = numpy.packbits(x_np, bitorder=bitorder) - return torch.tensor(x_packed) + return paddle.to_tensor(data=x_packed) @pytest.mark.parametrize("num_elements", [1, 10, 99, 128, 999, 5000, 131072, 999999]) @pytest.mark.parametrize("bitorder", ["big", "little"]) def test_packbits(num_elements, bitorder): - torch.manual_seed(42) - x_cpu = torch.rand(num_elements) < 0.5 + paddle.seed(seed=42) + x_cpu = paddle.rand(shape=num_elements) < 0.5 x_gpu = x_cpu.to(0) x_packed_ref = numpy_packbits_ref(x_cpu, bitorder) x_packed = flashinfer.quantization.packbits(x_gpu, bitorder) - - assert torch.equal(x_packed_ref.cpu(), x_packed.cpu()) + assert paddle.equal_all(x=x_packed_ref.cpu(), y=x_packed.cpu()).item() @pytest.mark.parametrize("batch_size", [1, 10, 99, 128, 777, 999]) @pytest.mark.parametrize("bitorder", ["big", "little"]) def test_segment_packbits(batch_size, bitorder): - torch.manual_seed(42) - old_indptr = torch.cumsum(torch.arange(batch_size + 1), 0).to(0) + paddle.seed(seed=42) + old_indptr = paddle.cumsum(x=paddle.arange(end=batch_size + 1), axis=0).to(0) num_elements = old_indptr[-1].item() - x_cpu = torch.rand(num_elements) < 0.5 + x_cpu = paddle.rand(shape=num_elements) < 0.5 x_gpu = x_cpu.to(0) - y_gpu, new_indptr = flashinfer.quantization.segment_packbits( x_gpu, old_indptr, bitorder ) - for i in range(batch_size): x_segment_i = x_gpu[old_indptr[i] : old_indptr[i + 1]] y_segment_i_ref = flashinfer.packbits(x_segment_i, bitorder) - assert torch.equal(y_gpu[new_indptr[i] : new_indptr[i + 1]], y_segment_i_ref) + assert paddle.equal_all( + x=y_gpu[new_indptr[i] : new_indptr[i + 1]], y=y_segment_i_ref + ).item() if __name__ == "__main__": diff --git a/tests/test_rope.py b/tests/test_rope.py index 839293ed3d..a1a2d98ea5 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2024 by FlashInfer team. @@ -13,9 +19,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch from rope_reference import * import flashinfer @@ -43,30 +47,27 @@ def test_rope( ): rotary_dim = int(head_dim * partial_rotary_factor) nnz = batch_size * qkv_len - qkv_packed = torch.randn( - nnz, - (num_qo_heads + 2 * num_kv_heads) * head_dim, - dtype=torch.float16, - device="cuda:0", + qkv_packed = paddle.randn( + shape=[nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim], dtype="float16" ) q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) k = qkv_packed[ :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim ].reshape(nnz, num_kv_heads, head_dim) - indptr = torch.tensor( - [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" + indptr = paddle.to_tensor( + data=[(i * qkv_len) for i in range(batch_size + 1)], + dtype="int32", + place="gpu:0", ) - offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") - - # reference implementation + offsets = paddle.full(shape=(batch_size,), fill_value=offset, dtype="int32") if llama_version == "llama": freqs_cis = precompute_freqs_cis( rotary_dim, qkv_len + offset, 10000.0, use_scaled=False, device="cuda:0" - ).to("cuda:0") + ).to("gpu:0") else: freqs_cis = precompute_freqs_cis( - rotary_dim, qkv_len + offset, 5e5, use_scaled=True, device="cuda:0" - ).to("cuda:0") + rotary_dim, qkv_len + offset, 500000.0, use_scaled=True, device="cuda:0" + ).to("gpu:0") q_rot_ref, k_rot_ref = apply_rotary_emb( q.reshape(batch_size, qkv_len, num_qo_heads, head_dim)[..., :rotary_dim], k.reshape(batch_size, qkv_len, num_kv_heads, head_dim)[..., :rotary_dim], @@ -78,14 +79,12 @@ def test_rope( k_pass_ref = k.reshape(batch_size, qkv_len, num_kv_heads, head_dim)[ ..., rotary_dim: ] - q_rope_ref = torch.cat([q_rot_ref, q_pass_ref], dim=-1).reshape( + q_rope_ref = paddle.concat(x=[q_rot_ref, q_pass_ref], axis=-1).reshape( nnz, num_qo_heads, head_dim ) - k_rope_ref = torch.cat([k_rot_ref, k_pass_ref], dim=-1).reshape( + k_rope_ref = paddle.concat(x=[k_rot_ref, k_pass_ref], axis=-1).reshape( nnz, num_kv_heads, head_dim ) - - # flashinfer implementation if llama_version == "llama": if inplace: flashinfer.apply_rope_inplace( @@ -95,7 +94,7 @@ def test_rope( offsets, rotary_dim=rotary_dim, interleave=True, - rope_theta=1e4, + rope_theta=10000.0, ) q_rope, k_rope = q, k else: @@ -106,34 +105,31 @@ def test_rope( offsets, rotary_dim=rotary_dim, interleave=True, - rope_theta=1e4, + rope_theta=10000.0, ) + elif inplace: + flashinfer.apply_llama31_rope_inplace( + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=True, + rope_theta=500000.0, + ) + q_rope, k_rope = q, k else: - if inplace: - flashinfer.apply_llama31_rope_inplace( - q, - k, - indptr, - offsets, - rotary_dim=rotary_dim, - interleave=True, - rope_theta=5e5, - ) - q_rope, k_rope = q, k - else: - q_rope, k_rope = flashinfer.apply_llama31_rope( - q, - k, - indptr, - offsets, - rotary_dim=rotary_dim, - interleave=True, - rope_theta=5e5, - ) - - # compare - torch.testing.assert_close(q_rope_ref, q_rope, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(k_rope_ref, k_rope, rtol=1e-3, atol=1e-3) + q_rope, k_rope = flashinfer.apply_llama31_rope( + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=True, + rope_theta=500000.0, + ) + assert paddle.allclose(x=q_rope_ref, y=q_rope, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=k_rope_ref, y=k_rope, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @@ -146,7 +142,7 @@ def test_rope( @pytest.mark.parametrize("partial_rotary_factor", [0.25, 0.5, 0.75, 1.0]) @pytest.mark.parametrize("inplace", [False, True]) @pytest.mark.parametrize("interleave", [True, False]) -@pytest.mark.parametrize("idtype", [torch.int32, torch.int64]) +@pytest.mark.parametrize("idtype", ["int32", "int64"]) def test_rope_pos_ids( batch_size, qkv_len, @@ -162,28 +158,23 @@ def test_rope_pos_ids( ): rotary_dim = int(head_dim * partial_rotary_factor) nnz = batch_size * qkv_len - qkv_packed = torch.randn( - nnz, - (num_qo_heads + 2 * num_kv_heads) * head_dim, - dtype=torch.float16, - device="cuda:0", + qkv_packed = paddle.randn( + shape=[nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim], dtype="float16" ) q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) k = qkv_packed[ :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim ].reshape(nnz, num_kv_heads, head_dim) - indptr = torch.tensor( - [i * qkv_len for i in range(batch_size + 1)], dtype=idtype, device="cuda:0" + indptr = paddle.to_tensor( + data=[(i * qkv_len) for i in range(batch_size + 1)], dtype=idtype, place="gpu:0" ) - offsets = torch.full((batch_size,), offset, dtype=idtype, device="cuda:0") - - pos_ids = torch.cat( - [ - torch.arange(offset, qkv_len + offset, dtype=idtype) + offsets = paddle.full(shape=(batch_size,), fill_value=offset, dtype=idtype) + pos_ids = paddle.concat( + x=[ + paddle.arange(start=offset, end=qkv_len + offset, dtype=idtype) for _ in range(batch_size) ] - ).to("cuda:0") - + ).to("gpu:0") if llama_version == "llama": if inplace: q_clone, k_clone = q.clone(), k.clone() @@ -194,7 +185,7 @@ def test_rope_pos_ids( offsets, rotary_dim=rotary_dim, interleave=interleave, - rope_theta=1e4, + rope_theta=10000.0, ) q_rope, k_rope = q, k flashinfer.apply_rope_pos_ids_inplace( @@ -203,7 +194,7 @@ def test_rope_pos_ids( pos_ids, rotary_dim=rotary_dim, interleave=interleave, - rope_theta=1e4, + rope_theta=10000.0, ) q_rope_pos_ids, k_rope_pos_ids = q_clone, k_clone else: @@ -214,72 +205,71 @@ def test_rope_pos_ids( offsets, rotary_dim=rotary_dim, interleave=interleave, - rope_theta=1e4, + rope_theta=10000.0, ) - q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_rope_pos_ids( q, k, pos_ids, rotary_dim=rotary_dim, interleave=interleave, - rope_theta=1e4, + rope_theta=10000.0, ) + elif inplace: + q_clone, k_clone = q.clone(), k.clone() + flashinfer.apply_llama31_rope_inplace( + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=interleave, + rope_theta=500000.0, + ) + q_rope, k_rope = q, k + flashinfer.apply_llama31_rope_pos_ids_inplace( + q_clone, + k_clone, + pos_ids, + rotary_dim=rotary_dim, + interleave=interleave, + rope_theta=500000.0, + ) + q_rope_pos_ids, k_rope_pos_ids = q_clone, k_clone else: - if inplace: - q_clone, k_clone = q.clone(), k.clone() - flashinfer.apply_llama31_rope_inplace( - q, - k, - indptr, - offsets, - rotary_dim=rotary_dim, - interleave=interleave, - rope_theta=5e5, - ) - q_rope, k_rope = q, k - flashinfer.apply_llama31_rope_pos_ids_inplace( - q_clone, - k_clone, - pos_ids, - rotary_dim=rotary_dim, - interleave=interleave, - rope_theta=5e5, - ) - q_rope_pos_ids, k_rope_pos_ids = q_clone, k_clone - else: - q_rope, k_rope = flashinfer.apply_llama31_rope( - q, - k, - indptr, - offsets, - rotary_dim=rotary_dim, - interleave=interleave, - rope_theta=5e5, - ) - - q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_llama31_rope_pos_ids( - q, - k, - pos_ids, - rotary_dim=rotary_dim, - interleave=interleave, - rope_theta=5e5, - ) - - # compare - torch.testing.assert_close(q_rope_pos_ids, q_rope, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(k_rope_pos_ids, k_rope, rtol=1e-3, atol=1e-3) + q_rope, k_rope = flashinfer.apply_llama31_rope( + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=interleave, + rope_theta=500000.0, + ) + q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_llama31_rope_pos_ids( + q, + k, + pos_ids, + rotary_dim=rotary_dim, + interleave=interleave, + rope_theta=500000.0, + ) + assert paddle.allclose( + x=q_rope_pos_ids, y=q_rope, rtol=0.001, atol=0.001 + ).item(), "" + assert paddle.allclose( + x=k_rope_pos_ids, y=k_rope, rtol=0.001, atol=0.001 + ).item(), "" class FlashInferRotaryEmbedding(RotaryEmbedding): def forward_cuda( self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + positions: paddle.Tensor, + query: paddle.Tensor, + key: paddle.Tensor, + offsets: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: flashinfer.apply_rope_with_cos_sin_cache_inplace( positions=positions, query=query, @@ -294,12 +284,12 @@ def forward_cuda( @pytest.mark.parametrize( "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", [ - (64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1), - (256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2), - (64, 32, 2048, 8432, True, torch.bfloat16, "cuda", 2, 199, 4, 1), - (64, 64, 32, 8000, False, torch.bfloat16, "cuda", 32, 32, 1, 1), - (64, 64, 32, 8000, False, torch.bfloat16, "cuda", 32, 32, 1, 1), - (256, 128, 4096, 9231, False, torch.bfloat16, "cuda", 3, 231, 4, 2), + (64, 64, 32, 8000, True, "bfloat16", "cuda", 32, 32, 1, 1), + (256, 128, 4096, 10000, True, "bfloat16", "cuda", 2, 512, 4, 2), + (64, 32, 2048, 8432, True, "bfloat16", "cuda", 2, 199, 4, 1), + (64, 64, 32, 8000, False, "bfloat16", "cuda", 32, 32, 1, 1), + (64, 64, 32, 8000, False, "bfloat16", "cuda", 32, 32, 1, 1), + (256, 128, 4096, 9231, False, "bfloat16", "cuda", 3, 231, 4, 2), ], ) def test_rope_cos_sin_cache( @@ -308,7 +298,7 @@ def test_rope_cos_sin_cache( max_position_embeddings: int, base: int, is_neox_style: bool, - dtype: torch.dtype, + dtype: paddle.dtype, device: str, batch_size: int, seq_len: int, @@ -333,62 +323,45 @@ def test_rope_cos_sin_cache( dtype, device, ) - - pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) - query = torch.randn( - batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device + pos_ids = paddle.arange(end=seq_len).tile(repeat_times=batch_size) + query = paddle.randn( + shape=[batch_size * seq_len, num_q_heads * head_size], dtype=dtype ) - key = torch.randn( - batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device + key = paddle.randn( + shape=[batch_size * seq_len, num_kv_heads * head_size], dtype=dtype ) - query_ref, key_ref = query.clone(), key.clone() query_flashinfer, key_flashinfer = query.clone(), key.clone() - query_ref_out, key_ref_out = rope_ref.forward_native(pos_ids, query_ref, key_ref) query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda( pos_ids, query_flashinfer, key_flashinfer ) - - torch.testing.assert_close( - query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2 - ) - torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2) + assert paddle.allclose( + x=query_ref_out, y=query_flashinfer_out, atol=0.01, rtol=0.01 + ).item(), "" + assert paddle.allclose( + x=key_ref_out, y=key_flashinfer_out, atol=0.01, rtol=0.01 + ).item(), "" @pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047]) -@pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) -def test_mla_rope_quantize( - num_tokens, - input_dtype, - quant_dtype, -): +@pytest.mark.parametrize("input_dtype", ["float16", "bfloat16"]) +@pytest.mark.parametrize("quant_dtype", [paddle.float8_e4m3fn, paddle.float8_e5m2]) +def test_mla_rope_quantize(num_tokens, input_dtype, quant_dtype): device = "cuda:0" num_qo_heads = 128 - q_in = torch.randn(num_tokens, num_qo_heads, 576, dtype=input_dtype, device=device) - k_in = torch.randn(num_tokens, 576, dtype=input_dtype, device=device) - pos_ids = torch.arange(num_tokens, device=device) - - # baseline + q_in = paddle.randn(shape=[num_tokens, num_qo_heads, 576], dtype=input_dtype) + k_in = paddle.randn(shape=[num_tokens, 576], dtype=input_dtype) + pos_ids = paddle.arange(end=num_tokens) rope_flashinfer = FlashInferRotaryEmbedding( - 576, - 64, - 4096, - 10000, - False, # is_neox_style - input_dtype, - device, + 576, 64, 4096, 10000, False, input_dtype, device ) - q_out_f16_ref, k_out_f16_ref = rope_flashinfer.forward_native(pos_ids, q_in, k_in) q_out_f8_ref, k_out_f8_ref = map( - lambda x: x.to(quant_dtype), - (q_out_f16_ref, k_out_f16_ref), + lambda x: x.to(quant_dtype), (q_out_f16_ref, k_out_f16_ref) ) - - q_out = torch.empty_like(q_in, dtype=quant_dtype) - k_out = torch.empty_like(k_in, dtype=quant_dtype) + q_out = paddle.empty_like(x=q_in, dtype=quant_dtype) + k_out = paddle.empty_like(x=k_in, dtype=quant_dtype) flashinfer.rope.mla_rope_quantize_fp8( q_in[..., :64], k_in[..., :64], @@ -404,19 +377,19 @@ def test_mla_rope_quantize( quant_scale_q=1.0, quant_scale_kv=1.0, ) - - torch.testing.assert_close( - q_out_f8_ref.float(), q_out.float(), atol=1e-2, rtol=2e-1 - ) - torch.testing.assert_close( - k_out_f8_ref.float(), k_out.float(), atol=1e-2, rtol=2e-1 - ) + assert paddle.allclose( + x=q_out_f8_ref.astype(dtype="float32"), + y=q_out.astype(dtype="float32"), + atol=0.01, + rtol=0.2, + ).item(), "" + assert paddle.allclose( + x=k_out_f8_ref.astype(dtype="float32"), + y=k_out.astype(dtype="float32"), + atol=0.01, + rtol=0.2, + ).item(), "" if __name__ == "__main__": - # test_rope(2, 1, 8, 8, 1, 128, "llama", 1.0, False) - # test_rope_pos_ids(2, 1, 8, 8, 1, 128, "llama31", 1.0, False) - # test_rope_cos_sin_cache( - # 64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1 - # ) - test_mla_rope_quantize(1, 1, torch.float16, torch.float8_e4m3fn) + test_mla_rope_quantize(1, 1, "float16", paddle.float8_e4m3fn) diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 333a24bce8..51939c1954 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2024 by FlashInfer team. @@ -13,16 +19,14 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch import flashinfer def normal_distribution(std): def normal_noise(shape, device): - return torch.randn(shape, device=device) * std + return paddle.randn(shape=shape) * std normal_noise.__name__ = f"normal_distribution(std={std})" return normal_noise @@ -30,9 +34,9 @@ def normal_noise(shape, device): def gumbel_distribution(beta): def gumbel_noise(shape, device): - U = torch.rand(shape, device=device) + U = paddle.rand(shape=shape) eps = 1e-20 - return torch.log(-torch.log(U + eps) + eps) / beta + return paddle.log(x=-paddle.log(x=U + eps) + eps) / beta gumbel_noise.__name__ = f"gumbel_distribution(beta={beta})" return gumbel_noise @@ -42,11 +46,7 @@ def gumbel_noise(shape, device): @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", - [ - normal_distribution(1), - normal_distribution(5), - gumbel_distribution(0.1), - ], + [normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1)], ) @pytest.mark.parametrize("temperature", [1.0, 0.5, 0.1]) @pytest.mark.parametrize("temperature_arr", [True, False]) @@ -54,172 +54,171 @@ def gumbel_noise(shape, device): def test_softmax( batch_size, vocab_size, distribution, temperature, temperature_arr, neg_inf_input ): - torch.manual_seed(42) + paddle.seed(seed=42) logits = distribution((batch_size, vocab_size), "cuda:0") if neg_inf_input: - # assign random logits to -inf - num_inf = torch.randint(0, logits.numel() - 1, (), device=logits.device).item() - inf_idx = torch.randperm(logits.numel(), device=logits.device)[:num_inf] - logits.view(-1).index_fill_(0, inf_idx, float("-inf")) - + num_inf = paddle.randint(low=0, high=logits.size - 1, shape=()).item() + inf_idx = paddle.randperm(n=logits.size)[:num_inf] + logits.view(-1).index_fill_(axis=0, index=inf_idx, value=float("-inf")) if temperature_arr: - temperature_arr = torch.full((batch_size,), temperature, device="cuda:0") + temperature_arr = paddle.full(shape=(batch_size,), fill_value=temperature) probs = flashinfer.sampling.softmax(logits, temperature=temperature_arr) - logits_scaled = logits / temperature_arr.unsqueeze(-1) + logits_scaled = logits / temperature_arr.unsqueeze(axis=-1) else: probs = flashinfer.sampling.softmax(logits, temperature=temperature) logits_scaled = logits / temperature - - probs_ref = torch.softmax(logits_scaled, dim=-1) - - assert torch.allclose(probs, probs_ref, atol=1e-5) + probs_ref = paddle.nn.functional.softmax(x=logits_scaled, axis=-1) + assert paddle.allclose(x=probs, y=probs_ref, atol=1e-05).item() @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", - [ - normal_distribution(1), - normal_distribution(5), - gumbel_distribution(0.1), - ], + [normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1)], ) @pytest.mark.parametrize("zero_ratio", [0.0, 0.5, 0.9]) def test_sampling_freq(vocab_size, distribution, zero_ratio): - torch.manual_seed(42) + paddle.seed(seed=42) num_trials = 5000000 logits = distribution((1, vocab_size), "cuda:0") - zero_indices = torch.randperm(vocab_size)[: int(vocab_size * zero_ratio)] + zero_indices = paddle.randperm(n=vocab_size)[: int(vocab_size * zero_ratio)] logits[:, zero_indices] = -float("inf") - probs = torch.softmax(logits, dim=-1) - counter = torch.zeros(vocab_size, dtype=torch.int32, device=logits.device) - + probs = paddle.nn.functional.softmax(x=logits, axis=-1) + counter = paddle.zeros(shape=vocab_size, dtype="int32") samples = flashinfer.sampling.sampling_from_probs( - probs, indices=torch.zeros(num_trials, dtype=torch.int32, device=logits.device) + probs, indices=paddle.zeros(shape=num_trials, dtype="int32") ) - counter.scatter_add_(0, samples.long(), torch.ones_like(samples)) - freq = counter.float() / num_trials - - assert torch.all(counter[zero_indices] == 0) - similarity = torch.cosine_similarity(freq, probs) + counter.put_along_axis_( + axis=0, + indices=samples.astype(dtype="int64"), + values=paddle.ones_like(x=samples), + reduce="add", + ) + freq = counter.astype(dtype="float32") / num_trials + assert paddle.all(x=counter[zero_indices] == 0) + similarity = paddle.nn.functional.cosine_similarity(x1=freq, x2=probs) assert similarity > 0.99, f"similarity: {similarity}" @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", - [ - normal_distribution(1), - normal_distribution(5), - gumbel_distribution(0.1), - ], + [normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1)], ) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) def test_top_p_sampling_freq(vocab_size, distribution, p): - # use torch profiler to check the performance of the code - torch.manual_seed(42) + paddle.seed(seed=42) logits = distribution((1, vocab_size), "cuda:0") - probs = torch.softmax(logits, dim=-1) - sorted_prob, indices = torch.sort(probs, descending=False) - cdf = torch.cumsum(sorted_prob, dim=-1) - mask = torch.zeros(1, vocab_size, dtype=torch.int32, device=logits.device) - mask.scatter_add_(1, indices, (cdf > (1 - p)).int()) - + probs = paddle.nn.functional.softmax(x=logits, axis=-1) + sorted_prob, indices = paddle.sort(x=probs, descending=False), paddle.argsort( + x=probs, descending=False + ) + cdf = paddle.cumsum(x=sorted_prob, axis=-1) + mask = paddle.zeros(shape=[1, vocab_size], dtype="int32") + mask.put_along_axis_( + axis=1, + indices=indices, + values=(cdf > 1 - p).astype(dtype="int32"), + reduce="add", + ) renorm_probs = flashinfer.sampling.top_p_renorm_probs(probs, p) - counter = torch.zeros(vocab_size, dtype=torch.int32, device=logits.device) + counter = paddle.zeros(shape=vocab_size, dtype="int32") num_trials = 5000000 samples = flashinfer.sampling.top_p_sampling_from_probs( - probs, - p, - indices=torch.zeros(num_trials, dtype=torch.int32, device=logits.device), + probs, p, indices=paddle.zeros(shape=num_trials, dtype="int32") + ) + counter.put_along_axis_( + axis=0, + indices=samples.astype(dtype="int64"), + values=paddle.ones_like(x=samples), + reduce="add", ) - counter.scatter_add_(0, samples.long(), torch.ones_like(samples)) - freq = counter.float() / num_trials - assert torch.all(mask[torch.arange(1), samples] == 1) - similarity = torch.cosine_similarity(freq, renorm_probs) + freq = counter.astype(dtype="float32") / num_trials + assert paddle.all(x=mask[paddle.arange(end=1), samples] == 1) + similarity = paddle.nn.functional.cosine_similarity(x1=freq, x2=renorm_probs) assert similarity > 0.99, f"similarity: {similarity}" @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", - [ - normal_distribution(1), - normal_distribution(5), - gumbel_distribution(0.1), - ], + [normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1)], ) @pytest.mark.parametrize("k", [10, 100, 500]) def test_top_k_sampling_freq(vocab_size, distribution, k): if k > vocab_size: pytest.skip("k should be less than vocab_size") - torch.manual_seed(42) + paddle.seed(seed=42) logits = distribution((1, vocab_size), "cuda:0") - probs = torch.softmax(logits, dim=-1) - sorted_prob, _ = torch.sort(probs, descending=True) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) + sorted_prob, _ = paddle.sort(x=probs, descending=True), paddle.argsort( + x=probs, descending=True + ) pivot = sorted_prob[:, k - 1] - mask = (probs >= pivot.unsqueeze(-1)).int() - + mask = (probs >= pivot.unsqueeze(axis=-1)).astype(dtype="int32") renorm_probs = flashinfer.sampling.top_k_renorm_probs(probs, k) - counter = torch.zeros(vocab_size, dtype=torch.int32, device=logits.device) + counter = paddle.zeros(shape=vocab_size, dtype="int32") num_trials = 5000000 samples = flashinfer.sampling.top_k_sampling_from_probs( - probs, - k, - indices=torch.zeros(num_trials, dtype=torch.int32, device=logits.device), + probs, k, indices=paddle.zeros(shape=num_trials, dtype="int32") + ) + counter.put_along_axis_( + axis=0, + indices=samples.astype(dtype="int64"), + values=paddle.ones_like(x=samples), + reduce="add", ) - counter.scatter_add_(0, samples.long(), torch.ones_like(samples)) - freq = counter.float() / num_trials - assert torch.all(mask[torch.arange(1), samples] == 1) - similarity = torch.cosine_similarity(freq, renorm_probs) + freq = counter.astype(dtype="float32") / num_trials + assert paddle.all(x=mask[paddle.arange(end=1), samples] == 1) + similarity = paddle.nn.functional.cosine_similarity(x1=freq, x2=renorm_probs) assert similarity > 0.99, f"similarity: {similarity}" @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) def test_sampling(batch_size, vocab_size): - torch.manual_seed(42) - pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") - normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) - + paddle.seed(seed=42) + pre_norm_prob = paddle.rand(shape=[batch_size, vocab_size]) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(axis=-1, keepdim=True) num_trails = 5000 for _ in range(num_trails): samples = flashinfer.sampling.sampling_from_probs(normalized_prob) - assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + assert paddle.all(x=samples < vocab_size) and paddle.all(x=samples >= 0) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) def test_sampling_from_logits(batch_size, vocab_size): - torch.manual_seed(42) - logits = torch.randn(batch_size, vocab_size, device="cuda:0") + paddle.seed(seed=42) + logits = paddle.randn(shape=[batch_size, vocab_size]) num_trails = 5000 for _ in range(num_trails): samples = flashinfer.sampling.sampling_from_logits(logits) - assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + assert paddle.all(x=samples < vocab_size) and paddle.all(x=samples >= 0) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize( "distribution", - [ - normal_distribution(1), - normal_distribution(5), - gumbel_distribution(0.1), - ], + [normal_distribution(1), normal_distribution(5), gumbel_distribution(0.1)], ) def test_sampling_from_logits_freq(vocab_size, distribution): - torch.manual_seed(42) + paddle.seed(seed=42) num_trials = 5000000 logits = distribution((1, vocab_size), "cuda:0") - probs = torch.softmax(logits, dim=-1) - counter = torch.zeros(vocab_size, dtype=torch.int32, device=logits.device) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) + counter = paddle.zeros(shape=vocab_size, dtype="int32") samples = flashinfer.sampling.sampling_from_logits( - logits, indices=torch.zeros(num_trials, dtype=torch.int32, device=logits.device) + logits, indices=paddle.zeros(shape=num_trials, dtype="int32") + ) + counter.put_along_axis_( + axis=0, + indices=samples.astype(dtype="int64"), + values=paddle.ones_like(x=samples), + reduce="add", ) - counter.scatter_add_(0, samples.long(), torch.ones_like(samples)) - freq = counter.float() / num_trials - similarity = torch.cosine_similarity(freq, probs) + freq = counter.astype(dtype="float32") / num_trials + similarity = paddle.nn.functional.cosine_similarity(x1=freq, x2=probs) assert similarity > 0.99, f"similarity: {similarity}" @@ -227,20 +226,26 @@ def test_sampling_from_logits_freq(vocab_size, distribution): @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) def test_top_p_sampling(batch_size, vocab_size, p): - torch.manual_seed(42) - eps = 1e-4 - pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") - normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) - sorted_prob, indices = torch.sort(normalized_prob, descending=False) - cdf = torch.cumsum(sorted_prob, dim=-1) - mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0") - mask.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) - + paddle.seed(seed=42) + eps = 0.0001 + pre_norm_prob = paddle.rand(shape=[batch_size, vocab_size]) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(axis=-1, keepdim=True) + sorted_prob, indices = paddle.sort( + x=normalized_prob, descending=False + ), paddle.argsort(x=normalized_prob, descending=False) + cdf = paddle.cumsum(x=sorted_prob, axis=-1) + mask = paddle.zeros(shape=[batch_size, vocab_size], dtype="int32") + mask.put_along_axis_( + axis=1, + indices=indices, + values=(cdf > 1 - p - eps).astype(dtype="int32"), + reduce="add", + ) num_trails = 1000 for _ in range(num_trails): samples = flashinfer.sampling.top_p_sampling_from_probs(normalized_prob, p) - assert torch.all(samples < vocab_size) and torch.all(samples >= 0) - assert torch.all(mask[torch.arange(batch_size), samples] == 1) + assert paddle.all(x=samples < vocab_size) and paddle.all(x=samples >= 0) + assert paddle.all(x=mask[paddle.arange(end=batch_size), samples] == 1) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @@ -249,20 +254,21 @@ def test_top_p_sampling(batch_size, vocab_size, p): def test_top_k_sampling(batch_size, vocab_size, k): if k > vocab_size: pytest.skip("k should be less than vocab_size") - torch.manual_seed(42) - pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") - normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) - sorted_prob, _ = torch.sort(normalized_prob, descending=True) + paddle.seed(seed=42) + pre_norm_prob = paddle.rand(shape=[batch_size, vocab_size]) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(axis=-1, keepdim=True) + sorted_prob, _ = paddle.sort(x=normalized_prob, descending=True), paddle.argsort( + x=normalized_prob, descending=True + ) pivot = sorted_prob[:, k - 1] - mask = (normalized_prob >= pivot.unsqueeze(-1)).int() - + mask = (normalized_prob >= pivot.unsqueeze(axis=-1)).astype(dtype="int32") num_trails = 1000 for _ in range(num_trails): samples = flashinfer.sampling.top_k_sampling_from_probs(normalized_prob, k) - assert torch.all(samples < vocab_size) and torch.all(samples >= 0) - assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ - torch.arange(batch_size), samples - ] + assert paddle.all(x=samples < vocab_size) and paddle.all(x=samples >= 0) + assert paddle.all( + x=mask[paddle.arange(end=batch_size), samples] == 1 + ), normalized_prob[paddle.arange(end=batch_size), samples] @pytest.mark.parametrize("batch_size", [1, 99, 989]) @@ -271,48 +277,51 @@ def test_top_k_sampling(batch_size, vocab_size, k): def test_top_k_sampling_with_variable_k(batch_size, vocab_size, k): if k > vocab_size: pytest.skip("k should be less than vocab_size") - torch.manual_seed(42) - pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") - normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) - sorted_prob, _ = torch.sort(normalized_prob, descending=True) - k = torch.randint(1, k + 1, (batch_size,), device="cuda:0") - pivot = sorted_prob[torch.arange(batch_size), k - 1] - mask = (normalized_prob >= pivot.unsqueeze(-1)).int() - + paddle.seed(seed=42) + pre_norm_prob = paddle.rand(shape=[batch_size, vocab_size]) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(axis=-1, keepdim=True) + sorted_prob, _ = paddle.sort(x=normalized_prob, descending=True), paddle.argsort( + x=normalized_prob, descending=True + ) + k = paddle.randint(low=1, high=k + 1, shape=(batch_size,)) + pivot = sorted_prob[paddle.arange(end=batch_size), k - 1] + mask = (normalized_prob >= pivot.unsqueeze(axis=-1)).astype(dtype="int32") num_trails = 1000 for _ in range(num_trails): samples = flashinfer.sampling.top_k_sampling_from_probs(normalized_prob, k) - assert torch.all(samples < vocab_size) and torch.all(samples >= 0) - assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ - torch.arange(batch_size), samples - ] + assert paddle.all(x=samples < vocab_size) and paddle.all(x=samples >= 0) + assert paddle.all( + x=mask[paddle.arange(end=batch_size), samples] == 1 + ), normalized_prob[paddle.arange(end=batch_size), samples] @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1]) def test_min_p_sampling(batch_size, vocab_size, p): - torch.manual_seed(42) - pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") - normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) - sorted_prob, indices = torch.sort(normalized_prob, descending=False) - # scale min-p - top_probs = sorted_prob[:, -1].unsqueeze(-1) + paddle.seed(seed=42) + pre_norm_prob = paddle.rand(shape=[batch_size, vocab_size]) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(axis=-1, keepdim=True) + sorted_prob, indices = paddle.sort( + x=normalized_prob, descending=False + ), paddle.argsort(x=normalized_prob, descending=False) + top_probs = sorted_prob[:, -1].unsqueeze(axis=-1) scaled_p = p * top_probs - # min-p mask - mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0") - mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int()) - min_p_tensor = torch.full((batch_size,), p, device="cuda:0") - + mask = paddle.zeros(shape=[batch_size, vocab_size], dtype="int32") + mask.put_along_axis_( + axis=1, + indices=indices, + values=(sorted_prob >= scaled_p).astype(dtype="int32"), + reduce="add", + ) + min_p_tensor = paddle.full(shape=(batch_size,), fill_value=p) num_trails = 1000 for _ in range(num_trails): samples = flashinfer.sampling.min_p_sampling_from_probs( - normalized_prob, - min_p_tensor, + normalized_prob, min_p_tensor ) - - assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[ - torch.nonzero(mask[torch.arange(batch_size), samples] == 0) + assert paddle.all(x=mask[paddle.arange(end=batch_size), samples] == 1), samples[ + paddle.nonzero(x=mask[paddle.arange(end=batch_size), samples] == 0) ] @@ -320,42 +329,44 @@ def test_min_p_sampling(batch_size, vocab_size, p): @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): - torch.manual_seed(42) + paddle.seed(seed=42) if p == 0.1: k = int(vocab_size * 0.5) elif p == 0.5: k = int(vocab_size * 0.1) else: raise ValueError("p not recognized") - eps = 1e-4 - pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") - normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) - # top-p mask - sorted_prob, indices = torch.sort(normalized_prob, descending=False) - cdf = torch.cumsum(sorted_prob, dim=-1) - mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0") - mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) - # top-k mask - sorted_prob, _ = torch.sort(normalized_prob, descending=True) + eps = 0.0001 + pre_norm_prob = paddle.rand(shape=[batch_size, vocab_size]) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(axis=-1, keepdim=True) + sorted_prob, indices = paddle.sort( + x=normalized_prob, descending=False + ), paddle.argsort(x=normalized_prob, descending=False) + cdf = paddle.cumsum(x=sorted_prob, axis=-1) + mask_top_p = paddle.zeros(shape=[batch_size, vocab_size], dtype="int32") + mask_top_p.put_along_axis_( + axis=1, + indices=indices, + values=(cdf > 1 - p - eps).astype(dtype="int32"), + reduce="add", + ) + sorted_prob, _ = paddle.sort(x=normalized_prob, descending=True), paddle.argsort( + x=normalized_prob, descending=True + ) pivot = sorted_prob[:, k - 1] - mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int() - # overall mask - mask = torch.minimum(mask_top_p, mask_top_k) - top_p_tensor = torch.full((batch_size,), p, device="cuda:0") - top_k_tensor = torch.full((batch_size,), k, device="cuda:0") - + mask_top_k = (normalized_prob >= pivot.unsqueeze(axis=-1)).astype(dtype="int32") + mask = paddle.minimum(x=mask_top_p, y=mask_top_k) + top_p_tensor = paddle.full(shape=(batch_size,), fill_value=p) + top_k_tensor = paddle.full(shape=(batch_size,), fill_value=k) num_trails = 1000 for _ in range(num_trails): samples = flashinfer.sampling.top_k_top_p_sampling_from_probs( - normalized_prob, - top_k_tensor, - top_p_tensor, - filter_apply_order="joint", + normalized_prob, top_k_tensor, top_p_tensor, filter_apply_order="joint" ) - assert torch.all(samples < vocab_size) and torch.all(samples >= 0) - assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ - torch.arange(batch_size), samples - ] + assert paddle.all(x=samples < vocab_size) and paddle.all(x=samples >= 0) + assert paddle.all( + x=mask[paddle.arange(end=batch_size), samples] == 1 + ), normalized_prob[paddle.arange(end=batch_size), samples] @pytest.mark.parametrize("batch_size", [1, 99, 989]) @@ -363,30 +374,30 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): @pytest.mark.parametrize("k", [100]) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_top_k_top_p_sampling_from_probs_logits_alignment(batch_size, vocab_size, k, p): - torch.manual_seed(42) - logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5 - generator_logits = torch.Generator("cuda:0") + paddle.seed(seed=42) + logits = paddle.randn(shape=[batch_size, vocab_size]) * 5 + generator_logits = paddle.framework.core.default_cpu_generator() generator_probs = generator_logits.clone_state() samples = flashinfer.sampling.top_k_top_p_sampling_from_logits( logits, k, p, filter_apply_order="top_k_first", generator=generator_logits ) samples_ref = flashinfer.sampling.top_k_top_p_sampling_from_probs( - torch.softmax(logits, dim=-1), + paddle.nn.functional.softmax(x=logits, axis=-1), k, p, filter_apply_order="top_k_first", generator=generator_probs, ) - assert torch.all(samples == samples_ref) + assert paddle.all(x=samples == samples_ref) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_top_k_top_p_joint_sampling_from_logits(batch_size, vocab_size, p): - torch.manual_seed(42) - logits = torch.rand(batch_size, vocab_size, device="cuda:0") * 5 - generator_logits = torch.Generator("cuda:0") + paddle.seed(seed=42) + logits = paddle.rand(shape=[batch_size, vocab_size]) * 5 + generator_logits = paddle.framework.core.default_cpu_generator() generator_probs = generator_logits.clone_state() if p == 0.1: k = int(vocab_size * 0.5) @@ -394,45 +405,46 @@ def test_top_k_top_p_joint_sampling_from_logits(batch_size, vocab_size, p): k = int(vocab_size * 0.1) else: raise ValueError("p not recognized") - samples = flashinfer.sampling.top_k_top_p_sampling_from_logits( logits, k, p, filter_apply_order="joint", generator=generator_logits ) - samples_ref = flashinfer.sampling.top_k_top_p_sampling_from_probs( - torch.softmax(logits, dim=-1), + paddle.nn.functional.softmax(x=logits, axis=-1), k, p, filter_apply_order="joint", generator=generator_probs, ) - assert torch.all(samples == samples_ref) + assert paddle.all(x=samples == samples_ref) @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9, 1.0]) def test_top_p_renorm_probs(batch_size, vocab_size, p): - torch.manual_seed(42) - pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") - normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) - sorted_prob, indices = torch.sort(normalized_prob, descending=False) - cdf = torch.cumsum(sorted_prob, dim=-1) - mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0") - mask.scatter_add_(1, indices, (cdf >= (1 - p)).int()) + paddle.seed(seed=42) + pre_norm_prob = paddle.rand(shape=[batch_size, vocab_size]) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(axis=-1, keepdim=True) + sorted_prob, indices = paddle.sort( + x=normalized_prob, descending=False + ), paddle.argsort(x=normalized_prob, descending=False) + cdf = paddle.cumsum(x=sorted_prob, axis=-1) + mask = paddle.zeros(shape=[batch_size, vocab_size], dtype="int32") + mask.put_along_axis_( + axis=1, + indices=indices, + values=(cdf >= 1 - p).astype(dtype="int32"), + reduce="add", + ) renorm_prob_ground_truth = normalized_prob.clone() renorm_prob_ground_truth[mask == 0] = 0 renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( - dim=-1, keepdim=True + axis=-1, keepdim=True ) - renorm_prob = flashinfer.sampling.top_p_renorm_probs(normalized_prob, p) - torch.testing.assert_close( - renorm_prob_ground_truth, - renorm_prob, - rtol=1e-3, - atol=1e-3, - ) + assert paddle.allclose( + x=renorm_prob_ground_truth, y=renorm_prob, rtol=0.001, atol=0.001 + ).item(), "" @pytest.mark.parametrize("batch_size", [1, 99, 989]) @@ -441,26 +453,24 @@ def test_top_p_renorm_probs(batch_size, vocab_size, p): def test_top_k_renorm_probs(batch_size, vocab_size, k): if k > vocab_size: pytest.skip("k should be less than vocab_size") - torch.manual_seed(42) - pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") - normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) - sorted_prob, _ = torch.sort(normalized_prob, descending=True) + paddle.seed(seed=42) + pre_norm_prob = paddle.rand(shape=[batch_size, vocab_size]) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(axis=-1, keepdim=True) + sorted_prob, _ = paddle.sort(x=normalized_prob, descending=True), paddle.argsort( + x=normalized_prob, descending=True + ) pivot = sorted_prob[:, k - 1] - mask = (normalized_prob >= pivot.unsqueeze(-1)).int() + mask = (normalized_prob >= pivot.unsqueeze(axis=-1)).astype(dtype="int32") renorm_prob_ground_truth = normalized_prob.clone() renorm_prob_ground_truth[mask == 0] = 0 renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( - dim=-1, keepdim=True + axis=-1, keepdim=True ) - renorm_prob = flashinfer.sampling.top_k_renorm_probs(normalized_prob, k) for i in range(batch_size): - torch.testing.assert_close( - renorm_prob_ground_truth[i], - renorm_prob[i], - rtol=1e-3, - atol=1e-3, - ) + assert paddle.allclose( + x=renorm_prob_ground_truth[i], y=renorm_prob[i], rtol=0.001, atol=0.001 + ).item(), "" @pytest.mark.parametrize("batch_size", [1, 99, 989]) @@ -470,23 +480,21 @@ def test_top_k_renorm_probs(batch_size, vocab_size, k): def test_top_k_mask_logits(batch_size, vocab_size, k, neginf_input): if k > vocab_size: pytest.skip("k should be less than vocab_size") - torch.manual_seed(42) - logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5 + paddle.seed(seed=42) + logits = paddle.randn(shape=[batch_size, vocab_size]) * 5 if neginf_input: - num_neginf = torch.randint(1, vocab_size * batch_size, (1,)).item() - idxs = torch.randperm(batch_size * vocab_size, device="cuda:0")[:num_neginf] + num_neginf = paddle.randint( + low=1, high=vocab_size * batch_size, shape=(1,) + ).item() + idxs = paddle.randperm(n=batch_size * vocab_size)[:num_neginf] logits[idxs // vocab_size, idxs % vocab_size] = -float("inf") - probs = torch.softmax(logits, dim=-1) + probs = paddle.nn.functional.softmax(x=logits, axis=-1) masked_logits = flashinfer.sampling.top_k_mask_logits(logits, k) - renormed_probs = torch.softmax(masked_logits, dim=-1) + renormed_probs = paddle.nn.functional.softmax(x=masked_logits, axis=-1) renormed_probs_ref = flashinfer.sampling.top_k_renorm_prob(probs, k) - - torch.testing.assert_close( - renormed_probs, - renormed_probs_ref, - rtol=1e-3, - atol=1e-3, - ) + assert paddle.allclose( + x=renormed_probs, y=renormed_probs_ref, rtol=0.001, atol=0.001 + ).item(), "" @pytest.mark.parametrize("batch_size", [1, 99, 989]) @@ -494,41 +502,41 @@ def test_top_k_mask_logits(batch_size, vocab_size, k, neginf_input): @pytest.mark.parametrize("num_speculate_tokens", [1, 3, 5, 7]) @pytest.mark.parametrize("onehot_target", [False, True]) def test_chain_speculative_sampling( - batch_size, - vocab_size, - num_speculate_tokens, - onehot_target, + batch_size, vocab_size, num_speculate_tokens, onehot_target ): - pre_norm_draft_prob = torch.rand( - batch_size, num_speculate_tokens, vocab_size, device="cuda:0" + pre_norm_draft_prob = paddle.rand( + shape=[batch_size, num_speculate_tokens, vocab_size] ) normalized_draft_prob = pre_norm_draft_prob / pre_norm_draft_prob.sum( - dim=-1, keepdim=True + axis=-1, keepdim=True ) - draft_token_ids = torch.randint( - vocab_size, (batch_size, num_speculate_tokens), device="cuda:0" + draft_token_ids = paddle.randint( + low=0, high=vocab_size, shape=(batch_size, num_speculate_tokens) ) if not onehot_target: - pre_norm_target_prob = torch.rand( - batch_size, num_speculate_tokens + 1, vocab_size, device="cuda:0" + pre_norm_target_prob = paddle.rand( + shape=[batch_size, num_speculate_tokens + 1, vocab_size] ) target_onehot_prob = pre_norm_target_prob / pre_norm_target_prob.sum( - dim=-1, keepdim=True + axis=-1, keepdim=True ) else: - target_token_ids = torch.randint( - vocab_size, (batch_size, num_speculate_tokens + 1), device="cuda:0" + target_token_ids = paddle.randint( + low=0, high=vocab_size, shape=(batch_size, num_speculate_tokens + 1) ) target_token_ids[..., :num_speculate_tokens] = draft_token_ids - target_onehot_prob = torch.zeros( - (batch_size, num_speculate_tokens + 1, vocab_size), device="cuda:0" + target_onehot_prob = paddle.zeros( + shape=(batch_size, num_speculate_tokens + 1, vocab_size) ) - target_onehot_prob.scatter_(2, target_token_ids.unsqueeze(-1), 1) - - # NOTE(Zihao): this is a very simple test that only checks whether output is valid or not. - for trials in range(10): # noqa: B007 - accepted_num = torch.zeros(batch_size, dtype=torch.int32, device="cuda:0") - emitted_num = torch.zeros(batch_size, dtype=torch.int32, device="cuda:0") + target_onehot_prob.put_along_axis_( + axis=2, + indices=target_token_ids.unsqueeze(axis=-1), + values=1, + broadcast=False, + ) + for trials in range(10): + accepted_num = paddle.zeros(shape=batch_size, dtype="int32") + emitted_num = paddle.zeros(shape=batch_size, dtype="int32") ( output_token_ids, accepted_num, @@ -541,33 +549,26 @@ def test_chain_speculative_sampling( emitted_num, ) if onehot_target: - assert torch.all(output_token_ids == target_token_ids) + assert paddle.all(x=output_token_ids == target_token_ids) else: - assert torch.all(output_token_ids[output_token_ids >= 0] < vocab_size) - assert output_token_ids.shape == (batch_size, num_speculate_tokens + 1) + assert paddle.all(x=output_token_ids[output_token_ids >= 0] < vocab_size) + assert tuple(output_token_ids.shape) == ( + batch_size, + num_speculate_tokens + 1, + ) matches = output_token_ids[..., :-1] != draft_token_ids for row in range(batch_size): - mismatch_idx = torch.nonzero(matches[row], as_tuple=True)[0] + paddle.utils.try_import("warnings").warn( + "Now, the return shape is inconsistent with torch when as_tuple is True" + ) + mismatch_idx = paddle.nonzero(x=matches[row], as_tuple=True)[0] if len(mismatch_idx) > 0: - # mismatch_idx should be contiguous - assert torch.all(mismatch_idx[1:] == mismatch_idx[:-1] + 1) - # from the second mismatched token on, the output tokens should be -1 - assert torch.all(output_token_ids[row, mismatch_idx[0] + 1 :] == -1) - - assert torch.all(emitted_num + 1 == (output_token_ids != -1).sum(dim=1)) + assert paddle.all(x=mismatch_idx[1:] == mismatch_idx[:-1] + 1) + assert paddle.all( + x=output_token_ids[row, mismatch_idx[0] + 1 :] == -1 + ) + assert paddle.all(x=emitted_num + 1 == (output_token_ids != -1).sum(axis=1)) if __name__ == "__main__": - # test_sampling_freq(128256, gumbel_distribution(0.1), 0.5) test_sampling_from_logits_freq(128256, gumbel_distribution(0.1)) - # test_top_p_sampling_freq(128256, gumbel_distribution(0.1), 0.5) - # test_top_k_sampling_freq(1, 128256, 10) - # test_sampling(19, 500) - # test_sampling(1, 111) - # test_top_p_sampling(3, 111, 0.9) - # test_top_k_sampling(3, 111, 10) - # test_top_p_renorm_probs(3, 111, 0.9) - # test_top_k_renorm_probs(3, 111, 10) - # test_top_k_mask_logits(99, 989, 10) - # test_chain_speculative_sampling(3, 111, 3, False) - # test_chain_speculative_sampling(3, 111, 3, True) diff --git a/tests/test_shared_prefix_kernels.py b/tests/test_shared_prefix_kernels.py index 414aaa875b..f1d0a212ee 100644 --- a/tests/test_shared_prefix_kernels.py +++ b/tests/test_shared_prefix_kernels.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2023 by FlashInfer team. @@ -13,10 +15,9 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch -from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules +from jit_utils import (gen_decode_attention_modules, + gen_prefill_attention_modules) import flashinfer @@ -25,21 +26,10 @@ def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps + ["float16"], ["float16"], [128, 256], [0], [False], [False] ) + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [128, 256], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions + ["float16"], ["float16"], [128, 256], [0], [False], [False], [False] ), verbose=False, ) @@ -73,35 +63,71 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( assert shared_kv_len % page_size == 0 kv_layout = "NHD" if stage == "append": - q = torch.randn(batch_size * unique_kv_len, num_heads, head_dim).to(0).half() - q_indptr = torch.arange(0, batch_size + 1).to(0).int() * unique_kv_len + q = ( + paddle.randn(shape=[batch_size * unique_kv_len, num_heads, head_dim]) + .to(0) + .astype(dtype="float16") + ) + q_indptr = ( + paddle.arange(start=0, end=batch_size + 1).to(0).astype(dtype="int32") + * unique_kv_len + ) else: - q = torch.randn(batch_size, num_heads, head_dim).to(0).half() - q_indptr = torch.arange(0, batch_size + 1).to(0).int() - k_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half() - v_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half() - k_unique = torch.randn(batch_size * unique_kv_len, num_heads, head_dim).to(0).half() - v_unique = torch.randn(batch_size * unique_kv_len, num_heads, head_dim).to(0).half() - + q = ( + paddle.randn(shape=[batch_size, num_heads, head_dim]) + .to(0) + .astype(dtype="float16") + ) + q_indptr = ( + paddle.arange(start=0, end=batch_size + 1).to(0).astype(dtype="int32") + ) + k_shared = ( + paddle.randn(shape=[shared_kv_len, num_heads, head_dim]) + .to(0) + .astype(dtype="float16") + ) + v_shared = ( + paddle.randn(shape=[shared_kv_len, num_heads, head_dim]) + .to(0) + .astype(dtype="float16") + ) + k_unique = ( + paddle.randn(shape=[batch_size * unique_kv_len, num_heads, head_dim]) + .to(0) + .astype(dtype="float16") + ) + v_unique = ( + paddle.randn(shape=[batch_size * unique_kv_len, num_heads, head_dim]) + .to(0) + .astype(dtype="float16") + ) kv_data = ( - torch.zeros( - ceil_div(shared_kv_len, page_size) - + batch_size * ceil_div(unique_kv_len, page_size), - 2, - page_size, - num_heads, - head_dim, + paddle.zeros( + shape=[ + ceil_div(shared_kv_len, page_size) + + batch_size * ceil_div(unique_kv_len, page_size), + 2, + page_size, + num_heads, + head_dim, + ] ) .to(0) - .half() + .astype(dtype="float16") ) - shared_kv_indices = torch.arange(0, ceil_div(shared_kv_len, page_size)).to(0).int() - shared_append_indptr = torch.arange(0, 2).to(0).int() * shared_kv_len - shared_kv_indptr = torch.arange(0, 2).to(0).int() * ceil_div( - shared_kv_len, page_size + shared_kv_indices = ( + paddle.arange(start=0, end=ceil_div(shared_kv_len, page_size)) + .to(0) + .astype(dtype="int32") + ) + shared_append_indptr = ( + paddle.arange(start=0, end=2).to(0).astype(dtype="int32") * shared_kv_len ) - shared_last_page_len = torch.full( - (1,), (shared_kv_len - 1) % page_size + 1, dtype=torch.int32 + shared_kv_indptr = paddle.arange(start=0, end=2).to(0).astype( + dtype="int32" + ) * ceil_div(shared_kv_len, page_size) + shared_last_page_len = paddle.full( + shape=(1,), fill_value=(shared_kv_len - 1) % page_size + 1, dtype="int32" ).to(0) flashinfer.append_paged_kv_cache( k_shared, @@ -109,23 +135,28 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( *flashinfer.get_batch_indices_positions( shared_append_indptr, flashinfer.get_seq_lens(shared_kv_indptr, shared_last_page_len, page_size), - k_shared.shape[0], + tuple(k_shared.shape)[0], ), kv_data, shared_kv_indices, shared_kv_indptr, shared_last_page_len, - kv_layout, + kv_layout ) - unique_kv_indices = torch.arange( - 0, batch_size * ceil_div(unique_kv_len, page_size) - ).to(0).int() + ceil_div(shared_kv_len, page_size) - unique_append_indptr = torch.arange(0, batch_size + 1).to(0).int() * unique_kv_len - unique_kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * ceil_div( - unique_kv_len, page_size + unique_kv_indices = paddle.arange( + start=0, end=batch_size * ceil_div(unique_kv_len, page_size) + ).to(0).astype(dtype="int32") + ceil_div(shared_kv_len, page_size) + unique_append_indptr = ( + paddle.arange(start=0, end=batch_size + 1).to(0).astype(dtype="int32") + * unique_kv_len ) - unique_last_page_len = torch.full( - (batch_size,), (unique_kv_len - 1) % page_size + 1, dtype=torch.int32 + unique_kv_indptr = paddle.arange(start=0, end=batch_size + 1).to(0).astype( + dtype="int32" + ) * ceil_div(unique_kv_len, page_size) + unique_last_page_len = paddle.full( + shape=(batch_size,), + fill_value=(unique_kv_len - 1) % page_size + 1, + dtype="int32", ).to(0) flashinfer.append_paged_kv_cache( k_unique, @@ -133,37 +164,37 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( *flashinfer.get_batch_indices_positions( unique_append_indptr, flashinfer.get_seq_lens(unique_kv_indptr, unique_last_page_len, page_size), - k_unique.shape[0], + tuple(k_unique.shape)[0], ), kv_data, unique_kv_indices, unique_kv_indptr, unique_last_page_len, - kv_layout, + kv_layout ) - if stage == "decode": multi_level_wrapper = flashinfer.MultiLevelCascadeAttentionWrapper( - 2, torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + 2, paddle.empty(shape=32 * 1024 * 1024, dtype="int8").to(0), kv_layout ) shared_prefix_decode_wrapper = ( flashinfer.BatchDecodeWithSharedPrefixPagedKVCacheWrapper( - torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + paddle.empty(shape=32 * 1024 * 1024, dtype="int8").to(0), kv_layout ) ) else: multi_level_wrapper = flashinfer.MultiLevelCascadeAttentionWrapper( - 2, torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + 2, paddle.empty(shape=32 * 1024 * 1024, dtype="int8").to(0), kv_layout ) shared_prefix_prefill_wrapper = ( flashinfer.BatchPrefillWithSharedPrefixPagedKVCacheWrapper( - torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + paddle.empty(shape=32 * 1024 * 1024, dtype="int8").to(0), kv_layout ) ) - - qo_indptr_top = torch.tensor([0, q.shape[0]], dtype=torch.int32).to(0) + qo_indptr_top = paddle.to_tensor(data=[0, tuple(q.shape)[0]], dtype="int32").to(0) if stage == "decode": - qo_indptr_bottom = torch.arange(0, batch_size + 1, dtype=torch.int32).to(0) + qo_indptr_bottom = paddle.arange(start=0, end=batch_size + 1, dtype="int32").to( + 0 + ) multi_level_wrapper.plan( [qo_indptr_top, qo_indptr_bottom], [shared_kv_indptr, unique_kv_indptr], @@ -177,7 +208,8 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( o_multi_level = multi_level_wrapper.run(q, kv_data) else: qo_indptr_bottom = ( - torch.arange(0, batch_size + 1, dtype=torch.int32).to(0) * unique_kv_len + paddle.arange(start=0, end=batch_size + 1, dtype="int32").to(0) + * unique_kv_len ) multi_level_wrapper.plan( [qo_indptr_top, qo_indptr_bottom], @@ -191,7 +223,6 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( causal=causal, ) o_multi_level = multi_level_wrapper.run(q, kv_data) - if stage == "decode": shared_prefix_decode_wrapper.begin_forward( unique_kv_indptr, @@ -219,8 +250,9 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache( o_two_level = shared_prefix_prefill_wrapper.forward( q, k_shared, v_shared, kv_data, causal=causal ) - - torch.testing.assert_close(o_multi_level, o_two_level, rtol=1e-3, atol=1e-3) + assert paddle.allclose( + x=o_multi_level, y=o_two_level, rtol=0.001, atol=0.001 + ).item(), "" @pytest.mark.parametrize("seed", [0]) @@ -229,48 +261,51 @@ def test_merge_state_in_place_with_mask(seed, num_tries): seq_len = 512 num_heads = 32 head_dim = 128 - va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") - sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") - vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") - sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") + va = ( + paddle.randn(shape=[seq_len, num_heads, head_dim]) + .astype(dtype="float16") + .to("gpu:0") + ) + sa = paddle.randn(shape=[seq_len, num_heads], dtype="float32").to("gpu:0") + vb = ( + paddle.randn(shape=[seq_len, num_heads, head_dim]) + .astype(dtype="float16") + .to("gpu:0") + ) + sb = paddle.randn(shape=[seq_len, num_heads], dtype="float32").to("gpu:0") va_orginal = va.clone() sa_original = sa.clone() - - # No mask. flashinfer.merge_state_in_place(va, sa, vb, sb) va_merged_ref = va.clone() sa_merged_ref = sa.clone() - assert not torch.allclose(va_merged_ref, va_orginal) - assert not torch.allclose(sa_merged_ref, sa_original) - - # Mask with all 1s. Should be identical to no mask. - mask = torch.ones(seq_len, dtype=torch.bool).to("cuda:0") + assert not paddle.allclose(x=va_merged_ref, y=va_orginal).item() + assert not paddle.allclose(x=sa_merged_ref, y=sa_original).item() + mask = paddle.ones(shape=seq_len, dtype="bool").to("gpu:0") va = va_orginal.clone() sa = sa_original.clone() flashinfer.merge_state_in_place(va, sa, vb, sb, mask=mask) va_merged = va sa_merged = sa - torch.testing.assert_close(va_merged, va_merged_ref, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(sa_merged, sa_merged_ref, rtol=1e-3, atol=1e-3) - - # Mask with all zeros. Input and output should be identical. - mask = torch.zeros(seq_len, dtype=torch.bool).to("cuda:0") + assert paddle.allclose( + x=va_merged, y=va_merged_ref, rtol=0.001, atol=0.001 + ).item(), "" + assert paddle.allclose( + x=sa_merged, y=sa_merged_ref, rtol=0.001, atol=0.001 + ).item(), "" + mask = paddle.zeros(shape=seq_len, dtype="bool").to("gpu:0") va = va_orginal.clone() sa = sa_original.clone() flashinfer.merge_state_in_place(va, sa, vb, sb, mask=mask) va_merged = va sa_merged = sa - torch.testing.assert_close(va_merged, va_orginal, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(sa_merged, sa_original, rtol=1e-3, atol=1e-3) - - # Test some random masks. - randgen = torch.Generator(device="cuda:0") + assert paddle.allclose(x=va_merged, y=va_orginal, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose( + x=sa_merged, y=sa_original, rtol=0.001, atol=0.001 + ).item(), "" + randgen = paddle.framework.core.default_cpu_generator() randgen.manual_seed(seed) for _ in range(num_tries): - rand_mask = ( - torch.rand(seq_len, generator=randgen, dtype=torch.float32, device="cuda:0") - > 0.5 - ).to(dtype=torch.bool) + rand_mask = (paddle.rand(shape=seq_len, dtype="float32") > 0.5).to(dtype="bool") true_indices = rand_mask.nonzero() false_indices = (rand_mask == 0).nonzero() va = va_orginal.clone() @@ -278,31 +313,30 @@ def test_merge_state_in_place_with_mask(seed, num_tries): flashinfer.merge_state_in_place(va, sa, vb, sb, mask=rand_mask) va_merged = va sa_merged = sa - - torch.testing.assert_close( - va_merged[false_indices], - va_orginal[false_indices], - rtol=1e-3, - atol=1e-3, - ) - torch.testing.assert_close( - sa_merged[false_indices], - sa_original[false_indices], - rtol=1e-3, - atol=1e-3, - ) - torch.testing.assert_close( - va_merged[true_indices], - va_merged_ref[true_indices], - rtol=1e-3, - atol=1e-3, - ) - torch.testing.assert_close( - sa_merged[true_indices], - sa_merged_ref[true_indices], - rtol=1e-3, - atol=1e-3, - ) + assert paddle.allclose( + x=va_merged[false_indices], + y=va_orginal[false_indices], + rtol=0.001, + atol=0.001, + ).item(), "" + assert paddle.allclose( + x=sa_merged[false_indices], + y=sa_original[false_indices], + rtol=0.001, + atol=0.001, + ).item(), "" + assert paddle.allclose( + x=va_merged[true_indices], + y=va_merged_ref[true_indices], + rtol=0.001, + atol=0.001, + ).item(), "" + assert paddle.allclose( + x=sa_merged[true_indices], + y=sa_merged_ref[true_indices], + rtol=0.001, + atol=0.001, + ).item(), "" if __name__ == "__main__": diff --git a/tests/test_single_prefill.py b/tests/test_single_prefill.py index d08a63d3c1..91c1c69658 100644 --- a/tests/test_single_prefill.py +++ b/tests/test_single_prefill.py @@ -1,56 +1,46 @@ import math +import paddle import pytest -import torch import flashinfer def build_causal_mask(qo_len, kv_len): - i = torch.arange(qo_len).unsqueeze(1).to("cuda:0") - j = torch.arange(kv_len).unsqueeze(0).to("cuda:0") + i = paddle.arange(end=qo_len).unsqueeze(axis=1).to("gpu:0") + j = paddle.arange(end=kv_len).unsqueeze(axis=0).to("gpu:0") offset = kv_len - qo_len - - mask = (j - offset > i).to(torch.bool) + mask = (j - offset > i).to("bool") return mask -def _repeat_kv(t: torch.Tensor, num_groups: int) -> torch.Tensor: - return t.repeat_interleave(num_groups, dim=1) +def _repeat_kv(t: paddle.Tensor, num_groups: int) -> paddle.Tensor: + return t.repeat_interleave(repeats=num_groups, axis=1) def single_prefill_with_kv_cache_ref( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - causal: bool = False, + q: paddle.Tensor, k: paddle.Tensor, v: paddle.Tensor, causal: bool = False ): - Lq, Hq, D = q.shape - Lk, Hkv, _ = k.shape - assert (Lk, Hkv, D) == v.shape + Lq, Hq, D = tuple(q.shape) + Lk, Hkv, _ = tuple(k.shape) + assert (Lk, Hkv, D) == tuple(v.shape) assert Hq % Hkv == 0 - groups = Hq // Hkv k_states = _repeat_kv(k, groups) v_states = _repeat_kv(v, groups) - - q_t = q.permute(1, 0, 2) # (Hq, Lq, D) - k_t = k_states.permute(1, 2, 0) # (Hq, D, Lk) - v_t = v_states.permute(1, 0, 2) # (Hq, Lk, D) - + q_t = q.transpose(perm=[1, 0, 2]) + k_t = k_states.transpose(perm=[1, 2, 0]) + v_t = v_states.transpose(perm=[1, 0, 2]) scale = 1.0 / math.sqrt(D) - attn_scores = torch.bmm(q_t, k_t) * scale # (Hq, Lq, Lk) - + attn_scores = paddle.bmm(x=q_t, y=k_t) * scale if causal: - # apply causal mask causal_mask = build_causal_mask(Lq, Lk) - attn_scores = attn_scores.masked_fill(causal_mask, float("-inf")) - - attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype) - - attn_output = torch.bmm(attn_weights, v_t) # (Hq, Lq, D) - attn_output = attn_output.permute(1, 0, 2).contiguous() # (Lq, Hq, D) - + attn_scores = attn_scores.masked_fill(mask=causal_mask, value=float("-inf")) + attn_weights = paddle.nn.functional.softmax( + x=attn_scores, axis=-1, dtype="float32" + ).to(q.dtype) + attn_output = paddle.bmm(x=attn_weights, y=v_t) + attn_output = attn_output.transpose(perm=[1, 0, 2]).contiguous() return attn_output @@ -62,42 +52,17 @@ def single_prefill_with_kv_cache_ref( @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE"]) def test_sinqle_prefill_with_paged_kv_cache( - kv_len, - qo_len, - num_kv_heads, - num_qo_heads, - head_dim, - causal, - pos_encoding_mode, + kv_len, qo_len, num_kv_heads, num_qo_heads, head_dim, causal, pos_encoding_mode ): - torch.manual_seed(0) - torch.cuda.manual_seed(0) + paddle.seed(seed=0) + paddle.seed(seed=0) if qo_len > kv_len and causal: pytest.skip("qo_len > kv_len and causal is not supported") - q = torch.randn( - qo_len, - num_qo_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, - ) - k = torch.randn( - kv_len, - num_kv_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, - ) - v = torch.randn( - kv_len, - num_kv_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, - ) + q = paddle.randn(shape=[qo_len, num_qo_heads, head_dim], dtype="float16") + k = paddle.randn(shape=[kv_len, num_kv_heads, head_dim], dtype="float16") + v = paddle.randn(shape=[kv_len, num_kv_heads, head_dim], dtype="float16") o = flashinfer.prefill.single_prefill_with_kv_cache( q, k, v, causal=causal, pos_encoding_mode=pos_encoding_mode, backend="fa2" ) - o_ref = single_prefill_with_kv_cache_ref(q, k, v, causal=causal) - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o, y=o_ref, rtol=0.001, atol=0.001).item(), "" diff --git a/tests/test_sliding_window.py b/tests/test_sliding_window.py index f4b7f31e68..88bea2f31e 100644 --- a/tests/test_sliding_window.py +++ b/tests/test_sliding_window.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2024 by FlashInfer team. @@ -13,10 +19,9 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch -from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules +from jit_utils import (gen_decode_attention_modules, + gen_prefill_attention_modules) import flashinfer @@ -25,21 +30,16 @@ def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False, True], # use_sliding_windows - [False], # use_logits_soft_caps + ["float16"], ["float16"], [64, 128, 256], [0], [False, True], [False] ) + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0], # pos_encoding_modes - [False, True], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions + ["float16"], + ["float16"], + [64, 128, 256], + [0], + [False, True], + [False], + [False], ), verbose=False, ) @@ -54,21 +54,14 @@ def warmup_jit(): def test_single_decode_sliding_window( seq_len, window_left, num_kv_heads, num_qo_heads, head_dim ): - q = torch.randn(num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0") - k = torch.randn( - seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" - ) - v = torch.randn( - seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" - ) - + q = paddle.randn(shape=[num_qo_heads, head_dim], dtype="float16") + k = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") + v = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") k_sliced = k[-(window_left + 1) :] v_sliced = v[-(window_left + 1) :] - o_ref = flashinfer.single_decode_with_kv_cache(q, k_sliced, v_sliced) o = flashinfer.single_decode_with_kv_cache(q, k, v, window_left=window_left) - - torch.testing.assert_close(o.cpu(), o_ref.cpu(), rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o.cpu(), y=o_ref.cpu(), rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [1, 3, 13, 32]) @@ -81,37 +74,23 @@ def test_single_decode_sliding_window( def test_batch_decode_sliding_window( batch_size, kv_len, window_left, num_kv_heads, num_qo_heads, head_dim, page_size ): - q = torch.randn( - batch_size, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0" - ) + q = paddle.randn(shape=[batch_size, num_qo_heads, head_dim], dtype="float16") num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size - k_data = torch.randn( - total_num_pages, - page_size, - num_kv_heads, - head_dim, - dtype=torch.float16, - device="cuda:0", + k_data = paddle.randn( + shape=[total_num_pages, page_size, num_kv_heads, head_dim], dtype="float16" ) - v_data = torch.randn( - total_num_pages, - page_size, - num_kv_heads, - head_dim, - dtype=torch.float16, - device="cuda:0", + v_data = paddle.randn( + shape=[total_num_pages, page_size, num_kv_heads, head_dim], dtype="float16" ) kv_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) - * num_pages_per_seq + paddle.arange(start=0, end=batch_size + 1, dtype="int32") * num_pages_per_seq ) - kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) - kv_last_page_len = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" + kv_indices = paddle.arange(start=0, end=total_num_pages, dtype="int32") + kv_last_page_len = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" ) - - workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + workspace_buffer = paddle.empty(shape=32 * 1024 * 1024, dtype="int8") wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") wrapper.plan( kv_indptr, @@ -124,34 +103,30 @@ def test_batch_decode_sliding_window( window_left=window_left, ) o = wrapper.run(q, (k_data, v_data)) - for i in range(batch_size): qi = q[i] - ki = torch.cat( - [ + ki = paddle.concat( + x=[ k_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape( -1, num_kv_heads, head_dim ), k_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :], ], - dim=0, + axis=0, ) - vi = torch.cat( - [ + vi = paddle.concat( + x=[ v_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape( -1, num_kv_heads, head_dim ), v_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :], ], - dim=0, + axis=0, ) o_ref_i = flashinfer.single_decode_with_kv_cache( - qi, - ki, - vi, - window_left=window_left, + qi, ki, vi, window_left=window_left ) - torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o[i], y=o_ref_i, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("seq_len", [1, 3, 19, 99, 199, 1999]) @@ -162,20 +137,18 @@ def test_batch_decode_sliding_window( def test_single_decode_prefill_sliding_window_match( seq_len, window_left, num_kv_heads, num_qo_heads, head_dim ): - q = torch.randn(1, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0") - k = torch.randn( - seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" - ) - v = torch.randn( - seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" - ) + q = paddle.randn(shape=[1, num_qo_heads, head_dim], dtype="float16") + k = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") + v = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") o = flashinfer.single_prefill_with_kv_cache( q, k, v, window_left=window_left, causal=True ) o_decoded = flashinfer.single_decode_with_kv_cache( q[0], k, v, window_left=window_left ) - torch.testing.assert_close(o.cpu()[0], o_decoded.cpu(), rtol=1e-3, atol=1e-3) + assert paddle.allclose( + x=o.cpu()[0], y=o_decoded.cpu(), rtol=0.001, atol=0.001 + ).item(), "" @pytest.mark.parametrize("seq_len", [99, 199, 1999]) @@ -186,25 +159,17 @@ def test_single_decode_prefill_sliding_window_match( def test_single_prefill_sliding_window( seq_len, window_left, num_kv_heads, num_qo_heads, head_dim ): - q = torch.randn( - seq_len, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0" - ) - k = torch.randn( - seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" - ) - v = torch.randn( - seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" - ) - - row_idx = torch.arange(seq_len, dtype=torch.int32, device="cuda:0")[:, None] - col_idx = torch.arange(seq_len, dtype=torch.int32, device="cuda:0")[None, :] + q = paddle.randn(shape=[seq_len, num_qo_heads, head_dim], dtype="float16") + k = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") + v = paddle.randn(shape=[seq_len, num_kv_heads, head_dim], dtype="float16") + row_idx = paddle.arange(dtype="int32", end=seq_len)[:, None] + col_idx = paddle.arange(dtype="int32", end=seq_len)[None, :] mask = (row_idx >= col_idx) & (row_idx - window_left <= col_idx) - o_ref = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask) o = flashinfer.single_prefill_with_kv_cache( q, k, v, window_left=window_left, causal=True ) - torch.testing.assert_close(o.cpu(), o_ref.cpu(), rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o.cpu(), y=o_ref.cpu(), rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [12, 17]) @@ -225,44 +190,26 @@ def test_batch_paged_prefill_sliding_window( head_dim, page_size, ): - q = torch.randn( - batch_size * qo_len, - num_qo_heads, - head_dim, - dtype=torch.float16, - device="cuda:0", - ) - q_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len + q = paddle.randn( + shape=[batch_size * qo_len, num_qo_heads, head_dim], dtype="float16" ) + q_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * qo_len num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size - k_data = torch.randn( - total_num_pages, - page_size, - num_kv_heads, - head_dim, - dtype=torch.float16, - device="cuda:0", + k_data = paddle.randn( + shape=[total_num_pages, page_size, num_kv_heads, head_dim], dtype="float16" ) - v_data = torch.randn( - total_num_pages, - page_size, - num_kv_heads, - head_dim, - dtype=torch.float16, - device="cuda:0", + v_data = paddle.randn( + shape=[total_num_pages, page_size, num_kv_heads, head_dim], dtype="float16" ) kv_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) - * num_pages_per_seq + paddle.arange(start=0, end=batch_size + 1, dtype="int32") * num_pages_per_seq ) - kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) - kv_last_page_len = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" + kv_indices = paddle.arange(start=0, end=total_num_pages, dtype="int32") + kv_last_page_len = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" ) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") wrapper.plan( q_indptr, @@ -276,36 +223,32 @@ def test_batch_paged_prefill_sliding_window( window_left=window_left, causal=True, ) - o = wrapper.run( - q, - (k_data, v_data), - ) - + o = wrapper.run(q, (k_data, v_data)) for i in range(batch_size): qi = q[q_indptr[i] : q_indptr[i + 1]] - ki = torch.cat( - [ + ki = paddle.concat( + x=[ k_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape( -1, num_kv_heads, head_dim ), k_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :], ], - dim=0, + axis=0, ) - vi = torch.cat( - [ + vi = paddle.concat( + x=[ v_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape( -1, num_kv_heads, head_dim ), v_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :], ], - dim=0, + axis=0, ) o_ref_i = flashinfer.single_prefill_with_kv_cache( qi, ki, vi, window_left=window_left, causal=True, backend="fa2" ) o_i = o[q_indptr[i] : q_indptr[i + 1]] - torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o_i, y=o_ref_i, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [12, 17]) @@ -318,34 +261,18 @@ def test_batch_paged_prefill_sliding_window( def test_batch_ragged_prefill_sliding_window( batch_size, kv_len, qo_len, window_left, num_kv_heads, num_qo_heads, head_dim ): - q = torch.randn( - batch_size * qo_len, - num_qo_heads, - head_dim, - dtype=torch.float16, - device="cuda:0", - ) - q_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len - ) - k = torch.randn( - batch_size * kv_len, - num_kv_heads, - head_dim, - dtype=torch.float16, - device="cuda:0", + q = paddle.randn( + shape=[batch_size * qo_len, num_qo_heads, head_dim], dtype="float16" ) - v = torch.randn( - batch_size * kv_len, - num_kv_heads, - head_dim, - dtype=torch.float16, - device="cuda:0", + q_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * qo_len + k = paddle.randn( + shape=[batch_size * kv_len, num_kv_heads, head_dim], dtype="float16" ) - kv_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * kv_len + v = paddle.randn( + shape=[batch_size * kv_len, num_kv_heads, head_dim], dtype="float16" ) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + kv_indptr = paddle.arange(start=0, end=batch_size + 1, dtype="int32") * kv_len + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer, "NHD") wrapper.plan( q_indptr, @@ -357,20 +284,15 @@ def test_batch_ragged_prefill_sliding_window( causal=True, ) o = wrapper.run(q, k, v) - for i in range(batch_size): qi = q[q_indptr[i] : q_indptr[i + 1]] ki = k[kv_indptr[i] : kv_indptr[i + 1]] vi = v[kv_indptr[i] : kv_indptr[i + 1]] o_ref_i = flashinfer.single_prefill_with_kv_cache( - qi, - ki, - vi, - window_left=window_left, - causal=True, + qi, ki, vi, window_left=window_left, causal=True ) o_i = o[q_indptr[i] : q_indptr[i + 1]] - torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o_i, y=o_ref_i, rtol=0.001, atol=0.001).item(), "" if __name__ == "__main__": diff --git a/tests/test_sm_constraint_gemm.py b/tests/test_sm_constraint_gemm.py index ee131219a9..511c167130 100644 --- a/tests/test_sm_constraint_gemm.py +++ b/tests/test_sm_constraint_gemm.py @@ -1,19 +1,18 @@ +import paddle import pytest -import torch import flashinfer import flashinfer.triton def torch_gemm(a, b, c, alpha, beta): - x = torch.matmul(a, b.T) + x = paddle.matmul(x=a, y=b.T) c = alpha * x + beta * c return c def torch_addmm(a, b, c, alpha=1.0, beta=0.0): - # Transpose b to match torch_gemm's matmul(a, b.T) - C = torch.addmm(c, a, b.T, beta=beta, alpha=alpha) + C = paddle.addmm(input=c, x=a, y=b.T, beta=beta, alpha=alpha) return C @@ -24,37 +23,27 @@ def torch_addmm(a, b, c, alpha=1.0, beta=0.0): @pytest.mark.parametrize("beta", [0.0, 0.5, 2.0]) @pytest.mark.parametrize("num_sms", [1, 16, 64, 128, 132, 133]) @pytest.mark.parametrize( - "dtype", [torch.float8_e4m3fn, torch.float16, torch.bfloat16, torch.float32] + "dtype", [paddle.float8_e4m3fn, "float16", "bfloat16", "float32"] ) -@pytest.mark.parametrize( - "EPILOGUE_SUBTILE", [True, False] -) # only for descriptor persistent +@pytest.mark.parametrize("EPILOGUE_SUBTILE", [True, False]) def test_sm_constraint_gemm(M, N, K, alpha, beta, num_sms, dtype, EPILOGUE_SUBTILE): - out_dtype = dtype if dtype != torch.float8_e4m3fn else torch.bfloat16 - a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) - b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) + out_dtype = dtype if dtype != paddle.float8_e4m3fn else "bfloat16" + a = paddle.randn(shape=(M, K), dtype="float16").to(dtype) + b = paddle.randn(shape=(K, N), dtype="float16").to(dtype) b = b.T.contiguous() - c = torch.randn((M, N), device="cuda", dtype=out_dtype) + c = paddle.randn(shape=(M, N), dtype=out_dtype) c_unmodified = c.clone() c0 = c.clone() c1 = c.clone() - - # torch gemm c_torch = torch_gemm(a.to(out_dtype), b.to(out_dtype), c.to(out_dtype), alpha, beta) - - # triton gemm: persistent c_persistent = flashinfer.triton.sm_constraint_gemm.gemm_persistent( a, b.T, c=c, alpha=alpha, beta=beta, num_sms=num_sms ) - - # triton gemm: naive c_naive = flashinfer.triton.sm_constraint_gemm.gemm( a, b.T, c=c0, alpha=alpha, beta=beta ) - c_descriptor = None - # triton gemm: descriptor persistent - if dtype != torch.float32: + if dtype != "float32": c_descriptor = flashinfer.triton.sm_constraint_gemm.gemm_descriptor_persistent( a, b, @@ -64,89 +53,67 @@ def test_sm_constraint_gemm(M, N, K, alpha, beta, num_sms, dtype, EPILOGUE_SUBTI num_sms=num_sms, EPILOGUE_SUBTILE=EPILOGUE_SUBTILE, ) - - torch_atol = 20.0 if out_dtype == torch.bfloat16 else 1.0 - - in_place_persistent = c_persistent.data_ptr() == c.data_ptr() and torch.allclose( - c_persistent.to(out_dtype), c.to(out_dtype) + torch_atol = 20.0 if out_dtype == "bfloat16" else 1.0 + in_place_persistent = ( + c_persistent.data_ptr() == c.data_ptr() + and paddle.allclose(x=c_persistent.to(out_dtype), y=c.to(out_dtype)).item() ) - assert in_place_persistent # modified in place - - in_place_naive = c_naive.data_ptr() == c0.data_ptr() and torch.allclose( - c_naive.to(out_dtype), c0.to(out_dtype) + assert in_place_persistent + in_place_naive = ( + c_naive.data_ptr() == c0.data_ptr() + and paddle.allclose(x=c_naive.to(out_dtype), y=c0.to(out_dtype)).item() ) - assert in_place_naive # modified in place - + assert in_place_naive if c_descriptor is not None: in_place_descriptor = ( c_descriptor.data_ptr() == c1.data_ptr() - and torch.allclose(c_descriptor.to(out_dtype), c1.to(out_dtype)) + and paddle.allclose(x=c_descriptor.to(out_dtype), y=c1.to(out_dtype)).item() ) - assert in_place_descriptor # modified in place - - # torch results vs triton results - torch_vs_triton_persistent = torch.allclose( - c_torch.to(out_dtype), c_persistent.to(out_dtype), atol=torch_atol - ) + assert in_place_descriptor + torch_vs_triton_persistent = paddle.allclose( + x=c_torch.to(out_dtype), y=c_persistent.to(out_dtype), atol=torch_atol + ).item() if not torch_vs_triton_persistent: print_all_on_failure( a, b, c_unmodified, c_torch, c_naive, c_persistent, c_descriptor, out_dtype ) print("compare c_torch and c_persistent") print_max_diff_on_failure(c_torch, c_persistent, out_dtype) - assert torch_vs_triton_persistent # value is correct - + assert torch_vs_triton_persistent if c_descriptor is not None: - torch_vs_triton_descriptor = torch.allclose( - c_torch.to(out_dtype), c_descriptor.to(out_dtype), atol=torch_atol - ) + torch_vs_triton_descriptor = paddle.allclose( + x=c_torch.to(out_dtype), y=c_descriptor.to(out_dtype), atol=torch_atol + ).item() if not torch_vs_triton_descriptor: print_all_on_failure( - a, - b, - c_unmodified, - c_torch, - c_naive, - c_persistent, - c_descriptor, + a, b, c_unmodified, c_torch, c_naive, c_persistent, c_descriptor ) print("compare c_torch and c_descriptor") print_max_diff_on_failure(c_torch, c_descriptor, out_dtype) - assert torch_vs_triton_descriptor # value is correct - - # triton naive results vs each other - triton_atol = 10.0 if out_dtype == torch.bfloat16 else 1.0 - naive_vs_persistent = torch.allclose( - c_naive.to(out_dtype), c_persistent.to(out_dtype), atol=triton_atol - ) + assert torch_vs_triton_descriptor + triton_atol = 10.0 if out_dtype == "bfloat16" else 1.0 + naive_vs_persistent = paddle.allclose( + x=c_naive.to(out_dtype), y=c_persistent.to(out_dtype), atol=triton_atol + ).item() if not naive_vs_persistent: print_all_on_failure( a, b, c_unmodified, c_torch, c_naive, c_persistent, c_descriptor, out_dtype ) print("compare c_naive and c_persistent") print_max_diff_on_failure(c_naive, c_persistent, out_dtype) - - assert naive_vs_persistent # value is correct - + assert naive_vs_persistent if c_descriptor is not None: - descriptor_atol = 10.0 if out_dtype == torch.bfloat16 else 1.0 - naive_vs_descriptor = torch.allclose( - c_naive.to(out_dtype), c_descriptor.to(out_dtype), atol=descriptor_atol - ) + descriptor_atol = 10.0 if out_dtype == "bfloat16" else 1.0 + naive_vs_descriptor = paddle.allclose( + x=c_naive.to(out_dtype), y=c_descriptor.to(out_dtype), atol=descriptor_atol + ).item() if not naive_vs_descriptor: print_all_on_failure( - a, - b, - c_unmodified, - c_torch, - c_naive, - c_persistent, - c_descriptor, + a, b, c_unmodified, c_torch, c_naive, c_persistent, c_descriptor ) print("compare c_naive and c_descriptor") print_max_diff_on_failure(c_naive, c_descriptor, out_dtype) - - assert naive_vs_descriptor # value is correct + assert naive_vs_descriptor def print_all_on_failure( @@ -164,13 +131,13 @@ def print_all_on_failure( def print_max_diff_on_failure(target1, target2, out_dtype): - max_diff = torch.max(torch.abs(target1.to(out_dtype) - target2.to(out_dtype))) + max_diff = paddle.max(x=paddle.abs(x=target1.to(out_dtype) - target2.to(out_dtype))) print(f"max diff: {max_diff}") - max_diff_index = torch.argmax( - torch.abs(target1.to(out_dtype) - target2.to(out_dtype)) + max_diff_index = paddle.argmax( + x=paddle.abs(x=target1.to(out_dtype) - target2.to(out_dtype)) ) print(f"max diff index: {max_diff_index}") if target1.dim() > 1: - max_diff_index = torch.unravel_index(max_diff_index, target1.shape) +>>>>>> max_diff_index = torch.unravel_index(max_diff_index, tuple(target1.shape)) print(f"target1[max_diff_index]: {target1[max_diff_index]}") print(f"target2[max_diff_index]: {target2[max_diff_index]}") diff --git a/tests/test_tensor_cores_decode.py b/tests/test_tensor_cores_decode.py index 49df0a5d94..6585f49192 100644 --- a/tests/test_tensor_cores_decode.py +++ b/tests/test_tensor_cores_decode.py @@ -1,3 +1,5 @@ +import paddle + """ Copyright (c) 2024 by FlashInfer team. @@ -13,10 +15,9 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch -from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules +from jit_utils import (gen_decode_attention_modules, + gen_prefill_attention_modules) import flashinfer @@ -25,21 +26,10 @@ def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps + ["float16"], ["float16"], [64, 128, 256], [0, 1], [False], [False] ) + gen_prefill_attention_modules( - [torch.float16], # q_dtypes - [torch.float16], # kv_dtypes - [64, 128, 256], # head_dims - [0, 1], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_caps - [False], # use_fp16_qk_reductions + ["float16"], ["float16"], [64, 128, 256], [0, 1], [False], [False], [False] ), verbose=False, ) @@ -61,34 +51,24 @@ def test_single_decode_tensor_cores( pos_encoding_mode: str, ): num_qo_heads = num_kv_heads * group_size - q = torch.randn(num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16) + q = paddle.randn(shape=[num_qo_heads, head_dim], dtype="float16") k = ( - torch.randn( - num_kv_heads, kv_len, head_dim, device="cuda:0", dtype=torch.float16 - ) + paddle.randn(shape=[num_kv_heads, kv_len, head_dim], dtype="float16") if kv_layout == "HND" - else torch.randn( - kv_len, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16 - ) + else paddle.randn(shape=[kv_len, num_kv_heads, head_dim], dtype="float16") ) v = ( - torch.randn( - num_kv_heads, kv_len, head_dim, device="cuda:0", dtype=torch.float16 - ) + paddle.randn(shape=[num_kv_heads, kv_len, head_dim], dtype="float16") if kv_layout == "HND" - else torch.randn( - kv_len, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16 - ) + else paddle.randn(shape=[kv_len, num_kv_heads, head_dim], dtype="float16") ) - o = flashinfer.single_decode_with_kv_cache( q, k, v, kv_layout, pos_encoding_mode, use_tensor_cores=False ) o_tensor_cores = flashinfer.single_decode_with_kv_cache( q, k, v, kv_layout, pos_encoding_mode, use_tensor_cores=True ) - - torch.testing.assert_close(o, o_tensor_cores, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o, y=o_tensor_cores, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [12, 17]) @@ -110,44 +90,30 @@ def test_batch_decode_tensor_cores( pos_encoding_mode: str, ): num_qo_heads = num_kv_heads * group_size - q = torch.randn( - batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 - ) + q = paddle.randn(shape=[batch_size, num_qo_heads, head_dim], dtype="float16") num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( - torch.randn( - total_num_pages, - 2, - num_kv_heads, - page_size, - head_dim, - device="cuda:0", - dtype=torch.float16, + paddle.randn( + shape=[total_num_pages, 2, num_kv_heads, page_size, head_dim], + dtype="float16", ) / 10 if kv_layout == "HND" - else torch.randn( - total_num_pages, - 2, - page_size, - num_kv_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, + else paddle.randn( + shape=[total_num_pages, 2, page_size, num_kv_heads, head_dim], + dtype="float16", ) / 10 ) kv_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) - * num_pages_per_seq + paddle.arange(start=0, end=batch_size + 1, dtype="int32") * num_pages_per_seq ) - kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) - kv_last_page_len = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" + kv_indices = paddle.arange(start=0, end=total_num_pages, dtype="int32") + kv_last_page_len = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" ) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout) wrapper.plan( kv_indptr, @@ -158,11 +124,10 @@ def test_batch_decode_tensor_cores( head_dim, page_size, pos_encoding_mode=pos_encoding_mode, - data_type=torch.float16, - q_data_type=torch.float16, + data_type="float16", + q_data_type="float16", ) o, lse = wrapper.run(q, kv_data, return_lse=True) - wrapper_tensor_cores = flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, use_tensor_cores=True ) @@ -175,15 +140,14 @@ def test_batch_decode_tensor_cores( head_dim, page_size, pos_encoding_mode=pos_encoding_mode, - data_type=torch.float16, - q_data_type=torch.float16, + data_type="float16", + q_data_type="float16", ) o_tensor_cores, lse_tensor_cores = wrapper_tensor_cores.run( q, kv_data, return_lse=True ) - - torch.testing.assert_close(o, o_tensor_cores, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(lse, lse_tensor_cores, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o, y=o_tensor_cores, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=lse, y=lse_tensor_cores, rtol=0.001, atol=0.001).item(), "" @pytest.mark.parametrize("batch_size", [12, 17]) @@ -205,46 +169,30 @@ def test_batch_decode_tensor_cores_cuda_graph( pos_encoding_mode: str, ): num_qo_heads = num_kv_heads * group_size - q = torch.randn( - batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 - ) + q = paddle.randn(shape=[batch_size, num_qo_heads, head_dim], dtype="float16") num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( - torch.randn( - total_num_pages, - 2, - num_kv_heads, - page_size, - head_dim, - device="cuda:0", - dtype=torch.float16, + paddle.randn( + shape=[total_num_pages, 2, num_kv_heads, page_size, head_dim], + dtype="float16", ) / 10 if kv_layout == "HND" - else torch.randn( - total_num_pages, - 2, - page_size, - num_kv_heads, - head_dim, - device="cuda:0", - dtype=torch.float16, + else paddle.randn( + shape=[total_num_pages, 2, page_size, num_kv_heads, head_dim], + dtype="float16", ) / 10 ) kv_indptr = ( - torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) - * num_pages_per_seq + paddle.arange(start=0, end=batch_size + 1, dtype="int32") * num_pages_per_seq ) - kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) - kv_last_page_len = torch.full( - (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" + kv_indices = paddle.arange(start=0, end=total_num_pages, dtype="int32") + kv_last_page_len = paddle.full( + shape=(batch_size,), fill_value=(kv_len - 1) % page_size + 1, dtype="int32" ) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") - - # cuda cores wrapper + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, @@ -262,26 +210,19 @@ def test_batch_decode_tensor_cores_cuda_graph( head_dim, page_size, pos_encoding_mode=pos_encoding_mode, - data_type=torch.float16, - q_data_type=torch.float16, + data_type="float16", + q_data_type="float16", ) - # warmup - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): + s = paddle.device.Stream() + s.wait_stream(paddle.device.current_stream()) + with paddle.device.stream_guard(stream=s): for _ in range(3): o, lse = wrapper.run(q, kv_data, return_lse=True) - torch.cuda.current_stream().wait_stream(s) - - # capture - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): + paddle.device.current_stream().wait_stream(s) +>>>>>> g = torch.cuda.CUDAGraph() +>>>>>> with torch.cuda.graph(g): o, lse = wrapper.run(q, kv_data, return_lse=True) - - # replay g.replay() - - # cuda cores wrapper wrapper_tensor_cores = flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, @@ -300,28 +241,22 @@ def test_batch_decode_tensor_cores_cuda_graph( head_dim, page_size, pos_encoding_mode=pos_encoding_mode, - data_type=torch.float16, - q_data_type=torch.float16, + data_type="float16", + q_data_type="float16", ) - # warmup - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): + s = paddle.device.Stream() + s.wait_stream(paddle.device.current_stream()) + with paddle.device.stream_guard(stream=s): for _ in range(3): o_tensor_cores, lse_tensor_cores = wrapper_tensor_cores.run( q, kv_data, return_lse=True ) - torch.cuda.current_stream().wait_stream(s) - - # capture - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): + paddle.device.current_stream().wait_stream(s) +>>>>>> g = torch.cuda.CUDAGraph() +>>>>>> with torch.cuda.graph(g): o_tensor_cores, lse_tensor_cores = wrapper_tensor_cores.run( q, kv_data, return_lse=True ) - - # replay g.replay() - - torch.testing.assert_close(o, o_tensor_cores, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(lse, lse_tensor_cores, rtol=1e-3, atol=1e-3) + assert paddle.allclose(x=o, y=o_tensor_cores, rtol=0.001, atol=0.001).item(), "" + assert paddle.allclose(x=lse, y=lse_tensor_cores, rtol=0.001, atol=0.001).item(), "" diff --git a/tests/test_triton_cascade.py b/tests/test_triton_cascade.py index ae8f2b423e..7cc80a178b 100644 --- a/tests/test_triton_cascade.py +++ b/tests/test_triton_cascade.py @@ -1,5 +1,5 @@ +import paddle import pytest -import torch import flashinfer import flashinfer.triton @@ -9,34 +9,44 @@ @pytest.mark.parametrize("num_heads", [32]) @pytest.mark.parametrize("head_dim", [128]) def test_merge_state(seq_len, num_heads, head_dim): - va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") - sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") - vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") - sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") + va = ( + paddle.randn(shape=[seq_len, num_heads, head_dim]) + .astype(dtype="float16") + .to("gpu:0") + ) + sa = paddle.randn(shape=[seq_len, num_heads], dtype="float32").to("gpu:0") + vb = ( + paddle.randn(shape=[seq_len, num_heads, head_dim]) + .astype(dtype="float16") + .to("gpu:0") + ) + sb = paddle.randn(shape=[seq_len, num_heads], dtype="float32").to("gpu:0") v_merged, s_merged = flashinfer.triton.cascade.merge_state(va, sa, vb, sb) v_merged_std, s_merged_std = flashinfer.merge_state(va, sa, vb, sb) - - assert torch.allclose(v_merged, v_merged_std, atol=1e-2) - assert torch.allclose(s_merged, s_merged_std, atol=1e-2) + assert paddle.allclose(x=v_merged, y=v_merged_std, atol=0.01).item() + assert paddle.allclose(x=s_merged, y=s_merged_std, atol=0.01).item() @pytest.mark.parametrize("seq_len", [2048]) @pytest.mark.parametrize("num_heads", [32]) @pytest.mark.parametrize("head_dim", [128]) def test_merge_state_in_place(seq_len, num_heads, head_dim): - v = torch.randn(seq_len, num_heads, head_dim).half() + v = paddle.randn(shape=[seq_len, num_heads, head_dim]).astype(dtype="float16") v_std = v.clone() - v, v_std = v.to("cuda:0"), v_std.to("cuda:0") - s = torch.randn(seq_len, num_heads, dtype=torch.float32) + v, v_std = v.to("gpu:0"), v_std.to("gpu:0") + s = paddle.randn(shape=[seq_len, num_heads], dtype="float32") s_std = s.clone() - s, s_std = s.to("cuda:0"), s_std.to("cuda:0") - v_other = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") - s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") + s, s_std = s.to("gpu:0"), s_std.to("gpu:0") + v_other = ( + paddle.randn(shape=[seq_len, num_heads, head_dim]) + .astype(dtype="float16") + .to("gpu:0") + ) + s_other = paddle.randn(shape=[seq_len, num_heads], dtype="float32").to("gpu:0") flashinfer.merge_state_in_place(v_std, s_std, v_other, s_other) flashinfer.triton.cascade.merge_state_in_place(v, s, v_other, s_other) - - assert torch.allclose(v, v_std, atol=1e-2) - assert torch.allclose(s, s_std, atol=1e-2) + assert paddle.allclose(x=v, y=v_std, atol=0.01).item() + assert paddle.allclose(x=s, y=s_std, atol=0.01).item() @pytest.mark.parametrize("seq_len", [2048]) @@ -44,13 +54,18 @@ def test_merge_state_in_place(seq_len, num_heads, head_dim): @pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("num_states", [100]) def test_merge_states(seq_len, num_states, num_heads, head_dim): - v = torch.randn(seq_len, num_states, num_heads, head_dim).half().to("cuda:0") - s = torch.randn(seq_len, num_states, num_heads, dtype=torch.float32).to("cuda:0") + v = ( + paddle.randn(shape=[seq_len, num_states, num_heads, head_dim]) + .astype(dtype="float16") + .to("gpu:0") + ) + s = paddle.randn(shape=[seq_len, num_states, num_heads], dtype="float32").to( + "gpu:0" + ) v_merged_std, s_merged_std = flashinfer.merge_states(v, s) v_merged, s_merged = flashinfer.triton.cascade.merge_states(v, s) - - assert torch.allclose(v_merged, v_merged_std, atol=1e-2) - assert torch.allclose(s_merged, s_merged_std, atol=1e-2) + assert paddle.allclose(x=v_merged, y=v_merged_std, atol=0.01).item() + assert paddle.allclose(x=s_merged, y=s_merged_std, atol=0.01).item() @pytest.mark.parametrize("seq_len", [2048]) @@ -58,24 +73,28 @@ def test_merge_states(seq_len, num_states, num_heads, head_dim): @pytest.mark.parametrize("head_dim", [128]) def test_variable_length_merge_states(seq_len, num_heads, head_dim): max_index_sets = 512 - lengths = torch.randint(low=1, high=max_index_sets, size=(seq_len,)) + lengths = paddle.randint(low=1, high=max_index_sets, shape=(seq_len,)) indptr = [0] for i in range(seq_len): indptr.append(indptr[-1] + lengths[i]) - v = torch.randn(indptr[-1], num_heads, head_dim).half().to("cuda:0") - s = torch.randn(indptr[-1], num_heads, dtype=torch.float32).to("cuda:0") - indptr = torch.tensor(indptr, dtype=torch.int32).to("cuda:0") + v = ( + paddle.randn(shape=[indptr[-1], num_heads, head_dim]) + .astype(dtype="float16") + .to("gpu:0") + ) + s = paddle.randn(shape=[indptr[-1], num_heads], dtype="float32").to("gpu:0") + indptr = paddle.to_tensor(data=indptr, dtype="int32").to("gpu:0") v_merged, s_merged = flashinfer.triton.cascade.variable_length_merge_states( v, s, indptr ) for i in range(seq_len): sub_v = v[indptr[i] : indptr[i + 1]] sub_s = s[indptr[i] : indptr[i + 1]] - sub_v = torch.unsqueeze(sub_v, 0) - sub_s = torch.unsqueeze(sub_s, 0) + sub_v = paddle.unsqueeze(x=sub_v, axis=0) + sub_s = paddle.unsqueeze(x=sub_s, axis=0) v_merged_std, s_merged_std = flashinfer.merge_states(sub_v, sub_s) - v_merged_std = torch.squeeze(v_merged_std, 0) - s_merged_std = torch.squeeze(s_merged_std, 0) - assert v_merged[i].shape == v_merged_std.shape - assert torch.allclose(v_merged[i], v_merged_std, atol=1e-2) - assert torch.allclose(s_merged[i], s_merged_std, atol=1e-2) + v_merged_std = paddle.squeeze(x=v_merged_std, axis=0) + s_merged_std = paddle.squeeze(x=s_merged_std, axis=0) + assert tuple(v_merged[i].shape) == tuple(v_merged_std.shape) + assert paddle.allclose(x=v_merged[i], y=v_merged_std, atol=0.01).item() + assert paddle.allclose(x=s_merged[i], y=s_merged_std, atol=0.01).item() diff --git a/tests/test_trtllm_allreduce.py b/tests/test_trtllm_allreduce.py index 7b9ddfc743..ee6eab6ff1 100644 --- a/tests/test_trtllm_allreduce.py +++ b/tests/test_trtllm_allreduce.py @@ -1,10 +1,13 @@ +import sys + + import multiprocessing as mp import socket from typing import Any +import paddle import pytest -import torch -import torch.distributed as dist +from flashinfer.paddle_utils import * import flashinfer.comm as comm @@ -20,36 +23,26 @@ If new trt-llm source kernels are available (function name starts with "trtllm_"), we would recommend using them. """ - maxBatchSize = 1 maxBeamWidth = 3 maxTokenNum = 128 -maxHiddenSize = 4096 # max hidden size for all reduce +maxHiddenSize = 4096 RANDOM_SEED = 42 def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) + device = device2str(f"cuda:{rank}") + paddle.device.set_device(device=device2str(device)) distributed_init_method = f"tcp://localhost:{distributed_init_port}" - dist.init_process_group( - backend="nccl", - init_method=distributed_init_method, - rank=rank, - world_size=world_size, - ) - group = dist.group.WORLD - + paddle.distributed.init_parallel_env() +>>>>>> group = torch.distributed.group.WORLD try: - device = torch.device(f"cuda:{rank}") + device = device2str(f"cuda:{rank}") token_nums = [64, 128] strategy_codes = [ comm.AllReduceStrategyType.ONESHOT, comm.AllReduceStrategyType.TWOSHOT, ] - - # below are the recommended hidden sizes for custom all-reduce in trtllm test - # hidden_size should be in range [256, 8192], and maxHiddenSize should be 8192 hidden_sizes = [1024, 4096] config_codes = [ 0, @@ -60,17 +53,8 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): comm.AllReduceFusionOp.NONE, comm.AllReduceFusionOp.RESIDUAL_RMS_NORM, comm.AllReduceFusionOp.RESIDUAL_RMS_PREPOST_NORM, - # Below are not enabled for custom all-reduce in trtllm test, skip - # comm.AllReduceFusionOp.LAST_PROCESS_FOR_UB, - # comm.AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8, - # comm.AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4, - # comm.AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8, - # comm.AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4, - # comm.AllReduceFusionOp.MOE_ALLREDUCE_RESIDUAL_RMS_NORM, ] launch_with_pdls = [True, False] - - # create ipc memory workspace = comm.trtllm_create_ipc_workspace_for_all_reduce( rank=rank, tp_size=world_size, @@ -78,10 +62,7 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): hidden_dim=maxHiddenSize, group=group, ) - - test_loop = 2 # could be any number - - # NOTE: the barrier flag should be initialized to 1, and incremented by 1 for each AR + test_loop = 2 flag_value = 1 for token_num in token_nums: for hidden_size in hidden_sizes: @@ -95,38 +76,30 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): and fusion_op_code == comm.AllReduceFusionOp.RESIDUAL_RMS_PREPOST_NORM ): - # skip twoshot pre-post norm: not supported in trtllm test continue print( f"test RANK {rank}: {world_size}-{dtype}-{strategy_code}-{config_code}-{fusion_op_code}-{launch_with_pdl}-{hidden_size} start" ) - torch.cuda.synchronize() + paddle.device.synchronize() for _ in range(test_loop): message_size = token_num * hidden_size - inp1 = torch.randn( - message_size, dtype=dtype, device=device - ) + inp1 = paddle.randn(shape=message_size, dtype=dtype) inp1_ref = inp1.clone() - out1 = torch.empty_like(inp1) - - # init params for each fusion op - bias = torch.randn( - hidden_size, dtype=dtype, device=device + out1 = paddle.empty_like(x=inp1) + bias = paddle.randn(shape=hidden_size, dtype=dtype) + residual = paddle.randn( + shape=message_size, dtype=dtype ) - residual = torch.randn( - message_size, dtype=dtype, device=device + weight = paddle.randn( + shape=hidden_size, dtype=dtype ) - weight = torch.randn( - hidden_size, dtype=dtype, device=device + weight_pre_residual_norm = paddle.randn( + shape=hidden_size, dtype=dtype ) - weight_pre_residual_norm = torch.randn( - hidden_size, dtype=dtype, device=device + eps = 1e-06 + intermediate_buffer = paddle.zeros( + shape=message_size, dtype=dtype ) - eps = 1e-6 - intermediate_buffer = torch.zeros( - message_size, dtype=dtype, device=device - ) - comm.trtllm_custom_all_reduce( inp=inp1, out=out1, @@ -138,14 +111,14 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): config_code=config_code, launch_with_pdl=launch_with_pdl, flag_value=flag_value, - peer_comm_buffer_ptrs=torch.tensor( - workspace[0], dtype=torch.int64 + peer_comm_buffer_ptrs=paddle.to_tensor( + data=workspace[0], dtype="int64" ), - peer_barrier_ptrs_in=torch.tensor( - workspace[2], dtype=torch.int64 + peer_barrier_ptrs_in=paddle.to_tensor( + data=workspace[2], dtype="int64" ), - peer_barrier_ptrs_out=torch.tensor( - workspace[3], dtype=torch.int64 + peer_barrier_ptrs_out=paddle.to_tensor( + data=workspace[3], dtype="int64" ), bias=bias, residual=residual, @@ -153,92 +126,83 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): weight_pre_residual_norm=weight_pre_residual_norm, eps=eps, intermediate_buffer=intermediate_buffer, - lamport_peer_comm_buffer_ptrs_0=torch.tensor( - workspace[4], dtype=torch.int64 + lamport_peer_comm_buffer_ptrs_0=paddle.to_tensor( + data=workspace[4], dtype="int64" ), - lamport_peer_comm_buffer_ptrs_1=torch.tensor( - workspace[5], dtype=torch.int64 + lamport_peer_comm_buffer_ptrs_1=paddle.to_tensor( + data=workspace[5], dtype="int64" ), - lamport_peer_comm_buffer_ptrs_2=torch.tensor( - workspace[6], dtype=torch.int64 + lamport_peer_comm_buffer_ptrs_2=paddle.to_tensor( + data=workspace[6], dtype="int64" ), ) - dist.all_reduce(inp1_ref, group=group) - - tolerance = 1e-2 if dtype == torch.float16 else 8e-2 - + paddle.distributed.all_reduce( + tensor=inp1_ref, group=group + ) + tolerance = 0.01 if dtype == "float16" else 0.08 if fusion_op_code == comm.AllReduceFusionOp.NONE: - torch.testing.assert_close( - out1, inp1_ref, atol=tolerance, rtol=3e-2 - ) + assert paddle.allclose( + x=out1, + y=inp1_ref, + atol=tolerance, + rtol=0.03, + ).item(), "" elif ( fusion_op_code == comm.AllReduceFusionOp.RESIDUAL_RMS_NORM ): - # cache intermediate_buffer to inter_buffer inter_buffer = intermediate_buffer.clone() - - # residual and bias ref = inp1_ref.clone() - ref_float = ref.to(torch.float32) - residual_float = residual.to(torch.float32) - bias_float = bias.to(torch.float32) - - for i in range(ref.numel()): + ref_float = ref.to("float32") + residual_float = residual.to("float32") + bias_float = bias.to("float32") + for i in range(ref.size): ref_float[i] += ( residual_float[i] + bias_float[i % hidden_size] ) ref_half = ref_float.to(dtype) - torch.testing.assert_close( - inter_buffer, - ref_half, + assert paddle.allclose( + x=inter_buffer, + y=ref_half, atol=tolerance, - rtol=3e-2, - ) - - # RMSNorm over hidden size + rtol=0.03, + ).item(), "" ref_float = ref_float.view( token_num, hidden_size ) - normed_float = torch.empty_like(ref_float) - - mean_sq = torch.mean( - ref_float * ref_float, dim=-1, keepdim=True + normed_float = paddle.empty_like(x=ref_float) + mean_sq = paddle.mean( + x=ref_float * ref_float, + axis=-1, + keepdim=True, ) - denom = torch.sqrt(mean_sq + eps) + denom = paddle.sqrt(x=mean_sq + eps) normed_float = ref_float / denom normed_float = normed_float * weight.to( - torch.float32 + "float32" ) normed_half = normed_float.to(dtype) - torch.testing.assert_close( - out1, - normed_half.view(-1), + assert paddle.allclose( + x=out1, + y=normed_half.view(-1), atol=tolerance, - rtol=3e-2, - ) - + rtol=0.03, + ).item(), "" elif ( fusion_op_code == comm.AllReduceFusionOp.RESIDUAL_RMS_PREPOST_NORM ): - # NOTE(yingyi): bugfix todo, the test invokes nccl timeout for now pass - flag_value += 1 if pass_flag: print( f"test RANK {rank}: {world_size}-{dtype}-{strategy_code}-{config_code}-{fusion_op_code}-{launch_with_pdl}-{hidden_size} passed" ) - # torch.cuda.synchronize() - # # you might want to enable this barrier for a better log output, but it's not mandatory across allReduce calls finally: - dist.barrier(group=group) - + paddle.distributed.barrier(group=group) comm.trtllm_destroy_ipc_workspace_for_all_reduce(workspace, group=group) - - dist.destroy_process_group(group=group) +>>>>>> torch.distributed.destroy_process_group(group=group) def get_open_port() -> int: @@ -253,40 +217,33 @@ def get_open_port() -> int: def multi_process_parallel( - world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = () + world_size: int, dtype: paddle.dtype, test_target: Any, target_args: tuple = () ) -> None: mp.set_start_method("spawn", force=True) - procs = [] distributed_init_port = get_open_port() for i in range(world_size): proc_args = (world_size, i, dtype, distributed_init_port) + target_args proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") - proc.start() + """Not Support auto convert *.start, please judge whether it is Pytorch API and convert by yourself""" +>>>>>> proc.start() procs.append(proc) - for i in range(world_size): procs[i].join() - assert procs[i].exitcode == 0, ( - f"Process {i} failed with exit code {procs[i].exitcode}" - ) + assert ( + procs[i].exitcode == 0 + ), f"Process {i} failed with exit code {procs[i].exitcode}" @pytest.mark.parametrize("world_size", [2, 4]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) def test_trtllm_custom_allreduce(world_size, dtype): - torch.manual_seed(RANDOM_SEED) - available_gpus = torch.cuda.device_count() + paddle.seed(seed=RANDOM_SEED) + available_gpus = paddle.device.cuda.device_count() if world_size > available_gpus: raise ValueError( f"world_size {world_size} is greater than available_gpus {available_gpus}" ) print(f"Running test for world_size={world_size}") - - multi_process_parallel( - world_size, - dtype, - _run_correctness_worker, - target_args=(), - ) + multi_process_parallel(world_size, dtype, _run_correctness_worker, target_args=()) print(f"custom allreduce tp = {world_size}: OK") diff --git a/tests/test_trtllm_allreduce_fusion.py b/tests/test_trtllm_allreduce_fusion.py index d4229cec23..d6c98059d6 100644 --- a/tests/test_trtllm_allreduce_fusion.py +++ b/tests/test_trtllm_allreduce_fusion.py @@ -1,40 +1,32 @@ +import sys + + import multiprocessing as mp import socket from typing import Any import numpy as np +import paddle import pytest -import torch -import torch.distributed as dist +from flashinfer.paddle_utils import * import flashinfer.comm as comm -# todo(Yingyi): add benchmark and quant test - -# Usage: test var kOneShotMaxTokenNum = 128 MIN_TOKEN_NUM = 1 MAX_TOKEN_NUM = 2048 SF_VEC_SIZE = 16 - -# temp var -SCALE_FACTOR_RANGE = (-1, 1) +SCALE_FACTOR_RANGE = -1, 1 def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_init_port): - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) + device = device2str(f"cuda:{rank}") + paddle.device.set_device(device=device2str(device)) distributed_init_method = f"tcp://localhost:{distributed_init_port}" - dist.init_process_group( - backend="nccl", - init_method=distributed_init_method, - rank=rank, - world_size=world_size, - ) - group = dist.group.WORLD - + paddle.distributed.init_parallel_env() +>>>>>> group = torch.distributed.group.WORLD try: - device = torch.device(f"cuda:{rank}") + device = device2str(f"cuda:{rank}") token_nums = [1, 128, 1024, 2048] pattern_codes = [ comm.AllReduceFusionPattern.kAllReduce, @@ -53,23 +45,19 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini use_oneshots = [True, False, None] trigger_completion_at_ends = [True, False] fp32_accs = [True, False] - - lamport_use_fp32 = dtype == torch.float32 - - # create workspace for allreduce fusion - ipc_handles, workspace_tensor = ( - comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( - rank, - world_size, - MAX_TOKEN_NUM, - hidden_dim, - group=group, - use_fp32_lamport=lamport_use_fp32, - ) + lamport_use_fp32 = dtype == "float32" + ( + ipc_handles, + workspace_tensor, + ) = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, + world_size, + MAX_TOKEN_NUM, + hidden_dim, + group=group, + use_fp32_lamport=lamport_use_fp32, ) - test_loop = 5 - for token_num in token_nums: for pattern_code in pattern_codes: for swizzled_layout_code in swizzled_layout_codes: @@ -79,87 +67,74 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini for fp32_acc in fp32_accs: if token_num < world_size and not use_oneshot: continue - if dtype == torch.float32 and ( + if dtype == "float32" and ( pattern_code == comm.AllReduceFusionPattern.kARResidualRMSNormOutFP4Quant or pattern_code == comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant ): continue - - dist.barrier(group=group) + paddle.distributed.barrier(group=group) test_passed = True print( f"test RANK {rank}: token{token_num}-hidden_dim{hidden_dim}-dtype{dtype}-pattern{pattern_code}-layout{swizzled_layout_code}-pdl{launch_with_pdl} start" ) - dist.barrier(group=group) - torch.cuda.synchronize() - + paddle.distributed.barrier(group=group) + paddle.device.synchronize() message_size = token_num * hidden_dim - - allreduce_in = torch.randn( - message_size, dtype=dtype, device=device + allreduce_in = paddle.randn( + shape=message_size, dtype=dtype ) allreduce_in_clone = allreduce_in.clone() - - all_reduce_out = torch.zeros( - message_size, dtype=dtype, device=device + all_reduce_out = paddle.zeros( + shape=message_size, dtype=dtype ) - - residual_in = torch.randn( - message_size, dtype=dtype, device=device + residual_in = paddle.randn( + shape=message_size, dtype=dtype ) residual_in_clone = residual_in.clone() - - residual_out = torch.empty_like(residual_in) - norm_out = torch.empty_like(residual_in) - quant_out = torch.empty( - message_size, dtype=dtype, device=device + residual_out = paddle.empty_like(x=residual_in) + norm_out = paddle.empty_like(x=residual_in) + quant_out = paddle.empty( + shape=message_size, dtype=dtype ) - scale_out = None - assert hidden_dim % SF_VEC_SIZE == 0, ( - "hidden_dim must be divisible by SF_VEC_SIZE" - ) + assert ( + hidden_dim % SF_VEC_SIZE == 0 + ), "hidden_dim must be divisible by SF_VEC_SIZE" if ( swizzled_layout_code == comm.QuantizationSFLayout.SWIZZLED_128x4 ): - # TODO(Yingyi): check this padded_message_size = ( - (token_num + 127) // 128 * 128 - ) * ((hidden_dim + 63) // 64 * 4) - scale_out = torch.empty( - padded_message_size, - dtype=dtype, - device=device, + (token_num + 127) + // 128 + * 128 + * ((hidden_dim + 63) // 64 * 4) + ) + scale_out = paddle.empty( + shape=padded_message_size, dtype=dtype ) else: - scale_out = torch.empty( - message_size // SF_VEC_SIZE, + scale_out = paddle.empty( + shape=message_size // SF_VEC_SIZE, dtype=dtype, - device=device, ) - - rms_gamma = torch.randn( - hidden_dim, dtype=dtype, device=device + rms_gamma = paddle.randn( + shape=hidden_dim, dtype=dtype ) scale_factor = ( - torch.rand( - 1, dtype=torch.float32, device=device - ) + paddle.rand(shape=[1], dtype="float32") * ( SCALE_FACTOR_RANGE[1] - SCALE_FACTOR_RANGE[0] ) + SCALE_FACTOR_RANGE[0] ) - rms_eps = 1e-3 - - # warmup - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): + rms_eps = 0.001 + s = paddle.device.Stream() + s.wait_stream(paddle.device.current_stream()) + with paddle.device.stream_guard(stream=s): for _ in range(test_loop): comm.trtllm_allreduce_fusion( allreduce_in=allreduce_in, @@ -184,11 +159,8 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini scale_factor=scale_factor, layout_code=swizzled_layout_code, ) - - # NOTE: in real case, you dont have to set all optional params. You could set those required by fusion pattern. - # capture - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): +>>>>>> g = torch.cuda.CUDAGraph() +>>>>>> with torch.cuda.graph(g): for _ in range(test_loop): comm.trtllm_allreduce_fusion( allreduce_in=allreduce_in, @@ -213,11 +185,8 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini scale_factor=scale_factor, layout_code=swizzled_layout_code, ) - # replay g.replay() - torch.cuda.synchronize() - - # match shape + paddle.device.synchronize() all_reduce_out = all_reduce_out.view( token_num, hidden_dim ) @@ -225,73 +194,61 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini token_num, hidden_dim ) norm_out = norm_out.view(token_num, hidden_dim) - - torch.cuda.synchronize() - - # calculate reference - # allreduce_out - dist.all_reduce(allreduce_in_clone, group=group) + paddle.device.synchronize() + paddle.distributed.all_reduce( + tensor=allreduce_in_clone, group=group + ) ref_allreduce_out = allreduce_in_clone.clone() ref_allreduce_out = ref_allreduce_out.view( token_num, hidden_dim - ).to(torch.float32) - - # residual_out + ).to("float32") ref_residual_out = ( ref_allreduce_out + residual_in_clone.view( token_num, hidden_dim - ).to(torch.float32) + ).to("float32") ) - - # norm_out variance = ( - ref_residual_out.to(torch.float32) - .pow(2) - .mean(dim=-1, keepdim=True) + ref_residual_out.to("float32") + .pow(y=2) + .mean(axis=-1, keepdim=True) ) - hidden_states = ref_residual_out * torch.rsqrt( - variance + rms_eps + hidden_states = ref_residual_out * paddle.rsqrt( + x=variance + rms_eps ) ref_norm_out = ( - rms_gamma.to(torch.float32) * hidden_states + rms_gamma.to("float32") * hidden_states ) - - # check correctness - tolerance = 8e-2 if dtype == torch.float16 else 8e-1 - # compare allreduce_out + tolerance = 0.08 if dtype == "float16" else 0.8 if ( pattern_code == comm.AllReduceFusionPattern.kAllReduce ): - torch.testing.assert_close( - all_reduce_out.to(torch.float32), - ref_allreduce_out, + assert paddle.allclose( + x=all_reduce_out.to("float32"), + y=ref_allreduce_out, atol=tolerance, - rtol=1e-2, - ) + rtol=0.01, + ).item(), "" elif ( pattern_code == comm.AllReduceFusionPattern.kARResidualRMSNormOutFP8Quant or pattern_code == comm.AllReduceFusionPattern.kARResidualRMSNormOutFP4Quant ): - torch.testing.assert_close( - residual_out.to(torch.float32), - ref_residual_out, + assert paddle.allclose( + x=residual_out.to("float32"), + y=ref_residual_out, atol=tolerance, - rtol=1e-2, - ) - - torch.testing.assert_close( - norm_out.to(torch.float32), - ref_norm_out, + rtol=0.01, + ).item(), "" + assert paddle.allclose( + x=norm_out.to("float32"), + y=ref_norm_out, atol=tolerance, - rtol=1e-2, - ) - - # todo(Yingyi): check quant out - dist.barrier(group=group) + rtol=0.01, + ).item(), "" + paddle.distributed.barrier(group=group) if test_passed: print( f"test RANK {rank}: token{token_num}-hidden_dim{hidden_dim}-dtype{dtype}-pattern{pattern_code}-layout{swizzled_layout_code}-pdl{launch_with_pdl} passed" @@ -301,11 +258,9 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini f"test RANK {rank}: token{token_num}-hidden_dim{hidden_dim}-dtype{dtype}-pattern{pattern_code}-layout{swizzled_layout_code}-pdl{launch_with_pdl} failed" ) finally: - dist.barrier(group=group) - + paddle.distributed.barrier(group=group) comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group) - - dist.destroy_process_group(group=group) +>>>>>> torch.distributed.destroy_process_group(group=group) def get_open_port() -> int: @@ -321,13 +276,12 @@ def get_open_port() -> int: def multi_process_parallel( world_size: int, - dtype: torch.dtype, + dtype: paddle.dtype, hidden_dim: int, test_target: Any, target_args: tuple = (), ) -> None: mp.set_start_method("spawn", force=True) - procs = [] distributed_init_port = get_open_port() for i in range(world_size): @@ -339,35 +293,30 @@ def multi_process_parallel( distributed_init_port, ) + target_args proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") - proc.start() + """Not Support auto convert *.start, please judge whether it is Pytorch API and convert by yourself""" +>>>>>> proc.start() procs.append(proc) - for i in range(world_size): procs[i].join() - assert procs[i].exitcode == 0, ( - f"Process {i} failed with exit code {procs[i].exitcode}" - ) + assert ( + procs[i].exitcode == 0 + ), f"Process {i} failed with exit code {procs[i].exitcode}" @pytest.mark.parametrize("world_size", [2, 4, 8]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("hidden_dim", [1024, 2048, 4096, 7168, 8192]) def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim): np.random.seed(42) - torch.manual_seed(42) - torch.cuda.manual_seed_all(42) - available_gpus = torch.cuda.device_count() + paddle.seed(seed=42) + paddle.seed(seed=42) + available_gpus = paddle.device.cuda.device_count() if world_size > available_gpus: raise ValueError( f"world_size {world_size} is greater than available_gpus {available_gpus}" ) print(f"Running test for world_size={world_size}") - multi_process_parallel( - world_size, - dtype, - hidden_dim, - _run_correctness_worker, - target_args=(), + world_size, dtype, hidden_dim, _run_correctness_worker, target_args=() ) print(f"allreduce fusion tp = {world_size}: OK") diff --git a/tests/test_trtllm_alltoall.py b/tests/test_trtllm_alltoall.py index a219a7ffa0..eb8d4becba 100644 --- a/tests/test_trtllm_alltoall.py +++ b/tests/test_trtllm_alltoall.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2024 by FlashInfer team. @@ -13,9 +19,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch import flashinfer.comm.trtllm_alltoall as tllm_alltoall @@ -27,59 +31,48 @@ def setup_test_environment(): """Set up test environment and warm up JIT compilation.""" global has_setup_max_sm_count if not has_setup_max_sm_count: - # Set up SM count once for all tests - sm_count = torch.cuda.get_device_properties(0).multi_processor_count - max_sm_count = sm_count // 8 # Maximum world size is 8 + sm_count = paddle.device.cuda.get_device_properties( + device="gpu:0" + ).multi_processor_count + max_sm_count = sm_count // 8 tllm_alltoall.set_moe_max_usable_sm_count(max_sm_count) has_setup_max_sm_count = True - - torch.manual_seed(0x1234) + paddle.seed(seed=4660) yield -# Single GPU test parameters SINGLE_GPU_PARAMS = [ - (902, 701, 32768, 100, torch.float16), # Large data, float16 - (101, 75, 288, 10, torch.float16), # Medium data, float16 - (10, 5, 8, 1, torch.float16), # Small data, float16 - (902, 701, 7168, 100, torch.bfloat16), # Large data, bfloat16 - (101, 75, 288, 10, torch.bfloat16), # Medium data, bfloat16 + (902, 701, 32768, 100, "float16"), + (101, 75, 288, 10, "float16"), + (10, 5, 8, 1, "float16"), + (902, 701, 7168, 100, "bfloat16"), + (101, 75, 288, 10, "bfloat16"), ] - MULTI_RANK_PARAMS = [ - (2, 5, 8, torch.float16), # Small input, 2 ranks - (4, 901, 32768, torch.bfloat16), # Large input, 4 ranks - (8, 16384, 128, torch.float16), # Many small vectors, 8 ranks + (2, 5, 8, "float16"), + (4, 901, 32768, "bfloat16"), + (8, 16384, 128, "float16"), ] - PREPARE_INDICES_PARAMS = [ - (0, 8, 256, 4, 3, False), # Rank 0, small config - (1, 8, 256, 4, 3, True), # Rank 1, small config with real cumsum - (7, 8, 256, 8, 1025, False), # High rank, medium config - (7, 64, 1024, 32, 1029, True), # High rank, large config with real cumsum + (0, 8, 256, 4, 3, False), + (1, 8, 256, 4, 3, True), + (7, 8, 256, 8, 1025, False), + (7, 64, 1024, 32, 1029, True), ] - -LOCAL_GATHER_PARAMS = [ - (0, 8, 256, 4, 3), # Rank 0, small config - (7, 8, 256, 8, 32), # High rank, medium config - (7, 64, 1024, 32, 1029), # High rank, large config -] - - -# Real cross-GPU communication test parameters +LOCAL_GATHER_PARAMS = [(0, 8, 256, 4, 3), (7, 8, 256, 8, 32), (7, 64, 1024, 32, 1029)] CROSS_GPU_PARAMS = [ - (2, 100, 256, torch.float16), # 2 GPUs, 2 ranks - (2, 300, 512, torch.bfloat16), # 2 GPUs, 2 ranks, larger data - (4, 150, 256, torch.float16), # 4 GPUs, 4 ranks (if available) - (4, 400, 512, torch.float16), # 4 GPUs, 4 ranks, larger data + (2, 100, 256, "float16"), + (2, 300, 512, "bfloat16"), + (4, 150, 256, "float16"), + (4, 400, 512, "float16"), ] def get_available_gpu_count(): """Get the number of available GPUs.""" - if not torch.cuda.is_available(): + if not paddle.device.cuda.device_count() >= 1: return 0 - return torch.cuda.device_count() + return paddle.device.cuda.device_count() def requires_gpus(min_gpus): @@ -102,40 +95,21 @@ def test_moe_alltoall_single_gpu( input_entry_count, output_entry_count, vector_dim, send_recv_count, dtype ): """Test MOE alltoall communication on single GPU.""" - torch.cuda.set_device(0) - # Create a random input tensor - input_tensor = torch.randn( - input_entry_count, vector_dim, dtype=dtype, device=torch.device("cuda") - ) - output_tensor = torch.zeros( - output_entry_count, vector_dim, dtype=dtype, device=torch.device("cuda") - ) - - send_cumsum = ( - torch.ones((1,), dtype=torch.int32, device=torch.device("cuda")) - * send_recv_count - ) - recv_cumsum = ( - torch.ones((1,), dtype=torch.int32, device=torch.device("cuda")) - * send_recv_count - ) - send_indices = torch.randperm( - input_entry_count, dtype=torch.int32, device=torch.device("cuda") - )[:send_recv_count] - recv_indices = torch.randperm( - output_entry_count, dtype=torch.int32, device=torch.device("cuda") - )[:send_recv_count] - - ref_output_tensor = torch.zeros( - output_entry_count, vector_dim, dtype=dtype, device=torch.device("cuda") + paddle.device.set_device(device="gpu:0") + input_tensor = paddle.randn(shape=[input_entry_count, vector_dim], dtype=dtype) + output_tensor = paddle.zeros(shape=[output_entry_count, vector_dim], dtype=dtype) + send_cumsum = paddle.ones(shape=(1,), dtype="int32") * send_recv_count + recv_cumsum = paddle.ones(shape=(1,), dtype="int32") * send_recv_count + send_indices = paddle.randperm(n=input_entry_count, dtype="int32")[:send_recv_count] + recv_indices = paddle.randperm(n=output_entry_count, dtype="int32")[ + :send_recv_count + ] + ref_output_tensor = paddle.zeros( + shape=[output_entry_count, vector_dim], dtype=dtype ) ref_output_tensor[recv_indices] = input_tensor[send_indices] - workspace_size = tllm_alltoall.get_moe_commworkspace_size_per_rank(1) - all_workspaces = torch.zeros( - 1, workspace_size, dtype=torch.uint64, device=torch.device("cuda") - ) - +>>>>>> all_workspaces = paddle.zeros(shape=[1, workspace_size], dtype=torch.uint64) tllm_alltoall.moe_comm( input_tensor, send_cumsum, @@ -147,8 +121,9 @@ def test_moe_alltoall_single_gpu( 0, 1, ) - - torch.testing.assert_close(output_tensor, ref_output_tensor, atol=1e-5, rtol=1e-5) + assert paddle.allclose( + x=output_tensor, y=ref_output_tensor, atol=1e-05, rtol=1e-05 + ).item(), "" @pytest.mark.parametrize( @@ -158,99 +133,75 @@ def test_moe_alltoall_multi_rank_single_gpu( world_size, input_entry_per_rank, vector_dim, dtype ): """Test MOE alltoall communication with multiple ranks on single GPU.""" - torch.cuda.set_device(0) + paddle.device.set_device(device="gpu:0") max_world_size = 8 - assert world_size <= max_world_size, ( - f"should run with world_size at most {max_world_size}" + assert ( + world_size <= max_world_size + ), f"should run with world_size at most {max_world_size}" + input_tensor = paddle.randn( + shape=[input_entry_per_rank * world_size, vector_dim], dtype=dtype ) - - # SM count is now set up globally in the fixture - - # Create a random input tensor - input_tensor = torch.randn( - input_entry_per_rank * world_size, - vector_dim, - dtype=dtype, - device=torch.device("cuda"), + output_tensor = paddle.zeros( + shape=[input_entry_per_rank * world_size, vector_dim], dtype=dtype ) - output_tensor = torch.zeros( - input_entry_per_rank * world_size, - vector_dim, - dtype=dtype, - device=torch.device("cuda"), + ref_output_tensor = paddle.zeros( + shape=[input_entry_per_rank * world_size, vector_dim], dtype=dtype ) - ref_output_tensor = torch.zeros( - input_entry_per_rank * world_size, - vector_dim, - dtype=dtype, - device=torch.device("cuda"), + target_rank_ids = paddle.randint( + low=0, + high=world_size, + shape=(input_entry_per_rank * world_size,), + dtype="int32", ) - target_rank_ids = torch.randint( - 0, - world_size, - (input_entry_per_rank * world_size,), - dtype=torch.int32, - device=torch.device("cuda"), + input_tensors_all_ranks = list( + paddle_split(x=input_tensor, num_or_sections=input_entry_per_rank) + ) + target_rank_ids_all_ranks = list( + paddle_split(x=target_rank_ids, num_or_sections=input_entry_per_rank) ) - - input_tensors_all_ranks = list(torch.split(input_tensor, input_entry_per_rank)) - target_rank_ids_all_ranks = list(torch.split(target_rank_ids, input_entry_per_rank)) - send_ids_all_ranks = [] send_counts_all_ranks = [] send_cumsum_all_ranks = [] send_start_end_all_ranks = [] - - # each rank do its own local compute to get how to send data to other ranks. for rank in range(world_size): send_start_end = [] local_target_rank_ids = target_rank_ids_all_ranks[rank] - sorted_local_target_rank_ids, local_send_id = torch.sort(local_target_rank_ids) - local_send_id = local_send_id.to(torch.int32) - padded_sorted_local_target_rank_ids = torch.cat( - ( + sorted_local_target_rank_ids, local_send_id = paddle.sort( + x=local_target_rank_ids + ), paddle.argsort(x=local_target_rank_ids) + local_send_id = local_send_id.to("int32") + padded_sorted_local_target_rank_ids = paddle.concat( + x=( sorted_local_target_rank_ids, - torch.arange( - world_size, dtype=torch.int32, device=torch.device("cuda") - ), + paddle.arange(dtype="int32", end=world_size), ) ) - unique_target_rank_ids, local_send_counts = torch.unique( - padded_sorted_local_target_rank_ids, return_counts=True - ) - local_send_counts = local_send_counts.to(torch.int32) - assert unique_target_rank_ids.numel() == world_size, ( - "unique_target_rank_ids must be equal to world_size" + unique_target_rank_ids, local_send_counts = paddle.unique( + x=padded_sorted_local_target_rank_ids, return_counts=True ) - local_send_counts -= 1 # remove padding - local_send_cumsum = torch.cumsum(local_send_counts, dim=0).to(torch.int32) + local_send_counts = local_send_counts.to("int32") + assert ( + unique_target_rank_ids.size == world_size + ), "unique_target_rank_ids must be equal to world_size" + local_send_counts -= 1 + local_send_cumsum = paddle.cumsum(x=local_send_counts, axis=0).to("int32") send_ids_all_ranks.append(local_send_id) send_counts_all_ranks.append(local_send_counts) send_cumsum_all_ranks.append(local_send_cumsum) local_send_cumsum_cpu = local_send_cumsum.cpu().tolist() for i in range(len(local_send_cumsum_cpu)): send_start_end.append( - ( - local_send_cumsum_cpu[i - 1] if i > 0 else 0, - local_send_cumsum_cpu[i], - ) + (local_send_cumsum_cpu[i - 1] if i > 0 else 0, local_send_cumsum_cpu[i]) ) send_start_end_all_ranks.append(send_start_end) - recv_ids_all_ranks = [] recv_cumsum_all_ranks = [] - output_tensors_all_ranks = [] - total_recv_all_ranks_cpu = [] output_indice_offset = 0 - output_start_current_rank = 0 - # each rank do compute based on other ranks' send counts to get how to receive data from other ranks. for rank in range(world_size): - local_recv_counts = torch.zeros( - world_size, dtype=torch.int32, device=torch.device("cuda") - ) + local_recv_counts = paddle.zeros(shape=world_size, dtype="int32") for other_rank in range(world_size): local_recv_counts[other_rank] = send_counts_all_ranks[other_rank][rank] local_recv_count_pair = local_recv_counts[other_rank].cpu().item() @@ -263,7 +214,7 @@ def test_moe_alltoall_multi_rank_single_gpu( ] ] output_indice_offset += local_recv_count_pair - local_recv_cumsum = torch.cumsum(local_recv_counts, dim=0).to(torch.int32) + local_recv_cumsum = paddle.cumsum(x=local_recv_counts, axis=0).to("int32") recv_cumsum_all_ranks.append(local_recv_cumsum) total_recv_count = local_recv_cumsum[-1].cpu() total_recv_all_ranks_cpu.append(total_recv_count) @@ -273,24 +224,16 @@ def test_moe_alltoall_multi_rank_single_gpu( ] ) output_start_current_rank += total_recv_count - local_recv_ids = torch.arange( - total_recv_count, dtype=torch.int32, device=torch.device("cuda") - ) + local_recv_ids = paddle.arange(dtype="int32", end=total_recv_count) recv_ids_all_ranks.append(local_recv_ids) - - cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(world_size)] - + cuda_streams_all_ranks = [paddle.device.Stream() for _ in range(world_size)] workspace_size = tllm_alltoall.get_moe_commworkspace_size_per_rank(world_size) - all_workspaces = torch.zeros( - world_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda") + all_workspaces = paddle.zeros( +>>>>>> shape=[world_size, workspace_size], dtype=torch.uint64 ) - - # Synchronize before starting parallel communication - torch.cuda.synchronize() - - # do alltoall in parallel + paddle.device.synchronize() for rank in range(world_size): - with torch.cuda.stream(cuda_streams_all_ranks[rank]): + with paddle.device.stream_guard(stream=cuda_streams_all_ranks[rank]): tllm_alltoall.moe_comm( input_tensors_all_ranks[rank], send_cumsum_all_ranks[rank], @@ -304,8 +247,9 @@ def test_moe_alltoall_multi_rank_single_gpu( ) for rank in range(world_size): cuda_streams_all_ranks[rank].synchronize() - - torch.testing.assert_close(output_tensor, ref_output_tensor, atol=1e-5, rtol=1e-5) + assert paddle.allclose( + x=output_tensor, y=ref_output_tensor, atol=1e-05, rtol=1e-05 + ).item(), "" @pytest.mark.parametrize( @@ -321,45 +265,37 @@ def test_moe_alltoall_prepare_indices( use_real_rank_token_count_cumsum, ): """Test MOE alltoall prepare indices functionality.""" - torch.cuda.set_device(0) + paddle.device.set_device(device="gpu:0") def generate_references(): rank_token_count = max_token_count_per_rank if use_real_rank_token_count_cumsum: - # Make sure we have at least 1 token in each rank except last rank rank_token_counts = [ - max(1, torch.randint(1, max_token_count_per_rank + 1, (1,)).item()) + max( + 1, + paddle.randint( + low=1, high=max_token_count_per_rank + 1, shape=(1,) + ).item(), + ) for _ in range(ep_size - 1) ] - rank_token_counts.append( - max_token_count_per_rank - ) # last rank has max tokens + rank_token_counts.append(max_token_count_per_rank) real_rank_token_count_cumsum = ( - torch.tensor( - rank_token_counts, dtype=torch.int32, device=torch.device("cuda") + paddle.to_tensor( + data=rank_token_counts, dtype="int32", place=device2str("gpu") ) - .cumsum(dim=0) - .to(torch.int32) + .cumsum(axis=0) + .to("int32") ) rank_token_count = rank_token_counts[ep_rank] else: real_rank_token_count_cumsum = None - - # Generate target rank ids for this rank - target_rank_ids = torch.randint( - 0, - ep_size, - (rank_token_count, top_k), - dtype=torch.int32, - device=torch.device("cuda"), + target_rank_ids = paddle.randint( + low=0, high=ep_size, shape=(rank_token_count, top_k), dtype="int32" ) - if not use_real_rank_token_count_cumsum: - gathered_target_rank_ids = torch.zeros( - ep_size * max_token_count_per_rank, - top_k, - dtype=torch.int32, - device=torch.device("cuda"), + gathered_target_rank_ids = paddle.zeros( + shape=[ep_size * max_token_count_per_rank, top_k], dtype="int32" ) gathered_target_rank_ids[ ep_rank * max_token_count_per_rank : ep_rank * max_token_count_per_rank @@ -367,22 +303,22 @@ def generate_references(): ] = target_rank_ids else: total_tokens = real_rank_token_count_cumsum[-1].item() - gathered_target_rank_ids = torch.zeros( - total_tokens, top_k, dtype=torch.int32, device=torch.device("cuda") + gathered_target_rank_ids = paddle.zeros( + shape=[total_tokens, top_k], dtype="int32" ) start_pos = ( 0 if ep_rank == 0 else real_rank_token_count_cumsum[ep_rank - 1].item() ) - gathered_target_rank_ids[start_pos : start_pos + rank_token_count] = ( - target_rank_ids - ) - - return gathered_target_rank_ids, real_rank_token_count_cumsum, target_rank_ids - - gathered_target_rank_ids, real_rank_token_count_cumsum, target_rank_ids = ( - generate_references() - ) + gathered_target_rank_ids[ + start_pos : start_pos + rank_token_count + ] = target_rank_ids + return (gathered_target_rank_ids, real_rank_token_count_cumsum, target_rank_ids) + ( + gathered_target_rank_ids, + real_rank_token_count_cumsum, + target_rank_ids, + ) = generate_references() ( local_gather_indices, send_rank_count_cumsum, @@ -399,99 +335,65 @@ def generate_references(): ep_rank, ep_size, ) - - # Validate shapes - assert local_gather_indices.shape[0] <= max_token_count_per_rank * ep_size - assert send_rank_count_cumsum.shape[0] == ep_size - assert recv_rank_count_cumsum.shape[0] == ep_size - assert send_rank_local_indices.shape[0] <= max_token_count_per_rank * max( - ep_size, top_k - ) - assert recv_rank_local_indices.shape[0] <= max_token_count_per_rank * ep_size - assert backward_recv_rank_local_indices.shape[0] <= max_token_count_per_rank * max( + assert tuple(local_gather_indices.shape)[0] <= max_token_count_per_rank * ep_size + assert tuple(send_rank_count_cumsum.shape)[0] == ep_size + assert tuple(recv_rank_count_cumsum.shape)[0] == ep_size + assert tuple(send_rank_local_indices.shape)[0] <= max_token_count_per_rank * max( ep_size, top_k ) - - # Basic validation - cumulative sums should be non-decreasing - assert torch.all(send_rank_count_cumsum[1:] >= send_rank_count_cumsum[:-1]) - assert torch.all(recv_rank_count_cumsum[1:] >= recv_rank_count_cumsum[:-1]) + assert tuple(recv_rank_local_indices.shape)[0] <= max_token_count_per_rank * ep_size + assert tuple(backward_recv_rank_local_indices.shape)[ + 0 + ] <= max_token_count_per_rank * max(ep_size, top_k) + assert paddle.all(x=send_rank_count_cumsum[1:] >= send_rank_count_cumsum[:-1]) + assert paddle.all(x=recv_rank_count_cumsum[1:] >= recv_rank_count_cumsum[:-1]) @pytest.mark.parametrize( "ep_rank,ep_size,expert_count,top_k,max_token_count_per_rank", LOCAL_GATHER_PARAMS ) def test_moe_local_gather( - ep_rank, - ep_size, - expert_count, - top_k, - max_token_count_per_rank, + ep_rank, ep_size, expert_count, top_k, max_token_count_per_rank ): """Test MOE local gather functionality.""" - torch.cuda.set_device(0) - - # Generate test data using the original method - rank_token_count_cumsum = torch.randint( - 0, - max_token_count_per_rank + 1, - (ep_size,), - dtype=torch.int32, - device=torch.device("cuda"), + paddle.device.set_device(device="gpu:0") + rank_token_count_cumsum = paddle.randint( + low=0, high=max_token_count_per_rank + 1, shape=(ep_size,), dtype="int32" ) - rank_token_count_cumsum = torch.cumsum(rank_token_count_cumsum, dim=0).to( - torch.int32 + rank_token_count_cumsum = paddle.cumsum(x=rank_token_count_cumsum, axis=0).to( + "int32" ) local_token_count = rank_token_count_cumsum[ep_size - 1].cpu().item() local_max_token_count = max_token_count_per_rank * ep_size - local_gather_indices = torch.randint( - 0, - max_token_count_per_rank * ep_size, - (local_max_token_count,), - dtype=torch.int32, - device=torch.device("cuda"), + local_gather_indices = paddle.randint( + low=0, + high=max_token_count_per_rank * ep_size, + shape=(local_max_token_count,), + dtype="int32", ) - - gathered_expert_ids = torch.randint( - 0, - expert_count, - (max_token_count_per_rank * ep_size, top_k), - dtype=torch.int32, - device=torch.device("cuda"), + gathered_expert_ids = paddle.randint( + low=0, + high=expert_count, + shape=(max_token_count_per_rank * ep_size, top_k), + dtype="int32", ) - gathered_scales = torch.rand( - (max_token_count_per_rank * ep_size, top_k), - dtype=torch.float32, - device=torch.device("cuda"), + gathered_scales = paddle.rand( + shape=(max_token_count_per_rank * ep_size, top_k), dtype="float32" ) - - ref_local_expert_ids = torch.zeros( - local_max_token_count, top_k, dtype=torch.int32, device=torch.device("cuda") + ref_local_expert_ids = paddle.zeros( + shape=[local_max_token_count, top_k], dtype="int32" ) - ref_local_scales = torch.zeros( - local_max_token_count, - top_k, - dtype=torch.float32, - device=torch.device("cuda"), + ref_local_scales = paddle.zeros( + shape=[local_max_token_count, top_k], dtype="float32" ) - - # compute reference ref_local_expert_ids += expert_count valid_local_gather_indices = local_gather_indices[:local_token_count] ref_local_expert_ids[:local_token_count] = gathered_expert_ids[ valid_local_gather_indices ] ref_local_scales[:local_token_count] = gathered_scales[valid_local_gather_indices] - - local_expert_ids = torch.empty( - local_max_token_count, top_k, dtype=torch.int32, device=torch.device("cuda") - ) - local_scales = torch.empty( - local_max_token_count, - top_k, - dtype=torch.float32, - device=torch.device("cuda"), - ) - + local_expert_ids = paddle.empty(shape=[local_max_token_count, top_k], dtype="int32") + local_scales = paddle.empty(shape=[local_max_token_count, top_k], dtype="float32") tllm_alltoall.moe_local_gather( rank_token_count_cumsum, local_gather_indices, @@ -505,9 +407,8 @@ def test_moe_local_gather( ep_rank, ep_size, ) - - assert torch.equal(local_expert_ids, ref_local_expert_ids) - assert torch.equal(local_scales, ref_local_scales) + assert paddle.equal_all(x=local_expert_ids, y=ref_local_expert_ids).item() + assert paddle.equal_all(x=local_scales, y=ref_local_scales).item() if __name__ == "__main__": diff --git a/tests/test_trtllm_cutlass_fused_moe.py b/tests/test_trtllm_cutlass_fused_moe.py index d680e9eab3..ac685067f6 100644 --- a/tests/test_trtllm_cutlass_fused_moe.py +++ b/tests/test_trtllm_cutlass_fused_moe.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,41 +19,37 @@ See the License for the specific language governing permissions and limitations under the License. """ - import pytest -import torch -from torch.nn import functional as F import flashinfer.fused_moe as fused_moe -from flashinfer import ( - fp4_quantize, - mxfp4_dequantize, - mxfp4_quantize, - mxfp8_dequantize_host, - mxfp8_quantize, - mxfp4_dequantize_host, -) +from flashinfer import (fp4_quantize, mxfp4_dequantize, mxfp4_dequantize_host, + mxfp4_quantize, mxfp8_dequantize_host, mxfp8_quantize) FLOAT4_E2M1_MAX = 6.0 -FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max -FP8_DTYPE = torch.float8_e4m3fn +>>>>>>FLOAT8_E4M3_MAX = paddle.finfo(dtype=torch.float8_e4m3fn).max +>>>>>>FP8_DTYPE = torch.float8_e4m3fn -def dynamic_per_tensor_fp8_quant(x: torch.tensor) -> tuple[torch.tensor, torch.tensor]: +def dynamic_per_tensor_fp8_quant( + x: paddle.to_tensor, +) -> tuple[paddle.to_tensor, paddle.to_tensor]: fp8_traits_max = FLOAT8_E4M3_MAX fp8_traits_min = -FLOAT8_E4M3_MAX - fp8_max = torch.tensor(fp8_traits_max).float() - one = torch.tensor(1.0).float() - - x_max = x.abs().max().float() + fp8_max = paddle.to_tensor(data=fp8_traits_max).astype(dtype="float32") + one = paddle.to_tensor(data=1.0).astype(dtype="float32") + x_max = x.abs()._max().astype(dtype="float32") scale = x_max / fp8_max iscale = one / scale - out = (x.float() * iscale).clamp(fp8_traits_min, fp8_traits_max).to(FP8_DTYPE) + out = ( + (x.astype(dtype="float32") * iscale) + .clip(min=fp8_traits_min, max=fp8_traits_max) + .to(FP8_DTYPE) + ) return out, scale.view((1,)) def gen_tensor(shape, dtype, stype=None, scale=1.0): - x = torch.randn(*shape, dtype=dtype).cuda() * scale + x = paddle.randn(shape=shape, dtype=dtype).cuda() * scale return x.to(stype) if stype else x @@ -57,12 +59,12 @@ def cast_to_representable(x): return x -def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): +def convert_swizzled_to_linear(a_sf_swizzled: paddle.Tensor, m, k, block_size): m_tiles = (m + 128 - 1) // 128 f = block_size * 4 k_tiles = (k + f - 1) // f - tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) - tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + tmp = paddle.reshape(x=a_sf_swizzled, shape=(1, m_tiles, k_tiles, 32, 4, 4)) + tmp = paddle.transpose(x=tmp, perm=(0, 1, 4, 3, 2, 5)) out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) return out[0:m, 0:k] @@ -71,51 +73,38 @@ def dequantize_nvfp4_to_dtype( tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 ): """Dequantize the fp4 tensor back to high precision.""" - # Two fp4 values are packed into one uint8. - assert tensor_fp4.dtype == torch.uint8 - m, packed_k = tensor_fp4.shape + assert tensor_fp4.dtype == "uint8" + m, packed_k = tuple(tensor_fp4.shape) k = packed_k * 2 tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) - tensor_sf = tensor_sf.view(torch.float8_e4m3fn) +>>>>>> tensor_sf = tensor_sf.view(torch.float8_e4m3fn) tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) - tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale - - # scale the tensor - out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + tensor_sf_dtype = tensor_sf.to("float32") / global_scale + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(axis=-1)).reshape(m, k) return out.to(dtype=dtype) def break_fp4_bytes(a, dtype): - assert a.dtype == torch.uint8 - m, n = a.shape - - # Vectorized nibble processing + assert a.dtype == "uint8" + m, n = tuple(a.shape) a_flat = a.flatten() - high = (a_flat & 0xF0) >> 4 # Upper nibbles - low = a_flat & 0x0F # Lower nibbles - - # Combine nibbles for batch processing - combined = torch.stack((low, high), dim=1).flatten() - - # Vectorized sign and magnitude extraction - signs = (combined & 0x08).to(torch.bool) # Sign bits - abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices - - # Device-aware lookup and sign application - kE2M1ToFloat = torch.tensor( - [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 + high = (a_flat & 240) >> 4 + low = a_flat & 15 + combined = paddle.stack(x=(low, high), axis=1).flatten() + signs = (combined & 8).to("bool") + abs_vals = (combined & 7).to("int64") + kE2M1ToFloat = paddle.to_tensor( + data=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype="float32" ) - kE2M1 = kE2M1ToFloat.to(device=a.device) - values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) - - # Reshape to final form + kE2M1 = kE2M1ToFloat.to(device=a.place) + values = kE2M1[abs_vals] * paddle.where(condition=signs, x=-1.0, y=1.0) return values.reshape(m, n * 2).to(dtype=dtype) def compute_routing( - router_logits: torch.Tensor, top_k: int -) -> tuple[torch.Tensor, torch.Tensor]: + router_logits: paddle.Tensor, top_k: int +) -> tuple[paddle.Tensor, paddle.Tensor]: """ Compute routing weights and selected experts from router logits. @@ -128,44 +117,44 @@ def compute_routing( - routing_weights: Expert weights of shape [batch_size, top_k] - selected_experts: Expert indices of shape [batch_size, top_k] """ - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.float() + routing_weights = paddle.nn.functional.softmax( + x=router_logits, axis=1, dtype="float32" + ) + routing_weights, selected_experts = paddle.topk(x=routing_weights, k=top_k, axis=-1) + routing_weights /= routing_weights.sum(axis=-1, keepdim=True) + routing_weights = routing_weights.astype(dtype="float32") return routing_weights, selected_experts def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids): - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - # score = torch.softmax(score, dim=-1, dtype=torch.float32) - # topk_weight, topk_ids = torch.topk(score, topk) + B, D = tuple(a.shape) + a = a.view(B, -1, D).tile(repeat_times=[1, topk, 1]).reshape(-1, D) + out = paddle.zeros(shape=[B * topk, tuple(w2.shape)[1]], dtype=a.dtype) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) - # w1 needs to be swapped in terms of gate and up_proj - - for i in range(w1.shape[0]): + for i in range(tuple(w1.shape)[0]): mask = topk_ids == i if mask.sum(): - m = w1[i].shape[0] + m = tuple(w1[i].shape)[0] assert m % 2 == 0 w1_expert, w3_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :] - inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t()) - inter_gs = torch.tensor(1.0).cuda() + inter = paddle.nn.functional.silu(x=a[mask] @ w1_expert.t()) * ( + a[mask] @ w3_expert.t() + ) + inter_gs = paddle.to_tensor(data=1.0).cuda() inter_q, inter_blockscale = fp4_quantize(inter, inter_gs) inter = dequantize_nvfp4_to_dtype( inter_q, inter_blockscale, inter_gs, dtype=inter.dtype, - device=inter.device, + device=inter.place, block_size=16, ).cuda() - out[mask] = inter @ w2[i].transpose(0, 1) + out[mask] = inter @ w2[i].transpose(perm=dim2perm(w2[i].ndim, 0, 1)) return ( - out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) - ).sum(dim=1) + out.view(B, -1, tuple(w2.shape)[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(axis=1) def compute_with_experts( @@ -179,49 +168,37 @@ def compute_with_experts( beta=None, limit=None, ): - results = torch.zeros_like(x) + results = paddle.zeros_like(x=x) for expert_id in range(num_experts): mask = selected_experts == expert_id if not mask.sum(): continue - batch_idx, nth_expert = torch.where(mask) - w31_expert = w31_weight[expert_id] # [2 * intermediate_size, hidden_size] - w2_expert = w2_weight[expert_id] # [hidden_size, intermediate_size] - - # Split w13 into w1 and w3 - w3_expert, w1_expert = torch.chunk(w31_expert, 2, dim=0) - + batch_idx, nth_expert = paddle.where(condition=mask) + w31_expert = w31_weight[expert_id] + w2_expert = w2_weight[expert_id] + w3_expert, w1_expert = paddle.chunk(x=w31_expert, chunks=2, axis=0) expert_inputs = x[batch_idx] if alpha is not None and limit is not None and beta is not None: - # SwiGLUBias x1 = expert_inputs @ w1_expert.t() - x1 = x1.clamp_(min=None, max=limit) - x1_scaled = x1 * torch.sigmoid(alpha * x1) + x1 = x1.clip_(min=None, max=limit) + x1_scaled = x1 * paddle.nn.functional.sigmoid(x=alpha * x1) x2 = expert_inputs @ w3_expert.t() - x2 = x2.clamp_(min=-limit, max=limit) + beta - + x2 = x2.clip_(min=-limit, max=limit) + beta inter = x1_scaled * x2 else: - inter = F.silu(expert_inputs @ w1_expert.t()) * ( + inter = paddle.nn.functional.silu(x=expert_inputs @ w1_expert.t()) * ( expert_inputs @ w3_expert.t() ) output = inter @ w2_expert.t() results[batch_idx] += routing_weights[batch_idx, nth_expert, None] * output - return results.view_as(x) + return results.view_as(other=x) -# Test configurations -BATCH_SIZES = [ - 1, -] -HIDDEN_SIZES = [ - 128, -] +BATCH_SIZES = [1] +HIDDEN_SIZES = [128] NUM_EXPERTS = [2] TOP_K_VALUES = [2] -INTERMEDIATE_SIZES = [ - 128, -] +INTERMEDIATE_SIZES = [128] EP_NUM_EXPERTS = [8] EP_TOP_K = [2] @@ -232,36 +209,35 @@ def compute_with_experts( @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) def test_moe(batch_size, hidden_size, num_experts, top_k, intermediate_size): - # Skip invalid configurations if top_k > num_experts: pytest.skip( f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" ) - - torch.manual_seed(42) - x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() / 5 - router_logits = torch.randn(batch_size, num_experts, dtype=torch.float32).cuda() + paddle.seed(seed=42) + x = paddle.randn(shape=[batch_size, hidden_size], dtype="float16").cuda() / 5 + router_logits = paddle.randn( + shape=[batch_size, num_experts], dtype="float32" + ).cuda() w31_weight = ( - torch.randn( - num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 + paddle.randn( + shape=[num_experts, 2 * intermediate_size, hidden_size], dtype="float16" ).cuda() / 5 ) w2_weight = ( - torch.randn( - num_experts, hidden_size, intermediate_size, dtype=torch.float16 + paddle.randn( + shape=[num_experts, hidden_size, intermediate_size], dtype="float16" ).cuda() / 5 ) - routing_weights, selected_experts = compute_routing(router_logits, top_k) ref_output = compute_with_experts( num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights ) - flash_output = torch.empty_like(ref_output) + flash_output = paddle.empty_like(x=ref_output) flash_output = fused_moe.cutlass_fused_moe( x, - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, w31_weight, w2_weight, @@ -269,8 +245,9 @@ def test_moe(batch_size, hidden_size, num_experts, top_k, intermediate_size): output=flash_output, quant_scales=None, ) - - torch.testing.assert_close(ref_output, flash_output[0], rtol=1e-2, atol=1e-2) + assert paddle.allclose( + x=ref_output, y=flash_output[0], rtol=0.01, atol=0.01 + ).item(), "" @pytest.mark.parametrize("batch_size", BATCH_SIZES) @@ -278,45 +255,43 @@ def test_moe(batch_size, hidden_size, num_experts, top_k, intermediate_size): @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) -@pytest.mark.parametrize("otype, wtype", [(torch.float16, torch.float8_e4m3fn)]) +>>>>>>@pytest.mark.parametrize("otype, wtype", [("float16", torch.float8_e4m3fn)]) def test_moe_fp8( batch_size, hidden_size, num_experts, top_k, intermediate_size, otype, wtype ): - # Skip invalid configurations if top_k > num_experts: pytest.skip( f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" ) - - torch.manual_seed(42) - input_shape = (batch_size, hidden_size) - w31_shape = (num_experts, 2 * intermediate_size, hidden_size) - w2_shape = (num_experts, hidden_size, intermediate_size) + paddle.seed(seed=42) + input_shape = batch_size, hidden_size + w31_shape = num_experts, 2 * intermediate_size, hidden_size + w2_shape = num_experts, hidden_size, intermediate_size x = cast_to_representable(gen_tensor(input_shape, otype)) router_logits = gen_tensor((batch_size, num_experts), otype) - - # Create weight tensors w31_weight = gen_tensor(w31_shape, otype, wtype) w2_weight = gen_tensor(w2_shape, otype, wtype) - w31_scales = torch.empty(num_experts, 2, dtype=otype).cuda() - w2_scales = torch.empty(num_experts, 1, dtype=otype).cuda() - + w31_scales = paddle.empty(shape=[num_experts, 2], dtype=otype).cuda() + w2_scales = paddle.empty(shape=[num_experts, 1], dtype=otype).cuda() w31_dequantized = gen_tensor(w31_shape, otype) w2_dequantized = gen_tensor(w2_shape, otype) for expert_id in range(num_experts): w31 = cast_to_representable(gen_tensor(w31_shape[1:], otype, scale=0.1)) w2 = cast_to_representable(gen_tensor(w2_shape[1:], otype, scale=0.09)) - w31_quant, s31 = dynamic_per_tensor_fp8_quant(w31) w2_quant, s2 = dynamic_per_tensor_fp8_quant(w2) - - w31_weight.data[expert_id].copy_(w31_quant) - w2_weight.data[expert_id].copy_(w2_quant) - w31_scales.data[expert_id].copy_(s31) - w2_scales.data[expert_id].copy_(s2) - w31_dequantized.data[expert_id].copy_(torch.mul(w31_quant.to(dtype=otype), s31)) - w2_dequantized.data[expert_id].copy_(torch.mul(w2_quant.to(dtype=otype), s2)) - + paddle.assign(w31_quant, output=w31_weight.data[expert_id]) + paddle.assign(w2_quant, output=w2_weight.data[expert_id]) + paddle.assign(s31, output=w31_scales.data[expert_id]) + paddle.assign(s2, output=w2_scales.data[expert_id]) + paddle.assign( + paddle.multiply(x=w31_quant.to(dtype=otype), y=paddle.to_tensor(s31)), + output=w31_dequantized.data[expert_id], + ) + paddle.assign( + paddle.multiply(x=w2_quant.to(dtype=otype), y=paddle.to_tensor(s2)), + output=w2_dequantized.data[expert_id], + ) routing_weights, selected_experts = compute_routing(router_logits, top_k) ref_output = compute_with_experts( num_experts, @@ -326,21 +301,19 @@ def test_moe_fp8( selected_experts, routing_weights, ) - flash_output = torch.empty_like(ref_output) - # For fp8, the hidden_state expects quantized. - _, w1_scales = torch.chunk(w31_scales, 2, dim=-1) + flash_output = paddle.empty_like(x=ref_output) + _, w1_scales = paddle.chunk(x=w31_scales, chunks=2, axis=-1) x_quant, hidden_states_scale = dynamic_per_tensor_fp8_quant(x) - hidden_states_scale = torch.tensor(hidden_states_scale[0]).cuda() + hidden_states_scale = paddle.to_tensor(data=hidden_states_scale[0]).cuda() quant_scales = [ - torch.squeeze(w1_scales * hidden_states_scale).float(), - torch.tensor(1.0).cuda(), - torch.squeeze(1.0 * w2_scales).float(), + paddle.squeeze(x=w1_scales * hidden_states_scale).astype(dtype="float32"), + paddle.to_tensor(data=1.0).cuda(), + paddle.squeeze(x=1.0 * w2_scales).astype(dtype="float32"), hidden_states_scale, ] - _ = fused_moe.cutlass_fused_moe( x_quant, - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, w31_weight, w2_weight, @@ -348,7 +321,7 @@ def test_moe_fp8( quant_scales=quant_scales, output=flash_output, ) - torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) + assert paddle.allclose(x=ref_output, y=flash_output, rtol=0.1, atol=0.1).item(), "" @pytest.mark.parametrize("batch_size", BATCH_SIZES) @@ -358,11 +331,11 @@ def test_moe_fp8( @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.parametrize( "otype, wtype", - [(torch.float16, torch.float8_e4m3fn), (torch.bfloat16, torch.float8_e4m3fn)], +>>>>>> [("float16", torch.float8_e4m3fn), ("bfloat16", torch.float8_e4m3fn)], ) @pytest.mark.parametrize("quantized_input", [False, True]) @pytest.mark.skipif( - torch.cuda.get_device_capability()[0] != 10, + paddle.device.cuda.get_device_capability()[0] != 10, reason="NVFP4 is only supported on SM100", ) def test_moe_nvfp4( @@ -375,82 +348,61 @@ def test_moe_nvfp4( wtype, quantized_input, ): - # Skip invalid configurations if top_k > num_experts: pytest.skip( f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" ) - - torch.manual_seed(42) + paddle.seed(seed=42) quant_blocksize = 16 round_up = lambda x, y: (x + y - 1) // y * y e = num_experts m = batch_size n = intermediate_size k = hidden_size - - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10 - w1_cutlass = torch.cat((w1[:, n:, :], w1[:, :n, :]), dim=1).contiguous() - + w1 = paddle.randn(shape=(e, 2 * n, k), dtype=otype) / 10 + w1_cutlass = paddle.concat(x=(w1[:, n:, :], w1[:, :n, :]), axis=1).contiguous() sf_w1_2n = round_up(2 * n, 128) sf_w1_k = round_up(k // quant_blocksize, 4) - w1_blockscale = torch.empty( - (e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + w1_blockscale = paddle.empty( +>>>>>> shape=(e, sf_w1_2n, sf_w1_k), dtype=torch.float8_e4m3fn ) - w1_blockscale_cutlass = torch.empty( - (e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + w1_blockscale_cutlass = paddle.empty( +>>>>>> shape=(e, sf_w1_2n, sf_w1_k), dtype=torch.float8_e4m3fn ) - - w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10 + w2 = paddle.randn(shape=(e, k, n), dtype=otype) / 10 sf_w2_k = round_up(k, 128) sf_w2_n = round_up(n // quant_blocksize, 4) - w2_blockscale = torch.empty( - (e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn - ) - w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) - w1_q_cutlass = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) - w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) - w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32) - w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32) - +>>>>>> w2_blockscale = paddle.empty(shape=(e, sf_w2_k, sf_w2_n), dtype=torch.float8_e4m3fn) + w1_q = paddle.empty(shape=(e, 2 * n, k // 2), dtype="uint8") + w1_q_cutlass = paddle.empty(shape=(e, 2 * n, k // 2), dtype="uint8") + w2_q = paddle.empty(shape=(e, k, n // 2), dtype="uint8") + w1_gs = paddle.empty(shape=(e,), dtype="float32") + w2_gs = paddle.empty(shape=(e,), dtype="float32") for expert in range(e): - w1_amax = torch.abs(w1).max().to(torch.float32) - w2_amax = torch.abs(w2).max().to(torch.float32) + w1_amax = paddle.abs(x=w1)._max().to("float32") + w2_amax = paddle.abs(x=w2)._max().to("float32") w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax - w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert]) - w1_q_cutlass[expert], w1_blockscale_cutlass[expert] = fp4_quantize( w1_cutlass[expert], w1_gs[expert] ) - w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert]) - - x = torch.randn(m, k, dtype=otype).cuda() - a1_gs = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(x).max().to( - torch.float32 - ).cuda() - a1_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) - a2_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) - router_logits = torch.randn(m, e, dtype=otype).cuda() + x = paddle.randn(shape=[m, k], dtype=otype).cuda() + a1_gs = ( + FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / paddle.abs(x=x)._max().to("float32").cuda() + ) + a1_gs = paddle.to_tensor(data=1.0, dtype="float32", place="gpu") + a2_gs = paddle.to_tensor(data=1.0, dtype="float32", place="gpu") + router_logits = paddle.randn(shape=[m, e], dtype=otype).cuda() routing_weights, selected_experts = compute_routing(router_logits, top_k) - - # quant_scales format - # auto const fc1_act_global = quant_scales.value()[0]; - # auto const fc1_weight_block = quant_scales.value()[1]; - # auto const fc1_global = quant_scales.value()[2]; - # auto const fc2_act_global = quant_scales.value()[3]; - # auto const fc2_weight_block = quant_scales.value()[4]; - # auto const fc2_global = quant_scales.value()[5]; - flash_output = torch.zeros_like(x) - + flash_output = paddle.zeros_like(x=x) quant_scales = [ a1_gs, - w1_blockscale.view(torch.int32), + w1_blockscale.view("int32"), 1.0 / (a1_gs * w1_gs), a2_gs, - w2_blockscale.view(torch.int32), + w2_blockscale.view("int32"), 1.0 / (a2_gs * w2_gs), ] hidden_states = x @@ -459,38 +411,34 @@ def test_moe_nvfp4( hidden_states, input_sf = fp4_quantize(x, a1_gs) _ = fused_moe.cutlass_fused_moe( hidden_states, - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, - w1_q.contiguous().view(torch.long), - w2_q.contiguous().view(torch.long), + w1_q.contiguous().view("int64"), + w2_q.contiguous().view("int64"), otype, quant_scales=quant_scales, input_sf=input_sf, output=flash_output, ) - - # Ref check a_fp4, a_scale_interleaved = fp4_quantize(x, a1_gs) - _, m_k = a_fp4.shape + _, m_k = tuple(a_fp4.shape) a_in_dtype = dequantize_nvfp4_to_dtype( a_fp4, a_scale_interleaved, a1_gs, dtype=otype, - device=x.device, + device=x.place, block_size=quant_blocksize, ) - - w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=otype) - w2_d = torch.empty((e, k, n), device="cuda", dtype=otype) - + w1_d = paddle.empty(shape=(e, 2 * n, k), dtype=otype) + w2_d = paddle.empty(shape=(e, k, n), dtype=otype) for idx in range(0, e): w1_d[idx] = dequantize_nvfp4_to_dtype( w1_q[idx], w1_blockscale[idx], w1_gs[idx], dtype=w1.dtype, - device=w1.device, + device=w1.place, block_size=quant_blocksize, ) w2_d[idx] = dequantize_nvfp4_to_dtype( @@ -498,18 +446,19 @@ def test_moe_nvfp4( w2_blockscale[idx], w2_gs[idx], dtype=w2.dtype, - device=w2.device, + device=w2.place, block_size=quant_blocksize, ) - - w1_q_cutlass = torch.cat((w1_q[:, n:, :], w1_q[:, :n, :]), dim=1).contiguous() - w1_blockscale_cutlass = torch.cat( - (w1_blockscale[:, n:, :], w1_blockscale[:, :n, :]), dim=1 + w1_q_cutlass = paddle.concat( + x=(w1_q[:, n:, :], w1_q[:, :n, :]), axis=1 + ).contiguous() + w1_blockscale_cutlass = paddle.concat( + x=(w1_blockscale[:, n:, :], w1_blockscale[:, :n, :]), axis=1 ).contiguous() ref_output = torch_moe_nvfp4( a_in_dtype, w1_d, w2_d, top_k, routing_weights, selected_experts ) - torch.testing.assert_close(ref_output, flash_output, rtol=2e-1, atol=2e-1) + assert paddle.allclose(x=ref_output, y=flash_output, rtol=0.2, atol=0.2).item(), "" @pytest.mark.parametrize("batch_size", BATCH_SIZES) @@ -532,62 +481,41 @@ def test_moe_expert_parallel( intermediate_size: Intermediate dimension size activation: Activation function type """ - # This test is specifically for 2 GPUs and 2 experts - # GPU 0 (ep_rank=0) handles expert 0 - # GPU 1 (ep_rank=1) handles expert 1 ep_size = num_experts // 2 - torch.manual_seed(42) - - # Create input tensors - x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() - - # Create weight tensors - each GPU will have one expert + paddle.seed(seed=42) + x = paddle.randn(shape=[batch_size, hidden_size], dtype="float16").cuda() w31_weight = ( - torch.randn( - num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 + paddle.randn( + shape=[num_experts, 2 * intermediate_size, hidden_size], dtype="float16" ).cuda() / 10 ) w2_weight = ( - torch.randn( - num_experts, hidden_size, intermediate_size, dtype=torch.float16 + paddle.randn( + shape=[num_experts, hidden_size, intermediate_size], dtype="float16" ).cuda() / 10 ) - - selected_experts = torch.stack( - [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] + selected_experts = paddle.stack( + x=[paddle.randperm(n=num_experts)[:top_k] for _ in range(batch_size)] ).cuda() - - routing_weights = torch.randn((batch_size, top_k)).cuda() - routing_weights = F.softmax(routing_weights, dim=1) + routing_weights = paddle.randn(shape=(batch_size, top_k)).cuda() + routing_weights = paddle.nn.functional.softmax(x=routing_weights, axis=1) ref_output = compute_with_experts( num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights ) - outputs = [] - flash_output = torch.zeros_like(ref_output) + flash_output = paddle.zeros_like(x=ref_output) for ep_rank in range(ep_size): - # Create output tensor for this GPU - out_hidden_states_local = torch.zeros_like(x) - - # Compute expert start and end positions for this rank - experts_per_rank = ( - num_experts // ep_size - ) # 2 GPUs, so each gets half the experts + out_hidden_states_local = paddle.zeros_like(x=x) + experts_per_rank = num_experts // ep_size expert_start = ep_rank * experts_per_rank - expert_end = expert_start + experts_per_rank # if ep_rank < 1 else num_experts - - w31_weight_local = w31_weight[ - expert_start:expert_end, : - ] # Get only the experts for this rank - w2_weight_local = w2_weight[ - expert_start:expert_end, : - ] # Get only the experts for this rank - + expert_end = expert_start + experts_per_rank + w31_weight_local = w31_weight[expert_start:expert_end, :] + w2_weight_local = w2_weight[expert_start:expert_end, :] _ = fused_moe.cutlass_fused_moe( x.contiguous(), - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, w31_weight_local.contiguous(), w2_weight_local.contiguous(), @@ -598,11 +526,9 @@ def test_moe_expert_parallel( output=out_hidden_states_local, ) outputs.append(out_hidden_states_local) - - # Reduce results from all GPUs for ep_rank in range(ep_size): - flash_output += outputs[ep_rank] # [batch_size, num_experts] - torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) + flash_output += outputs[ep_rank] + assert paddle.allclose(x=ref_output, y=flash_output, rtol=0.1, atol=0.1).item(), "" TP_SIZES = [2, 4] @@ -630,74 +556,49 @@ def test_moe_tensor_parallel( intermediate_size: Intermediate dimension size activation: Activation function type """ - # Set random seed for reproducibility - torch.manual_seed(42) + paddle.seed(seed=42) top_k = 2 - # Create input tensors - x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() - - # Create weight tensors + x = paddle.randn(shape=[batch_size, hidden_size], dtype="float16").cuda() w31_weight = ( - torch.randn( - num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 + paddle.randn( + shape=[num_experts, 2 * intermediate_size, hidden_size], dtype="float16" ).cuda() / 10 ) w2_weight = ( - torch.randn( - num_experts, hidden_size, intermediate_size, dtype=torch.float16 + paddle.randn( + shape=[num_experts, hidden_size, intermediate_size], dtype="float16" ).cuda() / 10 ) - - # Generate unique random expert indices for each token - selected_experts = torch.stack( - [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] + selected_experts = paddle.stack( + x=[paddle.randperm(n=num_experts)[:top_k] for _ in range(batch_size)] ).cuda() - - routing_weights = torch.randn((batch_size, top_k)).cuda() - routing_weights = F.softmax(routing_weights, dim=1) - - # Run reference implementation (no parallelism) + routing_weights = paddle.randn(shape=(batch_size, top_k)).cuda() + routing_weights = paddle.nn.functional.softmax(x=routing_weights, axis=1) ref_output = compute_with_experts( num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights ) - - # Simulate tensor parallelism on # TP GPUs outputs = [] for tp_rank in range(tp_size): - # Create output tensor for this GPU - out_hidden_states_local = torch.zeros_like(x) - - # Shard w31 along second dimension (intermediate_size) - # First split w31 into w3 and w1 - w3_weight, w1_weight = torch.chunk( - w31_weight, 2, dim=1 - ) # [num_experts, intermediate_size, hidden_size] each - - # Shard w3 and w1 separately + out_hidden_states_local = paddle.zeros_like(x=x) + w3_weight, w1_weight = paddle.chunk(x=w31_weight, chunks=2, axis=1) w3_shard_size = intermediate_size // tp_size w3_start = tp_rank * w3_shard_size w3_end = w3_start + w3_shard_size w3_weight_local = w3_weight[:, w3_start:w3_end, :] - w1_shard_size = intermediate_size // tp_size w1_start = tp_rank * w1_shard_size w1_end = w1_start + w1_shard_size w1_weight_local = w1_weight[:, w1_start:w1_end, :] - - # Stack the sharded weights back together - w31_weight_local = torch.cat([w3_weight_local, w1_weight_local], dim=1) - - # Shard w2 along third dimension (intermediate_size) + w31_weight_local = paddle.concat(x=[w3_weight_local, w1_weight_local], axis=1) w2_shard_size = intermediate_size // tp_size w2_start = tp_rank * w2_shard_size w2_end = w2_start + w2_shard_size w2_weight_local = w2_weight[:, :, w2_start:w2_end] - _ = fused_moe.cutlass_fused_moe( x.contiguous(), - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, w31_weight_local.contiguous(), w2_weight_local.contiguous(), @@ -708,10 +609,10 @@ def test_moe_tensor_parallel( output=out_hidden_states_local, ) outputs.append(out_hidden_states_local) - - # All-reduce to sum partial results from all GPUs flash_output = sum(outputs) - torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-2) + assert paddle.allclose( + x=ref_output, y=flash_output, rtol=0.01, atol=0.01 + ).item(), "" @pytest.mark.parametrize("batch_size", BATCH_SIZES) @@ -738,85 +639,57 @@ def test_moe_tensor_expert_parallel( tp_size: Number of GPUs for tensor parallelism intermediate_size: Intermediate dimension size """ - torch.manual_seed(42) - x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() + paddle.seed(seed=42) + x = paddle.randn(shape=[batch_size, hidden_size], dtype="float16").cuda() w31_weight = ( - torch.randn( - num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16 + paddle.randn( + shape=[num_experts, 2 * intermediate_size, hidden_size], dtype="float16" ).cuda() / 10 ) w2_weight = ( - torch.randn( - num_experts, hidden_size, intermediate_size, dtype=torch.float16 + paddle.randn( + shape=[num_experts, hidden_size, intermediate_size], dtype="float16" ).cuda() / 10 ) - - # Generate unique random expert indices for each token - selected_experts = torch.stack( - [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] + selected_experts = paddle.stack( + x=[paddle.randperm(n=num_experts)[:top_k] for _ in range(batch_size)] ).cuda() - - routing_weights = torch.randn((batch_size, top_k)).cuda() - routing_weights = F.softmax(routing_weights, dim=1) - - # Run reference implementation (no parallelism) + routing_weights = paddle.randn(shape=(batch_size, top_k)).cuda() + routing_weights = paddle.nn.functional.softmax(x=routing_weights, axis=1) ref_output = compute_with_experts( num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights ) - - # Simulate combined parallelism - ep_size = num_experts // 2 # Number of GPUs for expert parallelism + ep_size = num_experts // 2 outputs = [] - - # For each expert parallel rank for ep_rank in range(ep_size): - # Get experts for this rank experts_per_rank = num_experts // ep_size expert_start = ep_rank * experts_per_rank expert_end = expert_start + experts_per_rank - - # Get expert weights for this rank - w31_weight_ep = w31_weight[ - expert_start:expert_end, : - ] # [experts_per_rank, 2*intermediate_size, hidden_size] - w2_weight_ep = w2_weight[ - expert_start:expert_end, : - ] # [experts_per_rank, hidden_size, intermediate_size] - - # For each tensor parallel rank + w31_weight_ep = w31_weight[expert_start:expert_end, :] + w2_weight_ep = w2_weight[expert_start:expert_end, :] for tp_rank in range(tp_size): - # Create output tensor for this GPU - out_hidden_states_local = torch.zeros_like(x) - - # Split w31 into w3 and w1 - w3_weight, w1_weight = torch.chunk(w31_weight_ep, 2, dim=1) - - # Shard w3 and w1 separately + out_hidden_states_local = paddle.zeros_like(x=x) + w3_weight, w1_weight = paddle.chunk(x=w31_weight_ep, chunks=2, axis=1) w3_shard_size = intermediate_size // tp_size w3_start = tp_rank * w3_shard_size w3_end = w3_start + w3_shard_size w3_weight_local = w3_weight[:, w3_start:w3_end, :] - w1_shard_size = intermediate_size // tp_size w1_start = tp_rank * w1_shard_size w1_end = w1_start + w1_shard_size w1_weight_local = w1_weight[:, w1_start:w1_end, :] - - # Stack the sharded weights back together - w31_weight_local = torch.cat([w3_weight_local, w1_weight_local], dim=1) - - # Shard w2 along third dimension + w31_weight_local = paddle.concat( + x=[w3_weight_local, w1_weight_local], axis=1 + ) w2_shard_size = intermediate_size // tp_size w2_start = tp_rank * w2_shard_size w2_end = w2_start + w2_shard_size w2_weight_local = w2_weight_ep[:, :, w2_start:w2_end] - - # Call flashinfer implementation with both parallelisms out_hidden_states_local = fused_moe.cutlass_fused_moe( x.contiguous(), - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, w31_weight_local.contiguous(), w2_weight_local.contiguous(), @@ -828,10 +701,10 @@ def test_moe_tensor_expert_parallel( quant_scales=None, ) outputs.append(out_hidden_states_local[0]) - - # All-reduce to sum partial results from all GPUs flash_output = sum(outputs) - torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-2) + assert paddle.allclose( + x=ref_output, y=flash_output, rtol=0.01, atol=0.01 + ).item(), "" def ceil_div(a: int, b: int) -> int: @@ -839,52 +712,57 @@ def ceil_div(a: int, b: int) -> int: def per_block_cast_to_fp8( - x: torch.Tensor, block_size_n: int = 128 -) -> tuple[torch.Tensor, torch.Tensor]: + x: paddle.Tensor, block_size_n: int = 128 +) -> tuple[paddle.Tensor, paddle.Tensor]: assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - (ceil_div(m, 128) * 128, ceil_div(n, block_size_n) * block_size_n), + m, n = tuple(x.shape) + x_padded = paddle.zeros( + shape=(ceil_div(m, 128) * 128, ceil_div(n, block_size_n) * block_size_n), dtype=x.dtype, - device=x.device, ) x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() - scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + x_view = x_padded.view(-1, 128, x_padded.shape[1] // 128, block_size_n) + x_amax = ( + x_view.abs() + .astype(dtype="float32") + .amax(axis=(1, 3), keepdim=True) + .clip(min=0.0001) + ) +>>>>>> x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + x_scaled_sub = x_scaled.view_as(other=x_padded)[:m, :n].contiguous() + scales = (x_amax / 448.0).view(x_view.shape[0], x_view.shape[2]) return x_scaled_sub, scales -def per_token_group_quant_fp8(x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn): +>>>>>>def per_token_group_quant_fp8(x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn): """Function to perform per-token-group quantization on an input tensor `x` using native torch.""" - assert x.shape[-1] % group_size == 0, ( - "the last dimension of `x` cannot be divisible by `group_size`" - ) + assert ( + tuple(x.shape)[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - - finfo = torch.finfo(dtype) + finfo = paddle.finfo(dtype=dtype) fp8_min = finfo.min fp8_max = finfo.max - - x_ = x.reshape(x.numel() // group_size, group_size) - amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_ = x.reshape(x.size // group_size, group_size) + amax = ( + (x_.abs().max(keepdim=True, axis=-1), x_.abs().argmax(keepdim=True, axis=-1))[0] + .clip(min=eps) + .to("float32") + ) x_s = amax / fp8_max - x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) - x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) - + x_q = (x_ / x_s).clip(min=fp8_min, max=fp8_max).to(dtype) + x_q = x_q.reshape(tuple(x.shape)) + x_s = x_s.reshape(tuple(x.shape)[:-1] + (tuple(x.shape)[-1] // group_size,)) return x_q, x_s def dequantize_block( - x_quant: torch.Tensor, - scales: torch.Tensor, - dtype: torch.dtype, + x_quant: paddle.Tensor, + scales: paddle.Tensor, + dtype: paddle.dtype, original_shape: tuple, -) -> torch.Tensor: +) -> paddle.Tensor: """ Dequantize a block-quantized tensor. @@ -897,38 +775,31 @@ def dequantize_block( Returns: torch.Tensor: Dequantized tensor """ - # Reshape scales to match block structure - def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor: - # Move target dim to last position if not already last + def transform_dim(a: paddle.Tensor, dim: int = -1) -> paddle.Tensor: if dim != -1: - a = a.transpose(dim, -1) - # Broadcast and reshape - a_broadcasted = a.unsqueeze(-1).expand(*a.shape, 128) - a_reshaped = a_broadcasted.reshape(*a.shape[:-1], a.shape[-1] * 128) - # Move back if needed + a = a.transpose(perm=dim2perm(a.ndim, dim, -1)) + a_broadcasted = a.unsqueeze(axis=-1).expand(shape=[*tuple(a.shape), 128]) + a_reshaped = a_broadcasted.reshape( + *tuple(a.shape)[:-1], tuple(a.shape)[-1] * 128 + ) if dim != -1: - a_reshaped = a_reshaped.transpose(dim, -1) + a_reshaped = a_reshaped.transpose(perm=dim2perm(a_reshaped.ndim, dim, -1)) return a_reshaped - if x_quant.dim() == 2: # For activation tensors [batch_size, hidden_size] - batch_size, hidden_size = x_quant.shape + if x_quant.dim() == 2: + batch_size, hidden_size = tuple(x_quant.shape) num_blocks = (hidden_size + 127) // 128 - scales = scales.view(batch_size, num_blocks, 1).expand(-1, -1, 128) + scales = scales.view(batch_size, num_blocks, 1).expand(shape=[-1, -1, 128]) scales = scales[:, :, : hidden_size % 128] if hidden_size % 128 != 0 else scales - else: # For weight tensors [..., in_dim, out_dim] - *_dims, in_dim, out_dim = x_quant.shape - - # Transform both dimensions - scales = transform_dim(scales, -1) # Last dim - scales = transform_dim(scales, -2) # Second-to-last dim - - # Handle padding + else: + *_dims, in_dim, out_dim = tuple(x_quant.shape) + scales = transform_dim(scales, -1) + scales = transform_dim(scales, -2) if in_dim % 128 != 0: scales = scales[..., : in_dim % 128, :] if out_dim % 128 != 0: scales = scales[..., :, : out_dim % 128] - x_dequant = x_quant.to(dtype) * scales.to(dtype) return x_dequant.view(original_shape) @@ -939,7 +810,7 @@ def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor: @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.skipif( - torch.cuda.get_device_capability()[0] != 10, + paddle.device.cuda.get_device_capability()[0] != 10, reason="FP8 block scaling is only supported on SM100", ) def test_moe_fp8_block_scaling( @@ -959,68 +830,64 @@ def test_moe_fp8_block_scaling( intermediate_size: Intermediate dimension size Only support bf16 for hidden_states """ - torch.manual_seed(42) - otype = torch.bfloat16 - - x = torch.randn(batch_size, hidden_size, dtype=otype).cuda() - + paddle.seed(seed=42) + otype = "bfloat16" + x = paddle.randn(shape=[batch_size, hidden_size], dtype=otype).cuda() w31_weight = ( - torch.randn(num_experts, 2 * intermediate_size, hidden_size, dtype=otype).cuda() + paddle.randn( + shape=[num_experts, 2 * intermediate_size, hidden_size], dtype=otype + ).cuda() / 10 ) w2_weight = ( - torch.randn(num_experts, hidden_size, intermediate_size, dtype=otype).cuda() + paddle.randn( + shape=[num_experts, hidden_size, intermediate_size], dtype=otype + ).cuda() / 10 ) - - # Generate unique random expert indices for each token - selected_experts = torch.stack( - [torch.randperm(num_experts)[:top_k] for _ in range(batch_size)] + selected_experts = paddle.stack( + x=[paddle.randperm(n=num_experts)[:top_k] for _ in range(batch_size)] ).cuda() - - routing_weights = torch.randn((batch_size, top_k)).cuda() - routing_weights = F.softmax(routing_weights, dim=1) - - # Run reference implementation (no quantization) + routing_weights = paddle.randn(shape=(batch_size, top_k)).cuda() + routing_weights = paddle.nn.functional.softmax(x=routing_weights, axis=1) _ref_output = compute_with_experts( num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights ) - - # Quantize input and weights x_quant, x_scales = per_token_group_quant_fp8(x, group_size=128) - - w31_dequant = torch.empty_like(w31_weight) - w2_dequant = torch.empty_like(w2_weight) - w31_quant = torch.empty_like(w31_weight).to(torch.float8_e4m3fn) - w2_quant = torch.empty_like(w2_weight).to(torch.float8_e4m3fn) - w31_scales = torch.randn( - num_experts, - ceil_div(2 * intermediate_size, 128), - ceil_div(hidden_size, 128), - dtype=torch.float32, + w31_dequant = paddle.empty_like(x=w31_weight) + w2_dequant = paddle.empty_like(x=w2_weight) +>>>>>> w31_quant = paddle.empty_like(x=w31_weight).to(torch.float8_e4m3fn) +>>>>>> w2_quant = paddle.empty_like(x=w2_weight).to(torch.float8_e4m3fn) + w31_scales = paddle.randn( + shape=[ + num_experts, + ceil_div(2 * intermediate_size, 128), + ceil_div(hidden_size, 128), + ], + dtype="float32", ).cuda() - w2_scales = torch.randn( - num_experts, - ceil_div(hidden_size, 128), - ceil_div(intermediate_size, 128), - dtype=torch.float32, + w2_scales = paddle.randn( + shape=[ + num_experts, + ceil_div(hidden_size, 128), + ceil_div(intermediate_size, 128), + ], + dtype="float32", ).cuda() - for expert_id in range(num_experts): w31, w31_s = per_block_cast_to_fp8(w31_weight[expert_id, :]) w2, w2_s = per_block_cast_to_fp8(w2_weight[expert_id, :]) - w31_quant.data[expert_id].copy_(w31) - w31_scales.data[expert_id].copy_(w31_s) - w2_quant.data[expert_id].copy_(w2) - w2_scales.data[expert_id].copy_(w2_s) - # Dequantize for verificationa - x_dequant = dequantize_block(x_quant, x_scales, x.dtype, x.shape) + paddle.assign(w31, output=w31_quant.data[expert_id]) + paddle.assign(w31_s, output=w31_scales.data[expert_id]) + paddle.assign(w2, output=w2_quant.data[expert_id]) + paddle.assign(w2_s, output=w2_scales.data[expert_id]) + x_dequant = dequantize_block(x_quant, x_scales, x.dtype, tuple(x.shape)) w31_dequant = dequantize_block( - w31_quant, w31_scales, w31_weight.dtype, w31_weight.shape + w31_quant, w31_scales, w31_weight.dtype, tuple(w31_weight.shape) + ) + w2_dequant = dequantize_block( + w2_quant, w2_scales, w2_weight.dtype, tuple(w2_weight.shape) ) - w2_dequant = dequantize_block(w2_quant, w2_scales, w2_weight.dtype, w2_weight.shape) - - # Run reference implementation with dequantized tensors _ref_output = compute_with_experts( num_experts, x_dequant, @@ -1029,19 +896,14 @@ def test_moe_fp8_block_scaling( selected_experts, routing_weights, ) - quant_scales = [ - w31_scales, # .view(-1), # W31 scales - w2_scales, # .view(-1), # W2 scales - ] - - # Call flashinfer implementation with block scaling and expect NotImplementedError + quant_scales = [w31_scales, w2_scales] with pytest.raises( NotImplementedError, match="DeepSeek FP8 Block Scaling is not yet implemented in CUTLASS for Blackwell", ): _ = fused_moe.cutlass_fused_moe( x.contiguous(), - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, w31_quant.contiguous(), w2_quant.contiguous(), @@ -1060,23 +922,16 @@ def quant_mxfp4_batches(a, num_experts): a_fp4, a_sf = mxfp4_quantize(a[i].cuda()) quant_a.append(a_fp4) sfs.append(a_sf) - - result_quant_a = torch.stack(quant_a) - result_sfs = torch.stack(sfs) - + result_quant_a = paddle.stack(x=quant_a) + result_sfs = paddle.stack(x=sfs) return result_quant_a, result_sfs -def dequant_mxfp4_batches( - mat_fp4: torch.Tensor, - scale_tensor: torch.Tensor, -): - num_batches = mat_fp4.size(0) - +def dequant_mxfp4_batches(mat_fp4: paddle.Tensor, scale_tensor: paddle.Tensor): + num_batches = mat_fp4.shape[0] scale_tensor = scale_tensor.view(num_batches, -1) - - return torch.stack( - [ + return paddle.stack( + x=[ mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :]) for b in range(num_batches) ] @@ -1088,12 +943,12 @@ def dequant_mxfp4_batches( @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) -@pytest.mark.parametrize("otype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("otype", ["float16", "bfloat16"]) @pytest.mark.parametrize( ("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)] ) @pytest.mark.skipif( - torch.cuda.get_device_capability()[0] != 10, + paddle.device.cuda.get_device_capability()[0] != 10, reason="MXFP8xMXFP4 is only supported on SM100", ) def test_moe_mxfp8_mxfp4( @@ -1111,57 +966,45 @@ def test_moe_mxfp8_mxfp4( Test MoE with MXFP8 activations and MXFP4 weights. Uses mxfp8_quantize for activations and fp4_quantize for weights. """ - # Skip invalid configurations if top_k > num_experts: pytest.skip( f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" ) - - torch.manual_seed(42) + paddle.seed(seed=42) e = num_experts m = batch_size n = intermediate_size k = hidden_size - - x = torch.randn(m, k, dtype=otype).cuda() - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10 - + x = paddle.randn(shape=[m, k], dtype=otype).cuda() + w1 = paddle.randn(shape=(e, 2 * n, k), dtype=otype) / 10 + w2 = paddle.randn(shape=(e, k, n), dtype=otype) / 10 mxfp8_x, mxfp8_x_sf = mxfp8_quantize(x, True, 32) - mxfp4_w1, mxfp4_w1_scale = quant_mxfp4_batches(w1, e) mxfp4_w2, mxfp4_w2_scale = quant_mxfp4_batches(w2, e) - - router_logits = torch.randn(m, e, dtype=otype).cuda() + router_logits = paddle.randn(shape=[m, e], dtype=otype).cuda() routing_weights, selected_experts = compute_routing(router_logits, top_k) - - fake_input_scale = torch.ones(e, device=x.device) - + fake_input_scale = paddle.ones(shape=e) quant_scales = [ - mxfp4_w1_scale.view(torch.int32), + mxfp4_w1_scale.view("int32"), fake_input_scale, - mxfp4_w2_scale.view(torch.int32), + mxfp4_w2_scale.view("int32"), fake_input_scale, ] - - flash_output = torch.zeros_like(x) - + flash_output = paddle.zeros_like(x=x) if alpha is not None and limit is not None and beta is not None: - alpha_t = torch.ones(e, device=x.device) * alpha - limit_t = torch.ones(e, device=x.device) * limit - beta_t = torch.ones(e, device=x.device) * beta + alpha_t = paddle.ones(shape=e) * alpha + limit_t = paddle.ones(shape=e) * limit + beta_t = paddle.ones(shape=e) * beta else: alpha_t = None limit_t = None beta_t = None - - # Call cutlass_fused_moe with MXFP8 activations and MXFP4 weights _ = fused_moe.cutlass_fused_moe( mxfp8_x, - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, - mxfp4_w1.contiguous().view(torch.long), - mxfp4_w2.contiguous().view(torch.long), + mxfp4_w1.contiguous().view("int64"), + mxfp4_w2.contiguous().view("int64"), otype, swiglu_alpha=alpha_t, swiglu_limit=limit_t, @@ -1171,36 +1014,29 @@ def test_moe_mxfp8_mxfp4( use_mxfp8_act_scaling=True, output=flash_output, ) - dq_mxfp8_x = ( mxfp8_dequantize_host( - mxfp8_x.cpu().view(torch.uint8), - mxfp8_x_sf.cpu().view(torch.uint8).reshape(-1), + mxfp8_x.cpu().view("uint8"), + mxfp8_x_sf.cpu().view("uint8").reshape(-1), True, ) .cuda() .to(otype) ) - dq_mfxp4_w1 = ( dequant_mxfp4_batches( - mxfp4_w1.cpu().view(torch.uint8), - mxfp4_w1_scale.cpu().view(torch.uint8).reshape(-1), + mxfp4_w1.cpu().view("uint8"), mxfp4_w1_scale.cpu().view("uint8").reshape(-1) ) .cuda() .to(otype) ) - dq_mfxp4_w2 = ( dequant_mxfp4_batches( - mxfp4_w2.cpu().view(torch.uint8), - mxfp4_w2_scale.cpu().view(torch.uint8).reshape(-1), + mxfp4_w2.cpu().view("uint8"), mxfp4_w2_scale.cpu().view("uint8").reshape(-1) ) .cuda() .to(otype) ) - - # Use original weights for reference computation ref_output = compute_with_experts( e, dq_mxfp8_x, @@ -1212,18 +1048,14 @@ def test_moe_mxfp8_mxfp4( beta, limit, ) + assert paddle.allclose(x=ref_output, y=flash_output, rtol=0.1, atol=0.1).item(), "" - torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) - -def dequant_mxfp4_batches_host( - mat_fp4: torch.Tensor, - scale_tensor: torch.Tensor, -): - return torch.stack( - [ +def dequant_mxfp4_batches_host(mat_fp4: paddle.Tensor, scale_tensor: paddle.Tensor): + return paddle.stack( + x=[ mxfp4_dequantize_host(mat_fp4[b, :, :], scale_tensor[b, :, :]) - for b in range(mat_fp4.size(0)) + for b in range(mat_fp4.shape[0]) ] ) @@ -1237,76 +1069,53 @@ def dequant_mxfp4_batches_host( ("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)] ) @pytest.mark.skipif( - torch.cuda.get_device_capability()[0] != 9, + paddle.device.cuda.get_device_capability()[0] != 9, reason="BF16xMXFP4 is only supported on SM90", ) def test_moe_bf16_mxfp4( - batch_size, - hidden_size, - num_experts, - top_k, - intermediate_size, - alpha, - beta, - limit, + batch_size, hidden_size, num_experts, top_k, intermediate_size, alpha, beta, limit ): """ Test MoE with bf16 activations and MXFP4 weights. Uses bf16 for activations and fp4_quantize for weights. """ - # Skip invalid configurations if top_k > num_experts: pytest.skip( f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" ) - - torch.manual_seed(42) + paddle.seed(seed=42) e = num_experts m = batch_size n = intermediate_size k = hidden_size - - x = torch.randn(m, k, dtype=torch.bfloat16).cuda() - w1 = torch.randint(0, 256, (e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) - w2 = torch.randint(0, 256, (e, k, n // 2), device="cuda", dtype=torch.uint8) - - w1_scale = torch.randint( - 118, 123, (e, 2 * n, k // 32), device="cuda", dtype=torch.uint8 + x = paddle.randn(shape=[m, k], dtype="bfloat16").cuda() + w1 = paddle.randint(low=0, high=256, shape=(e, 2 * n, k // 2), dtype="uint8") + w2 = paddle.randint(low=0, high=256, shape=(e, k, n // 2), dtype="uint8") + w1_scale = paddle.randint( + low=118, high=123, shape=(e, 2 * n, k // 32), dtype="uint8" ) - w2_scale = torch.randint( - 118, 123, (e, k, n // 32), device="cuda", dtype=torch.uint8 - ) - - router_logits = torch.randn(m, e, dtype=torch.bfloat16).cuda() + w2_scale = paddle.randint(low=118, high=123, shape=(e, k, n // 32), dtype="uint8") + router_logits = paddle.randn(shape=[m, e], dtype="bfloat16").cuda() routing_weights, selected_experts = compute_routing(router_logits, top_k) - - flash_output = torch.zeros_like(x) - + flash_output = paddle.zeros_like(x=x) if alpha is not None and limit is not None and beta is not None: - alpha_t = torch.ones(e, device=x.device) * alpha - limit_t = torch.ones(e, device=x.device) * limit - beta_t = torch.ones(e, device=x.device) * beta + alpha_t = paddle.ones(shape=e) * alpha + limit_t = paddle.ones(shape=e) * limit + beta_t = paddle.ones(shape=e) * beta else: alpha_t = None limit_t = None beta_t = None - - pad_size = hidden_size - x.shape[1] - x_pad = torch.nn.functional.pad(x, (0, pad_size)) - - quant_scales = [ - w1_scale.view(torch.int32), - w2_scale.view(torch.int32), - ] - - # Call cutlass_fused_moe with BF16 activations and MXFP4 weights + pad_size = hidden_size - tuple(x.shape)[1] + x_pad = paddle.nn.functional.pad(x=x, pad=(0, pad_size), pad_from_left_axis=False) + quant_scales = [w1_scale.view("int32"), w2_scale.view("int32")] _ = fused_moe.cutlass_fused_moe( x_pad, - selected_experts.to(torch.int), + selected_experts.to("int32"), routing_weights, - w1.contiguous().view(torch.uint8), - w2.contiguous().view(torch.uint8), - torch.bfloat16, + w1.contiguous().view("uint8"), + w2.contiguous().view("uint8"), + "bfloat16", swiglu_alpha=alpha_t, swiglu_limit=limit_t, swiglu_beta=beta_t, @@ -1314,26 +1123,12 @@ def test_moe_bf16_mxfp4( use_w4_group_scaling=True, output=flash_output, ) - dq_mfxp4_w1 = ( - dequant_mxfp4_batches_host( - w1.cpu(), - w1_scale.cpu(), - ) - .cuda() - .to(torch.bfloat16) + dequant_mxfp4_batches_host(w1.cpu(), w1_scale.cpu()).cuda().to("bfloat16") ) - dq_mfxp4_w2 = ( - dequant_mxfp4_batches_host( - w2.cpu(), - w2_scale.cpu(), - ) - .cuda() - .to(torch.bfloat16) + dequant_mxfp4_batches_host(w2.cpu(), w2_scale.cpu()).cuda().to("bfloat16") ) - - # Use original weights for reference computation ref_output = compute_with_experts( e, x, @@ -1345,8 +1140,7 @@ def test_moe_bf16_mxfp4( beta, limit, ) - - torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) + assert paddle.allclose(x=ref_output, y=flash_output, rtol=0.1, atol=0.1).item(), "" if __name__ == "__main__": diff --git a/tests/test_trtllm_gen_attention.py b/tests/test_trtllm_gen_attention.py index f304ea575e..b48fc40f21 100755 --- a/tests/test_trtllm_gen_attention.py +++ b/tests/test_trtllm_gen_attention.py @@ -1,145 +1,122 @@ +import sys + + import math +import paddle import pytest -import torch +from flashinfer.paddle_utils import * from utils_fp4 import cast_from_fp4, recover_swizzled_scales, ref_fp4_quant import flashinfer from flashinfer.utils import FP4Tensor, ceil_div, round_up DTYPE_MAP = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp8": torch.float8_e4m3fn, + "fp16": "float16", + "bf16": "bfloat16", + "fp8": paddle.float8_e4m3fn, "nvfp4": "nvfp4", } - GPU_DEVICE = "cuda:0" - global_workspace_buffer = None def flip_coin(*args, **kwargs): - # Use any test parameters to deterministically decide branch - # This makes test configurations go through different paths param_tuple = args + tuple(sorted(kwargs.items())) hash_value = hash(param_tuple) - return (hash_value % 2) == 0 + return hash_value % 2 == 0 -def to_float8(x, dtype=torch.float8_e4m3fn): - finfo = torch.finfo(dtype) - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) +def to_float8(x, dtype=paddle.float8_e4m3fn): + finfo = paddle.finfo(dtype=dtype) + min_val, max_val = tuple( + [ + paddle.amin(x, axis=None, keepdim=False), + paddle.max(x, axis=None, keepdim=False), + ] + ) + amax = paddle.maximum(x=min_val.abs(), y=max_val.abs()).clip(min=1e-12) scale = finfo.max / amax * 0.1 - x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) - return x_scl_sat.to(dtype), scale.float().reciprocal() + x_scl_sat = (x * scale).clip(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.astype(dtype="float32").reciprocal() def generate_seq_lens(batch_size, max_q_len, max_in_kv_len): - q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32) + q_lens = paddle.randint( + low=1, high=max_q_len + 1, shape=(batch_size,), dtype="int32" + ) q_lens[-1] = max_q_len - in_kv_lens = torch.randint(0, max_in_kv_len + 1, (batch_size,), dtype=torch.int) + in_kv_lens = paddle.randint( + low=0, high=max_in_kv_len + 1, shape=(batch_size,), dtype="int32" + ) in_kv_lens[-1] = max_in_kv_len seq_lens = q_lens + in_kv_lens return q_lens, in_kv_lens, seq_lens def generate_cumsum_lens(lens): - return torch.cat( - [ - torch.tensor([0], dtype=torch.int32, device=GPU_DEVICE), - torch.cumsum(lens.to(GPU_DEVICE), dim=0, dtype=torch.int32), + return paddle.concat( + x=[ + paddle.to_tensor(data=[0], dtype="int32", place=GPU_DEVICE), + paddle.cumsum(x=lens.to(GPU_DEVICE), axis=0, dtype="int32"), ] ) def create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype): - q = torch.randn( - torch.sum(q_lens).item(), - num_qo_heads, - head_dim, - dtype=torch.bfloat16 if q_dtype == "fp8" else DTYPE_MAP[q_dtype], - device=GPU_DEVICE, + q = paddle.randn( + shape=[paddle.sum(x=q_lens).item(), num_qo_heads, head_dim], + dtype="bfloat16" if q_dtype == "fp8" else DTYPE_MAP[q_dtype], ) if q_dtype == "fp8": q, q_scale = to_float8(q) - # Reference implementation have functional issue or low precision with fp8, use bfloat16 and fake-quantization instead. - ref_q = q.bfloat16() * q_scale + ref_q = q.astype(dtype="bfloat16") * q_scale else: q_scale = 1.0 ref_q = q - return q, q_scale, ref_q def create_kv_cache( batch_size, seq_lens, page_size, num_kv_heads, head_dim, kv_dtype, ref_kv_dtype ): - # Create separate K and V caches - max_seq_len = torch.max(seq_lens).item() + max_seq_len = paddle.max(x=seq_lens).item() num_tokens = max_seq_len * batch_size num_pages = (num_tokens + page_size - 1) // page_size ref_kv_dtype_torch = DTYPE_MAP[ref_kv_dtype] - if kv_dtype != "fp8": # for fp8, create with high precision to generate scale. - assert kv_dtype == ref_kv_dtype, ( - "kv_dtype and ref_kv_dtype must be the same for non-fp8 kv_cache" - ) - - k_cache = torch.randn( - num_pages, - num_kv_heads, - page_size, - head_dim, - dtype=ref_kv_dtype_torch, - device=GPU_DEVICE, + if kv_dtype != "fp8": + assert ( + kv_dtype == ref_kv_dtype + ), "kv_dtype and ref_kv_dtype must be the same for non-fp8 kv_cache" + k_cache = paddle.randn( + shape=[num_pages, num_kv_heads, page_size, head_dim], dtype=ref_kv_dtype_torch ) - v_cache = torch.randn( - num_pages, - num_kv_heads, - page_size, - head_dim, - dtype=ref_kv_dtype_torch, - device=GPU_DEVICE, + v_cache = paddle.randn( + shape=[num_pages, num_kv_heads, page_size, head_dim], dtype=ref_kv_dtype_torch ) - - # Convert K and V separately to fp8 if needed if kv_dtype == "fp8": k_cache, k_scale = to_float8(k_cache) v_cache, v_scale = to_float8(v_cache) - # use high precision and fake-quantization for reference to avoid precision/functional issue - ref_kv_cache = torch.stack( - [ + ref_kv_cache = paddle.stack( + x=[ k_cache.to(ref_kv_dtype_torch) * k_scale, v_cache.to(ref_kv_dtype_torch) * v_scale, ], - dim=1, + axis=1, ) else: k_scale = v_scale = 1.0 - ref_kv_cache = torch.stack([k_cache, v_cache], dim=1) - # Combine K and V into interleaved format for the API - kv_cache = torch.stack([k_cache, v_cache], dim=1) - + ref_kv_cache = paddle.stack(x=[k_cache, v_cache], axis=1) + kv_cache = paddle.stack(x=[k_cache, v_cache], axis=1) return kv_cache, k_scale, v_scale, ref_kv_cache def create_page_table(batch_size, seq_lens, page_size): page_per_seq = (seq_lens + page_size - 1) // page_size - max_num_pages_per_seq = torch.max(page_per_seq).item() - - # Generate random but unique page IDs for all sequences - total_pages_needed = torch.sum(page_per_seq).item() - all_page_ids = torch.randperm( - total_pages_needed, dtype=torch.int32, device=GPU_DEVICE - ) - - # Generate unique page IDs for all sequences - page_tables = torch.zeros( - (batch_size, max_num_pages_per_seq), dtype=torch.int32, device=GPU_DEVICE - ) - - # Populate page tables and track page assignments + max_num_pages_per_seq = paddle.max(x=page_per_seq).item() + total_pages_needed = paddle.sum(x=page_per_seq).item() + all_page_ids = paddle.randperm(n=total_pages_needed, dtype="int32") + page_tables = paddle.zeros(shape=(batch_size, max_num_pages_per_seq), dtype="int32") page_id = 0 for i in range(batch_size): num_pages_needed = page_per_seq[i] @@ -152,38 +129,31 @@ def create_page_table(batch_size, seq_lens, page_size): def create_output(q, o_dtype, create_out_tensor): if o_dtype == "fp8": - o_scale = torch.rand(1).item() * 0.5 + 0.5 # Scale range: 0.5 ~ 1.0 + o_scale = paddle.rand(shape=[1]).item() * 0.5 + 0.5 else: o_scale = 1.0 - o_sf_scale = ( - 300 if o_dtype == "nvfp4" else None - ) # choose a value to make error smaller by testing. + o_sf_scale = 300 if o_dtype == "nvfp4" else None o_sf_vec_size = 16 if o_dtype == "nvfp4" else None - if create_out_tensor: if o_dtype == "nvfp4": - fp4_out_shape = q.shape[:-1] + (ceil_div(q.shape[-1], 2),) - - extra_size = torch.randint(0, 256, (1,)).item() - - fp4_out_scale_shape = ( - round_up(q.shape[0] + extra_size, 128), - round_up(q.shape[1] * q.shape[2] // o_sf_vec_size, 4), - ) - - out_scale_factor = torch.empty( - fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=q.device + fp4_out_shape = tuple(q.shape)[:-1] + (ceil_div(tuple(q.shape)[-1], 2),) + extra_size = paddle.randint(low=0, high=256, shape=(1,)).item() + fp4_out_scale_shape = round_up( + tuple(q.shape)[0] + extra_size, 128 + ), round_up(tuple(q.shape)[1] * tuple(q.shape)[2] // o_sf_vec_size, 4) + out_scale_factor = paddle.empty( + shape=fp4_out_scale_shape, dtype=paddle.float8_e4m3fn ) - rounded_extra_size = fp4_out_scale_shape[0] - q.shape[0] + rounded_extra_size = fp4_out_scale_shape[0] - tuple(q.shape)[0] o_sf_start_index = ( - torch.randint(0, rounded_extra_size, (1,)).item() + paddle.randint(low=0, high=rounded_extra_size, shape=(1,)).item() if rounded_extra_size > 0 else 0 ) - out_data = torch.empty(fp4_out_shape, dtype=torch.uint8, device=q.device) + out_data = paddle.empty(shape=fp4_out_shape, dtype="uint8") out = FP4Tensor(out_data, out_scale_factor, o_sf_start_index) else: - out = torch.empty_like(q, dtype=DTYPE_MAP[o_dtype]) + out = paddle.empty_like(x=q, dtype=DTYPE_MAP[o_dtype]) else: out = None return out, o_scale, o_sf_scale, o_sf_vec_size @@ -200,40 +170,48 @@ def unpack_compare_nvfp4( output_ref, o_sf_scale, o_sf_vec_size, - sf_rtol=2e-1, - sf_atol=2e-1, + sf_rtol=0.2, + sf_atol=0.2, rmse_tol=0.3, ): output_ref, out_scale_factor_ref = ref_fp4_quant( output_ref, o_sf_scale, o_sf_vec_size ) - output_unpacked = cast_from_fp4(output.data) out_scale_factor = recover_swizzled_scales( output.scale, - output_unpacked.shape[0], - math.prod(list(output_unpacked.shape[1:])), + tuple(output_unpacked.shape)[0], + math.prod(list(tuple(output_unpacked.shape)[1:])), o_sf_vec_size, output.scale_start_index, ) - - torch.testing.assert_close( - out_scale_factor.float().reshape(out_scale_factor_ref.shape), - out_scale_factor_ref.float(), + assert paddle.allclose( + x=out_scale_factor.astype(dtype="float32").reshape( + tuple(out_scale_factor_ref.shape) + ), + y=out_scale_factor_ref.astype(dtype="float32"), rtol=sf_rtol, atol=sf_atol, + ).item(), "" + rmse = paddle.sqrt( + x=paddle.mean( + x=( + output_unpacked.astype(dtype="float32") + - output_ref.astype(dtype="float32") + ) + ** 2 + ) ) - rmse = torch.sqrt(torch.mean((output_unpacked.float() - output_ref.float()) ** 2)) assert rmse.item() < rmse_tol return output_unpacked, output_ref -@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize("kv_layout", ["HND"]) @pytest.mark.parametrize("batch_size", [4, 128, 256]) @pytest.mark.parametrize("page_size", [16, 32, 64]) @pytest.mark.parametrize("num_kv_heads", [2, 4]) @pytest.mark.parametrize("head_grp_size", [1, 5, 8]) -@pytest.mark.parametrize("window_left", [-1]) # todo(Siyuan): add 127 window_left +@pytest.mark.parametrize("window_left", [-1]) @pytest.mark.parametrize( "q_dtype,kv_dtype,o_dtype", [ @@ -258,23 +236,16 @@ def test_trtllm_batch_prefill( kv_dtype, enable_pdl, ): - # Set up test parameters - torch.manual_seed(0) + paddle.seed(seed=0) head_dim = 128 MAX_Q_LEN = 511 MAX_IN_KV_LEN = 2047 - - # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size q_lens, in_kv_lens, seq_lens = generate_seq_lens( batch_size, MAX_Q_LEN, MAX_IN_KV_LEN ) - - # Create query tensor and related data q, q_scale, ref_q = create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype) q_indptr = generate_cumsum_lens(q_lens) - - # Create KV cache and related data kv_cache, k_scale, v_scale, ref_kv_cache = create_kv_cache( batch_size, seq_lens, @@ -289,23 +260,16 @@ def test_trtllm_batch_prefill( ) kv_indptr = generate_cumsum_lens(page_per_seq) kv_last_page_len = get_last_page_len(seq_lens, page_size) - - # Create output tensor and related data create_out_tensor = flip_coin( batch_size, page_size, num_kv_heads, head_grp_size, o_dtype ) out, o_scale, o_sf_scale, o_sf_vec_size = create_output( q, o_dtype, create_out_tensor ) - global global_workspace_buffer if global_workspace_buffer is None: - global_workspace_buffer = torch.zeros( - 128 * 1024 * 1024, dtype=torch.int8, device=GPU_DEVICE - ) + global_workspace_buffer = paddle.zeros(shape=128 * 1024 * 1024, dtype="int8") workspace_buffer = global_workspace_buffer - - # Run reference wrapper wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) @@ -327,50 +291,46 @@ def test_trtllm_batch_prefill( } wrapper_ref.plan(**plan_params) output_ref = wrapper_ref.run(ref_q, ref_kv_cache) - - # Run trtllm-gen function call - sm_scale = float(1.0 / (head_dim**0.5)) + sm_scale = float(1.0 / head_dim**0.5) output = flashinfer.prefill.trtllm_batch_context_with_kv_cache( q.contiguous(), kv_cache, workspace_buffer, page_table, seq_lens.to(GPU_DEVICE), - torch.max(q_lens).item(), - torch.max(seq_lens).item(), - q_scale * k_scale * sm_scale, # bmm1_scale - v_scale / o_scale, # bmm2_scale + paddle.max(x=q_lens).item(), + paddle.max(x=seq_lens).item(), + q_scale * k_scale * sm_scale, + v_scale / o_scale, batch_size, q_indptr, kv_indptr, - window_left, # window_left + window_left, out=out, out_dtype=DTYPE_MAP[o_dtype], o_sf_scale=o_sf_scale, o_sf_vec_size=o_sf_vec_size, enable_pdl=enable_pdl, ) - if o_dtype == "nvfp4": output, output_ref = unpack_compare_nvfp4( output, output_ref, o_sf_scale, o_sf_vec_size ) assert o_scale == 1.0 - rtol, atol = 4e-1, 1e0 + rtol, atol = 0.4, 1.0 elif q_dtype == "fp8" and o_dtype == "fp8": - rtol, atol = 5e-2, 7e-2 + rtol, atol = 0.05, 0.07 elif q_dtype == "fp8" and o_dtype in ["bf16", "fp16"]: - rtol, atol = 4e-2, 6e-2 + rtol, atol = 0.04, 0.06 else: - rtol, atol = 1e-2, 1e-2 - - # convert to float32 for fp8 is not supported by assert_close - torch.testing.assert_close( - output.float() * o_scale, output_ref.float(), rtol=rtol, atol=atol - ) - - if o_dtype != "nvfp4": # wrapper api does not support fp4 output yet. - # test wrapper with trtllm-gen backend + rtol, atol = 0.01, 0.01 + assert paddle.allclose( + x=output.astype(dtype="float32") * o_scale, + y=output_ref.astype(dtype="float32"), + rtol=rtol, + atol=atol, + ).item(), "" + if o_dtype != "nvfp4": wrapper_trtllm_gen = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout, backend="trtllm-gen" ) @@ -385,16 +345,18 @@ def test_trtllm_batch_prefill( v_scale=v_scale / o_scale, enable_pdl=enable_pdl, ) - # v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel. if v_scale == o_scale == 1.0: - assert (output_wrapper == output).all() + assert (output_wrapper == output).astype("bool").all() else: - torch.testing.assert_close( - output.float(), output_wrapper.float(), rtol=1e-1, atol=1e-1 - ) + assert paddle.allclose( + x=output.astype(dtype="float32"), + y=output_wrapper.astype(dtype="float32"), + rtol=0.1, + atol=0.1, + ).item(), "" -@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize("kv_layout", ["HND"]) @pytest.mark.parametrize("batch_size", [4, 128, 256]) @pytest.mark.parametrize("page_size", [16, 32, 64]) @pytest.mark.parametrize("num_kv_heads", [2, 4]) @@ -426,22 +388,15 @@ def test_trtllm_batch_decode( kv_dtype, enable_pdl, ): - # Set up test parameters - torch.manual_seed(0) + paddle.seed(seed=0) head_dim = 128 - MAX_Q_LEN = 1 # must be 1 for decode test + MAX_Q_LEN = 1 MAX_IN_KV_LEN = 110 - - # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size q_lens, in_kv_lens, seq_lens = generate_seq_lens( batch_size, MAX_Q_LEN, MAX_IN_KV_LEN ) - - # Create query tensor and related data q, q_scale, ref_q = create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype) - - # Create KV cache and related data kv_cache, k_scale, v_scale, ref_kv_cache = create_kv_cache( batch_size, seq_lens, @@ -456,23 +411,16 @@ def test_trtllm_batch_decode( ) kv_indptr = generate_cumsum_lens(page_per_seq) kv_last_page_len = get_last_page_len(seq_lens, page_size) - - # Create output tensor and related data create_out_tensor = flip_coin( batch_size, page_size, num_kv_heads, head_grp_size, o_dtype ) out, o_scale, o_sf_scale, o_sf_vec_size = create_output( q, o_dtype, create_out_tensor ) - global global_workspace_buffer if global_workspace_buffer is None: - global_workspace_buffer = torch.zeros( - 128 * 1024 * 1024, dtype=torch.int8, device=GPU_DEVICE - ) + global_workspace_buffer = paddle.zeros(shape=128 * 1024 * 1024, dtype="int8") workspace_buffer = global_workspace_buffer - - # Run reference wrapper wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, use_tensor_cores=True ) @@ -491,47 +439,42 @@ def test_trtllm_batch_decode( } wrapper_ref.plan(**plan_params) output_ref = wrapper_ref.run(ref_q, ref_kv_cache) - - # Run trtllm-gen function call - sm_scale = float(1.0 / (head_dim**0.5)) - + sm_scale = float(1.0 / head_dim**0.5) output = flashinfer.decode.trtllm_batch_decode_with_kv_cache( q.contiguous(), kv_cache, workspace_buffer, page_table, seq_lens.to(GPU_DEVICE), - torch.max(seq_lens).item(), - q_scale * k_scale * sm_scale, # bmm1_scale - v_scale / o_scale, # bmm2_scale - window_left, # window_left + paddle.max(x=seq_lens).item(), + q_scale * k_scale * sm_scale, + v_scale / o_scale, + window_left, out=out, out_dtype=DTYPE_MAP[o_dtype], o_sf_scale=o_sf_scale, o_sf_vec_size=o_sf_vec_size, enable_pdl=enable_pdl, ) - if o_dtype == "nvfp4": output, output_ref = unpack_compare_nvfp4( output, output_ref, o_sf_scale, o_sf_vec_size ) assert o_scale == 1.0 - rtol, atol = 3e-1, 1e0 + rtol, atol = 0.3, 1.0 elif q_dtype == "fp8" and o_dtype == "fp8": - rtol, atol = 5e-2, 7e-2 + rtol, atol = 0.05, 0.07 elif q_dtype == "fp8" and o_dtype in ["bf16", "fp16"]: - rtol, atol = 4e-2, 6e-2 + rtol, atol = 0.04, 0.06 else: - rtol, atol = 1e-2, 1e-2 - - # convert to float32 for fp8 is not supported by assert_close - torch.testing.assert_close( - output.float() * o_scale, output_ref.float(), rtol=rtol, atol=atol - ) - - if o_dtype != "nvfp4": # wrapper api does not support fp4 output yet. - # test wrapper with trtllm-gen backend + rtol, atol = 0.01, 0.01 + assert paddle.allclose( + x=output.astype(dtype="float32") * o_scale, + y=output_ref.astype(dtype="float32"), + rtol=rtol, + atol=atol, + ).item(), "" + if o_dtype != "nvfp4": wrapper_trtllm_gen = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, backend="trtllm-gen" ) @@ -546,13 +489,15 @@ def test_trtllm_batch_decode( v_scale=v_scale / o_scale, enable_pdl=enable_pdl, ) - # v_scale, o_scale in wrapper is emulated by multiplying output by v_scale instead of fused into kernel. if v_scale == o_scale == 1.0: - assert (output_wrapper == output).all() + assert (output_wrapper == output).astype("bool").all() else: - torch.testing.assert_close( - output.float(), output_wrapper.float(), rtol=1e-1, atol=1e-1 - ) + assert paddle.allclose( + x=output.astype(dtype="float32"), + y=output_wrapper.astype(dtype="float32"), + rtol=0.1, + atol=0.1, + ).item(), "" @pytest.mark.parametrize("batch_size", [4, 128, 256]) @@ -566,68 +511,43 @@ def test_trtllm_gen_prefill_deepseek( ): if s_qo > s_kv: pytest.skip("s_qo > s_kv, skipping test as causal") - num_qo_heads = num_kv_heads * head_grp_size head_dim_qk = 192 head_dim_vo = 128 - seed = 0 - torch.manual_seed(seed) + paddle.seed(seed=seed) device = "cuda:0" - - actual_seq_lens_q = torch.randint( - 1, s_qo + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device - ) - - actual_seq_lens_kv = torch.randint( - s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device - ) - - cumsum_s_qo = torch.sum(actual_seq_lens_q) - cumsum_s_kv = torch.sum(actual_seq_lens_kv) - - q = torch.randn( - cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=torch.bfloat16 - ) - - k_cache = torch.randn( - (cumsum_s_kv, num_kv_heads, head_dim_qk), - device=device, - dtype=torch.bfloat16, - ) - v_cache = torch.randn( - (cumsum_s_kv, num_kv_heads, head_dim_vo), - device=device, - dtype=torch.bfloat16, - ) - - # Initialize scale - scale = float(1.0 / (head_dim_qk**0.5)) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - - qo_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q.view(-1), dim=0), + actual_seq_lens_q = paddle.randint( + low=1, high=s_qo + 1, shape=(batch_size, 1, 1, 1), dtype="int32" + ) + actual_seq_lens_kv = paddle.randint( + low=s_qo, high=s_kv + 1, shape=(batch_size, 1, 1, 1), dtype="int32" + ) + cumsum_s_qo = paddle.sum(x=actual_seq_lens_q) + cumsum_s_kv = paddle.sum(x=actual_seq_lens_kv) + q = paddle.randn(shape=[cumsum_s_qo, num_qo_heads, head_dim_qk], dtype="bfloat16") + k_cache = paddle.randn( + shape=(cumsum_s_kv, num_kv_heads, head_dim_qk), dtype="bfloat16" + ) + v_cache = paddle.randn( + shape=(cumsum_s_kv, num_kv_heads, head_dim_vo), dtype="bfloat16" + ) + scale = float(1.0 / head_dim_qk**0.5) + workspace_buffer = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") + qo_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_q.view(-1), axis=0), ] - ).int() - - # kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * s_kv - - # Create kv_indptr as cumulative sum of actual_seq_lens_kv - kv_indptr = torch.cat( - [ - torch.tensor( - [0], - device=device, - ), - torch.cumsum(actual_seq_lens_kv.view(-1), dim=0), + ).astype(dtype="int32") + kv_indptr = paddle.concat( + x=[ + paddle.to_tensor(data=[0], place=device), + paddle.cumsum(x=actual_seq_lens_kv.view(-1), axis=0), ] - ).int() - + ).astype(dtype="int32") wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + paddle.empty(shape=128 * 1024 * 1024, dtype="uint8"), kv_layout="NHD", backend="cutlass", ) @@ -640,12 +560,11 @@ def test_trtllm_gen_prefill_deepseek( head_dim_vo=head_dim_vo, causal=causal, sm_scale=scale, - q_data_type=torch.bfloat16, - kv_data_type=torch.bfloat16, + q_data_type="bfloat16", + kv_data_type="bfloat16", ) output_ref, lse_ref = wrapper.run(q, k_cache, v_cache, return_lse=True) - output = torch.empty_like(output_ref) - + output = paddle.empty_like(x=output_ref) bmm1_scale = scale bmm2_scale = 1.0 output_trtllm, lse_trtllm = flashinfer.prefill.trtllm_ragged_attention_deepseek( @@ -668,18 +587,10 @@ def test_trtllm_gen_prefill_deepseek( True, out=output, ) - torch.testing.assert_close( - output_trtllm, - output_ref, - atol=1e-2, - rtol=1e-2, - ) - torch.testing.assert_close( - lse_trtllm, - lse_ref, - atol=1e-3, - rtol=1e-3, - ) + assert paddle.allclose( + x=output_trtllm, y=output_ref, atol=0.01, rtol=0.01 + ).item(), "" + assert paddle.allclose(x=lse_trtllm, y=lse_ref, atol=0.001, rtol=0.001).item(), "" if __name__ == "__main__": diff --git a/tests/test_trtllm_gen_fused_moe.py b/tests/test_trtllm_gen_fused_moe.py index 86a8714dfc..c301da0dcd 100644 --- a/tests/test_trtllm_gen_fused_moe.py +++ b/tests/test_trtllm_gen_fused_moe.py @@ -1,3 +1,9 @@ +import sys + + +import paddle +from flashinfer.paddle_utils import * + """ Copyright (c) 2025 by FlashInfer team. @@ -13,39 +19,25 @@ See the License for the specific language governing permissions and limitations under the License. """ - from abc import ABC, abstractmethod from enum import IntEnum from typing import Dict import pytest -import torch from cuda.bindings import runtime -from torch.nn import functional as F - -from flashinfer import ( - RoutingMethodType, - GatedActType, - e2m1_and_ufp8sf_scale_to_float, - fp4_quantize, - mxfp8_dequantize_host, - mxfp8_quantize, - next_positive_power_of_2, - reorder_rows_for_gated_act_gemm, - shuffle_matrix_a, -) + +from flashinfer import (GatedActType, RoutingMethodType, + e2m1_and_ufp8sf_scale_to_float, fp4_quantize, + mxfp8_dequantize_host, mxfp8_quantize, + next_positive_power_of_2, + reorder_rows_for_gated_act_gemm, shuffle_matrix_a) from flashinfer.fp4_quantization import block_scale_interleave -from flashinfer.fused_moe import ( - WeightLayout, - convert_to_block_layout, - trtllm_fp4_block_scale_moe, - trtllm_fp8_block_scale_moe, - trtllm_fp8_per_tensor_scale_moe, -) -from flashinfer.fused_moe.core import ( - _maybe_get_cached_w2_permute_indices, - _maybe_get_cached_w3_w1_permute_indices, -) +from flashinfer.fused_moe import (WeightLayout, convert_to_block_layout, + trtllm_fp4_block_scale_moe, + trtllm_fp8_block_scale_moe, + trtllm_fp8_per_tensor_scale_moe) +from flashinfer.fused_moe.core import (_maybe_get_cached_w2_permute_indices, + _maybe_get_cached_w3_w1_permute_indices) def check_cuda(err): @@ -92,38 +84,24 @@ def capture(self, hidden_states_sample, **runtime_args): raise NotImplementedError( f"CUDA graph capture not yet implemented for {type(self.moe_impl)}" ) - - # Create stream err, self.stream = runtime.cudaStreamCreate() check_cuda(err) - - # Get the raw stream pointer for PyTorch stream_ptr = int(self.stream) - torch_stream = torch.cuda.ExternalStream(stream_ptr) - - # Store input tensor reference (will be updated in place during launch) +>>>>>> torch_stream = torch.cuda.ExternalStream(stream_ptr) self.input_tensor = hidden_states_sample.clone() - - # Warmup - with torch.cuda.stream(torch_stream): + with paddle.device.stream_guard(stream=torch_stream): for _ in range(1): self._run_moe_computation(runtime_args) - - # Synchronize our stream after warmup err = runtime.cudaStreamSynchronize(self.stream)[0] check_cuda(err) - - # Begin capture err, self.graph = runtime.cudaGraphCreate(0) check_cuda(err) err = runtime.cudaStreamBeginCapture( self.stream, runtime.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal )[0] check_cuda(err) - try: - # Capture computation on our stream - with torch.cuda.stream(torch_stream): + with paddle.device.stream_guard(stream=torch_stream): self.output_tensor = self._run_moe_computation(runtime_args) err, self.graph = runtime.cudaStreamEndCapture(self.stream) check_cuda(err) @@ -138,17 +116,11 @@ def launch(self, hidden_states_new): """Launch captured CUDA graph with new input.""" if not self.is_captured: raise RuntimeError("Graph not captured. Call capture() first.") - - # Update input tensor in place - self.input_tensor.copy_(hidden_states_new) - - # Launch graph + paddle.assign(hidden_states_new, output=self.input_tensor) err = runtime.cudaGraphLaunch(self.graph_exec, self.stream)[0] check_cuda(err) err = runtime.cudaStreamSynchronize(self.stream)[0] check_cuda(err) - - # Return output tensor (automatically updated by graph execution) return self.output_tensor def cleanup(self): @@ -176,7 +148,6 @@ def _run_moe_computation(self, runtime_args): self.config["hidden_states_scale_global"], is_swizzling=False, ) - output = trtllm_fp4_block_scale_moe( routing_logits=runtime_args["expert_logits"], routing_bias=runtime_args["routing_bias"], @@ -207,7 +178,7 @@ def _run_moe_computation(self, runtime_args): gated_act_type=self.config["gated_act_type"], do_finalize=True, ) - return output # Extract tensor from tuple + return output class QuantMode(IntEnum): @@ -220,11 +191,6 @@ class QuantMode(IntEnum): FP8_PER_TENSOR = 5 -# ==================================================================================== -# Abstract Base Class for MoE Implementations -# ==================================================================================== - - class Moe(ABC): """Abstract base class for MoE implementations.""" @@ -294,11 +260,6 @@ def __str__(self): return self.name -# ==================================================================================== -# FP4 Quantization Implementation -# ==================================================================================== - - class FP4Moe(Moe): """ FP4 NvFP4 / MxFP4 MoE implementation with block scaling. @@ -318,28 +279,23 @@ def __init__(self, quant_mode: QuantMode): def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP4 format and compute global scale factors.""" - num_experts = gemm1_weights.shape[0] - # Compute global scale factor for hidden states (offline calibration) + num_experts = tuple(gemm1_weights.shape)[0] if self.quant_mode == QuantMode.FP4_NVFP4_NVFP4: - # nvfp4 hidden states hidden_states_scale_global = calculate_fp4_global_scale_factor( - hidden_states_sample, - False, + hidden_states_sample, False ) else: - # mxfp8 / bf16 hidden states hidden_states_scale_global = 1.0 - - # Quantize the weights for FC1 - gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = ( - quant_fp4_batches(gemm1_weights, num_experts, self.is_mxfp4, True) - ) - - # Quantize the weights for FC2 - gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = ( - quant_fp4_batches(gemm2_weights, num_experts, self.is_mxfp4, True) - ) - + ( + gemm1_weights_fp4_bytes, + gemm1_scales_fp4_bytes, + gemm1_scales_global, + ) = quant_fp4_batches(gemm1_weights, num_experts, self.is_mxfp4, True) + ( + gemm2_weights_fp4_bytes, + gemm2_scales_fp4_bytes, + gemm2_scales_global, + ) = quant_fp4_batches(gemm2_weights, num_experts, self.is_mxfp4, True) return { "hidden_states_scale_global": hidden_states_scale_global, "gemm1_weights": gemm1_weights_fp4_bytes, @@ -358,8 +314,8 @@ def quantize_inputs( hidden_states_quant, hidden_states_scale = mxfp8_quantize( hidden_states, is_swizzling ) - hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( - *hidden_states.shape[:-1], -1 + hidden_states_scale = hidden_states_scale.view(paddle.float8_e4m3fn).reshape( + *tuple(hidden_states.shape)[:-1], -1 ) return { "hidden_states": hidden_states_quant, @@ -367,24 +323,19 @@ def quantize_inputs( } elif self.quant_mode == QuantMode.FP4_NVFP4_NVFP4: """Quantize hidden states to NvFP4 format using pre-computed global scale.""" - ( - hidden_states_fp4_bytes, - hidden_states_scale_fp4_bytes, - _, - ) = quant_fp4( + hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, _ = quant_fp4( hidden_states, hidden_states_scale_global, False, is_swizzling ) hidden_states_scale_fp4_bytes = hidden_states_scale_fp4_bytes.view( - torch.float8_e4m3fn - ).reshape(*hidden_states.shape[:-1], -1) - + paddle.float8_e4m3fn + ).reshape(*tuple(hidden_states.shape)[:-1], -1) return { "hidden_states": hidden_states_fp4_bytes, "hidden_states_scale": hidden_states_scale_fp4_bytes, } - else: # bf16 + else: return { - "hidden_states": hidden_states.to(torch.bfloat16), + "hidden_states": hidden_states.to("bfloat16"), "hidden_states_scale": None, } @@ -401,129 +352,99 @@ def prepare_static_weights_for_kernel( ): """Prepare quantized weights for kernel (done offline with weights).""" use_ue8m0 = self.is_mxfp4 - epilogue_tile_m = 128 # FIXME: this depends on the kernel internals - - # Quantize weights with linear layout for kernels + epilogue_tile_m = 128 _, gemm1_scales_linear_fp4_bytes, _ = quant_fp4_batches( gemm1_weights_orig, num_experts, use_ue8m0, False ) _, gemm2_scales_linear_fp4_bytes, _ = quant_fp4_batches( gemm2_weights_orig, num_experts, use_ue8m0, False ) - - # Convert quantized weights to proper formats - gemm1_weights_fp4 = args.gemm1_weights.view(torch.float8_e4m3fn).reshape( + gemm1_weights_fp4 = args.gemm1_weights.view(paddle.float8_e4m3fn).reshape( num_experts, 2 * intermediate_size, hidden_size // 2 - ) # packed fp4 + ) gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( - torch.float8_e4m3fn - ).reshape( - num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size - ) # fp8 scaling factors - - gemm2_weights_fp4 = args.gemm2_weights.view(torch.float8_e4m3fn).reshape( + paddle.float8_e4m3fn + ).reshape(num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size) + gemm2_weights_fp4 = args.gemm2_weights.view(paddle.float8_e4m3fn).reshape( num_experts, hidden_size, intermediate_size // 2 - ) # packed fp4 + ) gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( - torch.float8_e4m3fn - ).reshape( - num_experts, hidden_size, intermediate_size // self.sf_vec_size - ) # fp8 scaling factors - - # Using cached permute index calculation can speed up weights preprocessing + paddle.float8_e4m3fn + ).reshape(num_experts, hidden_size, intermediate_size // self.sf_vec_size) gemm1_weights_fp4_shuffled = [] gemm1_scales_fp4_shuffled = [] gemm2_weights_fp4_shuffled = [] gemm2_scales_fp4_shuffled = [] for i in range(num_experts): - # Calculate the permute indices for the following: - # 1. Reorder rows of W1 and scales for fused gated activation - # 2. Shuffle weights and scaling factors for transposed mma output - # for both w3_w1 and w2 weights and scale factors permute_indices = _maybe_get_cached_w3_w1_permute_indices( self._cache_permute_indices, - gemm1_weights_fp4[i].view(torch.uint8), + gemm1_weights_fp4[i].view("uint8"), epilogue_tile_m, ) gemm1_weights_fp4_shuffled.append( gemm1_weights_fp4[i] - .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)] + .view("uint8")[permute_indices.to(gemm1_weights_fp4.place)] .contiguous() ) - permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( self._cache_permute_indices, - gemm1_scales_linear_fp4[i].view(torch.uint8), + gemm1_scales_linear_fp4[i].view("uint8"), epilogue_tile_m, num_elts_per_sf=16, ) gemm1_scales_fp4_shuffled.append( block_scale_interleave( gemm1_scales_linear_fp4[i] - .view(torch.uint8)[ - permute_sf_indices.to(gemm1_scales_linear_fp4.device) - ] + .view("uint8")[permute_sf_indices.to(gemm1_scales_linear_fp4.place)] .contiguous() ) ) - permute_indices = _maybe_get_cached_w2_permute_indices( self._cache_permute_indices, - gemm2_weights_fp4[i].view(torch.uint8), + gemm2_weights_fp4[i].view("uint8"), epilogue_tile_m, ) gemm2_weights_fp4_shuffled.append( gemm2_weights_fp4[i] - .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)] + .view("uint8")[permute_indices.to(gemm2_weights_fp4.place)] .contiguous() ) - permute_sf_indices = _maybe_get_cached_w2_permute_indices( self._cache_permute_indices, - gemm2_scales_linear_fp4[i].view(torch.uint8), + gemm2_scales_linear_fp4[i].view("uint8"), epilogue_tile_m, num_elts_per_sf=16, ) gemm2_scales_fp4_shuffled.append( block_scale_interleave( gemm2_scales_linear_fp4[i] - .view(torch.uint8)[ - permute_sf_indices.to(gemm2_scales_linear_fp4.device) - ] + .view("uint8")[permute_sf_indices.to(gemm2_scales_linear_fp4.place)] .contiguous() ) ) - - # Stack weights for all experts - gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled) + gemm1_weights_fp4_shuffled = paddle.stack(x=gemm1_weights_fp4_shuffled) gemm1_scales_fp4_shuffled = ( - torch.stack(gemm1_scales_fp4_shuffled) - .view(torch.float8_e4m3fn) + paddle.stack(x=gemm1_scales_fp4_shuffled) + .view(paddle.float8_e4m3fn) .reshape( num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size ) ) - - gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled) + gemm2_weights_fp4_shuffled = paddle.stack(x=gemm2_weights_fp4_shuffled) gemm2_scales_fp4_shuffled = ( - torch.stack(gemm2_scales_fp4_shuffled) - .view(torch.float8_e4m3fn) + paddle.stack(x=gemm2_scales_fp4_shuffled) + .view(paddle.float8_e4m3fn) .reshape(num_experts, hidden_size, intermediate_size // self.sf_vec_size) ) - - # Calculate scaling factors that depend on weights scale_c_fc1 = ( args_dequant.c_global_sf * (1.0 / args.gemm1_scales_global) * (1.0 / args.hidden_states_scale_global) ) - scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * ( - 1.0 / args.hidden_states_scale_global + scale_gate_fc1 = ( + 1.0 / args.gemm1_scales_global * (1.0 / args.hidden_states_scale_global) ) - scale_c_fc2 = (1.0 / args_dequant.c_global_sf) * ( - 1.0 / args.gemm2_scales_global - ) - + scale_c_fc2 = 1.0 / args_dequant.c_global_sf * (1.0 / args.gemm2_scales_global) return { "gemm1_weights_fp4_shuffled": gemm1_weights_fp4_shuffled, "gemm1_scales_fp4_shuffled": gemm1_scales_fp4_shuffled, @@ -538,7 +459,6 @@ def call_moe( self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs ): """Call MoE using CUDA graph for maximum performance (create, capture, launch).""" - # Extract runtime arguments expert_logits = kwargs["expert_logits"] routing_bias = kwargs["routing_bias"] num_experts = kwargs["num_experts"] @@ -550,8 +470,6 @@ def call_moe( gated_act_type = kwargs["gated_act_type"] routing_method_type = kwargs["routing_method_type"] tile_tokens_dim = kwargs["tile_tokens_dim"] - - # Create CUDA graph configuration config = { "hidden_states_scale_global": hidden_states_scale_global, "num_experts": num_experts, @@ -564,18 +482,12 @@ def call_moe( "gated_act_type": gated_act_type, "routing_method_type": routing_method_type, } - - runtime_args = { - "expert_logits": expert_logits, - "routing_bias": routing_bias, - } - - # Create, capture and launch CUDA graph in one shot + runtime_args = {"expert_logits": expert_logits, "routing_bias": routing_bias} cuda_graph = CUDAGraphMoE(self, static_data, **config) try: cuda_graph.capture(hidden_states_orig, **runtime_args) output = cuda_graph.launch(hidden_states_orig) - return output[0].to(torch.float) + return output[0].to("float32") finally: cuda_graph.cleanup() @@ -587,36 +499,24 @@ def get_tolerances(self): return {"atol": 0.1, "rtol": 0.85, "percent": 0.925} -# ==================================================================================== -# FP8 Block Scale Quantization Implementation -# ==================================================================================== - - class FP8BlockScaleMoe(Moe): """FP8 MoE implementation with block scaling (DeepSeek style).""" def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP8 with block scaling.""" - num_experts = gemm1_weights.shape[0] - intermediate_size = gemm1_weights.shape[1] // 2 - hidden_size = gemm1_weights.shape[ - 2 - ] # [num_experts, 2*intermediate_size, hidden_size] - - # Quantize weights to FP8 - gemm1_weights_fp8 = gemm1_weights.to(torch.float8_e4m3fn) - gemm1_scales = 2 * torch.rand( - (num_experts, 2 * intermediate_size // 128, hidden_size // 128), - device="cuda", - ).to(torch.float) - - gemm2_weights_fp8 = gemm2_weights.to(torch.float8_e4m3fn) - gemm2_scales = 2 * torch.rand( - (num_experts, hidden_size // 128, intermediate_size // 128), device="cuda" - ).to(torch.float) - + num_experts = tuple(gemm1_weights.shape)[0] + intermediate_size = tuple(gemm1_weights.shape)[1] // 2 + hidden_size = tuple(gemm1_weights.shape)[2] + gemm1_weights_fp8 = gemm1_weights.to(paddle.float8_e4m3fn) + gemm1_scales = 2 * paddle.rand( + shape=(num_experts, 2 * intermediate_size // 128, hidden_size // 128) + ).to("float32") + gemm2_weights_fp8 = gemm2_weights.to(paddle.float8_e4m3fn) + gemm2_scales = 2 * paddle.rand( + shape=(num_experts, hidden_size // 128, intermediate_size // 128) + ).to("float32") return { - "hidden_states_scale_global": None, # Block scales computed at runtime + "hidden_states_scale_global": None, "gemm1_weights": gemm1_weights_fp8, "gemm1_scales": gemm1_scales, "gemm1_scales_global": None, @@ -627,10 +527,7 @@ def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): def quantize_inputs(self, hidden_states, hidden_states_scale_global): """For FP8 block scaling, no pre-quantization - everything happens at runtime.""" - return { - "hidden_states": hidden_states, # Keep original - "hidden_states_scale": None, # No pre-computed scales - } + return {"hidden_states": hidden_states, "hidden_states_scale": None} def prepare_static_weights_for_kernel( self, @@ -644,43 +541,34 @@ def prepare_static_weights_for_kernel( weight_processing, ): """Prepare quantized weights for kernel (done offline with weights).""" - - # Use shuffled weights with BlockMajorK layout for better performance use_shuffled_weight = weight_processing["use_shuffled_weight"] weight_layout = weight_processing["layout"] - if use_shuffled_weight: - # FIXME: this depends on the kernel internals epilogue_tile_m = 64 - gemm1_weights_fp8_shuffled = [] gemm2_weights_fp8_shuffled = [] for i in range(num_experts): tmp_weights1 = shuffle_matrix_a( - args.gemm1_weights[i].view(torch.uint8), epilogue_tile_m + args.gemm1_weights[i].view("uint8"), epilogue_tile_m ) tmp_weights2 = shuffle_matrix_a( - args.gemm2_weights[i].view(torch.uint8), epilogue_tile_m + args.gemm2_weights[i].view("uint8"), epilogue_tile_m ) - if weight_layout == WeightLayout.BlockMajorK: block_k = 128 tmp_weights1 = convert_to_block_layout(tmp_weights1, block_k) tmp_weights2 = convert_to_block_layout(tmp_weights2, block_k) - gemm1_weights_fp8_shuffled.append(tmp_weights1) - gemm2_weights_fp8_shuffled.append(tmp_weights2) - kernel_gemm1_weights = torch.stack(gemm1_weights_fp8_shuffled).view( - torch.float8_e4m3fn + kernel_gemm1_weights = paddle.stack(x=gemm1_weights_fp8_shuffled).view( + paddle.float8_e4m3fn ) - kernel_gemm2_weights = torch.stack(gemm2_weights_fp8_shuffled).view( - torch.float8_e4m3fn + kernel_gemm2_weights = paddle.stack(x=gemm2_weights_fp8_shuffled).view( + paddle.float8_e4m3fn ) else: kernel_gemm1_weights = args.gemm1_weights kernel_gemm2_weights = args.gemm2_weights - return { "gemm1_weights": kernel_gemm1_weights, "gemm1_scales": args.gemm1_scales, @@ -707,14 +595,10 @@ def call_moe( routing_method_type = kwargs["routing_method_type"] tile_tokens_dim = kwargs["tile_tokens_dim"] enable_pdl = kwargs.get("enable_pdl") - - # Generate block scales and quantize hidden states at runtime - hidden_states_fp8 = hidden_states_orig.to(torch.float8_e4m3fn) - # Use deterministic scales for testing consistency - hidden_states_scale = 2.0 * torch.ones( - (hidden_size // 128, num_tokens), device="cuda", dtype=torch.float + hidden_states_fp8 = hidden_states_orig.to(paddle.float8_e4m3fn) + hidden_states_scale = 2.0 * paddle.ones( + shape=(hidden_size // 128, num_tokens), dtype="float32" ) - output = trtllm_fp8_block_scale_moe( expert_logits, routing_bias, @@ -738,8 +622,7 @@ def call_moe( weight_layout=static_data["weight_layout"], enable_pdl=enable_pdl, ) - - return output.to(torch.float) + return output.to("float32") def compute_reference(self, args): """FP8 block-scale reference implementation.""" @@ -750,29 +633,20 @@ def get_tolerances(self): return {"atol": 0.1, "rtol": 0.85, "percent": 0.925} -# ==================================================================================== -# FP8 Per-Tensor Quantization Implementation -# ==================================================================================== - - class FP8PerTensorMoe(Moe): """FP8 MoE implementation with per-tensor scaling (Llama4 style).""" def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP8 per-tensor and compute global scale factors.""" - # Compute global scale factor for hidden states (offline calibration) hidden_states_global_scale = calculate_fp8_global_scale_factor( hidden_states_sample ) - - # Quantize to FP8 per-tensor gemm1_weights_quant, gemm1_global_scales = quant_fp8_per_tensor_batches( gemm1_weights ) gemm2_weights_quant, gemm2_global_scales = quant_fp8_per_tensor_batches( gemm2_weights ) - return { "hidden_states_scale_global": hidden_states_global_scale, "gemm1_weights": gemm1_weights_quant, @@ -785,15 +659,10 @@ def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): def quantize_inputs(self, hidden_states, hidden_states_scale_global): """Quantize hidden states to FP8 per-tensor using pre-computed global scale.""" - # Quantize to FP8 per-tensor using pre-computed global scale factor hidden_states_quant, _ = quant_fp8_per_tensor( hidden_states, hidden_states_scale_global ) - - return { - "hidden_states": hidden_states_quant, - "hidden_states_scale": None, - } + return {"hidden_states": hidden_states_quant, "hidden_states_scale": None} def prepare_static_weights_for_kernel( self, @@ -807,58 +676,41 @@ def prepare_static_weights_for_kernel( weight_processing, ): """Prepare quantized weights for kernel (done offline with weights).""" - # FIXME: this depends on the kernel internals epilogue_tile_m = 128 - - # Reorder rows of W1 for fused gated activation gemm1_weights_fp8_interleaved = [] for i in range(num_experts): gemm1_weights_fp8_interleaved.append( reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone()) ) - - # Stack weights and scales for all experts - gemm1_weights_fp8_interleaved = torch.stack( - gemm1_weights_fp8_interleaved + gemm1_weights_fp8_interleaved = paddle.stack( + x=gemm1_weights_fp8_interleaved ).reshape(num_experts, 2 * intermediate_size, hidden_size) - - # Shuffle weights and scaling factors for transposed mma output gemm1_weights_fp8_shuffled = [] gemm2_weights_fp8_shuffled = [] for i in range(num_experts): gemm1_weights_fp8_shuffled.append( shuffle_matrix_a( - gemm1_weights_fp8_interleaved[i].view(torch.uint8), epilogue_tile_m + gemm1_weights_fp8_interleaved[i].view("uint8"), epilogue_tile_m ) ) - gemm2_weights_fp8_shuffled.append( - shuffle_matrix_a( - args.gemm2_weights[i].view(torch.uint8), epilogue_tile_m - ) + shuffle_matrix_a(args.gemm2_weights[i].view("uint8"), epilogue_tile_m) ) - - # Stack weights for all experts - gemm1_weights_fp8_shuffled = torch.stack(gemm1_weights_fp8_shuffled).view( - torch.float8_e4m3fn + gemm1_weights_fp8_shuffled = paddle.stack(x=gemm1_weights_fp8_shuffled).view( + paddle.float8_e4m3fn ) - gemm2_weights_fp8_shuffled = torch.stack(gemm2_weights_fp8_shuffled).view( - torch.float8_e4m3fn + gemm2_weights_fp8_shuffled = paddle.stack(x=gemm2_weights_fp8_shuffled).view( + paddle.float8_e4m3fn ) - - # Calculate scaling factors that depend on weights scale_c_fc1 = ( args_dequant.c_global_sf * (1.0 / args.gemm1_scales_global) * (1.0 / args.hidden_states_scale_global) ) - scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * ( - 1.0 / args.hidden_states_scale_global - ) - scale_c_fc2 = (1.0 / args_dequant.c_global_sf) * ( - 1.0 / args.gemm2_scales_global + scale_gate_fc1 = ( + 1.0 / args.gemm1_scales_global * (1.0 / args.hidden_states_scale_global) ) - + scale_c_fc2 = 1.0 / args_dequant.c_global_sf * (1.0 / args.gemm2_scales_global) return { "gemm1_weights": gemm1_weights_fp8_shuffled, "gemm2_weights": gemm2_weights_fp8_shuffled, @@ -881,18 +733,13 @@ def call_moe( routed_scaling = kwargs["routed_scaling"] routing_method_type = kwargs["routing_method_type"] tile_tokens_dim = kwargs["tile_tokens_dim"] - - # Quantize to FP8 per-tensor using pre-computed global scale factor hidden_states_fp8, _ = quant_fp8_per_tensor( hidden_states_orig, hidden_states_scale_global ) - output = trtllm_fp8_per_tensor_scale_moe( - ( - expert_logits.to(torch.bfloat16) - if routing_method_type == RoutingMethodType.Llama4 - else expert_logits - ), + expert_logits.to("bfloat16") + if routing_method_type == RoutingMethodType.Llama4 + else expert_logits, routing_bias, hidden_states_fp8, static_data["gemm1_weights"], @@ -908,13 +755,11 @@ def call_moe( 0, num_experts, routed_scaling, - routing_method_type - == RoutingMethodType.Llama4, # Use_routing_scales_on_input + routing_method_type == RoutingMethodType.Llama4, tile_tokens_dim, routing_method_type, ) - - return output.to(torch.float) + return output.to("float32") def compute_reference(self, args): """FP8 per-tensor reference implementation.""" @@ -925,11 +770,6 @@ def get_tolerances(self): return {"atol": 0.1, "rtol": 0.85, "percent": 0.925} -# ==================================================================================== -# Quantizer Factory -# ==================================================================================== - - def get_moe_impl(quant_mode: QuantMode): """Factory function to get the appropriate MoE implementation.""" if quant_mode == QuantMode.FP8_BLOCK_SCALE: @@ -1022,16 +862,14 @@ def __init__( def routing_reference(expertLogits, topK, padding): """Reference routing implementation for permutation calculation.""" - originalDevice = expertLogits.device + originalDevice = expertLogits.place expertLogits = expertLogits.cpu() - numTokens, numExperts = expertLogits.shape + numTokens, numExperts = tuple(expertLogits.shape) assert topK <= numExperts - - numTokensPerExpert = torch.zeros(numExperts, dtype=torch.int64) - expandedTokenIdxToExpert = -torch.ones(numTokens * topK, dtype=torch.int64) - expandedTokenIdxToIdxInExpert = -torch.ones(numTokens * topK, dtype=torch.int64) - - topKLogits, topKIndices = torch.topk(expertLogits, topK, dim=1) + numTokensPerExpert = paddle.zeros(shape=numExperts, dtype="int64") + expandedTokenIdxToExpert = -paddle.ones(shape=numTokens * topK, dtype="int64") + expandedTokenIdxToIdxInExpert = -paddle.ones(shape=numTokens * topK, dtype="int64") + topKLogits, topKIndices = paddle.topk(x=expertLogits, k=topK, axis=1) for tokenIdx in range(numTokens): for k in range(topK): expandedIdx = tokenIdx * topK + k @@ -1039,8 +877,7 @@ def routing_reference(expertLogits, topK, padding): expandedTokenIdxToExpert[expandedIdx] = expertIndex expandedTokenIdxToIdxInExpert[expandedIdx] = numTokensPerExpert[expertIndex] numTokensPerExpert[expertIndex] += 1 - - paddedTokensPerExpertPrefixSum = torch.zeros(numExperts + 1, dtype=torch.int64) + paddedTokensPerExpertPrefixSum = paddle.zeros(shape=numExperts + 1, dtype="int64") for ii in range(numExperts): def divUpMul(a, b): @@ -1050,10 +887,9 @@ def divUpMul(a, b): ii ] + divUpMul(numTokensPerExpert[ii], padding) permutedBufferSize = paddedTokensPerExpertPrefixSum[numExperts] - - expandedTokenIdxToPermutedIdx = -torch.ones(numTokens * topK, dtype=torch.int64) - permutedIdxToExpandedIdx = -torch.ones(permutedBufferSize, dtype=torch.int64) - permutedIdxToTokenIdx = -torch.ones(permutedBufferSize, dtype=torch.int64) + expandedTokenIdxToPermutedIdx = -paddle.ones(shape=numTokens * topK, dtype="int64") + permutedIdxToExpandedIdx = -paddle.ones(shape=permutedBufferSize, dtype="int64") + permutedIdxToTokenIdx = -paddle.ones(shape=permutedBufferSize, dtype="int64") for tokenIdx in range(numTokens): for k in range(topK): expandedIdx = tokenIdx * topK + k @@ -1061,7 +897,6 @@ def divUpMul(a, b): offsetWithinExpert = expandedTokenIdxToIdxInExpert[expandedIdx] offsetForExpert = paddedTokensPerExpertPrefixSum[expert] permutedIdx = offsetForExpert + offsetWithinExpert - expandedTokenIdxToPermutedIdx[expandedIdx] = permutedIdx permutedIdxToExpandedIdx[permutedIdx] = expandedIdx permutedIdxToTokenIdx[permutedIdx] = tokenIdx @@ -1084,41 +919,42 @@ def divUpMul(a, b): def noaux_tc_ref(logits, bias, n_group, topk_group, top_k, routed_scaling_factor): """DeepSeek-style no-aux routing reference implementation.""" - scores = F.sigmoid(logits) + scores = paddle.nn.functional.sigmoid(x=logits) scores_with_bias = scores + bias if n_group > 1: - scores_shape = list(scores_with_bias.shape) - group_scores = torch.sum( - torch.topk( - scores_with_bias.view( + scores_shape = list(tuple(scores_with_bias.shape)) + group_scores = paddle.sum( + x=paddle.topk( + x=scores_with_bias.view( scores_shape[:-1] + [n_group, scores_shape[-1] // n_group] ), k=2, - dim=-1, + axis=-1, largest=True, sorted=True, )[0], - dim=-1, + axis=-1, ) - _, group_idx = torch.topk( - group_scores, k=topk_group, dim=-1, largest=True, sorted=True + _, group_idx = paddle.topk( + x=group_scores, k=topk_group, axis=-1, largest=True, sorted=True + ) + group_mask = paddle.zeros_like(x=group_scores) + group_mask.put_along_axis_( + axis=-1, indices=group_idx, values=1, broadcast=False ) - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(-1, group_idx, 1) score_mask = ( - group_mask.unsqueeze(-1) - .expand(scores_shape[:-1] + [n_group, scores_shape[-1] // n_group]) + group_mask.unsqueeze(axis=-1) + .expand(shape=scores_shape[:-1] + [n_group, scores_shape[-1] // n_group]) .reshape(scores_shape) ) scores_with_bias = scores_with_bias * score_mask - - _, topk_idx = torch.topk( - scores_with_bias, k=top_k, dim=-1, largest=True, sorted=True + _, topk_idx = paddle.topk( + x=scores_with_bias, k=top_k, axis=-1, largest=True, sorted=True ) - new_mask = torch.zeros_like(scores) - new_mask.scatter_(-1, topk_idx, 1) + new_mask = paddle.zeros_like(x=scores) + new_mask.put_along_axis_(axis=-1, indices=topk_idx, values=1, broadcast=False) scores = scores * new_mask - score_sum = torch.sum(scores, dim=-1, keepdim=True) + 1e-20 + score_sum = paddle.sum(x=scores, axis=-1, keepdim=True) + 1e-20 scores = scores / score_sum * routed_scaling_factor return scores @@ -1134,10 +970,9 @@ def routing_reference_no_aux( use_routing_scales_on_input=False, ): """Tiered TopK routing used by DeepSeek.""" - routing_logits = expert_logits.to(dtype=torch.float, device="cuda") + routing_logits = expert_logits.to(dtype="float32", device="gpu") if use_routing_scales_on_input: - # if using routing scales on input, topK == 1 and the score is a plain sigmoid - scores = F.sigmoid(routing_logits) + scores = paddle.nn.functional.sigmoid(x=routing_logits) else: scores = noaux_tc_ref( routing_logits, routing_bias, n_groups, top_k_groups, top_k, routed_scaling @@ -1148,15 +983,15 @@ def routing_reference_no_aux( def routing_reference_renormalize(expert_logits, top_k, num_experts, padding): """TopK -> Softmax routing reference.""" - topk_values, topk_idx = torch.topk(expert_logits, k=top_k, dim=-1) - topk_values = torch.nn.functional.softmax(topk_values.float(), dim=-1) - - new_mask = torch.zeros_like(expert_logits) - new_mask.scatter_(-1, topk_idx, 1) + topk_values, topk_idx = paddle.topk(x=expert_logits, k=top_k, axis=-1) + topk_values = paddle.nn.functional.softmax( + x=topk_values.astype(dtype="float32"), axis=-1 + ) + new_mask = paddle.zeros_like(x=expert_logits) + new_mask.put_along_axis_(axis=-1, indices=topk_idx, values=1, broadcast=False) scores = expert_logits * new_mask - - for i in range(topk_idx.shape[0]): - for j in range(topk_idx.shape[1]): + for i in range(tuple(topk_idx.shape)[0]): + for j in range(tuple(topk_idx.shape)[1]): scores[i, topk_idx[i, j]] = topk_values[i, j] permute_info = routing_reference(scores, top_k, padding) return permute_info, scores @@ -1165,20 +1000,19 @@ def routing_reference_renormalize(expert_logits, top_k, num_experts, padding): def routing_reference_renormalize_naive(expert_logits, top_k, num_experts, padding): """Softmax->TopK -> Normalize routing reference.""" norm_topk_prob = True - scores = torch.nn.functional.softmax(expert_logits.float(), dim=-1) - topk_values, topk_idx = torch.topk(scores, k=top_k, dim=-1) - - if norm_topk_prob: # only diff with mixtral sparse moe block! - topk_values /= topk_values.sum(dim=-1, keepdim=True) + scores = paddle.nn.functional.softmax( + x=expert_logits.astype(dtype="float32"), axis=-1 + ) + topk_values, topk_idx = paddle.topk(x=scores, k=top_k, axis=-1) + if norm_topk_prob: + topk_values /= topk_values.sum(axis=-1, keepdim=True) topk_values = topk_values.to(expert_logits.dtype) scores = scores.to(expert_logits.dtype) - - new_mask = torch.zeros_like(expert_logits) - new_mask.scatter_(-1, topk_idx, 1) + new_mask = paddle.zeros_like(x=expert_logits) + new_mask.put_along_axis_(axis=-1, indices=topk_idx, values=1, broadcast=False) scores = expert_logits * new_mask - - for i in range(topk_idx.shape[0]): - for j in range(topk_idx.shape[1]): + for i in range(tuple(topk_idx.shape)[0]): + for j in range(tuple(topk_idx.shape)[1]): scores[i, topk_idx[i, j]] = topk_values[i, j] permute_info = routing_reference(scores, top_k, padding) return permute_info, scores @@ -1186,14 +1020,12 @@ def routing_reference_renormalize_naive(expert_logits, top_k, num_experts, paddi def routing_reference_topk(expert_logits, top_k, num_experts, padding): """TopK only (no softmax) routing reference.""" - topk_values, topk_idx = torch.topk(expert_logits, k=top_k, dim=-1) - - new_mask = torch.zeros_like(expert_logits) - new_mask.scatter_(-1, topk_idx, 1) + topk_values, topk_idx = paddle.topk(x=expert_logits, k=top_k, axis=-1) + new_mask = paddle.zeros_like(x=expert_logits) + new_mask.put_along_axis_(axis=-1, indices=topk_idx, values=1, broadcast=False) scores = expert_logits * new_mask - - for i in range(topk_idx.shape[0]): - for j in range(topk_idx.shape[1]): + for i in range(tuple(topk_idx.shape)[0]): + for j in range(tuple(topk_idx.shape)[1]): scores[i, topk_idx[i, j]] = topk_values[i, j] permute_info = routing_reference(scores, top_k, padding) return permute_info, scores @@ -1201,32 +1033,27 @@ def routing_reference_topk(expert_logits, top_k, num_experts, padding): def check_accuracy(a, b, atol, rtol, percent): """Unified accuracy checking function with detailed error reporting.""" - if torch.any(torch.isnan(a)): + if paddle.any(x=paddle.isnan(x=a)): raise Exception("NaN in reference output") - if torch.any(torch.isnan(b)): + if paddle.any(x=paddle.isnan(x=b)): raise Exception("NaN in actual output") - if torch.any(torch.isinf(a)): + if paddle.any(x=paddle.isinf(x=a)): raise Exception("Inf in reference output") - if torch.any(torch.isinf(b)): + if paddle.any(x=paddle.isinf(x=b)): raise Exception("Inf in actual output") - assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}" - - left = torch.abs(a - b) - right = atol + rtol * torch.abs(b) - count = torch.sum(left > right) - mismatch_percent = count / a.numel() + assert tuple(a.shape) == tuple( + b.shape + ), f"Shape mismatch: {tuple(a.shape)} vs {tuple(b.shape)}" + left = paddle.abs(x=a - b) + right = atol + rtol * paddle.abs(x=b) + count = paddle.sum(x=left > right) + mismatch_percent = count / a.size if mismatch_percent > 1 - percent: raise Exception( - f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " - f"(threshold: {1 - percent:.4f})" + f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} (threshold: {1 - percent:.4f})" ) -# ==================================================================================== -# FP4 Quantization Functions -# ==================================================================================== - - def calculate_fp4_global_scale_factor(tensor, use_ue8m0=False): """ Calculate FP4 global scale factor for a tensor. @@ -1239,22 +1066,21 @@ def calculate_fp4_global_scale_factor(tensor, use_ue8m0=False): Formula: (448 * 6) represents max representable value in FP4 format. """ if use_ue8m0: - return torch.tensor(1.0, dtype=torch.float32) + return paddle.to_tensor(data=1.0, dtype="float32") else: - return (448 * 6) / tensor.float().abs().nan_to_num().max() + return 448 * 6 / tensor.astype(dtype="float32").abs().nan_to_num()._max() def e2m1_and_ufp8_scale_batches( - mat_fp4: torch.Tensor, - scale_tensor: torch.Tensor, - global_scale_tensor: torch.Tensor, + mat_fp4: paddle.Tensor, + scale_tensor: paddle.Tensor, + global_scale_tensor: paddle.Tensor, sf_vec_size: int, ufp8_type: int = 1, ): """Batch FP4 dequantization helper.""" - num_batches = mat_fp4.size(0) + num_batches = mat_fp4.shape[0] scale_tensor = scale_tensor.view(num_batches, -1) - tensors = [ e2m1_and_ufp8sf_scale_to_float( mat_fp4[b, :, :].cpu(), @@ -1262,12 +1088,11 @@ def e2m1_and_ufp8_scale_batches( global_scale_tensor[b].cpu(), sf_vec_size, ufp8_type, - True, # is_sf_swizzled_layout + True, ) for b in range(num_batches) ] - - result = torch.stack(tensors) + result = paddle.stack(x=tensors) return result @@ -1282,11 +1107,9 @@ def quant_fp4(a, a_global_sf, use_ue8m0=False, is_sf_swizzled_layout=True): Pure function - same inputs always produce same outputs. """ sf_vec_size = 32 if use_ue8m0 else 16 - a_fp4, a_sf = fp4_quantize( a.cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0, is_sf_swizzled_layout ) - return a_fp4, a_sf, a_global_sf @@ -1296,47 +1119,35 @@ def quant_fp4_batches(a, num_experts, use_ue8m0=False, is_sf_swizzled_layout=Tru sfs = [] global_sfs = [] for i in range(num_experts): - # Use centralized global scale factor calculation a_global_sf = calculate_fp4_global_scale_factor(a[i], use_ue8m0) a_fp4, a_sf, _ = quant_fp4(a[i], a_global_sf, use_ue8m0, is_sf_swizzled_layout) quant_a.append(a_fp4) sfs.append(a_sf) global_sfs.append(a_global_sf) - - result_quant_a = torch.stack(quant_a) - result_sfs = torch.stack(sfs) - result_global_sfs = torch.stack(global_sfs) - + result_quant_a = paddle.stack(x=quant_a) + result_sfs = paddle.stack(x=sfs) + result_global_sfs = paddle.stack(x=global_sfs) return result_quant_a, result_sfs, result_global_sfs def quant_dequant_fp4(a, use_ue8m0=False, is_sf_swizzled_layout=True): """FP4 quantize-dequantize roundtrip function with centralized global scale factor calculation.""" - # Use centralized global scale factor calculation a_global_sf = calculate_fp4_global_scale_factor(a, use_ue8m0) sf_vec_size = 32 if use_ue8m0 else 16 - a_fp4, a_sf = fp4_quantize( a.cuda(), a_global_sf.cuda(), sf_vec_size, use_ue8m0, is_sf_swizzled_layout ) - a_pt = e2m1_and_ufp8sf_scale_to_float( a_fp4.cpu(), a_sf.cpu().reshape(-1), (1 / a_global_sf).cpu(), sf_vec_size, - 1 if not use_ue8m0 else 0, # ufp8_type + 1 if not use_ue8m0 else 0, is_sf_swizzled_layout, ) - return a_pt.cuda(), a_global_sf -# ==================================================================================== -# FP8 Quantization Functions -# ==================================================================================== - - def calculate_fp8_global_scale_factor(tensor): """ Calculate FP8 global scale factor for a tensor. @@ -1348,7 +1159,7 @@ def calculate_fp8_global_scale_factor(tensor): This function is used here for testing/reference purposes. Formula: 448 represents max representable value in FP8 E4M3 format. """ - return 448 / tensor.float().abs().nan_to_num().max() + return 448 / tensor.astype(dtype="float32").abs().nan_to_num()._max() def quant_fp8_per_tensor(a, a_global_sf): @@ -1361,91 +1172,71 @@ def quant_fp8_per_tensor(a, a_global_sf): Pure function - same inputs always produce same outputs. """ - a_fp8 = (a * a_global_sf).to(torch.float8_e4m3fn) + a_fp8 = (a * a_global_sf).to(paddle.float8_e4m3fn) return a_fp8, a_global_sf def quant_fp8_per_tensor_batches(a): """FP8 per-tensor batch quantization function with centralized global scale factor calculation.""" - num_batches = a.size(0) + num_batches = a.shape[0] a_quant = [] a_scales = [] - for i in range(num_batches): - # Use centralized global scale factor calculation a_global_sf = calculate_fp8_global_scale_factor(a[i]) a_fp8, _ = quant_fp8_per_tensor(a[i], a_global_sf) a_quant.append(a_fp8) a_scales.append(a_global_sf) - - result_a_quant = torch.stack(a_quant) - result_a_scales = torch.stack(a_scales) - + result_a_quant = paddle.stack(x=a_quant) + result_a_scales = paddle.stack(x=a_scales) return result_a_quant, result_a_scales def quant_dequant_per_tensor_fp8(a): """FP8 per-tensor quantize-dequantize roundtrip function with centralized global scale factor calculation.""" - # Use centralized global scale factor calculation a_global_sf = calculate_fp8_global_scale_factor(a) a_fp8, _ = quant_fp8_per_tensor(a, a_global_sf) - a_pt = a_fp8.to(torch.float) / a_global_sf + a_pt = a_fp8.to("float32") / a_global_sf return a_pt.cuda(), a_global_sf def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n): """Reference FP8 block-scale dequantization.""" - input = input.to(torch.float) - scale = scale.to(torch.float) + input = input.to("float32") + scale = scale.to("float32") if transpose_scale: scale = scale.t() - - m, n = input.shape + m, n = tuple(input.shape) m_tile = 128 if block_m else 1 n_tile = 128 if block_n else 1 - assert m % m_tile == 0 assert n % n_tile == 0 - assert scale.shape == (m // m_tile, n // n_tile) - - # Expand scale to match input dimensions using tensor operations + assert tuple(scale.shape) == (m // m_tile, n // n_tile) if m_tile > 1: - scale = torch.repeat_interleave(scale, m_tile, dim=0) + scale = paddle.repeat_interleave(x=scale, repeats=m_tile, axis=0) if n_tile > 1: - scale = torch.repeat_interleave(scale, n_tile, dim=1) - - # Element-wise multiplication (equivalent to the nested loop logic) + scale = paddle.repeat_interleave(x=scale, repeats=n_tile, axis=1) output = input * scale return output -# ==================================================================================== -# Common MoE Reference Implementation -# ==================================================================================== - - def run_moe_dequant(args, quant_mode: QuantMode): """Common dequantized MoE reference implementation.""" - # Permute total_num_padded_tokens = args.permute_info["permutedBufferSize"] expanded_idx_to_permuted_idx = args.permute_info[ "expandedTokenIdxToPermutedIdx" ].cpu() num_tokens_per_expert = args.permute_info["numTokensPerExpert"].cpu() - permute_output = torch.full( - (total_num_padded_tokens, args.hidden_size), float("nan"), device="cuda" - ).to(torch.float) + permute_output = paddle.full( + shape=(total_num_padded_tokens, args.hidden_size), fill_value=float("nan") + ).to("float32") for i in range(args.num_tokens): for j in range(args.top_k): permuted_idx = expanded_idx_to_permuted_idx[i * args.top_k + j] permute_output[permuted_idx] = args.hidden_states[i] - - # Gemm1 - gemm1_output = torch.full( - (total_num_padded_tokens, 2 * args.intermediate_size), - float("nan"), - device="cuda", - ).to(torch.float) + gemm1_output = paddle.full( + shape=(total_num_padded_tokens, 2 * args.intermediate_size), + fill_value=float("nan"), + ).to("float32") i = 0 for expert_idx in range(args.num_experts): my_num_tokens = num_tokens_per_expert[expert_idx] @@ -1457,33 +1248,24 @@ def run_moe_dequant(args, quant_mode: QuantMode): gemm1_output[i : i + my_num_tokens] = my_c i += my_num_tokens i = (i + args.padding - 1) // args.padding * args.padding - if args.use_routing_scales_on_input: assert args.top_k == 1 - # For each token and its top_k experts for token_idx in range(args.num_tokens): for k in range(args.top_k): - # Get the permuted index for this token's k-th expert expanded_idx = token_idx * args.top_k + k permuted_idx = expanded_idx_to_permuted_idx[expanded_idx] - expert_weight = args.permute_info["topKLogits"].to(torch.float) - # Get the expert weight for this token and expert + expert_weight = args.permute_info["topKLogits"].to("float32") weight = expert_weight[token_idx, k] - # Scale the corresponding row in gemm1_output gemm1_output[permuted_idx] *= weight - - # Activation - activation_output = torch.full( - (total_num_padded_tokens, args.intermediate_size), float("nan"), device="cuda" - ).to(torch.float) - + activation_output = paddle.full( + shape=(total_num_padded_tokens, args.intermediate_size), fill_value=float("nan") + ).to("float32") gated_act_type = args.gated_act_type gated_act_type_to_func = { - 0: F.silu, - 1: F.gelu, + (0): paddle.nn.functional.silu, + (1): paddle.nn.functional.gelu, } gated_act_func = gated_act_type_to_func[gated_act_type] - i = 0 for expert_idx in range(args.num_experts): my_num_tokens = num_tokens_per_expert[expert_idx] @@ -1495,41 +1277,35 @@ def run_moe_dequant(args, quant_mode: QuantMode): activation_output[i : i + my_num_tokens] = gated_act_func(my_x2) * my_x1 i += my_num_tokens i = (i + args.padding - 1) // args.padding * args.padding - if quant_mode == QuantMode.FP4_NVFP4_NVFP4: - # Use centralized function for activation quantization activation_output, c_global_sf = quant_dequant_fp4( - activation_output.to(torch.bfloat16), False, True + activation_output.to("bfloat16"), False, True ) - activation_output = activation_output.to(torch.float) + activation_output = activation_output.to("float32") args.c_global_sf = c_global_sf elif quant_mode == QuantMode.FP8_PER_TENSOR: activation_output, c_global_sf = quant_dequant_per_tensor_fp8( - activation_output.to(torch.bfloat16) + activation_output.to("bfloat16") ) - activation_output = activation_output.to(torch.float) + activation_output = activation_output.to("float32") args.c_global_sf = c_global_sf elif quant_mode == QuantMode.FP4_MXFP4_MXFP8: activation_output, scale_bytes = mxfp8_quantize( - activation_output.to(torch.bfloat16), True + activation_output.to("bfloat16"), True ) - scale_bytes = scale_bytes.view(torch.uint8).reshape(-1).cpu() + scale_bytes = scale_bytes.view("uint8").reshape(-1).cpu() activation_output = ( - mxfp8_dequantize_host( - activation_output.cpu().view(torch.uint8), scale_bytes - ) + mxfp8_dequantize_host(activation_output.cpu().view("uint8"), scale_bytes) .cuda() - .to(torch.float) + .to("float32") ) args.c_global_sf = 1.0 - else: # mxfp4Bf16 - activation_output = activation_output.to(torch.bfloat16).to(torch.float) + else: + activation_output = activation_output.to("bfloat16").to("float32") args.c_global_sf = 1.0 - - # Gemm2 - gemm2_output = torch.full( - (total_num_padded_tokens, args.hidden_size), float("nan"), device="cuda" - ).to(torch.float) + gemm2_output = paddle.full( + shape=(total_num_padded_tokens, args.hidden_size), fill_value=float("nan") + ).to("float32") i = 0 for expert_idx in range(args.num_experts): my_num_tokens = num_tokens_per_expert[expert_idx] @@ -1541,14 +1317,12 @@ def run_moe_dequant(args, quant_mode: QuantMode): gemm2_output[i : i + my_num_tokens] = my_c i += my_num_tokens i = (i + args.padding - 1) // args.padding * args.padding - - # Finalize - expert_weight = args.permute_info["topKLogits"].to(torch.float) - finalize_output = torch.full( - (args.num_tokens, args.hidden_size), float("nan"), device="cuda" - ).to(torch.float) + expert_weight = args.permute_info["topKLogits"].to("float32") + finalize_output = paddle.full( + shape=(args.num_tokens, args.hidden_size), fill_value=float("nan") + ).to("float32") for i in range(args.num_tokens): - acc = torch.zeros(args.hidden_size, dtype=torch.float, device="cuda") + acc = paddle.zeros(shape=args.hidden_size, dtype="float32") for top_k_idx in range(args.top_k): expanded_idx = i * args.top_k + top_k_idx permuted_idx = expanded_idx_to_permuted_idx[expanded_idx] @@ -1563,33 +1337,26 @@ def run_moe_dequant(args, quant_mode: QuantMode): return finalize_output -# ==================================================================================== -# Quantization-Specific Reference Implementations -# ==================================================================================== - - def run_moe_reference_fp4(args, quant_mode: QuantMode): sf_vec_size = 16 if quant_mode == QuantMode.FP4_NVFP4_NVFP4 else 32 ufp8_type_weights = 1 if quant_mode == QuantMode.FP4_NVFP4_NVFP4 else 0 - if quant_mode == QuantMode.FP4_NVFP4_NVFP4: hidden_states_dequant = e2m1_and_ufp8sf_scale_to_float( args.hidden_states.cpu(), - args.hidden_states_scale.cpu().view(torch.uint8).reshape(-1), + args.hidden_states_scale.cpu().view("uint8").reshape(-1), (1 / args.hidden_states_scale_global).cpu(), sf_vec_size, ufp8_type_weights, - True, # is_sf_swizzled_layout + True, ).cuda() elif quant_mode == QuantMode.FP4_MXFP4_MXFP8: hidden_states_dequant = mxfp8_dequantize_host( - args.hidden_states.cpu().view(torch.uint8), - args.hidden_states_scale.cpu().view(torch.uint8).reshape(-1), - True, # is_sf_swizzled_layout + args.hidden_states.cpu().view("uint8"), + args.hidden_states_scale.cpu().view("uint8").reshape(-1), + True, ).cuda() else: - hidden_states_dequant = args.hidden_states.to(torch.bfloat16).to(torch.float) - + hidden_states_dequant = args.hidden_states.to("bfloat16").to("float32") gemm1_weights_dequant = e2m1_and_ufp8_scale_batches( args.gemm1_weights, args.gemm1_scales, @@ -1597,7 +1364,6 @@ def run_moe_reference_fp4(args, quant_mode: QuantMode): sf_vec_size, ufp8_type_weights, ).cuda() - gemm2_weights_dequant = e2m1_and_ufp8_scale_batches( args.gemm2_weights, args.gemm2_scales, @@ -1605,7 +1371,6 @@ def run_moe_reference_fp4(args, quant_mode: QuantMode): sf_vec_size, ufp8_type_weights, ).cuda() - args_dequant = moe_args_dequant( args.num_tokens, args.num_experts, @@ -1621,34 +1386,27 @@ def run_moe_reference_fp4(args, quant_mode: QuantMode): args.use_routing_scales_on_input, args.gated_act_type, ) - return run_moe_dequant(args_dequant, quant_mode), args_dequant def run_moe_reference_dsfp8(args): """FP8 block-scale reference implementation.""" - # Generate block scales at runtime for FP8 block scaling - # Use deterministic scales for testing consistency - hidden_states_scale = 2.0 * torch.ones( - (args.hidden_size // 128, args.num_tokens), device="cuda", dtype=torch.float + hidden_states_scale = 2.0 * paddle.ones( + shape=(args.hidden_size // 128, args.num_tokens), dtype="float32" ) - hidden_states_dequant = dequant_reference_dsfp8( args.hidden_states, hidden_states_scale, True, False, True ) - gemm1_weights_dequant = {} for i in range(args.num_experts): gemm1_weights_dequant[i] = dequant_reference_dsfp8( args.gemm1_weights[i], args.gemm1_scales[i], False, True, True ) - gemm2_weights_dequant = {} for i in range(args.num_experts): gemm2_weights_dequant[i] = dequant_reference_dsfp8( args.gemm2_weights[i], args.gemm2_scales[i], False, True, True ) - args_dequant = moe_args_dequant( args.num_tokens, args.num_experts, @@ -1662,30 +1420,26 @@ def run_moe_reference_dsfp8(args): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - GatedActType.SwiGlu.value, # gated_act_type + GatedActType.SwiGlu.value, ) - return run_moe_dequant(args_dequant, QuantMode.FP8_BLOCK_SCALE), args_dequant def run_moe_reference_per_tensor_scale_fp8(args): """FP8 per-tensor reference implementation.""" hidden_states_dequant = ( - args.hidden_states.to(torch.float) / args.hidden_states_scale_global + args.hidden_states.to("float32") / args.hidden_states_scale_global ) - gemm1_weights_dequant = {} for i in range(args.num_experts): gemm1_weights_dequant[i] = ( - args.gemm1_weights[i].to(torch.float) / args.gemm1_scales_global[i] + args.gemm1_weights[i].to("float32") / args.gemm1_scales_global[i] ) - gemm2_weights_dequant = {} for i in range(args.num_experts): gemm2_weights_dequant[i] = ( - args.gemm2_weights[i].to(torch.float) / args.gemm2_scales_global[i] + args.gemm2_weights[i].to("float32") / args.gemm2_scales_global[i] ) - args_dequant = moe_args_dequant( args.num_tokens, args.num_experts, @@ -1699,15 +1453,13 @@ def run_moe_reference_per_tensor_scale_fp8(args): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - GatedActType.SwiGlu.value, # gated_act_type + GatedActType.SwiGlu.value, ) - return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): """Unified actual computation that delegates to implementation-specific methods.""" - # 1. Prepare static weights for the kernel (offline processing) static_data = moe_impl.prepare_static_weights_for_kernel( args_dequant, args, @@ -1718,8 +1470,6 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): args.num_experts, kwargs["weight_processing"], ) - - # 2. Call MoE with runtime input quantization + kernel execution kernel_kwargs = { "expert_logits": kwargs["expert_logits"], "routing_bias": kwargs["routing_bias"], @@ -1736,7 +1486,6 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): "do_finalize": True, "gated_act_type": args.gated_act_type, } - return moe_impl.call_moe( static_data, kwargs["hidden_states_orig"], @@ -1746,20 +1495,15 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): def calculate_tile_tokens_dim(num_tokens: int, num_experts: int, top_k: int) -> int: - # Guess tokens per expert assuming perfect expert distribution first. num_tokens_per_expert = num_tokens * top_k // num_experts - - # And pad the number to the next power of 2. tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - return tile_tokens_dim @pytest.fixture(scope="module") def cache_permute_indices(): - _cache_permute_indices: Dict[torch.Size, torch.Tensor] = {} + _cache_permute_indices: Dict[list, paddle.Tensor] = {} return _cache_permute_indices @@ -1789,10 +1533,7 @@ def cache_permute_indices(): "routed_scaling": 2.5, "has_routing_bias": True, "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [ - FP4Moe, - FP8BlockScaleMoe, - ], + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], }, id="DSv3", ), @@ -1806,10 +1547,7 @@ def cache_permute_indices(): "routed_scaling": 2.5, "has_routing_bias": True, "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [ - FP4Moe, - FP8BlockScaleMoe, - ], + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], }, id="DSLite", ), @@ -1929,25 +1667,21 @@ def test_moe_quantization_classes( Each quantization class clearly shows which precision is being used. """ - # Skip incompatible combinations if gated_act_type == GatedActType.GeGlu and ( type(moe_impl) is not FP4Moe or moe_impl.quant_mode != QuantMode.FP4_NVFP4_NVFP4 or routing_config["routing_method_type"] != RoutingMethodType.TopK or num_tokens > 128 ): - # GeGlu is only supported for FP4Moe FP4_NVFP4_NVFP4 and TopK routing pytest.skip( f"Incompatible: {moe_impl.name} + {gated_act_type} + {routing_config['routing_method_type']} + {num_tokens}" ) elif gated_act_type == GatedActType.SwiGlu and ( hidden_size > 1024 or intermediate_size > 1024 ): - # Skip some tests for SwiGlu for testing speed pytest.skip( f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}" ) - if type(moe_impl) not in routing_config["compatible_moe_impls"]: pytest.skip( f"Incompatible: {moe_impl.name} + {routing_config['routing_method_type'].name}" @@ -1956,13 +1690,9 @@ def test_moe_quantization_classes( pytest.skip( f"Incompatible: {moe_impl.name} + {weight_processing['use_shuffled_weight']} + {weight_processing['layout']}" ) - moe_impl._cache_permute_indices = cache_permute_indices - seed = 0 - torch.random.manual_seed(seed) - - # Extract routing configuration + paddle.seed(seed=seed) top_k = routing_config["top_k"] padding = routing_config["padding"] n_groups = routing_config["n_groups"] @@ -1971,52 +1701,30 @@ def test_moe_quantization_classes( num_experts = routing_config["num_experts"] routing_method_type = routing_config["routing_method_type"] tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k) - - # Validation checks assert top_k <= num_experts assert top_k <= 8 - if (top_k_groups is not None) and (n_groups is not None) and (n_groups > 0): + if top_k_groups is not None and n_groups is not None and n_groups > 0: assert top_k_groups <= 4 assert num_experts > n_groups assert num_experts % n_groups == 0 assert num_experts % 4 == 0 - assert top_k < (top_k_groups * num_experts / n_groups) - - # Create test data based on routing method and quantization mode - # Different kernels have different dtype requirements for routing logits + assert top_k < top_k_groups * num_experts / n_groups if routing_method_type == RoutingMethodType.DeepSeekV3: - # DeepSeekV3 uses float for routing logits - expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( - torch.float - ) + expert_logits = paddle.randn(shape=(num_tokens, num_experts)).to("float32") else: - # Other routing methods (Renormalize, RenormalizeNaive, Llama4) use bfloat16 - expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( - torch.bfloat16 - ) - + expert_logits = paddle.randn(shape=(num_tokens, num_experts)).to("bfloat16") if routing_config["has_routing_bias"]: - routing_bias = torch.randn(num_experts, device="cuda", dtype=torch.bfloat16) + routing_bias = paddle.randn(shape=num_experts, dtype="bfloat16") else: routing_bias = None - - hidden_states = 2 * torch.randn( - (num_tokens, hidden_size), device="cuda", dtype=torch.bfloat16 + hidden_states = 2 * paddle.randn(shape=(num_tokens, hidden_size), dtype="bfloat16") + gemm1_weights = paddle.randn( + shape=(num_experts, 2 * intermediate_size, hidden_size), dtype="bfloat16" ) - gemm1_weights = torch.randn( - (num_experts, 2 * intermediate_size, hidden_size), - device="cuda", - dtype=torch.bfloat16, + gemm2_weights = paddle.randn( + shape=(num_experts, hidden_size, intermediate_size), dtype="bfloat16" ) - gemm2_weights = torch.randn( - (num_experts, hidden_size, intermediate_size), - device="cuda", - dtype=torch.bfloat16, - ) - - # Generate routing info use_routing_scales_on_input = routing_method_type == RoutingMethodType.Llama4 - if routing_method_type == RoutingMethodType.DeepSeekV3: permute_info, scores = routing_reference_no_aux( expert_logits, @@ -2055,21 +1763,13 @@ def test_moe_quantization_classes( raise NotImplementedError( f"Routing method {routing_method_type} not implemented" ) - - # 1. Quantize weights offline (static, done once) + compute global scale factors weights_data = moe_impl.quantize_weights( gemm1_weights, gemm2_weights, hidden_states ) - - # 2. Quantize inputs at runtime (dynamic, done per inference) using pre-computed scales inputs_data = moe_impl.quantize_inputs( hidden_states, weights_data["hidden_states_scale_global"] ) - - # 3. Combine quantized data quant_data = {**weights_data, **inputs_data} - - # Create arguments for reference computation args = moe_args( num_tokens, num_experts, @@ -2091,15 +1791,9 @@ def test_moe_quantization_classes( use_routing_scales_on_input, gated_act_type, ) - - # Compute reference output using the moe_impl output_dequant_reference, args_dequant = moe_impl.compute_reference(args) - - # Validate that reference computation succeeded if output_dequant_reference is None: pytest.fail("Reference computation failed to produce output") - - # Compute actual output using the moe_impl output_dequant_actual = moe_impl.compute_production( args_dequant, args, @@ -2116,8 +1810,6 @@ def test_moe_quantization_classes( weight_processing=weight_processing, enable_pdl=True, ) - - # Compare outputs using moe_impl-specific tolerances tolerances = moe_impl.get_tolerances() check_accuracy( output_dequant_reference, diff --git a/tests/test_trtllm_gen_mla.py b/tests/test_trtllm_gen_mla.py index ad29d77e64..8193043e5f 100644 --- a/tests/test_trtllm_gen_mla.py +++ b/tests/test_trtllm_gen_mla.py @@ -1,19 +1,20 @@ +import sys + + import math +import paddle import pytest -import torch +from flashinfer.paddle_utils import * import flashinfer global_workspace_buffer = None -@pytest.mark.parametrize( - "batch_size", - [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024], -) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024]) @pytest.mark.parametrize("scale", [1.0, 0.5]) -@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [paddle.float8_e4m3fn, "bfloat16"]) @pytest.mark.parametrize("page_size", [32, 64]) @pytest.mark.parametrize("q_len_per_request", [1, 2]) @pytest.mark.parametrize("dynamic_scale", [False]) @@ -21,61 +22,46 @@ def test_trtllm_batch_decode_mla( batch_size: int, scale: float, - dtype: torch.dtype, + dtype: paddle.dtype, page_size: int, q_len_per_request: int, dynamic_scale: bool, enable_pdl: bool, ): - if dynamic_scale and dtype != torch.float8_e4m3fn: + if dynamic_scale and dtype != paddle.float8_e4m3fn: pytest.skip("Dynamic scale is not supported for non-fp8 dtype") - - torch.manual_seed(42) + paddle.seed(seed=42) device = "cuda:0" - - # Fixed max sequence length MAX_SEQ_LEN = 1024 - - # Deepseek attention config (decode-MLA) num_q_heads = 128 qk_nope_head_dim = 128 qk_rope_head_dim = 64 kv_lora_rank = 512 - - # Initialize tensors - query = torch.randn( - batch_size, - q_len_per_request, - num_q_heads, - kv_lora_rank + qk_rope_head_dim, - device=device, + query = paddle.randn( + shape=[ + batch_size, + q_len_per_request, + num_q_heads, + kv_lora_rank + qk_rope_head_dim, + ] ).to(dtype) - num_tokens = MAX_SEQ_LEN * batch_size num_blocks = (num_tokens + page_size - 1) // page_size - - # Sequence lengths and block tables - seq_lens = [torch.randint(1, MAX_SEQ_LEN, (1,)).item() for _ in range(batch_size)] + seq_lens = [ + paddle.randint(low=1, high=MAX_SEQ_LEN, shape=(1,)).item() + for _ in range(batch_size) + ] seq_lens[-1] = MAX_SEQ_LEN max_seq_len = max(seq_lens) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) - + seq_lens_tensor = paddle.to_tensor(data=seq_lens, dtype="int32", place=device) blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size - max_num_blocks_per_seq = blocks_per_seq.max().item() - - # Generate random but unique block IDs for all sequences + max_num_blocks_per_seq = blocks_per_seq._max().item() total_blocks_needed = sum(blocks_per_seq) - all_block_ids = torch.randperm( - total_blocks_needed, device=device - ) # Random permutation - - # Generate unique block IDs for all sequences + all_block_ids = paddle.randperm(n=total_blocks_needed) block_id = 0 - block_tables = torch.zeros( - (batch_size, max_num_blocks_per_seq), dtype=torch.int, device=device + block_tables = paddle.zeros( + shape=(batch_size, max_num_blocks_per_seq), dtype="int32" ) - - # Populate block tables and track block assignments block_id = 0 for i in range(batch_size): num_blocks_needed = blocks_per_seq[i] @@ -83,42 +69,30 @@ def test_trtllm_batch_decode_mla( block_id : block_id + num_blocks_needed ] block_id += num_blocks_needed - - # Create interleaved KV cache - # Allocate more than needed blocks, block_id is just enough, to mimick real-world cases - kv_cache = torch.randn( - size=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim), device=device + kv_cache = paddle.randn( + shape=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim) ).to(dtype) - # (num_blocks, 1, page_size, kv_lora_rank + qk_rope_head_dim) - - # Allocate workspace buffer - # todo(Yingyi): calculate the actual size of workspace buffer global global_workspace_buffer if global_workspace_buffer is None: - global_workspace_buffer = torch.zeros( - 128 * 1024 * 1024, dtype=torch.int8, device=device - ) + global_workspace_buffer = paddle.zeros(shape=128 * 1024 * 1024, dtype="int8") workspace_buffer = global_workspace_buffer - bmm1_log2_scale_tensor = ( - torch.tensor( - [scale / ((128 + 64) ** 0.5 * math.log2(math.e))], - dtype=torch.float32, - device=device, + paddle.to_tensor( + data=[scale / ((128 + 64) ** 0.5 * math.log2(math.e))], + dtype="float32", + place=device, ) if dynamic_scale else None ) bmm2_scale_tensor = ( - torch.tensor([1.0], dtype=torch.float32, device=device) + paddle.to_tensor(data=[1.0], dtype="float32", place=device) if dynamic_scale else None ) - - # Run decode-MLA output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, - kv_cache=kv_cache.unsqueeze(1), + kv_cache=kv_cache.unsqueeze(axis=1), workspace_buffer=workspace_buffer, qk_nope_head_dim=qk_nope_head_dim, kv_lora_rank=kv_lora_rank, @@ -126,38 +100,26 @@ def test_trtllm_batch_decode_mla( block_tables=block_tables, seq_lens=seq_lens_tensor, max_seq_len=max_seq_len, - bmm1_scale=scale / ((128 + 64) ** 0.5), + bmm1_scale=scale / (128 + 64) ** 0.5, bmm2_scale=1.0, bmm1_scale_log2_tensor=bmm1_log2_scale_tensor, bmm2_scale_tensor=bmm2_scale_tensor, enable_pdl=enable_pdl, ) - - # Run reference attention and align output - sm_scale = scale / ( - (128 + 64) ** 0.5 - ) # use head dimension before matrix absorption - workspace_buffer_ref = torch.empty( - 128 * 1024 * 1024, dtype=torch.int8, device=device - ) + sm_scale = scale / (128 + 64) ** 0.5 + workspace_buffer_ref = paddle.empty(shape=128 * 1024 * 1024, dtype="int8") wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( - workspace_buffer_ref, - backend="fa2", + workspace_buffer_ref, backend="fa2" ) - - if dtype == torch.float8_e4m3fn: - # convert query and kv_cache to bfloat16 - query = query.to(torch.bfloat16) - kv_cache = kv_cache.to(torch.bfloat16) - + if dtype == paddle.float8_e4m3fn: + query = query.to("bfloat16") + kv_cache = kv_cache.to("bfloat16") q_indptr = ( - torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) - * q_len_per_request + paddle.arange(start=0, end=batch_size + 1, dtype="int32") * q_len_per_request ) - kv_indptr = torch.zeros_like(q_indptr) - kv_indptr[1:] = torch.cumsum(blocks_per_seq, dim=0) - kv_indices = all_block_ids.int() - + kv_indptr = paddle.zeros_like(x=q_indptr) + kv_indptr[1:] = paddle.cumsum(x=blocks_per_seq, axis=0) + kv_indices = all_block_ids.astype(dtype="int32") wrapper.plan( q_indptr, kv_indptr, @@ -178,37 +140,31 @@ def test_trtllm_batch_decode_mla( q_pe = query[..., kv_lora_rank:].view( batch_size * q_len_per_request, num_q_heads, qk_rope_head_dim ) - - # todo: fix kv_cache ckv = kv_cache[..., :kv_lora_rank] kpe = kv_cache[..., kv_lora_rank:] - o_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) - - # check is nan - assert not torch.isnan(o_ref).any(), "o_ref is nan" - assert not torch.isnan(output).any(), "output is nan" - - if dtype == torch.float8_e4m3fn: + assert not paddle.isnan(x=o_ref).astype("bool").any(), "o_ref is nan" + assert not paddle.isnan(x=output).astype("bool").any(), "output is nan" + if dtype == paddle.float8_e4m3fn: try: - torch.testing.assert_close( - output, - o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), - rtol=1e-1, - atol=1e-1, - ) # todo: do reference with normal attention? + assert paddle.allclose( + x=output, + y=o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), + rtol=0.1, + atol=0.1, + ).item(), "" except AssertionError as e: print("output:", output) print("o_ref:", o_ref) raise e else: try: - torch.testing.assert_close( - output, - o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), - rtol=1e-2, - atol=1e-2, - ) + assert paddle.allclose( + x=output, + y=o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), + rtol=0.01, + atol=0.01, + ).item(), "" except AssertionError as e: print("output:", output) print("o_ref:", o_ref) diff --git a/tests/test_trtllm_mnnvl_allreduce.py b/tests/test_trtllm_mnnvl_allreduce.py index a1a18bb228..7dcc672a0c 100644 --- a/tests/test_trtllm_mnnvl_allreduce.py +++ b/tests/test_trtllm_mnnvl_allreduce.py @@ -1,43 +1,42 @@ -# Check torch version: +import sys + + import sys from typing import Tuple +import paddle import pytest -import torch -from mpi4py import MPI # Added MPI import +from mpi4py import MPI +from flashinfer.paddle_utils import * import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping - -# Use flashinfer.norm.rmsnorm as reference implementation. from flashinfer.norm import rmsnorm -@torch.inference_mode() +@paddle.no_grad() def row_linear_residual_norm_fusion_forward( - x: torch.Tensor, - residual: torch.Tensor, - norm_weight: torch.Tensor, + x: paddle.Tensor, + residual: paddle.Tensor, + norm_weight: paddle.Tensor, eps: float, hidden_size: int, - dtype: torch.dtype, + dtype: paddle.dtype, mapping: Mapping, fusion: bool, - reference_output: tuple[torch.Tensor, ...], + reference_output: tuple[paddle.Tensor, ...], multicast_ptr: int, buffer_ptrs_dev: int, unicast_ptr: int, max_num_elements_mnnvl: int, - buffer_flags_mnnvl: torch.Tensor, + buffer_flags_mnnvl: paddle.Tensor, ): x = x.cuda() residual = residual.cuda() norm_weight = norm_weight.cuda() reference_output = tuple(t.cuda() for t in reference_output) - tensor_parallel_size = mapping.tp_size tensor_parallel_rank = mapping.tp_rank - MPI.COMM_WORLD.barrier() def func( @@ -51,23 +50,15 @@ def func( unicast_ptr, max_num_elements_mnnvl, ): - # For both fused and unfused cases: - shape = input.shape - + shape = tuple(input.shape) assert max_num_elements_mnnvl % hidden_size == 0 - input = input.view(-1, shape[-1]) - buffer_M = max_num_elements_mnnvl // hidden_size - if enable_fusion: use_pdl = True - - prenorm_output = torch.empty_like(residual) - normed_output = torch.empty_like(residual) - + prenorm_output = paddle.empty_like(x=residual) + normed_output = paddle.empty_like(x=residual) trtllm_mnnvl_ar.mpi_barrier() - trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm( prenorm_output, normed_output, @@ -84,12 +75,9 @@ def func( residual, use_pdl, ) - return normed_output.view(shape), prenorm_output.view(shape) - else: - output = torch.empty_like(input) - + output = paddle.empty_like(x=input) trtllm_mnnvl_ar.trtllm_mnnvl_all_reduce( input, multicast_ptr, @@ -98,9 +86,9 @@ def func( buffer_flags_mnnvl, tensor_parallel_size, tensor_parallel_rank, - True, # wait_for_results - False, # launch_with_pdl - output, # Need to provide output tensor since we are writing them out. + True, + False, + output, ) return (output.view(shape),) @@ -115,157 +103,106 @@ def func( unicast_ptr, max_num_elements_mnnvl, ) - - assert output[0].shape == reference_output[0].shape - + assert tuple(output[0].shape) == tuple(reference_output[0].shape) if tensor_parallel_rank == 0: print("output[0] (first 10 values):", output[0].flatten()[:10]) print( - "reference_output[0] (first 10 values):", - reference_output[0].flatten()[:10], + "reference_output[0] (first 10 values):", reference_output[0].flatten()[:10] ) - if fusion: print("output[1] (first 10 values):", output[1].flatten()[:10]) print( "reference_output[1] (first 10 values):", reference_output[1].flatten()[:10], ) - - torch.testing.assert_close( - output[0], - reference_output[0], - rtol=0.05, - atol=0.15, - ) - + assert paddle.allclose( + x=output[0], y=reference_output[0], rtol=0.05, atol=0.15 + ).item(), "" if fusion: - torch.testing.assert_close( - output[1], - reference_output[1], - rtol=0.05, - atol=0.15, - ) + assert paddle.allclose( + x=output[1], y=reference_output[1], rtol=0.05, atol=0.15 + ).item(), "" """Main test function that runs on each MPI rank""" -@pytest.mark.parametrize( - "seq_lens", - [ - [1], - [4], - [15], - [27, 11, 24], - [127], - ], -) # Test with different sequence length lists +@pytest.mark.parametrize("seq_lens", [[1], [4], [15], [27, 11, 24], [127]]) @pytest.mark.parametrize("fusion", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) def test_mnnvl_allreduce_full( - monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int + monkeypatch, + seq_lens: list[int], + fusion: bool, + dtype: paddle.dtype, + hidden_size: int, ): - monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. - - # Get MPI info + monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") rank = MPI.COMM_WORLD.Get_rank() world_size = MPI.COMM_WORLD.Get_size() - gpus_per_node = torch.cuda.device_count() - + gpus_per_node = paddle.device.cuda.device_count() if gpus_per_node == 0: pytest.skip("MNNVL allreduce test requires at least one CUDA device per node") - - # Ensure we have exactly 2 ranks for this test if world_size < 2: if rank == 0: print(f"ERROR: This test requires at least 2 MPI ranks, got {world_size}") sys.exit(1) - mapping = Mapping( world_size=world_size, rank=rank, gpus_per_node=gpus_per_node, tp_size=world_size, ) - - # Set CUDA device based on rank - torch.cuda.set_device(mapping.local_rank) - + paddle.device.set_device(device=device2str(mapping.local_rank)) if mapping.local_rank == 0: print( f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks" ) print( - f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" + f"[Node {mapping.node_rank}] Rank {rank} using GPU {paddle.device.get_device()}" ) - tensor_parallel_size = world_size - eps = 1e-5 - torch.manual_seed(42) - - # Track if this rank failed + eps = 1e-05 + paddle.seed(seed=42) rank_failed = False failure_message = "" - try: - # Get workspace buffers using MPI rank - allocate once per seq_lens list and reuse within the list - # This workspace is sized for the maximum expected sequence length and can be reused within each list - # Each parameterized list gets its own fresh workspace allocation - mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( - trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(mapping, dtype) - ) - + ( + mcast_buffer_mnnvl, + buffer_flags_mnnvl, + max_num_elements_mnnvl, + ) = trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(mapping, dtype) multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( mapping.tp_rank ) - - # Test each sequence length with the same workspace (reusing allocated buffers within this list) for seq_len in seq_lens: if rank == 0: print( f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" ) - - # Generate test data (same on all ranks due to same seed) - x_full = torch.randn( - (tensor_parallel_size, seq_len, hidden_size), - dtype=dtype, - device=torch.device("cuda"), - ) - residual = torch.randn( - (seq_len, hidden_size), dtype=dtype, device=torch.device("cuda") + x_full = paddle.randn( + shape=(tensor_parallel_size, seq_len, hidden_size), dtype=dtype ) - norm_weight = torch.randn( - (hidden_size,), dtype=dtype, device=torch.device("cuda") - ) - - # Each rank gets its slice of the input + residual = paddle.randn(shape=(seq_len, hidden_size), dtype=dtype) + norm_weight = paddle.randn(shape=(hidden_size,), dtype=dtype) x = x_full[rank, :, :] - - # Compute reference output based on fusion mode - reference_output: Tuple[torch.Tensor, ...] = None + reference_output: Tuple[paddle.Tensor, ...] = None if fusion: - # Fused case: AllReduce + Residual Add + RMS Norm - allreduce_result = torch.sum(x_full, dim=0) # AllReduce result - residual_out = allreduce_result + residual # Add residual + allreduce_result = paddle.sum(x=x_full, axis=0) + residual_out = allreduce_result + residual print( "Device of residual_out:{}, norm_weight:{}".format( - residual_out.device, norm_weight.device + residual_out.place, norm_weight.place ) ) norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False) - - reference_output = (norm_out, residual_out) + reference_output = norm_out, residual_out else: - # Non-fused case: Only AllReduce - allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + allreduce_result = paddle.sum(x=x_full, axis=0) reference_output = (allreduce_result,) - - # Run the test with the same workspace row_linear_residual_norm_fusion_forward( x, residual, @@ -282,35 +219,22 @@ def test_mnnvl_allreduce_full( max_num_elements_mnnvl, buffer_flags_mnnvl, ) - - # Synchronize before next test trtllm_mnnvl_ar.mpi_barrier() - print( f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}" ) - except Exception as e: rank_failed = True failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}" print(failure_message) - # Gather failure status from all ranks all_failures = MPI.COMM_WORLD.allgather(rank_failed) - - # If any rank failed, fail the test if any(all_failures): failed_ranks = [i for i, failed in enumerate(all_failures) if failed] if rank == 0: print(f"Test failed on ranks: {failed_ranks}") - - # Fail the test on all ranks pytest.fail(f"Test failed on ranks {failed_ranks}") trtllm_mnnvl_ar.mpi_barrier() - finally: - # Ensure cleanup happens for this list's workspace if "mcast_buffer_mnnvl" in locals(): del mcast_buffer_mnnvl - - # Final synchronization and check for failures across all ranks trtllm_mnnvl_ar.mpi_barrier() diff --git a/tests/test_trtllm_moe_allreduce_fusion.py b/tests/test_trtllm_moe_allreduce_fusion.py index 59e65c108c..baef0cc565 100644 --- a/tests/test_trtllm_moe_allreduce_fusion.py +++ b/tests/test_trtllm_moe_allreduce_fusion.py @@ -1,98 +1,75 @@ +import sys + + import multiprocessing as mp import socket from typing import Any import numpy as np +import paddle import pytest -import torch -import torch.distributed as dist +from flashinfer.paddle_utils import * import flashinfer.comm as comm -# todo(Yingyi): add benchmark and quant test - -# Usage: test var kOneShotMaxTokenNum = 128 MAX_TOKEN_NUM = 2048 HIDDEN_SIZE = 7168 MAX_EXPERT_NUM = 16 SF_VEC_SIZE = 16 - -# temp var -SCALE_FACTOR_RANGE = (-1, 1) +SCALE_FACTOR_RANGE = -1, 1 def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) + device = device2str(f"cuda:{rank}") + paddle.device.set_device(device=device2str(device)) distributed_init_method = f"tcp://localhost:{distributed_init_port}" - dist.init_process_group( - backend="nccl", - init_method=distributed_init_method, - rank=rank, - world_size=world_size, - ) - group = dist.group.WORLD - + paddle.distributed.init_parallel_env() +>>>>>> group = torch.distributed.group.WORLD try: - device = torch.device(f"cuda:{rank}") - token_nums = [ - 1, - 64, - 128, - 256, - 2048, - ] # 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 + device = device2str(f"cuda:{rank}") + token_nums = [1, 64, 128, 256, 2048] candidate_active_expert_num = [8, 12, 16] - # candidate_active_expert_num = [1] # debug-only swizzled_layout_codes = [ comm.QuantizationSFLayout.LINEAR, comm.QuantizationSFLayout.SWIZZLED_128x4, comm.QuantizationSFLayout.SWIZZLED_8x4, ] launch_with_pdls = [True, False] - - # create workspace for moe allreduce fusion - ipc_handles, workspace_tensor = ( - comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( - rank, world_size, MAX_TOKEN_NUM, HIDDEN_SIZE, group=group - ) + ( + ipc_handles, + workspace_tensor, + ) = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, world_size, MAX_TOKEN_NUM, HIDDEN_SIZE, group=group ) - test_loop = 5 - for token_num in token_nums: for active_expert_num in candidate_active_expert_num: for swizzled_layout_code in swizzled_layout_codes: for launch_with_pdl in launch_with_pdls: - dist.barrier(group=group) + paddle.distributed.barrier(group=group) test_passed = True print( f"test RANK {rank}: token{token_num}-expert{active_expert_num}-tp{world_size}-{dtype}-layout{swizzled_layout_code}-pdl{launch_with_pdl} start" ) - dist.barrier(group=group) - torch.cuda.synchronize() + paddle.distributed.barrier(group=group) + paddle.device.synchronize() for _ in range(test_loop): message_size = token_num * HIDDEN_SIZE - - residual_in = torch.randn( - message_size, dtype=dtype, device=device - ) + residual_in = paddle.randn(shape=message_size, dtype=dtype) residual_in_clone = residual_in.clone() - - moe_allreduce_out = torch.zeros( - message_size, dtype=dtype, device=device + moe_allreduce_out = paddle.zeros( + shape=message_size, dtype=dtype ) - residual_out = torch.empty_like(residual_in) - norm_out = torch.empty_like(residual_in) - quant_out = torch.empty( - message_size // 4, dtype=dtype, device=device - ) # quant: fp16/bf16 -> fp4, reference: cpp/tensorrt_llm/thop/allreduceOp.cpp:L487 - - scale_out = None - assert HIDDEN_SIZE % SF_VEC_SIZE == 0, ( - "HIDDEN_SIZE must be divisible by SF_VEC_SIZE" + residual_out = paddle.empty_like(x=residual_in) + norm_out = paddle.empty_like(x=residual_in) + quant_out = paddle.empty( + shape=message_size // 4, dtype=dtype ) + scale_out = None + assert ( + HIDDEN_SIZE % SF_VEC_SIZE == 0 + ), "HIDDEN_SIZE must be divisible by SF_VEC_SIZE" if ( swizzled_layout_code == comm.QuantizationSFLayout.SWIZZLED_128x4 @@ -102,113 +79,82 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): token_num, HIDDEN_SIZE // SF_VEC_SIZE ) ) - scale_out = torch.empty( - padded_message_size, dtype=dtype, device=device + scale_out = paddle.empty( + shape=padded_message_size, dtype=dtype ) else: - scale_out = torch.empty( - message_size // SF_VEC_SIZE, - dtype=dtype, - device=device, + scale_out = paddle.empty( + shape=message_size // SF_VEC_SIZE, dtype=dtype ) - - rms_gamma = torch.randn( - HIDDEN_SIZE, dtype=dtype, device=device - ) + rms_gamma = paddle.randn(shape=HIDDEN_SIZE, dtype=dtype) scale_factor = ( - torch.rand(1, dtype=torch.float32, device=device) + paddle.rand(shape=[1], dtype="float32") * (SCALE_FACTOR_RANGE[1] - SCALE_FACTOR_RANGE[0]) + SCALE_FACTOR_RANGE[0] ) - rms_eps = 1e-3 + rms_eps = 0.001 scale_factor_float = scale_factor.item() - - # init moe params - # [device_num_expert, m] - moe_reduction_scale_input = torch.randn( - active_expert_num * token_num, - dtype=torch.float32, - device=device, + moe_reduction_scale_input = paddle.randn( + shape=active_expert_num * token_num, dtype="float32" ) - moe_reduction_scale_input_clone = ( moe_reduction_scale_input.clone() ) - - # [device_num_expert, m, 7168] - moe_reduction_active_experts_token_input = torch.randn( - active_expert_num * message_size, - dtype=dtype, - device=device, - ) - moe_reduction_active_experts_token_input_clone = ( - moe_reduction_active_experts_token_input.clone() + moe_reduction_active_experts_token_input = paddle.randn( + shape=active_expert_num * message_size, dtype=dtype ) - # [m, 7168] - moe_reduction_token_input = torch.randn( - message_size, dtype=dtype, device=device + ( + moe_reduction_active_experts_token_input_clone + ) = moe_reduction_active_experts_token_input.clone() + moe_reduction_token_input = paddle.randn( + shape=message_size, dtype=dtype ) moe_reduction_token_input_clone = ( moe_reduction_token_input.clone() ) - - # == Calculate reference output == - # 1. MoE Reduction moe_expert_out = ( moe_reduction_active_experts_token_input_clone.view( active_expert_num, token_num, HIDDEN_SIZE - ).to(torch.float32) + ).to("float32") ) moe_scales = moe_reduction_scale_input_clone.view( active_expert_num, token_num - ).to(torch.float32) - moe_scales = moe_scales.unsqueeze( - 2 - ) # [active_expert_num, token_num, 1] + ).to("float32") + moe_scales = moe_scales.unsqueeze(axis=2) scaled_expert_out = moe_expert_out * moe_scales.to( - torch.float32 - ) # [active_expert_num, token_num, HIDDEN_SIZE] - reduced_expert_out = torch.sum( - scaled_expert_out, dim=0 - ) # [token_num, HIDDEN_SIZE] - - # 2. Add FC2 output + "float32" + ) + reduced_expert_out = paddle.sum(x=scaled_expert_out, axis=0) moe_out_ref = ( reduced_expert_out + moe_reduction_token_input_clone.view( token_num, HIDDEN_SIZE - ).to(torch.float32) - ) # [token_num, HIDDEN_SIZE] - - # 3. All-Reduce + ).to("float32") + ) moe_allreduce_ref = moe_out_ref.clone().to(dtype) - dist.all_reduce(moe_allreduce_ref, group=group) - moe_allreduce_ref = moe_allreduce_ref.to(torch.float32) - - # 4. Fused Ops + paddle.distributed.all_reduce( + tensor=moe_allreduce_ref, group=group + ) + moe_allreduce_ref = moe_allreduce_ref.to("float32") ref_residual_out = ( moe_allreduce_ref + residual_in_clone.view(token_num, HIDDEN_SIZE).to( - torch.float32 + "float32" ) ) - variance = ( - ref_residual_out.to(torch.float32) - .pow(2) - .mean(dim=-1, keepdim=True) + ref_residual_out.to("float32") + .pow(y=2) + .mean(axis=-1, keepdim=True) ) - hidden_states = ref_residual_out * torch.rsqrt( - variance + rms_eps + hidden_states = ref_residual_out * paddle.rsqrt( + x=variance + rms_eps ) - ref_norm_out = rms_gamma.to(torch.float32) * hidden_states - - # 5. Run kernel - # warmup - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - for _ in range(3): # Multiple warmup iterations + ref_norm_out = rms_gamma.to("float32") * hidden_states + s = paddle.device.Stream() + s.wait_stream(paddle.device.current_stream()) + with paddle.device.stream_guard(stream=s): + for _ in range(3): comm.trtllm_moe_allreduce_fusion( world_size=world_size, world_rank=rank, @@ -231,13 +177,11 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): quant_out=quant_out, scale_out=scale_out, ) - torch.cuda.current_stream().wait_stream(s) - torch.cuda.synchronize() # Ensure warmup is complete - - # capture - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - for _ in range(3): # Multiple iterations in graph + paddle.device.current_stream().wait_stream(s) + paddle.device.synchronize() +>>>>>> g = torch.cuda.CUDAGraph() +>>>>>> with torch.cuda.graph(g): + for _ in range(3): comm.trtllm_moe_allreduce_fusion( world_size=world_size, world_rank=rank, @@ -260,42 +204,33 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): quant_out=quant_out, scale_out=scale_out, ) - - # replay g.replay() - - # match shape moe_allreduce_out = moe_allreduce_out.view( token_num, HIDDEN_SIZE ) residual_out = residual_out.view(token_num, HIDDEN_SIZE) norm_out = norm_out.view(token_num, HIDDEN_SIZE) - - torch.cuda.synchronize() - - # 6. Check correctness - tolerance = 8e-2 if dtype == torch.float16 else 8e-1 - # 6.1 Check allreduce_out - if not torch.allclose( - moe_allreduce_out.to(torch.float32), - moe_allreduce_ref, + paddle.device.synchronize() + tolerance = 0.08 if dtype == "float16" else 0.8 + if not paddle.allclose( + x=moe_allreduce_out.to("float32"), + y=moe_allreduce_ref, atol=tolerance, - rtol=1e-2, - ): + rtol=0.01, + ).item(): test_passed = False print(f"Rank {rank} moe_allreduce_out mismatch") print(f"moe_allreduce_out: {moe_allreduce_out}") print(f"moe_allreduce_ref: {moe_allreduce_ref}") - # Print max diff elements for allreduce_out - max_diff = torch.max( - torch.abs( - moe_allreduce_out.to(torch.float32) + max_diff = paddle.max( + x=paddle.abs( + x=moe_allreduce_out.to("float32") - moe_allreduce_ref ) ) - max_diff_idx = torch.argmax( - torch.abs( - moe_allreduce_out.to(torch.float32) + max_diff_idx = paddle.argmax( + x=paddle.abs( + x=moe_allreduce_out.to("float32") - moe_allreduce_ref ) ) @@ -311,36 +246,30 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): print( f"Rank {rank} moe_allreduce_out ref value at max diff: {moe_allreduce_ref.view(-1)[max_diff_idx]}" ) - - torch.testing.assert_close( - moe_allreduce_out.to(torch.float32), - moe_allreduce_ref, + assert paddle.allclose( + x=moe_allreduce_out.to("float32"), + y=moe_allreduce_ref, atol=tolerance, - rtol=1e-2, - ) - - # 6.2 Check residual_out - if not torch.allclose( - residual_out.to(torch.float32), - ref_residual_out, + rtol=0.01, + ).item(), "" + if not paddle.allclose( + x=residual_out.to("float32"), + y=ref_residual_out, atol=tolerance, - rtol=1e-2, - ): + rtol=0.01, + ).item(): test_passed = False print(f"Rank {rank} residual_out mismatch") print(f"residual_out: {residual_out}") print(f"ref_residual_out: {ref_residual_out}") - # Print max diff elements for residual_out - max_diff = torch.max( - torch.abs( - residual_out.to(torch.float32) - - ref_residual_out + max_diff = paddle.max( + x=paddle.abs( + x=residual_out.to("float32") - ref_residual_out ) ) - max_diff_idx = torch.argmax( - torch.abs( - residual_out.to(torch.float32) - - ref_residual_out + max_diff_idx = paddle.argmax( + x=paddle.abs( + x=residual_out.to("float32") - ref_residual_out ) ) print(f"Rank {rank} residual_out max diff: {max_diff}") @@ -353,29 +282,31 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): print( f"Rank {rank} residual_out ref value at max diff: {ref_residual_out.view(-1)[max_diff_idx]}" ) - torch.testing.assert_close( - residual_out.to(torch.float32), - ref_residual_out, + assert paddle.allclose( + x=residual_out.to("float32"), + y=ref_residual_out, atol=tolerance, - rtol=1e-2, - ) - # 6.3 Check norm_out - if not torch.allclose( - norm_out.to(torch.float32), - ref_norm_out, + rtol=0.01, + ).item(), "" + if not paddle.allclose( + x=norm_out.to("float32"), + y=ref_norm_out, atol=tolerance, - rtol=1e-2, - ): + rtol=0.01, + ).item(): test_passed = False print(f"Rank {rank} norm_out mismatch") print(f"norm_out: {norm_out}") print(f"ref_norm_out: {ref_norm_out}") - # Print max diff elements for norm_out - max_diff = torch.max( - torch.abs(norm_out.to(torch.float32) - ref_norm_out) + max_diff = paddle.max( + x=paddle.abs( + x=norm_out.to("float32") - ref_norm_out + ) ) - max_diff_idx = torch.argmax( - torch.abs(norm_out.to(torch.float32) - ref_norm_out) + max_diff_idx = paddle.argmax( + x=paddle.abs( + x=norm_out.to("float32") - ref_norm_out + ) ) print(f"Rank {rank} norm_out max diff: {max_diff}") print( @@ -387,17 +318,13 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): print( f"Rank {rank} norm_out ref value at max diff: {ref_norm_out.view(-1)[max_diff_idx]}" ) - - torch.testing.assert_close( - norm_out.to(torch.float32), - ref_norm_out, + assert paddle.allclose( + x=norm_out.to("float32"), + y=ref_norm_out, atol=tolerance, - rtol=1e-2, - ) - # 6.4 Check quant_out - # todo - - dist.barrier(group=group) + rtol=0.01, + ).item(), "" + paddle.distributed.barrier(group=group) if test_passed: print( f"test RANK {rank}: token{token_num}-expert{active_expert_num}-tp{world_size}-{dtype}-layout{swizzled_layout_code}-pdl{launch_with_pdl} passed" @@ -407,11 +334,9 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port): f"test RANK {rank}: token{token_num}-expert{active_expert_num}-tp{world_size}-{dtype}-layout{swizzled_layout_code}-pdl{launch_with_pdl} failed" ) finally: - dist.barrier(group=group) - + paddle.distributed.barrier(group=group) comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group) - - dist.destroy_process_group(group=group) +>>>>>> torch.distributed.destroy_process_group(group=group) def get_open_port() -> int: @@ -426,46 +351,39 @@ def get_open_port() -> int: def multi_process_parallel( - world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = () + world_size: int, dtype: paddle.dtype, test_target: Any, target_args: tuple = () ) -> None: mp.set_start_method("spawn", force=True) - procs = [] distributed_init_port = get_open_port() for i in range(world_size): proc_args = (world_size, i, dtype, distributed_init_port) + target_args proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") - proc.start() + """Not Support auto convert *.start, please judge whether it is Pytorch API and convert by yourself""" +>>>>>> proc.start() procs.append(proc) - for i in range(world_size): procs[i].join() - assert procs[i].exitcode == 0, ( - f"Process {i} failed with exit code {procs[i].exitcode}" - ) + assert ( + procs[i].exitcode == 0 + ), f"Process {i} failed with exit code {procs[i].exitcode}" @pytest.mark.parametrize("world_size", [2, 4]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) def test_trtllm_moe_allreduce_fusion(world_size, dtype): np.random.seed(42) - torch.manual_seed(42) - torch.cuda.manual_seed_all(42) - available_gpus = torch.cuda.device_count() + paddle.seed(seed=42) + paddle.seed(seed=42) + available_gpus = paddle.device.cuda.device_count() if world_size > available_gpus: raise ValueError( f"world_size {world_size} is greater than available_gpus {available_gpus}" ) print(f"Running test for world_size={world_size}") - - multi_process_parallel( - world_size, - dtype, - _run_correctness_worker, - target_args=(), - ) + multi_process_parallel(world_size, dtype, _run_correctness_worker, target_args=()) print(f"moe allreduce fusion tp = {world_size}: OK") if __name__ == "__main__": - test_trtllm_moe_allreduce_fusion(2, torch.float16) + test_trtllm_moe_allreduce_fusion(2, "float16") diff --git a/tests/test_trtllm_moe_allreduce_fusion_finalize.py b/tests/test_trtllm_moe_allreduce_fusion_finalize.py index b59d471103..431a4a0e5a 100644 --- a/tests/test_trtllm_moe_allreduce_fusion_finalize.py +++ b/tests/test_trtllm_moe_allreduce_fusion_finalize.py @@ -1,25 +1,23 @@ +import sys + + import multiprocessing as mp import socket from typing import Any import numpy as np +import paddle import pytest -import torch -import torch.distributed as dist +from flashinfer.paddle_utils import * import flashinfer.comm as comm -# todo(Yingyi): add benchmark and quant test - -# Usage: test var kOneShotMaxTokenNum = 128 MAX_TOKEN_NUM = 2048 HIDDEN_SIZE = 7168 MAX_EXPERT_NUM = 16 SF_VEC_SIZE = 16 - -# temp var -SCALE_FACTOR_RANGE = (-1, 1) +SCALE_FACTOR_RANGE = -1, 1 def _run_correctness_worker( @@ -33,52 +31,40 @@ def _run_correctness_worker( expanded_idx_to_permuted_idx, residual, ): - def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): - y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + def rms_norm(x: paddle.Tensor, weight: paddle.Tensor = None, eps: float = 1e-06): + y = x * paddle.rsqrt(x=x.pow(y=2).mean(axis=-1, keepdim=True) + eps) if weight is not None: y = y * weight return y - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) + device = device2str(f"cuda:{rank}") + paddle.device.set_device(device=device2str(device)) distributed_init_method = f"tcp://localhost:{distributed_init_port}" - dist.init_process_group( - backend="nccl", - init_method=distributed_init_method, - rank=rank, - world_size=world_size, - ) - group = dist.group.WORLD - + paddle.distributed.init_parallel_env() +>>>>>> group = torch.distributed.group.WORLD try: - device = torch.device(f"cuda:{rank}") + device = device2str(f"cuda:{rank}") seq_lens = [16] top_k = 8 - eps = 1e-5 - + eps = 1e-05 launch_with_pdls = [True, False] - - # create workspace for moe allreduce fusion - ipc_handles, workspace_tensor = ( - comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( - rank, world_size, MAX_TOKEN_NUM, HIDDEN_SIZE, group=group - ) + ( + ipc_handles, + workspace_tensor, + ) = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, world_size, MAX_TOKEN_NUM, HIDDEN_SIZE, group=group ) - test_loop = 5 - for seq_len in seq_lens: for launch_with_pdl in launch_with_pdls: - dist.barrier(group=group) + paddle.distributed.barrier(group=group) test_passed = True print( f"test RANK {rank}: seq_len{seq_len}-topk{top_k}-tp{world_size}-{dtype}-pdl{launch_with_pdl} start" ) - dist.barrier(group=group) - torch.cuda.synchronize() + paddle.distributed.barrier(group=group) + paddle.device.synchronize() for _ in range(test_loop): - # == Generate input == - # move to local device shared_expert_output = shared_expert_output.to(device) fc2_output = fc2_output.to(device) scale = scale.to(device) @@ -86,22 +72,14 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): device ) residual = residual.to(device) - - # make clone fc2_output_clone = fc2_output.clone() - norm_weight = torch.randn( - (HIDDEN_SIZE,), dtype=dtype, device=device - ) - - norm_out = torch.empty_like(residual) - residual_out = torch.empty_like(residual) - - # == Run kernel == - torch.cuda.synchronize() - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - # warmup - with torch.cuda.stream(s): + norm_weight = paddle.randn(shape=(HIDDEN_SIZE,), dtype=dtype) + norm_out = paddle.empty_like(x=residual) + residual_out = paddle.empty_like(x=residual) + paddle.device.synchronize() + s = paddle.device.Stream() + s.wait_stream(paddle.device.current_stream()) + with paddle.device.stream_guard(stream=s): for _ in range(test_loop): comm.trtllm_moe_finalize_allreduce_fusion( allreduce_in=fc2_output, @@ -118,11 +96,9 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): norm_out=norm_out, residual_out=residual_out, ) - torch.cuda.current_stream().wait_stream(s) - - # capture - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): + paddle.device.current_stream().wait_stream(s) +>>>>>> g = torch.cuda.CUDAGraph() +>>>>>> with torch.cuda.graph(g): for _ in range(test_loop): comm.trtllm_moe_finalize_allreduce_fusion( allreduce_in=fc2_output, @@ -139,58 +115,49 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): norm_out=norm_out, residual_out=residual_out, ) - - # replay g.replay() - - torch.cuda.synchronize() - - # == Calculate reference output == - expert_reduction = torch.sum( - fc2_output_clone[expanded_idx_to_permuted_idx] - * scale.unsqueeze(-1), - dim=1, + paddle.device.synchronize() + expert_reduction = paddle.sum( + x=fc2_output_clone[expanded_idx_to_permuted_idx] + * scale.unsqueeze(axis=-1), + axis=1, ) - torch_before_residual = ( expert_reduction + shared_expert_output ) * world_size torch_residual = torch_before_residual + residual - torch_residual = torch_residual.to(torch.float32) + torch_residual = torch_residual.to("float32") torch_output_hidden_states = rms_norm( torch_residual, norm_weight, eps ).to(dtype) - - # == Check correctness == - if not torch.allclose( - residual_out.to(torch.float32), - torch_residual.to(torch.float32), + if not paddle.allclose( + x=residual_out.to("float32"), + y=torch_residual.to("float32"), rtol=0.2, atol=0.2, - ): + ).item(): test_passed = False print(f"Rank {rank} residual_out mismatch") print(f"residual_out: {residual_out}") print(f"torch_residual: {torch_residual}") print( - f"max diff: {torch.max(torch.abs(residual_out.to(torch.float32) - torch_residual.to(torch.float32)))}" + f"max diff: {paddle.max(x=paddle.abs(x=residual_out.to('float32') - torch_residual.to('float32')))}" ) print( - f"max diff idx: {torch.argmax(torch.abs(residual_out.to(torch.float32) - torch_residual.to(torch.float32)))}" + f"max diff idx: {paddle.argmax(x=paddle.abs(x=residual_out.to('float32') - torch_residual.to('float32')))}" ) print( - f"max diff value: {residual_out.to(torch.float32).view(-1)[torch.argmax(torch.abs(residual_out.to(torch.float32) - torch_residual.to(torch.float32)))]}" + f"max diff value: {residual_out.to('float32').view(-1)[paddle.argmax(x=paddle.abs(x=residual_out.to('float32') - torch_residual.to('float32')))]}" ) print( - f"max diff ref value: {torch_residual.to(torch.float32).view(-1)[torch.argmax(torch.abs(residual_out.to(torch.float32) - torch_residual.to(torch.float32)))]}" + f"max diff ref value: {torch_residual.to('float32').view(-1)[paddle.argmax(x=paddle.abs(x=residual_out.to('float32') - torch_residual.to('float32')))]}" ) - - if not torch.allclose( - norm_out.to(torch.float32), - torch_output_hidden_states.to(torch.float32), + if not paddle.allclose( + x=norm_out.to("float32"), + y=torch_output_hidden_states.to("float32"), rtol=0.2, atol=0.2, - ): + ).item(): test_passed = False print(f"Rank {rank} norm_out mismatch") print(f"norm_out: {norm_out}") @@ -198,32 +165,30 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): f"torch_output_hidden_states: {torch_output_hidden_states}" ) print( - f"max diff: {torch.max(torch.abs(norm_out.to(torch.float32) - torch_output_hidden_states.to(torch.float32)))}" + f"max diff: {paddle.max(x=paddle.abs(x=norm_out.to('float32') - torch_output_hidden_states.to('float32')))}" ) print( - f"max diff idx: {torch.argmax(torch.abs(norm_out.to(torch.float32) - torch_output_hidden_states.to(torch.float32)))}" + f"max diff idx: {paddle.argmax(x=paddle.abs(x=norm_out.to('float32') - torch_output_hidden_states.to('float32')))}" ) print( - f"max diff value: {norm_out.to(torch.float32).view(-1)[torch.argmax(torch.abs(norm_out.to(torch.float32) - torch_output_hidden_states.to(torch.float32)))]}" + f"max diff value: {norm_out.to('float32').view(-1)[paddle.argmax(x=paddle.abs(x=norm_out.to('float32') - torch_output_hidden_states.to('float32')))]}" ) print( - f"max diff ref value: {torch_output_hidden_states.to(torch.float32).view(-1)[torch.argmax(torch.abs(norm_out.to(torch.float32) - torch_output_hidden_states.to(torch.float32)))]}" + f"max diff ref value: {torch_output_hidden_states.to('float32').view(-1)[paddle.argmax(x=paddle.abs(x=norm_out.to('float32') - torch_output_hidden_states.to('float32')))]}" ) - - torch.testing.assert_close( - residual_out.to(torch.float32), - torch_residual.to(torch.float32), + assert paddle.allclose( + x=residual_out.to("float32"), + y=torch_residual.to("float32"), rtol=0.2, atol=0.2, - ) - torch.testing.assert_close( - norm_out.to(torch.float32), - torch_output_hidden_states.to(torch.float32), + ).item(), "" + assert paddle.allclose( + x=norm_out.to("float32"), + y=torch_output_hidden_states.to("float32"), rtol=0.2, atol=0.2, - ) - - dist.barrier(group=group) + ).item(), "" + paddle.distributed.barrier(group=group) if test_passed: print( f"test RANK {rank}: seq_len{seq_len}-topk{top_k}-tp{world_size}-{dtype}-pdl{launch_with_pdl} passed" @@ -233,11 +198,9 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6): f"test RANK {rank}: seq_len{seq_len}-topk{top_k}-tp{world_size}-{dtype}-pdl{launch_with_pdl} failed" ) finally: - dist.barrier(group=group) - + paddle.distributed.barrier(group=group) comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group) - - dist.destroy_process_group(group=group) +>>>>>> torch.distributed.destroy_process_group(group=group) def get_open_port() -> int: @@ -252,51 +215,48 @@ def get_open_port() -> int: def multi_process_parallel( - world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = () + world_size: int, dtype: paddle.dtype, test_target: Any, target_args: tuple = () ) -> None: mp.set_start_method("spawn", force=True) - procs = [] distributed_init_port = get_open_port() for i in range(world_size): proc_args = (world_size, i, dtype, distributed_init_port) + target_args proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") - proc.start() + """Not Support auto convert *.start, please judge whether it is Pytorch API and convert by yourself""" +>>>>>> proc.start() procs.append(proc) - for i in range(world_size): procs[i].join() - assert procs[i].exitcode == 0, ( - f"Process {i} failed with exit code {procs[i].exitcode}" - ) + assert ( + procs[i].exitcode == 0 + ), f"Process {i} failed with exit code {procs[i].exitcode}" @pytest.mark.parametrize("world_size", [2, 4]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) def test_trtllm_moe_finalize_allreduce_fusion(world_size, dtype): np.random.seed(42) - torch.manual_seed(42) - torch.cuda.manual_seed_all(42) - available_gpus = torch.cuda.device_count() + paddle.seed(seed=42) + paddle.seed(seed=42) + available_gpus = paddle.device.cuda.device_count() if world_size > available_gpus: raise ValueError( f"world_size {world_size} is greater than available_gpus {available_gpus}" ) print(f"Running test for world_size={world_size}") - - # generate shared random input tensor across all ranks seq_len = 16 hidden_size = 7168 top_k = 8 - - shared_expert_output = torch.randn((seq_len, hidden_size), dtype=dtype) - fc2_output = torch.randn((seq_len * top_k, hidden_size), dtype=dtype) - scale = torch.randn((seq_len, top_k), dtype=dtype) - expanded_idx_to_permuted_idx = torch.randint( - 0, seq_len * top_k, (seq_len, top_k), dtype=torch.int32 + shared_expert_output = paddle.randn(shape=(seq_len, hidden_size), dtype=dtype) + fc2_output = paddle.randn(shape=(seq_len * top_k, hidden_size), dtype=dtype) + scale = paddle.randn(shape=(seq_len, top_k), dtype=dtype) + expanded_idx_to_permuted_idx = paddle.randint( + low=0, high=seq_len * top_k, shape=(seq_len, top_k), dtype="int32" + ) + residual = paddle.randn( + shape=shared_expert_output.shape, dtype=shared_expert_output.dtype ) - residual = torch.randn_like(shared_expert_output) - multi_process_parallel( world_size, dtype, diff --git a/tests/test_vllm_custom_allreduce.py b/tests/test_vllm_custom_allreduce.py index a94975caa6..8f3f43d131 100644 --- a/tests/test_vllm_custom_allreduce.py +++ b/tests/test_vllm_custom_allreduce.py @@ -1,48 +1,36 @@ -# flashinfer: adapted from sglang + vllm -# refer to sgl-kernel/tests/test_custom_allreduce.py from sglang +import sys + import logging import multiprocessing as mp import socket from typing import Any +import paddle import pytest -import torch -import torch.distributed as dist +from flashinfer.paddle_utils import * import flashinfer.comm as comm -# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py - - logger = logging.getLogger(__name__) def _run_correctness_worker(world_size, rank, distributed_init_port): - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) + device = device2str(f"cuda:{rank}") + paddle.device.set_device(device=device2str(device)) distributed_init_method = f"tcp://localhost:{distributed_init_port}" - dist.init_process_group( - backend="nccl", - init_method=distributed_init_method, - rank=rank, - world_size=world_size, - ) - group = dist.group.WORLD - + paddle.distributed.init_parallel_env() +>>>>>> group = torch.distributed.group.WORLD try: - device = torch.device(f"cuda:{rank}") + device = device2str(f"cuda:{rank}") max_size = 8192 * 1024 meta_ptrs = comm.create_shared_buffer( comm.vllm_meta_size() + max_size, group=group ) - - rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device) + rank_data = paddle.empty(shape=8 * 1024 * 1024, dtype="uint8") buffer_ptrs = comm.create_shared_buffer(max_size, group=group) - custom_ptr = comm.vllm_init_custom_ar(meta_ptrs, rank_data, rank, True) comm.vllm_register_buffer(custom_ptr, buffer_ptrs) - test_sizes = [ 512, 2560, @@ -56,41 +44,31 @@ def _run_correctness_worker(world_size, rank, distributed_init_port): 2097152, ] num_ctas = [1, 2, 4, 8, 16, 32, 36] - dtypes = [torch.float32, torch.float16, torch.bfloat16] + dtypes = ["float32", "float16", "bfloat16"] test_loop = 10 - for test_size in test_sizes: for num_cta in num_ctas: for dtype in dtypes: for _ in range(test_loop): - inp1 = torch.randint( - 1, 16, (test_size,), dtype=dtype, device=device + inp1 = paddle.randint( + low=1, high=16, shape=(test_size,), dtype=dtype ) inp1_ref = inp1.clone() - out1 = torch.empty_like(inp1) - + out1 = paddle.empty_like(x=inp1) comm.vllm_all_reduce( - custom_ptr, - inp1, - out1, - buffer_ptrs[rank], - max_size, - num_cta, + custom_ptr, inp1, out1, buffer_ptrs[rank], max_size, num_cta ) - - dist.all_reduce(inp1_ref, group=group) - - torch.testing.assert_close(out1, inp1_ref) + paddle.distributed.all_reduce(tensor=inp1_ref, group=group) + assert paddle.allclose(x=out1, y=inp1_ref).item(), "" finally: - dist.barrier(group=group) + paddle.distributed.barrier(group=group) if custom_ptr is not None: comm.vllm_dispose(custom_ptr) if buffer_ptrs: comm.free_shared_buffer(buffer_ptrs, group) if meta_ptrs: comm.free_shared_buffer(meta_ptrs, group) - - dist.destroy_process_group(group=group) +>>>>>> torch.distributed.destroy_process_group(group=group) def get_open_port() -> int: @@ -108,33 +86,28 @@ def multi_process_parallel( world_size: int, test_target: Any, target_args: tuple = () ) -> None: mp.set_start_method("spawn", force=True) - procs = [] distributed_init_port = get_open_port() for i in range(world_size): proc_args = (world_size, i, distributed_init_port) + target_args proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") - proc.start() + """Not Support auto convert *.start, please judge whether it is Pytorch API and convert by yourself""" +>>>>>> proc.start() procs.append(proc) - for i in range(world_size): procs[i].join() - assert procs[i].exitcode == 0, ( - f"Process {i} failed with exit code {procs[i].exitcode}" - ) + assert ( + procs[i].exitcode == 0 + ), f"Process {i} failed with exit code {procs[i].exitcode}" @pytest.mark.parametrize("world_size", [2, 4]) def test_vllm_custom_allreduce(world_size): - available_gpus = torch.cuda.device_count() + available_gpus = paddle.device.cuda.device_count() if world_size > available_gpus: raise ValueError( f"world_size {world_size} is greater than available_gpus {available_gpus}" ) print(f"Running test for world_size={world_size}") - multi_process_parallel( - world_size, - _run_correctness_worker, - target_args=(), - ) + multi_process_parallel(world_size, _run_correctness_worker, target_args=()) print(f"custom allreduce tp = {world_size}: OK") diff --git a/tests/utils_fp4.py b/tests/utils_fp4.py index f5aaa8e7ec..fa262ab2d1 100644 --- a/tests/utils_fp4.py +++ b/tests/utils_fp4.py @@ -1,18 +1,12 @@ -import torch +import sys + + +import paddle +from flashinfer.paddle_utils import * import flashinfer.utils as utils FLOAT4_E2M1_MAX = 6.0 - -# E2M1 to float -# 0111 -> 6 -# 0110 -> 4 -# 0101 -> 3 -# 0100 -> 2 -# 0011 -> 1.5 -# 0010 -> 1 -# 0001 -> 0.5 -# 0000 -> 0 E2M1_TO_FLOAT32 = [ 0.0, 0.5, @@ -34,21 +28,18 @@ def cast_from_fp4(x): - # The fp4 values are packed in uint8 as [v_1st | v_2nd] - v_2nd = x & 0xF - v_1st = (x >> 4) & 0xF - c = torch.stack((v_2nd, v_1st), dim=-1) - new_shape = c.shape[:-2] + ( - c.shape[-2] * c.shape[-1], - ) # fuse the dim added by stack - lookup_table = torch.tensor(E2M1_TO_FLOAT32, device=c.device) - out = lookup_table[c.to(torch.long)].reshape(new_shape).to(torch.float32) + v_2nd = x & 15 + v_1st = x >> 4 & 15 + c = paddle.stack(x=(v_2nd, v_1st), axis=-1) + new_shape = tuple(c.shape)[:-2] + (tuple(c.shape)[-2] * tuple(c.shape)[-1],) + lookup_table = paddle.to_tensor(data=E2M1_TO_FLOAT32, place=c.place) + out = lookup_table[c.to("int64")].reshape(new_shape).to("float32") return out def cast_to_fp4(x): - sign = torch.sign(x) - x = torch.abs(x) + sign = paddle.sign(x=x) + x = paddle.abs(x=x) x[(x >= 0.0) & (x <= 0.25)] = 0.0 x[(x > 0.25) & (x < 0.75)] = 0.5 x[(x >= 0.75) & (x <= 1.25)] = 1.0 @@ -61,8 +52,10 @@ def cast_to_fp4(x): def get_reciprocal(x): - if isinstance(x, torch.Tensor): - return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) + if isinstance(x, paddle.Tensor): + return paddle.where( + condition=x == 0, x=paddle.to_tensor(data=0.0, dtype=x.dtype), y=1.0 / x + ) elif isinstance(x, (float, int)): return 0.0 if x == 0 else 1.0 / x else: @@ -70,31 +63,31 @@ def get_reciprocal(x): def ref_fp4_quant(x, global_scale, block_size, sf_use_ue8m0=False): - assert isinstance(global_scale, (float, int)) or global_scale.dtype == torch.float32 - - sliced_shape = x.shape[:-1] + (x.shape[-1] // block_size, block_size) - sliced_x = torch.reshape(x, sliced_shape) - vec_max = torch.max(torch.abs(sliced_x), dim=-1, keepdim=True)[0].to(torch.float32) + assert isinstance(global_scale, (float, int)) or global_scale.dtype == "float32" + sliced_shape = tuple(x.shape)[:-1] + (tuple(x.shape)[-1] // block_size, block_size) + sliced_x = paddle.reshape(x=x, shape=sliced_shape) + vec_max = ( + paddle.max(keepdim=True, x=paddle.abs(x=sliced_x), axis=-1), + paddle.argmax(keepdim=True, x=paddle.abs(x=sliced_x), axis=-1), + )[0].to("float32") scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) if sf_use_ue8m0: - scale = (scale.view(torch.int32) + 0x007FFFFF) & 0x7F800000 - scale = scale.view(torch.float32) + scale = scale.view("int32") + 8388607 & 2139095040 + scale = scale.view("float32") else: - scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + scale = scale.to(paddle.float8_e4m3fn).to("float32") output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) - - scaled_x = sliced_x.to(torch.float32) * output_scale - clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(x.shape) - return cast_to_fp4(clipped_x), scale.squeeze(-1) + scaled_x = sliced_x.to("float32") * output_scale + clipped_x = paddle.clip(x=scaled_x, min=-6.0, max=6.0).reshape(tuple(x.shape)) + return cast_to_fp4(clipped_x), scale.squeeze(axis=-1) def recover_swizzled_scales(scale, m, n, block_size, sf_start_index=0): - assert sf_start_index + m <= scale.shape[0] - full_m = scale.shape[0] + assert sf_start_index + m <= tuple(scale.shape)[0] + full_m = tuple(scale.shape)[0] scale_n = n // block_size rounded_n = utils.round_up(scale_n, 4) - # Recover the swizzled scaling factor to linear layout - tmp = torch.reshape(scale, (1, full_m // 128, rounded_n // 4, 32, 4, 4)) - tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) - result = torch.reshape(tmp, (full_m, rounded_n)).to(torch.float32) + tmp = paddle.reshape(x=scale, shape=(1, full_m // 128, rounded_n // 4, 32, 4, 4)) + tmp = paddle.transpose(x=tmp, perm=(0, 1, 4, 3, 2, 5)) + result = paddle.reshape(x=tmp, shape=(full_m, rounded_n)).to("float32") return result[sf_start_index : sf_start_index + m, :scale_n]