Skip to content

[HiCache] refactor page_first_direct io kernel#18113

Open
huangtingwei9988 wants to merge 13 commits intosgl-project:mainfrom
antgroup:refactor_direct_io_backend
Open

[HiCache] refactor page_first_direct io kernel#18113
huangtingwei9988 wants to merge 13 commits intosgl-project:mainfrom
antgroup:refactor_direct_io_backend

Conversation

@huangtingwei9988
Copy link
Collaborator

@huangtingwei9988 huangtingwei9988 commented Feb 2, 2026

Motivation

The implementation of transfer_kv_page_first_direct_impl involves numerous item(), select(), and slice() calculations, which significantly impacts the host's performance.
CCA2A9F9-6175-4690-BBAE-9AC81422CE2F_1_201_a

Furthermore, due to the slow execution speed of the CPU, it is impossible to achieve overlap between forward computation and cache loading, which significantly affects TTFT when cache hit occurs.
image

Modifications

Directly calling the cudaMemcpyAsync interface and passing pointers directly avoids the computations of item(), select(), and slice(), which can significantly improve performance.
8A25A565-EAF0-4B1D-9FDC-BA29DBC650AA_1_201_a

Furthermore, because the CPU-side submission speed is greater than the GPU-side execution speed, it allows for overlap between cache loading and forward execution.
image

Accuracy Tests

root@gpulingjun010013003244:/home/shenghai.htw# pytest test_kvcacheio.py
=========================================================================================== test session starts ===========================================================================================
platform linux -- Python 3.12.12, pytest-9.0.0, pluggy-1.6.0
rootdir: /home/shenghai.htw
plugins: anyio-4.11.0, typeguard-4.4.4
collected 192 items                                                                                                                                                                                       

test_kvcacheio.py ............................................................................................................................................................................................ [ 97%]
....                                                                                                                                                                                                [100%]

===================================================================================== 192 passed in 215.77s (0:03:35) =====================================================================================

Benchmarking and Profiling

Achieving approximately 3 times (1.3558 ms->0.4274 ms) the speed of kernel launch significantly alleviates the long-term blocking of the scheduler proc in the start_writing and start_loading methods.

Before:

================================================================================
Function: transfer_kv_per_layer_direct_pf_lf
Direction: page_first_direct -> layer_first
--------------------------------------------------------------------------------
CPU Performance (Async - without sync):
  Mean time: 1.3558 ms
  Median time: 1.3466 ms
  Std dev: 0.0201 ms
  Min time: 1.3319 ms
  Max time: 1.4273 ms
--------------------------------------------------------------------------------
GPU Performance (CUDA event timing):
  Mean time: 1.3565 ms
  Median time: 1.3474 ms
  Std dev: 0.0191 ms
  Min time: 1.3327 ms
  Max time: 1.3991 ms
  Bandwidth: 18.00 GB/s
--------------------------------------------------------------------------------
Overhead Analysis:
  CPU Async time: 1.3558 ms (kernel launch overhead)
================================================================================

After:

================================================================================
Function: transfer_kv_per_layer_direct_pf_lf
Direction: page_first_direct -> layer_first
--------------------------------------------------------------------------------
CPU Performance (Async - without sync):
  Mean time: 0.4274 ms
  Median time: 0.4263 ms
  Std dev: 0.0083 ms
  Min time: 0.4184 ms
  Max time: 0.4820 ms
--------------------------------------------------------------------------------
GPU Performance (CUDA event timing):
  Mean time: 1.0996 ms
  Median time: 1.0978 ms
  Std dev: 0.0090 ms
  Min time: 1.0940 ms
  Max time: 1.1703 ms
  Bandwidth: 22.20 GB/s
--------------------------------------------------------------------------------
Overhead Analysis:
  CPU Async time: 0.4274 ms (kernel launch overhead)
================================================================================

python benchmark_transfer_kv.py

