Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,17 @@ def mla_decode_fwd(
else:
if num_kv_splits is None:
num_kv_splits = get_cu_num()
if nhead == 16 or (
nhead == 128 and q.dtype == dtypes.fp8 and kv_buffer.dtype == dtypes.fp8
if (
nhead == 16
or (
nhead == 128 and q.dtype == dtypes.fp8 and kv_buffer.dtype == dtypes.fp8
)
or (
nhead == 32
and q.dtype == dtypes.fp8
and kv_buffer.dtype == dtypes.fp8
and max_seqlen_q == 4
)
):
# Natively support cases
pass
Expand Down
56 changes: 28 additions & 28 deletions aiter/ops/triton/_triton_kernels/attention/fav3_sage_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def compute_window_bounds(

# Right boundary
if IS_CAUSAL:
# Causal cap: col row + diag
# Causal cap: col <= row + diag
right_min = tl.minimum(seqlen_k - 1, q_start + diag)
right_max = tl.minimum(seqlen_k - 1, q_end + diag)
else:
Expand Down Expand Up @@ -654,14 +654,14 @@ def handle_padded_last_block(
# current 'full' range right edge
full_right_block = clipped_left + n_full_blocks - 1

# If last_block is already beyond full_right_block, it's already in back-masked nothing to do
# If last_block is already beyond full_right_block, it's already in back-masked -> nothing to do
last_already_back_masked = last_block > full_right_block
if not last_already_back_masked:
# If the window starts past last_block, it was counted in front-masked
if clipped_left > last_block:
n_front_masked_blocks = tl.maximum(0, n_front_masked_blocks - 1)
else:
# Otherwise it was counted 'full' move it out of full
# Otherwise it was counted 'full' -> move it out of full
n_full_blocks = tl.maximum(0, n_full_blocks - 1)
# In both cases we need one more back-masked block
n_back_masked_blocks = n_back_masked_blocks + 1
Expand All @@ -678,7 +678,7 @@ def compute_padding_info(seqlen_k, BLOCK_N: tl.constexpr):
# K blocks visualization:
# Block 0 Block 1 Block 2 (last)
# K0 K1 K2 K3 K4 K5 K6 K7 K8 K9 ?? ??
# --------- --------- ---↑ ↑---
# ?---------? ?---------? ?---? ?---?
# full block full block valid pad
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
Expand Down Expand Up @@ -731,7 +731,7 @@ def compute_block_masking(
IS_CAUSAL,
)

# window vanishes early exit
# window vanishes -> early exit
if right_max < left_min:
return 0, 0, 0, 0, n_extra_tokens

Expand Down Expand Up @@ -770,24 +770,24 @@ def compute_block_masking(
# ========== CAUSAL MODE: Classify K Blocks ==========
# Calculate causal boundary for this Q block
# [K0 K1 K2 K3] [K4 K5 K6 K7] [K8 K9 ?? ??]
# Q0-Q3: [ 1 0 0 0] [ 0 0 0 0] [ 0 0 -- --] Q0
# [ 1 1 0 0] [ 0 0 0 0] [ 0 0 -- --] Q1
# [ 1 1 1 0] [ 0 0 0 0] [ 0 0 -- --] Q2
# [ 1 1 1 1] [ 1 1 0 0] [ 0 0 -- --] Q3
# can see up to K5
# Q0-Q3: [ 1 0 0 0] [ 0 0 0 0] [ 0 0 -- --] <- Q0
# [ 1 1 0 0] [ 0 0 0 0] [ 0 0 -- --] <- Q1
# [ 1 1 1 0] [ 0 0 0 0] [ 0 0 -- --] <- Q2
# [ 1 1 1 1] [ 1 1 0 0] [ 0 0 -- --] <- Q3
# ? can see up to K5
#
# Q4-Q7: [ 1 1 1 1] [ 1 1 1 0] [ 0 0 -- --] Q4
# [ 1 1 1 1] [ 1 1 1 1] [ 0 0 -- --] Q5
# [ 1 1 1 1] [ 1 1 1 1] [ 1 0 -- --] Q6
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -- --] Q7
# Q4-Q7: [ 1 1 1 1] [ 1 1 1 0] [ 0 0 -- --] <- Q4
# [ 1 1 1 1] [ 1 1 1 1] [ 0 0 -- --] <- Q5
# [ 1 1 1 1] [ 1 1 1 1] [ 1 0 -- --] <- Q6
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -- --] <- Q7

# ------------------------------------------------------------
# 1. figure out, in tokens, the right-most K position
# this Q-block may attend to
# ------------------------------------------------------------
k_max_token = q_end + diag # last visible K index

# this Q-block is entirely above the diagonal nothing to do
# this Q-block is entirely above the diagonal => nothing to do
if k_max_token < 0:
return 0, 0, 0, 0, n_extra_tokens

Expand All @@ -801,12 +801,12 @@ def compute_block_masking(

# ------------------------------------------------------------
# 3. classify those visible blocks
# we *never* skip or mask blocks in front, because causal
# - we *never* skip or mask blocks in front, because causal
# attention always starts at K0
# the back side can require several masked blocks:
# intersection of the causal diagonal with K-grid
# (at most BLOCK_M / BLOCK_N blocks)
# plus one extra block if this Q-block stops in the
# - the back side can require several masked blocks:
# o intersection of the causal diagonal with K-grid
# (at most ?BLOCK_M / BLOCK_N? blocks)
# o plus one extra block if this Q-block stops in the
# middle of a K-block or the last K-block is padded
# ------------------------------------------------------------
padded_last_k = n_extra_tokens != 0
Expand All @@ -823,15 +823,15 @@ def compute_block_masking(
# Without causal mask, all positions can attend to all positions
# Only need to handle the padding in the last block
# [K0 K1 K2 K3] [K4 K5 K6 K7] [K8 K9 ?? ??]
# Q0-Q3: [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞]
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞]
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞]
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞]
# Q0-Q3: [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -? -?]
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -? -?]
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -? -?]
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -? -?]
#
# Q4-Q7: [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞]
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞]
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞]
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞]
# Q4-Q7: [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -? -?]
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -? -?]
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -? -?]
# [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -? -?]

n_front_skip_blocks = 0 # never skips the left side
n_front_masked_blocks = 0 # ditto
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _causal_conv1d_update_split_qkv_kernel_v2(
per_cu_batchs = batch // batch_cus
cu_mores = batch % batch_cus

# 原理:负数的最高位为1,右移后为-1;正数右移后为0 → 取反+1后映射为1/0
# ??:???????1,????-1;??????0 -> ??+1????1/0
# x = x - thresh
# t = (x >> (x.bit_length() - 1)) ^ 1
# r = val0 + (val1 - val0) * t
Expand Down Expand Up @@ -1009,7 +1009,7 @@ def causal_conv1d_update_split_qkv(
_, width = weight.shape
num_cache_lines, _, state_len = conv_state.size()

# 创建输出 buffer(已经是分离的!)
# ???? buffer(??????!)
query = torch.empty(
(batch, key_dim, seqlen),
dtype=x.dtype,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (C) 2023-2026, Songlin Yang, Yu Zhang

import contextlib
import functools
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

from .cumsum import (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
2 changes: 1 addition & 1 deletion aiter/ops/triton/attention/fav3_sage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

from __future__ import annotations
from typing import Optional, Tuple
Expand Down
2 changes: 1 addition & 1 deletion aiter/ops/triton/gated_delta_net/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
2 changes: 1 addition & 1 deletion aiter/ops/triton/gated_delta_net/gated_delta_rule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Adapted from flash-linear-attention: Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

"""
Expand Down
2 changes: 1 addition & 1 deletion aiter/ops/triton/utils/_triton/pid_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: MIT

# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton
import triton.language as tl
Expand Down
2 changes: 1 addition & 1 deletion csrc/include/aiter_hip_common.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <cstdint>
Expand Down
Loading
Loading