Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c35c8ac
add initial implementation of projection mapping
anhminhnguyenhoang Jan 16, 2026
2a8325c
Refactor mHC kernel and wrapper to include sigmoid activation in proj…
waqahmed-amd-fi Jan 19, 2026
f11244d
Add Sinkhorn-Knopp log-domain kernel implementation
anhminhnguyenhoang Jan 20, 2026
87e5839
clean up sinkhorn-knopp tests
anhminhnguyenhoang Jan 20, 2026
152bb21
review invalid test case
anhminhnguyenhoang Jan 20, 2026
18a62a1
Refactor mHC kernel and wrapper to implement equations 14-18 as fused…
waqahmed-amd-fi Jan 20, 2026
36e36d6
Fix H dims
waqahmed-amd-fi Jan 20, 2026
6156c13
fix test_mhc_output_range
waqahmed-amd-fi Jan 20, 2026
80b8d34
Refactor test cases in mHC and Sinkhorn-Knopp implementations and sim…
anhminhnguyenhoang Jan 20, 2026
4801d63
Fix issues (#1878)
waqahmed-amd-fi Jan 21, 2026
d952829
optimization to loads x_tile once, reducing memory bandwidth
waqahmed-amd-fi Jan 21, 2026
30741ca
Update mHC implementation to apply Sinkhorn-Knopp (Equation 19) to ma…
waqahmed-amd-fi Jan 21, 2026
686711f
Refactor mHC implementation to separate projection (phi) matrices int…
waqahmed-amd-fi Jan 21, 2026
831572b
Enhance mHC fused kernel to implement stream-aware processing
anhminhnguyenhoang Jan 22, 2026
ad198b3
Refactor mHC implementation
anhminhnguyenhoang Jan 22, 2026
05810eb
Adjust tolerance levels in mHC tests based on input size to improve a…
anhminhnguyenhoang Jan 22, 2026
7d333a8
Add benchmark scripts for mHC kernel performance evaluation
waqahmed-amd-fi Jan 22, 2026
46d023a
add modes to bench
waqahmed-amd-fi Jan 23, 2026
e8f4464
- Add naive configs for fused mHC and Sinkhorn-Knopp kernels
anhminhnguyenhoang Jan 23, 2026
68a2df8
switch to using exp2/log2 for sinkhorn-knopp for optimization
anhminhnguyenhoang Jan 23, 2026
df24377
Sort benchmark configurations by hidden dimension and refine FLOPs ca…
waqahmed-amd-fi Jan 26, 2026
a4a1793
Refactor Sinkhorn-Knopp kernel to support batch processing
anhminhnguyenhoang Jan 26, 2026
23609ed
better tuned configs
anhminhnguyenhoang Jan 26, 2026
514f3f5
Refactor mHC fused kernel for improved arithmetic operations and clar…
waqahmed-amd-fi Jan 27, 2026
10b122a
Add split-K support to mHC kernel by new split and reduce kernels to …
anhminhnguyenhoang Jan 27, 2026
9f6aba3
add better config with split reduce usage
anhminhnguyenhoang Jan 27, 2026
833c47a
Apply optim in mhc_fused to split reduce kernels, rename functions fo…
anhminhnguyenhoang Jan 27, 2026
70490ad
Add json config loading
anhminhnguyenhoang Jan 28, 2026
b01c00a
Add tuned JSON configuration files for fused mhc kernels
anhminhnguyenhoang Jan 28, 2026
42c944d
inittial implementation of zero-iteration Sinkhorn-Knopp (mHC-Lite). …
waqahmed-amd-fi Jan 28, 2026
190122e
optimized zero-iteration Sinkhorn-Knopp (mHC-Lite)
waqahmed-amd-fi Jan 29, 2026
507d95a
Removed Unused Projection Code (Wrapper Function)
waqahmed-amd-fi Jan 29, 2026
c2ccafc
add config loading bug fix due to caching and better tuned configs
anhminhnguyenhoang Jan 29, 2026
7862631
2D grid parallelization. Key improvements:
waqahmed-amd-fi Jan 29, 2026
cd84075
remove _sinkhorn_knopp_lite, and implement mHC_lite i.e., non-iterati…
waqahmed-amd-fi Jan 29, 2026
3e85b9d
add config loading bug fix due to caching and better tuned configs
anhminhnguyenhoang Jan 29, 2026
5374eb6
add mhc-lite
anhminhnguyenhoang Jan 30, 2026
4f4c272
revised mHC and mHC-Lite description for clarity
waqahmed-amd-fi Jan 30, 2026
57c75ac
update comments and replace if-else with assert check
waqahmed-amd-fi Jan 30, 2026
68c76b0
revised _mhc_lite_fused_split_kernel kernel
waqahmed-amd-fi Jan 30, 2026
604426d
revised _mhc_lite_fused_reduce_kernel
waqahmed-amd-fi Jan 30, 2026
e52ebdc
add mhc-lite bench mode
anhminhnguyenhoang Jan 30, 2026
e70f3dd
integrate mhc-lite into mhc_fused
anhminhnguyenhoang Jan 30, 2026
41b4908
update config loading for mode
anhminhnguyenhoang Jan 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions aiter/ops/triton/fusions/mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,18 @@ def mhc(
out_res: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute mHC projection mapping with all three streams (equations 14-18).
Compute mHC projection mapping with all three streams (equations 14-19).

This function implements:
- Eq 14: H̃ = x̃φ (matrix multiplication)
- Eq 15: r = ||x̃||₂ / √(nC) (RMS normalization)
- Eq 16: [H^pre, H^post, H^res] = 1/r [α^pre·H̃^pre, α^post·H̃^post, α^res·H̃^res] + b
- Eq 17: H^pre = σ(H^pre) - sigmoid activation for pre-stream
- Eq 18: H^post = 2σ(H^post) - scaled sigmoid activation for post-stream
- H^res: identity (no activation, ready for Eq 19: Sinkhorn-Knopp)
- Eq 19: H^res = Sinkhorn(H^res) - project residual stream onto doubly stochastic
manifold (identity activation followed by iterative row/column normalization)

All operations are fused in a single Triton kernel for optimal performance.
All operations are fused in optimized Triton kernels for maximum performance.

Args:
x: Input tensor with shape (M, nC) where M is batch/sequence length and
Expand All @@ -66,7 +67,7 @@ def mhc(
Tuple of three tensors (H_pre, H_post, H_res):
- H_pre: (M, n) - manifold projection with sigmoid activation (H^{pre} ∈ ℝ^{M×n})
- H_post: (M, n) - post-processing with scaled sigmoid (H^{post} ∈ ℝ^{M×n})
- H_res: (M, n²) - residual connection, identity activation (H^{res} ∈ ℝ^{M×n²})
- H_res: (M, n²) - doubly stochastic residual connection (H^{res} ∈ ℝ^{M×n²})

Shape requirements:
- x: (M, nC) where nC = n * C (flattened streams)
Expand All @@ -82,8 +83,11 @@ def mhc(
>>> phi = torch.randn(nC, N_total, dtype=torch.bfloat16, device='cuda')
>>> bias = torch.randn(N_total, dtype=torch.float32, device='cuda')
>>> alpha_pre, alpha_post, alpha_res = 1.0, 1.5, 0.8
>>>
>>> # Full mHC with Sinkhorn-Knopp (Eq 14-19)
>>> H_pre, H_post, H_res = mhc(x, phi, alpha_pre, alpha_post, alpha_res, bias, n)
>>> H_pre.shape, H_post.shape, H_res.shape # (32, 4), (32, 4), (32, 16)
>>> # H_res is doubly stochastic: rows and columns sum to 1
"""
_LOGGER.info(
f"MHC: x={tuple(x.shape)} phi={tuple(phi.shape)} alpha_pre={alpha_pre} alpha_post={alpha_post} alpha_res={alpha_res}"
Expand Down Expand Up @@ -170,6 +174,11 @@ def mhc(
BLOCK_K=BLOCK_K,
)

# Apply Sinkhorn-Knopp (Equation 19) to make H_res doubly stochastic
# Reshape H_res from (M, n²) to (M, n, n) for Sinkhorn kernel
H_res_3d = out_res.view(M, n, n)
sinkhorn_knopp(H_res_3d, out=H_res_3d)

return out_pre, out_post, out_res


Expand Down
51 changes: 37 additions & 14 deletions op_tests/triton_tests/fusions/test_mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,13 @@ def test_mhc_correctness(M, n, C, dtype):
atol=1e-2,
rtol=1e-2,
)
# Relaxed tolerance for H_res due to Sinkhorn-Knopp iterative algorithm
# which amplifies small numerical differences, especially with bfloat16
torch.testing.assert_close(
H_res_triton.to(torch.float32),
H_res_torch.to(torch.float32),
atol=1e-2,
rtol=1e-2,
atol=5e-2,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you run into test failure because of this for similar tests that you need to relax the tolerance?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, mainly because of sinkhorn which is an iterative process and returns higher differences due you only 10 iterations. May be we can try 20 for better results?

rtol=5e-2,
)


Expand Down Expand Up @@ -111,11 +113,12 @@ def test_mhc_preallocated_output(M, n, C):
atol=1e-2,
rtol=1e-2,
)
# Relaxed tolerance for H_res due to Sinkhorn-Knopp iterative algorithm
torch.testing.assert_close(
out_res.to(torch.float32),
H_res_torch.to(torch.float32),
atol=1e-2,
rtol=1e-2,
atol=5e-2,
rtol=5e-2,
)


Expand All @@ -135,11 +138,15 @@ def test_mhc_different_epsilon(eps, M, n, C):
H_pre_triton, H_post_triton, H_res_triton = mhc(x, phi, alpha_pre, alpha_post, alpha_res, bias, n_streams, eps=eps)

for torch_out, triton_out in [(H_pre_torch, H_pre_triton), (H_post_torch, H_post_triton), (H_res_torch, H_res_triton)]:
# Use relaxed tolerance for H_res due to Sinkhorn-Knopp
is_res = torch_out is H_res_torch
atol = 5e-2 if is_res else 1e-2
rtol = 5e-2 if is_res else 1e-2
torch.testing.assert_close(
triton_out.to(torch.float32),
torch_out.to(torch.float32),
atol=1e-2,
rtol=1e-2,
atol=atol,
rtol=rtol,
)


Expand All @@ -164,11 +171,15 @@ def test_mhc_different_alpha(alpha_scale):
H_pre_triton, H_post_triton, H_res_triton = mhc(x, phi, alpha_pre, alpha_post, alpha_res, bias, n_streams)

for torch_out, triton_out in [(H_pre_torch, H_pre_triton), (H_post_torch, H_post_triton), (H_res_torch, H_res_triton)]:
# Use relaxed tolerance for H_res due to Sinkhorn-Knopp
is_res = torch_out is H_res_torch
atol = 5e-2 if is_res else 1e-2
rtol = 5e-2 if is_res else 1e-2
torch.testing.assert_close(
triton_out.to(torch.float32),
torch_out.to(torch.float32),
atol=1e-2,
rtol=1e-2,
atol=atol,
rtol=rtol,
)


Expand All @@ -193,11 +204,15 @@ def test_mhc_zero_input():
H_pre_triton, H_post_triton, H_res_triton = mhc(x, phi, alpha_pre, alpha_post, alpha_res, bias, n)

for torch_out, triton_out in [(H_pre_torch, H_pre_triton), (H_post_torch, H_post_triton), (H_res_torch, H_res_triton)]:
# Use relaxed tolerance for H_res due to Sinkhorn-Knopp
is_res = torch_out is H_res_torch
atol = 5e-2 if is_res else 1e-2
rtol = 5e-2 if is_res else 1e-2
torch.testing.assert_close(
triton_out.to(torch.float32),
torch_out.to(torch.float32),
atol=1e-2,
rtol=1e-2,
atol=atol,
rtol=rtol,
)


Expand All @@ -222,11 +237,15 @@ def test_mhc_large_values():
H_pre_triton, H_post_triton, H_res_triton = mhc(x, phi, alpha_pre, alpha_post, alpha_res, bias, n)

for torch_out, triton_out in [(H_pre_torch, H_pre_triton), (H_post_torch, H_post_triton), (H_res_torch, H_res_triton)]:
# Use even more relaxed tolerance for large values + Sinkhorn-Knopp
is_res = torch_out is H_res_torch
atol = 0.2 if is_res else 0.1
rtol = 0.1 if is_res else 0.05
torch.testing.assert_close(
triton_out.to(torch.float32),
torch_out.to(torch.float32),
atol=0.1,
rtol=0.05,
atol=atol,
rtol=rtol,
)


Expand All @@ -247,11 +266,15 @@ def test_mhc_small_shapes(M, n, C, dtype):
H_pre_triton, H_post_triton, H_res_triton = mhc(x, phi, alpha_pre, alpha_post, alpha_res, bias, n_streams)

for torch_out, triton_out in [(H_pre_torch, H_pre_triton), (H_post_torch, H_post_triton), (H_res_torch, H_res_triton)]:
# Use relaxed tolerance for H_res due to Sinkhorn-Knopp
is_res = torch_out is H_res_torch
atol = 5e-2 if is_res else 1e-2
rtol = 5e-2 if is_res else 1e-2
torch.testing.assert_close(
triton_out.to(torch.float32),
torch_out.to(torch.float32),
atol=1e-2,
rtol=1e-2,
atol=atol,
rtol=rtol,
)


Expand Down
17 changes: 9 additions & 8 deletions op_tests/triton_tests/utils/mhc_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def mhc_torch(
eps: float = 1e-6,
) -> torch.Tensor:
"""
PyTorch reference implementation of mHC projection mapping (Eq 14-18).
PyTorch reference implementation of mHC projection mapping (Eq 14-19).

This serves as ground truth for validating the Triton kernel implementation.

Expand All @@ -52,7 +52,7 @@ def mhc_torch(
- Eq 16: [H^pre, H^post, H^res] = 1/r [α^pre·H̃^pre, α^post·H̃^post, α^res·H̃^res] + b (scaling)
- Eq 17: H^pre = σ(H^pre) (sigmoid activation for pre-stream)
- Eq 18: H^post = 2σ(H^post) (scaled sigmoid activation for post-stream)
- H^res: identity (no activation, ready for Eq 19: Sinkhorn-Knopp)
- Eq 19: H^res = Sinkhorn(H^res) (project residual stream onto doubly stochastic manifold)

Args:
x: Input x_l with shape (M, nC) - flattened n-stream residual
Expand All @@ -68,10 +68,9 @@ def mhc_torch(
Tuple of three tensors (H_pre, H_post, H_res):
- H_pre: (M, n) manifold projection with sigmoid
- H_post: (M, n) post-processing with 2*sigmoid
- H_res: (M, n²) residual connection (identity)
- H_res: (M, n²) doubly stochastic residual connection
"""
x_f32 = x.to(torch.float32)
nC = x.shape[1]

# Eq 15: r = ||x̃||₂ / √(nC)
mean_sq = torch.mean(x_f32 ** 2, dim=-1, keepdim=True)
Expand All @@ -83,7 +82,6 @@ def mhc_torch(
H_tilde = x_norm @ phi_f32

# Split into three streams
n_squared = n * n
H_tilde_pre = H_tilde[:, :n] # n coefficients (H^{pre} ∈ ℝ^{1×n})
H_tilde_post = H_tilde[:, n:2*n] # n coefficients (H^{post} ∈ ℝ^{1×n})
H_tilde_res = H_tilde[:, 2*n:] # n² coefficients (H^{res} ∈ ℝ^{n×n})
Expand All @@ -108,9 +106,12 @@ def mhc_torch(
# H^post = 2σ(H^post)
H_post = 2.0 * torch.sigmoid(H_post)

# H^res: identity activation (no change)
# Preserves values for subsequent Sinkhorn-Knopp normalization (Eq 19)
# H_res stays as is
# Eq 19: Apply Sinkhorn-Knopp to H^res for doubly stochastic constraint
# Reshape H_res from (M, n²) to (M, n, n) for Sinkhorn algorithm
M = H_res.shape[0]
H_res_3d = H_res.view(M, n, n)
H_res_ds = sinkhorn_knopp_log_domain_torch(H_res_3d)
H_res = H_res_ds.view(M, -1) # Reshape back to (M, n²)

# Return three separate streams
return H_pre.to(x.dtype), H_post.to(x.dtype), H_res.to(x.dtype)
Expand Down