Skip to content
Draft
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c35c8ac
add initial implementation of projection mapping
anhminhnguyenhoang Jan 16, 2026
2a8325c
Refactor mHC kernel and wrapper to include sigmoid activation in proj…
waqahmed-amd-fi Jan 19, 2026
f11244d
Add Sinkhorn-Knopp log-domain kernel implementation
anhminhnguyenhoang Jan 20, 2026
87e5839
clean up sinkhorn-knopp tests
anhminhnguyenhoang Jan 20, 2026
152bb21
review invalid test case
anhminhnguyenhoang Jan 20, 2026
18a62a1
Refactor mHC kernel and wrapper to implement equations 14-18 as fused…
waqahmed-amd-fi Jan 20, 2026
36e36d6
Fix H dims
waqahmed-amd-fi Jan 20, 2026
6156c13
fix test_mhc_output_range
waqahmed-amd-fi Jan 20, 2026
80b8d34
Refactor test cases in mHC and Sinkhorn-Knopp implementations and sim…
anhminhnguyenhoang Jan 20, 2026
4801d63
Fix issues (#1878)
waqahmed-amd-fi Jan 21, 2026
d952829
optimization to loads x_tile once, reducing memory bandwidth
waqahmed-amd-fi Jan 21, 2026
30741ca
Update mHC implementation to apply Sinkhorn-Knopp (Equation 19) to ma…
waqahmed-amd-fi Jan 21, 2026
686711f
Refactor mHC implementation to separate projection (phi) matrices int…
waqahmed-amd-fi Jan 21, 2026
831572b
Enhance mHC fused kernel to implement stream-aware processing
anhminhnguyenhoang Jan 22, 2026
ad198b3
Refactor mHC implementation
anhminhnguyenhoang Jan 22, 2026
05810eb
Adjust tolerance levels in mHC tests based on input size to improve a…
anhminhnguyenhoang Jan 22, 2026
7d333a8
Add benchmark scripts for mHC kernel performance evaluation
waqahmed-amd-fi Jan 22, 2026
46d023a
add modes to bench
waqahmed-amd-fi Jan 23, 2026
e8f4464
- Add naive configs for fused mHC and Sinkhorn-Knopp kernels
anhminhnguyenhoang Jan 23, 2026
68a2df8
switch to using exp2/log2 for sinkhorn-knopp for optimization
anhminhnguyenhoang Jan 23, 2026
df24377
Sort benchmark configurations by hidden dimension and refine FLOPs ca…
waqahmed-amd-fi Jan 26, 2026
a4a1793
Refactor Sinkhorn-Knopp kernel to support batch processing
anhminhnguyenhoang Jan 26, 2026
23609ed
better tuned configs
anhminhnguyenhoang Jan 26, 2026
514f3f5
Refactor mHC fused kernel for improved arithmetic operations and clar…
waqahmed-amd-fi Jan 27, 2026
10b122a
Add split-K support to mHC kernel by new split and reduce kernels to …
anhminhnguyenhoang Jan 27, 2026
9f6aba3
add better config with split reduce usage
anhminhnguyenhoang Jan 27, 2026
833c47a
Apply optim in mhc_fused to split reduce kernels, rename functions fo…
anhminhnguyenhoang Jan 27, 2026
70490ad
Add json config loading
anhminhnguyenhoang Jan 28, 2026
b01c00a
Add tuned JSON configuration files for fused mhc kernels
anhminhnguyenhoang Jan 28, 2026
42c944d
inittial implementation of zero-iteration Sinkhorn-Knopp (mHC-Lite). …
waqahmed-amd-fi Jan 28, 2026
190122e
optimized zero-iteration Sinkhorn-Knopp (mHC-Lite)
waqahmed-amd-fi Jan 29, 2026
507d95a
Removed Unused Projection Code (Wrapper Function)
waqahmed-amd-fi Jan 29, 2026
c2ccafc
add config loading bug fix due to caching and better tuned configs
anhminhnguyenhoang Jan 29, 2026
7862631
2D grid parallelization. Key improvements:
waqahmed-amd-fi Jan 29, 2026
cd84075
remove _sinkhorn_knopp_lite, and implement mHC_lite i.e., non-iterati…
waqahmed-amd-fi Jan 29, 2026
3e85b9d
add config loading bug fix due to caching and better tuned configs
anhminhnguyenhoang Jan 29, 2026
5374eb6
add mhc-lite
anhminhnguyenhoang Jan 30, 2026
4f4c272
revised mHC and mHC-Lite description for clarity
waqahmed-amd-fi Jan 30, 2026
57c75ac
update comments and replace if-else with assert check
waqahmed-amd-fi Jan 30, 2026
68c76b0
revised _mhc_lite_fused_split_kernel kernel
waqahmed-amd-fi Jan 30, 2026
604426d
revised _mhc_lite_fused_reduce_kernel
waqahmed-amd-fi Jan 30, 2026
e52ebdc
add mhc-lite bench mode
anhminhnguyenhoang Jan 30, 2026
e70f3dd
integrate mhc-lite into mhc_fused
anhminhnguyenhoang Jan 30, 2026
41b4908
update config loading for mode
anhminhnguyenhoang Jan 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel updated 76 files
+6 −6 .github/CODEOWNERS
+13 −3 CMakeLists.txt
+6 −34 Jenkinsfile
+16 −0 README.md
+30 −0 example/CMakeLists.txt
+1 −1 example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
+12 −3 example/ck_tile/38_block_scale_gemm/CMakeLists.txt
+6 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp
+6 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp
+6 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp
+6 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp
+0 −222 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp
+53 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp
+57 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp
+53 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp
+57 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp
+0 −62 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp
+50 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp
+52 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp
+50 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp
+52 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp
+0 −270 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp
+55 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp
+59 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp
+55 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp
+59 −0 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp
+33 −6 example/ck_tile/38_block_scale_gemm/gemm_quant.cpp
+35 −26 experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp
+86 −641 experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp
+84 −0 ...l/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+84 −0 ...uilder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+84 −0 ...nclude/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
+739 −0 experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp
+8 −0 experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp
+8 −0 ...ilder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+8 −0 ...er/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+8 −0 ...de/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
+2 −1 experimental/builder/test/CMakeLists.txt
+78 −78 experimental/builder/test/conv/ck/test_conv_traits.cpp
+0 −1,127 experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp
+800 −0 experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp
+262 −0 experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp
+175 −63 include/ck/library/utility/gpu_verification.hpp
+4 −0 include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp
+8 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+4 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp
+4 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp
+1 −1 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+1 −1 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
+59 −17 include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
+2 −10 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+17 −1 include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp
+50 −27 include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp
+5 −4 include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp
+3 −1 include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp
+3 −1 include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp
+48 −26 include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp
+0 −2 include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp
+22 −13 include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp
+54 −56 profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp
+44 −62 profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
+44 −30 profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp
+254 −0 script/run-tests.ps1
+26 −0 test/ck_tile/gemm_block_scale/CMakeLists.txt
+39 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp
+54 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp
+41 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp
+63 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp
+8 −5 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp
+10 −5 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp
+22 −29 test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp
+4 −3 test/ck_tile/gemm_streamk/CMakeLists.txt
+1 −1 test/gpu_verification/test_gpu_verification.cpp
+21 −17 test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+4 −4 test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp
+4 −4 test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp
12 changes: 12 additions & 0 deletions aiter/ops/triton/_triton_kernels/fusions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

from aiter.ops.triton._triton_kernels.fusions.mhc import (
_mhc_fused_kernel,
_sinkhorn_knopp_log_domain_kernel,
)

__all__ = [
"_mhc_fused_kernel",
"_sinkhorn_knopp_log_domain_kernel",
]
260 changes: 260 additions & 0 deletions aiter/ops/triton/_triton_kernels/fusions/mhc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

"""
Triton kernel for mHC (manifold-constrained Hyper Connection) operations.

Implements equations 14-18 from the mHC paper in a single fused kernel:
- Eq 14: H̃ = x̃φ (matrix multiplication)
- Eq 15: r = ||x̃||₂ / √(nC) (RMS normalization)
- Eq 16: [H^pre, H^post, H^res] = 1/r [α^pre·H̃^pre, α^post·H̃^post, α^res·H̃^res] + b
- Eq 17: H^pre = σ(H^pre)
- Eq 18: H^post = 2σ(H^post)
- H^res: identity (no activation, ready for equation 19: Sinkhorn-Knopp)

Single fused kernel minimizes memory traffic and kernel launch overhead.
"""

import triton
import triton.language as tl


@triton.jit
def _mhc_fused_kernel(
x_ptr,
phi_ptr,
alpha_pre,
alpha_post,
alpha_res,
bias_ptr,
out_pre_ptr,
out_post_ptr,
out_res_ptr,
M: tl.constexpr, # rows: x.shape[0] - the batch/sequence dimension. Represents how many input vectors we're processing
K: tl.constexpr, # input features: nC = x.shape[1] - must match phi.shape[0]. Called nC in the paper (n × C where C is some latent dimension)
N: tl.constexpr, # output features: n² + 2n - total output dimension split into 3 streams (pre: n, post: n, res: n²)
n: tl.constexpr, # stream parameter: n - Hyperparameter from paper controlling manifold dimension. Determines stream sizes
eps: tl.constexpr, # epsilon for numerical stability in RMSNorm
stride_xm,
stride_xk,
stride_phik,
stride_phin,
stride_pre_m,
stride_pre_n,
stride_post_m,
stride_post_n,
stride_res_m,
stride_res_n,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""
Fused kernel for equations 14-18.

Computes three separate outputs:
- H^pre: (M, n) with sigmoid activation (H^{pre} ∈ ℝ^{1×n})
- H^post: (M, n) with 2*sigmoid activation (H^{post} ∈ ℝ^{1×n})
- H^res: (M, n²) with identity (no activation) (H^{res} ∈ ℝ^{n×n})

All operations fused in a single kernel pass for maximum efficiency.
"""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

# Row and column indices
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

# Eq 14 & 15: Compute matrix multiplication and RMS norm in single pass
acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
acc_sq = tl.zeros([BLOCK_M], dtype=tl.float32)
for k in range(0, K, BLOCK_K):
rk = k + tl.arange(0, BLOCK_K)

x_tile = tl.load(
x_ptr + rm[:, None] * stride_xm + rk[None, :] * stride_xk,
mask=(rm[:, None] < M) & (rk[None, :] < K),
other=0.0,
)

phi_tile = tl.load(
phi_ptr + rk[:, None] * stride_phik + rn[None, :] * stride_phin,
mask=(rk[:, None] < K) & (rn[None, :] < N),
other=0.0,
)

# Eq 14: Accumulate matrix multiplication H̃ = x̃φ
acc += tl.dot(x_tile, phi_tile)

# Eq 15: Accumulate squared sum for RMS norm r = ||x̃||₂ / √(nC)
x_tile_f32 = x_tile.to(tl.float32)
acc_sq += tl.sum(x_tile_f32 * x_tile_f32, axis=1)

rms = tl.sqrt(acc_sq / K + eps)
# Performance optimization: compute 1/r once instead of dividing N times
# Division is ~10x slower than multiplication on GPUs
rsigma = 1.0 / rms

# Load bias
bias = tl.load(bias_ptr + rn, mask=rn < N, other=0.0).to(tl.float32)

# Eq 16: Apply stream-specific scaling and bias
# Output is split into 3 contiguous streams:
# Pre-stream: indices [0, n) - n elements for manifold projection (H^{pre} ∈ ℝ^{1×n})
# Post-stream: indices [n, 2n) - n elements for post-processing (H^{post} ∈ ℝ^{1×n})
# Res-stream: indices [2n, 2n+n²) - n² elements for residual connections (H^{res} ∈ ℝ^{n×n})
n_pre_end = n # End of pre-stream
n_post_end = 2 * n # End of post-stream

# Create boolean masks to identify which stream each output column belongs to
is_pre = rn < n_pre_end
is_post = (rn >= n_pre_end) & (rn < n_post_end)

# Select the appropriate scaling factor (alpha) for each stream
# This creates a vector where each element has its stream-specific alpha
alpha = tl.where(is_pre, alpha_pre,
tl.where(is_post, alpha_post, alpha_res))

# Apply Eq 16: H = (1/r) * α * H̃ + b
# rsigma[:, None] broadcasts 1/r across columns (per-row normalization)
# alpha is per-column (stream-specific scaling)
# acc is the matrix product H̃ from Eq 14
out = rsigma[:, None] * alpha * acc + bias[None, :]

# Apply stream-specific activations and store to separate output buffers

# Pre-stream (Eq 17): H^pre = σ(H^pre) - sigmoid activation
# Columns [0:n] go to out_pre
out_pre = tl.sigmoid(out)
rn_pre = rn # Pre-stream columns [0:n]
tl.store(
out_pre_ptr + rm[:, None] * stride_pre_m + rn_pre[None, :] * stride_pre_n,
out_pre,
mask=(rm[:, None] < M) & is_pre[None, :] & (rn_pre[None, :] >= 0) & (rn_pre[None, :] < n),
)

# Post-stream (Eq 18): H^post = 2σ(H^post) - scaled sigmoid activation
# Columns [n:2n] go to out_post
out_post = 2.0 * tl.sigmoid(out)
rn_post = rn - n # Map global column index to post-stream local index [0:n]
tl.store(
out_post_ptr + rm[:, None] * stride_post_m + rn_post[None, :] * stride_post_n,
out_post,
mask=(rm[:, None] < M) & is_post[None, :] & (rn_post[None, :] >= 0) & (rn_post[None, :] < n),
)

# Res-stream: H^res remains unchanged (identity activation)
# Columns [2n:2n+n²] go to out_res
# This preserves the values for subsequent Sinkhorn-Knopp normalization (Eq 19)
is_res = rn >= n_post_end
n_squared = n * n
out_res = out
rn_res = rn - (2 * n) # Map global column index to res-stream local index [0:n²]
tl.store(
out_res_ptr + rm[:, None] * stride_res_m + rn_res[None, :] * stride_res_n,
out_res,
mask=(rm[:, None] < M) & is_res[None, :] & (rn_res[None, :] >= 0) & (rn_res[None, :] < n_squared),
)


@triton.jit
def _sinkhorn_knopp_log_domain_kernel(
# Pointers
logits_ptr, # Input: (M, N, N) raw logits
out_ptr, # Output: (M, N, N) doubly stochastic matrices
# Dimensions
M, # Batch size (number of matrices)
# Strides
stride_batch, # Stride for batch dimension
stride_row, # Stride for row dimension
stride_col, # Stride for column dimension
# Meta-parameters
N: tl.constexpr, # Matrix size (must be power of 2, max 64)
NUM_ITERS: tl.constexpr, # Number of Sinkhorn iterations
):
"""
Log-domain Sinkhorn-Knopp kernel for projecting raw logits onto doubly stochastic matrices.

Computes doubly stochastic matrix P where all rows and columns sum to 1.

Grid: (M,) - one program per batch element

Reference algorithm (Exponential Domain)- Sinkhorn & Knopp (1967):
──────────────────────────────────────────────────────
1. P = exp(A) # Ensure positivity
2. For each iteration:
- P = P / P.sum(axis=cols) # Row normalize
- P = P / P.sum(axis=rows) # Col normalize
3. Output: P

Problem: exp(large) → Inf, exp(-large) → 0, causing overflow/underflow.

Implementation algorithm (Log Domain):
───────────────────────────────────────────────────────
1. log_u = 0, log_v = 0
2. For each iteration:
- log_u = -logsumexp(A + log_v, axis=cols) # Row normalize
- log_v = -logsumexp(A + log_u, axis=rows) # Col normalize
3. Output: P = exp(A + log_u + log_v)

Key insight: Division becomes subtraction in log space.
logsumexp uses stable formula: max(x) + log(Σ exp(x - max(x)))

"""
batch_idx = tl.program_id(axis=0)

if batch_idx >= M:
return

# Base offset for this batch
batch_offset = batch_idx * stride_batch

# Compute flat indices within this batch's matrix
row_idx = tl.arange(0, N)[:, None] # (N, 1)
col_idx = tl.arange(0, N)[None, :] # (1, N)
flat_idx = row_idx * stride_row + col_idx * stride_col

# Load the NxN matrix (raw logits) in log domain
log_A = tl.load(logits_ptr + batch_offset + flat_idx).to(tl.float32)

# Initialize log scaling factors
# Initially u = v = 1 (no scaling), so log(1) = 0,
log_u = tl.zeros((N,), dtype=tl.float32) # Row scalings
log_v = tl.zeros((N,), dtype=tl.float32) # Column scalings

# Iterate and alternate between row and column normalization.
for _ in range(NUM_ITERS):
# Add column scaling: scaled[i,j] = log_A[i,j] + log_v[j]
scaled_row = log_A + log_v[None, :] # (N, N)

# Compute max per row for numerical stability (prevents overflow in exp)
row_max = tl.max(scaled_row, axis=1) # (N,)

# Compute logsumexp per row
exp_shifted = tl.exp(scaled_row - row_max[:, None])
row_sum_exp = tl.sum(exp_shifted, axis=1) # (N,)
log_row_sums = row_max + tl.log(row_sum_exp) # (N,)

# Update row scaling: log_u = -log(row_sum) to normalize rows to 1
log_u = -log_row_sums

# Add row scaling: scaled[i,j] = log_A[i,j] + log_u[i]
scaled_col = log_A + log_u[:, None] # (N, N)

# Compute max per column for numerical stability
col_max = tl.max(scaled_col, axis=0) # (N,)

# Compute logsumexp per column
exp_shifted = tl.exp(scaled_col - col_max[None, :])
col_sum_exp = tl.sum(exp_shifted, axis=0) # (N,)
log_col_sums = col_max + tl.log(col_sum_exp) # (N,)

# Update column scaling: log_v = -log(col_sum) to normalize cols to 1
log_v = -log_col_sums

# Combine base logits with accumulated scaling factors:
log_P = log_A + log_u[:, None] + log_v[None, :]
P = tl.exp(log_P)

tl.store(out_ptr + batch_offset + flat_idx, P.to(out_ptr.dtype.element_ty))
9 changes: 9 additions & 0 deletions aiter/ops/triton/fusions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

from aiter.ops.triton.fusions.mhc import mhc, sinkhorn_knopp

__all__ = [
"mhc",
"sinkhorn_knopp",
]
Loading