-
Notifications
You must be signed in to change notification settings - Fork 599
Add logic for block-scaled tensors with GEMM swizzled scales #2486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
d274220 to
52ce3a4
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
4925b63 to
1de4b5e
Compare
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
|
/te-ci L1 |
| } | ||
| } | ||
|
|
||
| void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const void *buf, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need a v2 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to break the existing APIs. That said, this PR isn't fully backward-compatible because the GEMM no longer secretly assumes that MXFP8 scales are swizzled.
Signed-off-by: Tim Moon <[email protected]>
Update copyright years. Tweak comments. Fix various complaints from @greptile-apps. Signed-off-by: Tim Moon <[email protected]>
This comment was marked as resolved.
This comment was marked as resolved.
|
/te-ci L1 |
Description
All of the supported block-scaled tensor formats (MXFP8, NVFP4, DSv3 FP8) have two ways of ordering their scaling factors:
The core infrastructure handles this in an ad hoc way, blindly assuming that the "right" scale ordering is used for the different operations. The PyTorch infrastructure only supports MXFP8 and NVFP4 scales are in compact order, although DSv3 FP8 does have awareness of "compact" and "GEMM-ready" formats. This situation makes it hard to implement fused kernels that can bypass the swizzle kernel.
This PR adds a
with_gemm_swizzled_scalesfield in the C++ tensor class so that the core infrastructure can distinguish between the different scale orderings. It also adds this field in the PyTorch quantized tensor classes, and exposes aoptimize_for_gemmoption in the quantizer so that we can create tensors that do not need communication or checkpointing. Finally, it rips out all the DSv3 FP8 infrastructure for the compact format, which is no longer necessary.Progress
Add option to pre-swizzle weightsCloses #2446.
Type of change
Changes
Please list the changes introduced in this PR:
optimize_for_gemmoption in PyTorch quantizerChecklist: