Skip to content

[JAX] grouped_gemm() uses variadic arguments #1658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 14, 2025

Conversation

huanghua1994
Copy link
Collaborator

@huanghua1994 huanghua1994 commented Apr 8, 2025

Description

This PR optimizes the grouped_gemm() implementation in JAX. The original implementation manually flattens all input matrices before lowering to C++ function and manually split the output into a list of tensors. Using variadic arguments allows the code to avoid extra copying of inputs and outputs.

This PR is marked as a draft since PR #1545 breaks both the original implementation and this implementation for MXFP8 on Blackwell. Fixed by PR #1652.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Rewrite grouped_gemm() and GroupedGemmPrimitive using variadic arguments.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@huanghua1994 huanghua1994 requested a review from phu0ngng April 8, 2025 20:08
@huanghua1994 huanghua1994 self-assigned this Apr 8, 2025
@huanghua1994 huanghua1994 force-pushed the JAX-GroupedGEMM-VariadicArgs branch from d9b9aef to 600029a Compare April 8, 2025 21:28
@huanghua1994 huanghua1994 changed the title [Draft][JAX] grouped_gemm() uses variadic arguments [JAX] grouped_gemm() uses variadic arguments Apr 8, 2025
@huanghua1994 huanghua1994 force-pushed the JAX-GroupedGEMM-VariadicArgs branch from 600029a to 3af8804 Compare April 10, 2025 16:53
@phu0ngng
Copy link
Collaborator

/te-ci jax L0

@huanghua1994 huanghua1994 force-pushed the JAX-GroupedGEMM-VariadicArgs branch from 5d581ee to 292f7a9 Compare April 10, 2025 22:41
@huanghua1994 huanghua1994 force-pushed the JAX-GroupedGEMM-VariadicArgs branch from 292f7a9 to c850c00 Compare April 11, 2025 16:27
@phu0ngng
Copy link
Collaborator

/te-ci jax L0

@phu0ngng
Copy link
Collaborator

Pipeline #26823920 passed.
Ready to merge.

@phu0ngng phu0ngng merged commit 98b4c0d into NVIDIA:main Apr 14, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants