Skip to content

[JAX] Add collective GEMM without compute/communication overlap #1675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

philipphack
Copy link

Description

Rebase of #1307:

Implements XLA custom calls in C++, and the corresponding JAX primitive including custom partitioning rules.

Custom partitioning rules for a LHS:([B,] M, K) x RHS:([B,] K, N) = OUT:([B,] M, N) batched mat-mul operation where [B] is the batch dimension:

Preserve the partitioning of the [B] dimension for all operands.
Always all-gather LHS along the M dimension.
Error out if RHS is partitioned in both K and N dimensions.
Force the K dimension of LHS to match the partitioning of the K dimension of RHS.
If K dimension is partitioned but M dimension is not, jax.lax.psum (all-reduce) the output over the TP mesh resource.
If both the M and K dimensions are partitioned, jax.lax.psum_scatter (reduce-scatter) the output over the TP mesh resource.
In practice, the RHS matrix (typically the weight tensor) should be allocated with transposed contracting dimensions ([B,] N, K) for optimal GEMM heuristics in cuBlasLt. This layout is also mandatory for FP8 inputs.

This PR does NOT update fused ops or Flax/Praxis modules to use the new GEMM custom op over the existing XLA pattern matching approach.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

Added JAX primitive for the XLA custom call.
Added serial unit test.
Added distributed unit test.

@ptrendx ptrendx requested a review from denera April 15, 2025 19:35
Copy link
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending very minor docstring fix.

use_split_accumulator,
):
"""
Fused attention fwd lowering rules
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like leftover incorrect docstring from the copied primitive template.

Comment on lines +37 to +44
def _jax_cast_fp8(inputs, scale, amax, out_dtype):
"""
JAX native fp8 casting implementation
"""
casted_output = _jax_quantize(inputs, scale, dq_dtype=out_dtype)
updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype))
return casted_output, updated_amax

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _jax_cast_fp8(inputs, scale, amax, out_dtype):
"""
JAX native fp8 casting implementation
"""
casted_output = _jax_quantize(inputs, scale, dq_dtype=out_dtype)
updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype))
return casted_output, updated_amax

Please use _jax_quantize() instead.

_load_library()
if module_name not in sys.modules:
_load_library()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
Any reasons for these changes?

@@ -101,7 +103,6 @@ def _load_library():
)

__all__ = [
"fp8_autocast",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do need to export fp8_autocast.

Comment on lines +1196 to +1201
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_scale_inv = lhs.scale_inv.reshape(-1)
rhs_scale_inv = rhs.scale_inv.reshape(-1)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
lhs_scale_inv = lhs_scale_inv.reshape(-1)
rhs_scale_inv = rhs_scale_inv.reshape(-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

  1. Why do we need to reshape the scale_inv for DelayedScaling?
  2. For MXFP8, don't we need to call swizzle_scale?

Result_Type out_amax_updated, Result_Type out_scale_updated,
Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace,
bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad,
bool accumulate, bool use_split_accumulator);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to bind the scaling_mode?

auto workspace_ = TensorWrapper(workspace, std::vector<size_t>{workspace_size}, DType::kByte);

// cuBLAS is column-major, so we swap LHS and RHS in the arguments
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think cuda::sm_count() involves initializing a new cudaDeviceProp which may break cudaGraph.
Please query the sm_count via the handler instead, as:

auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin;

@@ -59,6 +59,7 @@ pybind11::dict Registrations() {
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler));

dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the prepare phase as in the te_grouped_gemm_ffi.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants