-
Notifications
You must be signed in to change notification settings - Fork 402
[JAX] Add collective GEMM without compute/communication overlap #1675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Philipp Hack <[email protected]>
Signed-off-by: Philipp Hack <[email protected]>
Signed-off-by: Philipp Hack <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, pending very minor docstring fix.
use_split_accumulator, | ||
): | ||
""" | ||
Fused attention fwd lowering rules |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like leftover incorrect docstring from the copied primitive template.
def _jax_cast_fp8(inputs, scale, amax, out_dtype): | ||
""" | ||
JAX native fp8 casting implementation | ||
""" | ||
casted_output = _jax_quantize(inputs, scale, dq_dtype=out_dtype) | ||
updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype)) | ||
return casted_output, updated_amax | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _jax_cast_fp8(inputs, scale, amax, out_dtype): | |
""" | |
JAX native fp8 casting implementation | |
""" | |
casted_output = _jax_quantize(inputs, scale, dq_dtype=out_dtype) | |
updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype)) | |
return casted_output, updated_amax |
Please use _jax_quantize()
instead.
_load_library() | ||
if module_name not in sys.modules: | ||
_load_library() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
Any reasons for these changes?
@@ -101,7 +103,6 @@ def _load_library(): | |||
) | |||
|
|||
__all__ = [ | |||
"fp8_autocast", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we do need to export fp8_autocast
.
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: | ||
lhs_scale_inv = lhs.scale_inv.reshape(-1) | ||
rhs_scale_inv = rhs.scale_inv.reshape(-1) | ||
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: | ||
lhs_scale_inv = lhs_scale_inv.reshape(-1) | ||
rhs_scale_inv = rhs_scale_inv.reshape(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
- Why do we need to reshape the scale_inv for DelayedScaling?
- For MXFP8, don't we need to call
swizzle_scale
?
Result_Type out_amax_updated, Result_Type out_scale_updated, | ||
Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace, | ||
bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, | ||
bool accumulate, bool use_split_accumulator); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you need to bind the scaling_mode
?
auto workspace_ = TensorWrapper(workspace, std::vector<size_t>{workspace_size}, DType::kByte); | ||
|
||
// cuBLAS is column-major, so we swap LHS and RHS in the arguments | ||
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think cuda::sm_count()
involves initializing a new cudaDeviceProp
which may break cudaGraph.
Please query the sm_count
via the handler instead, as:
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin;
@@ -59,6 +59,7 @@ pybind11::dict Registrations() { | |||
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), | |||
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); | |||
|
|||
dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the prepare
phase as in the te_grouped_gemm_ffi
.
Description
Rebase of #1307:
Implements XLA custom calls in C++, and the corresponding JAX primitive including custom partitioning rules.
Custom partitioning rules for a LHS:([B,] M, K) x RHS:([B,] K, N) = OUT:([B,] M, N) batched mat-mul operation where [B] is the batch dimension:
Preserve the partitioning of the [B] dimension for all operands.
Always all-gather LHS along the M dimension.
Error out if RHS is partitioned in both K and N dimensions.
Force the K dimension of LHS to match the partitioning of the K dimension of RHS.
If K dimension is partitioned but M dimension is not, jax.lax.psum (all-reduce) the output over the TP mesh resource.
If both the M and K dimensions are partitioned, jax.lax.psum_scatter (reduce-scatter) the output over the TP mesh resource.
In practice, the RHS matrix (typically the weight tensor) should be allocated with transposed contracting dimensions ([B,] N, K) for optimal GEMM heuristics in cuBlasLt. This layout is also mandatory for FP8 inputs.
This PR does NOT update fused ops or Flax/Praxis modules to use the new GEMM custom op over the existing XLA pattern matching approach.
Type of change
Changes
Please list the changes introduced in this PR:
Added JAX primitive for the XLA custom call.
Added serial unit test.
Added distributed unit test.