Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion tritonbench/operators/decoding_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,42 @@
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental:gen_ai_attention_ops"
)

HAS_FB_IMPORT = True
except ImportError:
HAS_FB_IMPORT = False

# Load FlashInfer FMHA Gen library (includes TRTLLM kernels)
torch.ops.load_library("//deeplearning/flashinfer:fmha_gen")

# Initialize FlashInfer cubin loader
try:
from flashinfer.jit.cubin_loader import setup_cubin_loader

# Find the loaded library from the dlopen handle
# The torch.ops.load_library should have loaded it already
lib_name = "libdeeplearning_flashinfer_fmha_gen.so"

# Try to find it in /proc/self/maps
found = False
with open('/proc/self/maps', 'r') as f:
for line in f:
if lib_name in line:
# Extract the path from the line
parts = line.strip().split()
if len(parts) >= 6:
lib_path = ' '.join(parts[5:])
setup_cubin_loader(lib_path)
found = True
break

if not found:
print(f"Warning: Could not find {lib_name} in loaded libraries")
except Exception as e:
print(f"Warning: Could not initialize FlashInfer cubin loader: {e}")
import traceback
traceback.print_exc()

from .trtllm_utils import trtllm_paged_attention_decode_func

from tritonbench.utils.triton_op import (
BenchmarkOperator,
Expand Down Expand Up @@ -663,3 +694,14 @@ def aiter_paged_fp8kv(
k_scale_asm,
v_scale_asm,
)

@register_benchmark()
def trtllm_decode_fmha(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
) -> Callable:
args = trtllm_paged_attention_decode_func(q, k_cache, v_cache, cache_seqlens)
return lambda: torch.ops.fmha_gen.trtllm_paged_attention_decode(*args)
141 changes: 141 additions & 0 deletions tritonbench/operators/decoding_attention/trtllm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
TRTLLM FMHA utility functions for handling tensor conversion and kernel preparation.
"""

import torch


def trtllm_paged_attention_decode_func(q, k_cache, v_cache, cache_seqlens):
"""
TRTLLM FMHA paged attention decode function that prepares inputs for the
FlashInfer fmha_gen library's trtllm_paged_attention_decode kernel.

This function converts standard KV cache tensors to paged format and prepares
all necessary parameters for the TRTLLM kernel.

Args:
q: Query tensor [batch, seq_len_q, num_qo_heads, head_dim]
k_cache: Key cache tensor [batch, max_seq_len_kv, num_kv_heads, head_dim]
v_cache: Value cache tensor [batch, max_seq_len_kv, num_kv_heads, head_dim]
cache_seqlens: Sequence lengths tensor [batch]

Returns:
Tuple of arguments for torch.ops.fmha_gen.trtllm_paged_attention_decode:
(out, out_scale_factor, query, key_cache, value_cache, workspace_buffer,
block_tables, seq_lens, max_kv_len, bmm1_scale, bmm2_scale, o_sf_scale,
o_sf_vec_size, o_sf_start_index, window_left, sm_count, enable_pdl,
workspace_size, attention_sinks)
"""

device = q.device
# Convert input tensors to paged format for TRTLLM FMHA
batch_size, seq_len_q, num_qo_heads, head_dim = q.shape
_, max_seq_len_kv, num_kv_heads, _ = k_cache.shape

# Use page size of 16 for TRTLLM FMHA
page_size = 16
max_num_blocks_per_seq = (max_seq_len_kv + page_size - 1) // page_size
total_pages = batch_size * max_num_blocks_per_seq

# Reshape k_cache and v_cache to paged format [total_pages, num_kv_heads, page_size, head_dim]
k_cache_paged = k_cache.view(
batch_size, max_num_blocks_per_seq, page_size, num_kv_heads, head_dim
)
k_cache_paged = k_cache_paged.permute(0, 1, 3, 2, 4).contiguous()
k_cache_paged = k_cache_paged.view(total_pages, num_kv_heads, page_size, head_dim)

v_cache_paged = v_cache.view(
batch_size, max_num_blocks_per_seq, page_size, num_kv_heads, head_dim
)
v_cache_paged = v_cache_paged.permute(0, 1, 3, 2, 4).contiguous()
v_cache_paged = v_cache_paged.view(total_pages, num_kv_heads, page_size, head_dim)

# Create block tables
block_tables = torch.zeros(
(batch_size, max_num_blocks_per_seq), dtype=torch.int32, device=device
)
for i in range(batch_size):
for j in range(max_num_blocks_per_seq):
block_tables[i, j] = i * max_num_blocks_per_seq + j

# Create output tensor
out = torch.zeros_like(q)

# Create workspace buffer
workspace_size = 128 * 1024 * 1024 # 128MB
workspace_buffer = torch.zeros(workspace_size, dtype=torch.uint8, device=device)

# Attention parameters
max_seq_len = cache_seqlens.max().item()
bmm1_scale = 1.0 / (head_dim**0.5)
bmm2_scale = 1.0

# Output scale factor parameters (not used for non-FP8)
out_scale_factor = None # Optional tensor for FP8 output scaling
o_sf_scale = -1.0 # Output scale factor scale (disabled when -1)
o_sf_vec_size = -1 # Output scale factor vector size (disabled when -1)
o_sf_start_index = -1 # Output scale factor start index (disabled when -1)

# Attention window settings
window_left = -1 # No sliding window (disabled when -1)

# Device settings
sm_count = torch.cuda.get_device_properties(device).multi_processor_count

# PDL (Programmatic Dependent Launch) settings
enable_pdl = False

# Attention sinks (optional)
attention_sinks = None

# Return tuple matching trtllm_paged_attention_decode signature:
# void trtllm_paged_attention_decode(
# at::Tensor out,
# std::optional<at::Tensor> out_scale_factor,
# at::Tensor query,
# at::Tensor key_cache,
# at::Tensor value_cache,
# at::Tensor workspace_buffer,
# at::Tensor block_tables,
# at::Tensor seq_lens,
# int64_t max_kv_len,
# double bmm1_scale,
# double bmm2_scale,
# double o_sf_scale,
# int64_t o_sf_vec_size,
# int64_t o_sf_start_index,
# int64_t window_left,
# int64_t sm_count,
# bool enable_pdl,
# int64_t workspace_size,
# std::optional<at::Tensor> attention_sinks
# )

args = (
out, # out
out_scale_factor, # out_scale_factor (optional)
q, # query
k_cache_paged, # key_cache
v_cache_paged, # value_cache
workspace_buffer, # workspace_buffer
block_tables, # block_tables
cache_seqlens, # seq_lens
max_seq_len, # max_kv_len
bmm1_scale, # bmm1_scale
bmm2_scale, # bmm2_scale
o_sf_scale, # o_sf_scale
o_sf_vec_size, # o_sf_vec_size
o_sf_start_index, # o_sf_start_index
window_left, # window_left
sm_count, # sm_count
enable_pdl, # enable_pdl
workspace_size, # workspace_size
attention_sinks, # attention_sinks (optional)
)
return args
Loading