#!/usr/bin/env python3
"""
Benchmark script for transfer_kv_per_layer_direct_pf_lf

Tests performance of transfer_kv_per_layer_direct_pf_lf: page_first_direct -> layer_first
Tests layer 0 only.
"""

import torch
import time
import statistics
from typing import List, Tuple
import argparse

try:
    from sgl_kernel import kvcacheio
except ImportError:
    # Fallback if import path is different
    import sys
    sys.path.insert(0, '/usr/local/lib/python3.12/dist-packages')
    from sgl_kernel import kvcacheio


def create_layer_first_tensors(
    layer_num: int,
    size: int,
    head_num: int,
    head_dim: int,
    dtype: torch.dtype = torch.float16,
    device: str = "cuda",
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Create layer_first layout tensors: (2, layer_num, size, head_num, head_dim)
    Returns: (k_buffer, v_buffer)
    """
    dims = (2, layer_num, size, head_num, head_dim)
    kv_buffer = torch.randn(*dims, dtype=dtype, device=device)
    k_buffer = kv_buffer[0]  # Shape: (layer_num, size, head_num, head_dim)
    v_buffer = kv_buffer[1]  # Shape: (layer_num, size, head_num, head_dim)
    return k_buffer, v_buffer


def create_page_first_direct_tensors(
    page_num: int,
    layer_num: int,
    page_size: int,
    head_num: int,
    head_dim: int,
    dtype: torch.dtype = torch.float16,
    device: str = "cuda",
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Create page_first_direct layout tensors: (2, page_num, layer_num, page_size, head_num, head_dim)
    Returns: (k_buffer, v_buffer)
    """
    dims = (2, page_num, layer_num, page_size, head_num, head_dim)
    kv_buffer = torch.randn(*dims, dtype=dtype, device=device)
    k_buffer = kv_buffer[0]  # Shape: (page_num, layer_num, page_size, head_num, head_dim)
    v_buffer = kv_buffer[1]  # Shape: (page_num, layer_num, page_size, head_num, head_dim)
    return k_buffer, v_buffer


def create_indices(
    num_pages: int,
    page_size: int,
    device: str = "cuda",
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Create src and dst indices for transfer.
    Indices are page-aligned (multiples of page_size).
    """
    # Create sequential page indices
    src_page_indices = torch.arange(num_pages, dtype=torch.int64, device=device)
    dst_page_indices = torch.arange(num_pages, dtype=torch.int64, device=device)
    
    # Expand to token-level indices
    src_indices = torch.cat([
        torch.arange(p * page_size, (p + 1) * page_size, dtype=torch.int64, device=device)
        for p in src_page_indices
    ])
    dst_indices = torch.cat([
        torch.arange(p * page_size, (p + 1) * page_size, dtype=torch.int64, device=device)
        for p in dst_page_indices
    ])
    
    return src_indices, dst_indices


def benchmark_transfer_kv_per_layer_direct_pf_lf(
    page_size: int,
    num_pages: int,
    layer_num: int,
    head_num: int,
    head_dim: int,
    dtype: torch.dtype = torch.float16,
    device: str = "cuda",
    num_warmup: int = 10,
    num_iterations: int = 100,
) -> dict:
    """
    Benchmark transfer_kv_per_layer_direct_pf_lf: page_first_direct -> layer_first
    Tests layer 0 only.
    Note: src is on CPU (pin_memory), dst is on GPU (device)
    """
    layer_id = 0
    
    # Create source tensors (page_first_direct layout) on CPU with pin_memory
    src_k, src_v = create_page_first_direct_tensors(
        page_num=num_pages,
        layer_num=layer_num,
        page_size=page_size,
        head_num=head_num,
        head_dim=head_dim,
        dtype=dtype,
        device="cpu",
    )
    src_k = src_k.pin_memory()
    src_v = src_v.pin_memory()
    
    # Create destination tensors (layer_first layout) on GPU
    size = num_pages * page_size
    dst_k, dst_v = create_layer_first_tensors(
        layer_num=layer_num,
        size=size,
        head_num=head_num,
        head_dim=head_dim,
        dtype=dtype,
        device=device,
    )
    
    # Create indices
    src_indices, dst_indices = create_indices(num_pages, page_size, device=device)
    
    # Prepare src_ptrs and dst_ptrs
    src_ptrs = [src_k, src_v]
    dst_ptrs = [dst_k[layer_id], dst_v[layer_id]]
    
    # Warmup
    for _ in range(num_warmup):
        kvcacheio.transfer_kv_per_layer_direct_pf_lf(
            src_ptrs=src_ptrs,
            dst_ptrs=dst_ptrs,
            src_indices=src_indices,
            dst_indices=dst_indices,
            layer_id=layer_id,
            page_size=page_size,
        )
    
    torch.cuda.synchronize()
    
    # Create CUDA events for GPU timing
    if device == "cuda":
        gpu_start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iterations)]
        gpu_end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iterations)]
    
    # Benchmark
    cpu_times_async = []  # CPU time without synchronization
    gpu_times = []
    
    for i in range(num_iterations):
        torch.cuda.synchronize()  # Ensure previous operations are done
        
        # CPU timing (async) - measure time without waiting for GPU
        cpu_start_async = time.perf_counter()
        
        # GPU timing (if CUDA available)
        if device == "cuda":
            gpu_start_events[i].record()
        
        kvcacheio.transfer_kv_per_layer_direct_pf_lf(
            src_ptrs=src_ptrs,
            dst_ptrs=dst_ptrs,
            src_indices=src_indices,
            dst_indices=dst_indices,
            layer_id=layer_id,
            page_size=page_size,
        )
        
        cpu_end_async = time.perf_counter()  # Measure async time immediately after call
        
        if device == "cuda":
            gpu_end_events[i].record()
        
        torch.cuda.synchronize()  # Wait for GPU operations to complete (for GPU timing)
        
        cpu_times_async.append((cpu_end_async - cpu_start_async) * 1000)  # Convert to ms
    
    # Get GPU times
    if device == "cuda":
        for i in range(num_iterations):
            gpu_times.append(gpu_start_events[i].elapsed_time(gpu_end_events[i]))  # Already in ms
    
    # Calculate statistics
    total_size_bytes = num_pages * page_size * head_num * head_dim * dtype.itemsize * 2  # K + V
    
    cpu_async_mean = statistics.mean(cpu_times_async)
    
    result = {
        "function": "transfer_kv_per_layer_direct_pf_lf",
        "direction": "page_first_direct -> layer_first",
        "layer_id": layer_id,
        "num_pages": num_pages,
        "page_size": page_size,
        "num_tokens": num_pages * page_size,
        "layer_num": layer_num,
        "head_num": head_num,
        "head_dim": head_dim,
        "dtype": str(dtype),
        "cpu_async_mean_time_ms": cpu_async_mean,
        "cpu_async_median_time_ms": statistics.median(cpu_times_async),
        "cpu_async_std_time_ms": statistics.stdev(cpu_times_async) if len(cpu_times_async) > 1 else 0.0,
        "cpu_async_min_time_ms": min(cpu_times_async),
        "cpu_async_max_time_ms": max(cpu_times_async),
        "total_size_bytes": total_size_bytes,
    }
    
    if device == "cuda" and len(gpu_times) > 0:
        gpu_mean = statistics.mean(gpu_times)
        gpu_bandwidth_gbps = (total_size_bytes / (1024**3)) / (gpu_mean / 1000) if gpu_mean > 0 else 0.0
        result.update({
            "gpu_mean_time_ms": gpu_mean,
            "gpu_median_time_ms": statistics.median(gpu_times),
            "gpu_std_time_ms": statistics.stdev(gpu_times) if len(gpu_times) > 1 else 0.0,
            "gpu_min_time_ms": min(gpu_times),
            "gpu_max_time_ms": max(gpu_times),
            "gpu_bandwidth_gbps": gpu_bandwidth_gbps,
        })
    else:
        result.update({
            "gpu_mean_time_ms": 0.0,
            "gpu_median_time_ms": 0.0,
            "gpu_std_time_ms": 0.0,
            "gpu_min_time_ms": 0.0,
            "gpu_max_time_ms": 0.0,
            "gpu_bandwidth_gbps": 0.0,
        })
    
    return result


def print_results(results: dict):
    """Print benchmark results in a formatted way"""
    print("\n" + "=" * 80)
    print(f"Function: {results['function']}")
    print(f"Direction: {results['direction']}")
    print("-" * 80)
    print(f"Configuration:")
    if "layer_id" in results:
        print(f"  Layer ID: {results['layer_id']}")
    print(f"  Number of pages: {results['num_pages']}")
    print(f"  Page size: {results['page_size']}")
    print(f"  Total tokens: {results['num_tokens']}")
    print(f"  Number of layers: {results['layer_num']}")
    print(f"  Number of heads: {results['head_num']}")
    print(f"  Head dimension: {results['head_dim']}")
    print(f"  Data type: {results['dtype']}")
    print(f"  Total size: {results['total_size_bytes'] / (1024**2):.2f} MB")
    print("-" * 80)
    print(f"CPU Performance (Async - without sync):")
    print(f"  Mean time: {results['cpu_async_mean_time_ms']:.4f} ms")
    print(f"  Median time: {results['cpu_async_median_time_ms']:.4f} ms")
    print(f"  Std dev: {results['cpu_async_std_time_ms']:.4f} ms")
    print(f"  Min time: {results['cpu_async_min_time_ms']:.4f} ms")
    print(f"  Max time: {results['cpu_async_max_time_ms']:.4f} ms")
    print("-" * 80)
    if results.get('gpu_mean_time_ms', 0) > 0:
        print(f"GPU Performance (CUDA event timing):")
        print(f"  Mean time: {results['gpu_mean_time_ms']:.4f} ms")
        print(f"  Median time: {results['gpu_median_time_ms']:.4f} ms")
        print(f"  Std dev: {results['gpu_std_time_ms']:.4f} ms")
        print(f"  Min time: {results['gpu_min_time_ms']:.4f} ms")
        print(f"  Max time: {results['gpu_max_time_ms']:.4f} ms")
        print(f"  Bandwidth: {results['gpu_bandwidth_gbps']:.2f} GB/s")
        print("-" * 80)
        async_overhead_ms = results['cpu_async_mean_time_ms']
        print(f"Overhead Analysis:")
        print(f"  CPU Async time: {async_overhead_ms:.4f} ms (kernel launch overhead)")
    print("=" * 80)


def main():
    parser = argparse.ArgumentParser(description="Benchmark transfer_kv_per_layer_direct_pf_lf for all layers")
    parser.add_argument("--layer-num", type=int, default=32, help="Number of layers")
    parser.add_argument("--head-num", type=int, default=8, help="Number of attention heads")
    parser.add_argument("--head-dim", type=int, default=128, help="Head dimension")
    parser.add_argument("--page-size", type=int, default=64, help="Page size")
    parser.add_argument("--num-pages", type=int, default=100, help="Number of pages to transfer")
    parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16"], help="Data type")
    parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
    parser.add_argument("--warmup", type=int, default=10, help="Number of warmup iterations")
    parser.add_argument("--iterations", type=int, default=100, help="Number of benchmark iterations")
    
    args = parser.parse_args()
    
    dtype_map = {
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }
    dtype = dtype_map[args.dtype]
    
    device = args.device
    if device == "cuda" and not torch.cuda.is_available():
        print("CUDA not available, using CPU")
        device = "cpu"
    
    print(f"Device: {device}")
    print(f"PyTorch version: {torch.__version__}")
    if device == "cuda":
        print(f"CUDA version: {torch.version.cuda}")
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    
    print("\n" + "#" * 80)
    print("# Benchmarking transfer_kv_per_layer_direct_pf_lf (Layer 0)")
    print("#" * 80)
    
    # Call function - tests layer 0 only
    result = benchmark_transfer_kv_per_layer_direct_pf_lf(
        page_size=args.page_size,
        num_pages=args.num_pages,
        layer_num=args.layer_num,
        head_num=args.head_num,
        head_dim=args.head_dim,
        dtype=dtype,
        device=device,
        num_warmup=args.warmup,
        num_iterations=args.iterations,
    )
    
    # Print results
    print_results(result)


if __name__ == "__main__":
    main()

Co-author: @hzh0425 @zhaoyongke

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@huangtingwei9988
Copy link
Collaborator Author

/tag-run-ci-label

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @huangtingwei9988, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the transfer_kv_page_first_direct_impl function in sgl-kernel/csrc/kvcacheio/transfer.cu to improve performance. The original implementation used several item(), select(), and slice() operations, which were found to be a bottleneck. This PR replaces those operations with direct calls to cudaMemcpyAsync to improve performance by overlapping cache loading and forward execution.

Highlights

  • Performance Improvement: This refactor replaces the use of item(), select(), and slice() with direct calls to cudaMemcpyAsync, significantly reducing CPU overhead.
  • Asynchronous Memory Copy: By using cudaMemcpyAsync, the CPU can submit memory transfer requests faster than the GPU executes them, enabling overlap between cache loading and forward execution.
  • Code Simplification: The change simplifies the code by directly passing pointers to cudaMemcpyAsync, avoiding intermediate computations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the transfer_kv_page_first_direct_impl function to enhance performance by replacing slower PyTorch tensor operations like item(), select(), and slice() with direct pointer manipulation and asynchronous CUDA memory copies. This is a solid optimization that should deliver the performance improvements described. The implementation appears correct, and I have a couple of minor suggestions to improve code readability by avoiding variable shadowing.

@xiezhq-hermann
Copy link
Collaborator

Nice work, btw you might also want to try cudaMemcpyBatchAsync for better direct io performance.
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY.html#group__CUDART__MEMORY_1g6126baf5d881835091c59e48890d6854

@hzh0425 hzh0425 self-assigned this Feb 4, 2026
@huangtingwei9988
Copy link
Collaborator Author

huangtingwei9988 commented Feb 4, 2026

After integrating cudaMemcpyBatchAsync, the performance improved significantly, achieving nearly a 8-fold speedup in kernel launch (CPU) and approximately a 2-fold increase in kernel transfer bandwidth (GPU).

================================================================================
Function: transfer_kv_per_layer_direct_pf_lf
Direction: page_first_direct -> layer_first
--------------------------------------------------------------------------------
CPU Performance (Async - without sync):
  Mean time: 0.1688 ms
  Median time: 0.1661 ms
  Std dev: 0.0094 ms
  Min time: 0.1641 ms
  Max time: 0.2295 ms
--------------------------------------------------------------------------------
GPU Performance (CUDA event timing):
  Mean time: 0.6404 ms
  Median time: 0.6369 ms
  Std dev: 0.0085 ms
  Min time: 0.6350 ms
  Max time: 0.6941 ms
  Bandwidth: 38.12 GB/s
--------------------------------------------------------------------------------
root@gpulingjun010013003244:/home/shenghai.htw# pytest test_kvcacheio.py 
======================================================================== test session starts ========================================================================
platform linux -- Python 3.12.12, pytest-9.0.0, pluggy-1.6.0
rootdir: /home/shenghai.htw
plugins: anyio-4.11.0, typeguard-4.4.4
collected 192 items                                                                                                                                                 

test_kvcacheio.py ........................................................................................................................................... [ 72%]
.....................................................                                                                                                                                     [100%]

================================================================================ 192 passed in 211.19s (0:03:31) ================================================================================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants