Skip to content
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
25bcbda
Initial work toward restoring UB support in te.Sequential
timmoon10 Feb 28, 2025
49c6a02
Forward UB linear runs, but has numerical error
timmoon10 Feb 28, 2025
4bacf2a
Debug UB forward tests
timmoon10 Mar 1, 2025
bd1d50a
Minor tweaks
timmoon10 Mar 4, 2025
b8b325b
Remove Python checks for MXFP8 UB linear forward
timmoon10 Mar 4, 2025
c8b2c51
Add dim check for MXFP8 full tiles
timmoon10 Mar 6, 2025
9f562b6
Move QuantizedTensor logic out of UB comm and into Python helper func…
timmoon10 Mar 11, 2025
c7a5e65
Support MXFP8 AGs
timmoon10 Mar 12, 2025
0c1a98f
Coalesce NCCL all-gathers for MXFP8 all-gather
timmoon10 Mar 14, 2025
4304ddf
Merge branch 'main' into mxfp8-ub-debug
timmoon10 Mar 14, 2025
15c34ec
Merge branch 'main' into optimize-wgrad-allgather
timmoon10 Mar 15, 2025
0917b20
Initial impl of backward UB linear in te.Sequential
timmoon10 Mar 17, 2025
33a3dbb
Merge branch 'optimize-wgrad-allgather' into mxfp8-ub-debug
timmoon10 Mar 17, 2025
a86cbb9
Debug UB linear backward with no quantization
timmoon10 Mar 17, 2025
3ff2955
Fix chunk dims for dgrad GEMM + dx RS
timmoon10 Mar 17, 2025
0942ee3
Debugging MXFP8 UB cases
timmoon10 Mar 18, 2025
df59ba0
Use NCCL to overlap dy AG with dgrad GEMM
timmoon10 Mar 26, 2025
1047a09
Merge branch 'main' into mxfp8-ub-debug
timmoon10 Apr 17, 2025
ffa65bf
Debug UB GEMM tests
timmoon10 Apr 17, 2025
8531147
Initial refactoring of linear module forward
timmoon10 Apr 22, 2025
0767cea
Refactor linear module backward
timmoon10 Apr 22, 2025
1a03c8f
Debug linear module UB tests
timmoon10 Apr 23, 2025
0639166
Tweak test tensor dims
timmoon10 Apr 23, 2025
6f855aa
Merge branch 'main' into mxfp8-ub-debug
timmoon10 Apr 23, 2025
d4ac2ea
Do not store autograd context within wgrad GEMM closure
timmoon10 Apr 23, 2025
7803cb0
Fix linter warnings
timmoon10 Apr 24, 2025
a8f1ada
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2025
46675e5
Update LayerNormLinear
timmoon10 Apr 24, 2025
7119346
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2025
7df9b74
Update LayerNormMLP
timmoon10 Apr 25, 2025
a85a79e
Debug UB tests
timmoon10 Apr 25, 2025
13fefeb
Merge branch 'main' into mxfp8-ub-debug
timmoon10 Apr 25, 2025
6f7da09
Fix linter warnings
timmoon10 Apr 25, 2025
9e16039
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2025
f717755
Debug test failures
timmoon10 Apr 26, 2025
f4544d9
Minor style tweaks
timmoon10 Apr 26, 2025
776fbe5
Merge branch 'main' into mxfp8-ub-debug
timmoon10 Apr 26, 2025
8c825f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2025
f783b95
Fix incorrect usage for GEMM input with block-scaled FP8
timmoon10 Apr 28, 2025
9288e6a
Merge branch 'main' into mxfp8-ub-debug
timmoon10 Apr 28, 2025
dfb53ca
Merge branch 'main' into mxfp8-ub-debug
timmoon10 May 1, 2025
546e02a
Fix RS out dims
timmoon10 May 1, 2025
8e63f8c
Disable dgrad GEMM + UB AG + NCCL AG overlapping
timmoon10 May 2, 2025
ab04e50
Merge branch 'main' into mxfp8-ub-debug
timmoon10 May 2, 2025
9794f91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2025
418aab2
Disable dgrad GEMM + UB AG + NCCL AG overlap in te.Sequential
timmoon10 May 2, 2025
da753ea
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2025
fe02a7e
Merge branch 'main' into mxfp8-ub-debug
timmoon10 May 6, 2025
19459ec
Restore support for internal quantized tensors
timmoon10 May 6, 2025
0efa1a9
Add tests for MXFP8 GEMM with UB
timmoon10 May 7, 2025
fb5ca6e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2025
70057c6
Fix linter warnings
timmoon10 May 7, 2025
7153804
Debug test failures
timmoon10 May 7, 2025
d1fc045
Debug test failures
timmoon10 May 7, 2025
6994f29
Merge branch 'main' into mxfp8-ub-debug
timmoon10 May 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PAT
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"

Expand Down
43 changes: 30 additions & 13 deletions tests/pytorch/distributed/run_gemm_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_cublas_workspace_size_bytes
from transformer_engine.pytorch.module.base import (
fill_userbuffers_buffer_for_all_gather,
get_cublas_workspace_size_bytes,
)

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
Expand Down Expand Up @@ -544,26 +547,30 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
rs_out2 = None
if opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap:
ub_obj.copy_into_buffer(bulk_inp, bulk_inp_quantizer, True)
fill_userbuffers_buffer_for_all_gather(
ub_obj,
bulk_inp,
bulk_inp_quantizer,
tp_group,
)
gemm_inp = inp
else:
ub_obj.copy_into_buffer(inp_fp8 if opts.fp8 else inp, inp_quantizer, True)
gemm_inp = ub_obj.get_buffer(inp_quantizer, False, inp_g.size())
gemm_inp, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
inp_fp8 if opts.fp8 else inp,
inp_quantizer,
tp_group,
)
if ub_obj2 is not None:
if opts.fp8 and opts.fp8_output:
ub_obj2.set_buffer_params(out_quantizer)
rs_out2 = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
)
else:
if opts.bulk_overlap:
ub_obj.copy_into_buffer(
bulk_inp_fp8 if opts.fp8 else bulk_inp, bulk_inp_quantizer, False
bulk_inp_fp8._data if opts.fp8 else bulk_inp,
local_chunk=False,
)
if opts.fp8:
ub_obj.set_buffer_params(bulk_inp_quantizer)
elif opts.fp8 and opts.fp8_output:
ub_obj.set_buffer_params(out_quantizer)
gemm_inp = inp_fp8 if opts.fp8 else inp
rs_out = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
Expand Down Expand Up @@ -688,10 +695,20 @@ def _gemm():
output_info = ""
if opts.comm_type == tex.CommOverlapType.AG:
# Bulk overlap AG output is already gathered
test_out = ub_obj.get_buffer(bulk_inp_quantizer, False)
if bulk_inp_quantizer is None:
test_out = ub_obj.get_buffer(False)
else:
test_out = Float8Tensor(
shape=test_out.shape,
dtype=torch.bfloat16,
data=ub_obj.get_buffer(False),
fp8_scale=bulk_inp_quantizer.scale,
fp8_dtype=bulk_inp_quantizer.dtype,
quantizer=bulk_inp_quantizer,
)
else:
# Bulk overlap RS output needs to be gathered
out_local = ub_obj.get_buffer(bulk_inp_quantizer, True)
out_local = ub_obj.get_buffer(True)
output_info += f"rs_output: {list(out_local.shape)} | "
test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0]

Expand Down
11 changes: 9 additions & 2 deletions tests/pytorch/distributed/run_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
import torch.distributed as dist

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Format,
MXFP8BlockScaling,
)

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
Expand Down Expand Up @@ -163,7 +168,7 @@ def _parse_args(argv=None, namespace=None):
"--quantization",
type=str.lower,
default="none",
choices=["none", "fp8_delayed_scaling", "fp8_current_scaling"],
choices=["none", "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
help="Quantization recipe",
)
parser.add_argument(
Expand Down Expand Up @@ -414,6 +419,8 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
)
elif opts.quantization == "fp8_current_scaling":
fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format)
elif opts.quantization == "mxfp8":
fp8_recipe = MXFP8BlockScaling()

# Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
Expand Down
57 changes: 22 additions & 35 deletions tests/pytorch/distributed/test_comm_gemm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")

fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

RNG_SEED: int = 42
SEQ_LENGTH: int = 1024
BATCH_SIZE: int = 2
NUM_HEADS: int = 16
NUM_HEADS: int = 32
HEAD_DIM: int = 48
TE_LAYERS = [
te.Linear,
Expand Down Expand Up @@ -107,8 +108,10 @@ def _run_layer_with_overlap(
test_cmd.append("--overlap-rs-dgrad")

if fp8:
if not fp8_available:
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
test_cmd.append("--fp8")
test_cmd.append(f"--quantization={quantization}")

Expand Down Expand Up @@ -251,15 +254,7 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d

@pytest.mark.parametrize(
"quantization",
["fp8_delayed_scaling", "fp8_current_scaling"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
Expand All @@ -279,28 +274,31 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
)
),
ids=[
f" {te.Linear.__name__} - ROW-PARALLEL ",
f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ",
f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ",
f"{te.Linear.__name__}-row_tensor_parallel",
f"{te.Linear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f"{te.Linear.__name__}-col_tensor_parallel-DGRAD+RS",
f"{te.LayerNormLinear.__name__}-row_tensor_parallel",
f"{te.LayerNormLinear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD",
f"{te.LayerNormLinear.__name__}-col_tensor_parallel-DGRAD+RS",
]
+ [
" " + " - ".join(test_name_parts) + " "
"-".join(test_name_parts)
for test_name_parts in zip(
[layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]),
)
],
)
def test_layers_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization
layer_type,
linear_parallel_mode,
overlap_rs_dgrad,
quantization,
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization)
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -347,22 +345,11 @@ def test_multi_layer_with_overlap_bf16(

@pytest.mark.parametrize(
"quantization",
["fp8_delayed_scaling", "fp8_current_scaling"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"],
)
@pytest.mark.parametrize(
"num_layers",
(2,),
ids=[
" 2 layers ",
],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
Expand All @@ -374,19 +361,19 @@ def test_multi_layer_with_overlap_bf16(
)
),
ids=[
" " + " - ".join(test_name_parts) + " "
"-".join(test_name_parts)
for test_name_parts in zip(
[te.TransformerLayer.__name__ for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"],
)
],
)
def test_multi_layer_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers
layer_type, linear_parallel_mode, overlap_rs_dgrad, quantization, num_layers
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers
layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, num_layers
)
Loading