Skip to content

Commit a3a9513

Browse files
Merge branch 'main' into dev/jberchtold/fix-layernorm-distributed-tests
2 parents 629dba5 + b6b3abc commit a3a9513

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1551
-824
lines changed

benchmarks/linear/benchmark_grouped_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
247247
num_gemms_list = [8]
248248

249249
if args.profile:
250-
mkns = [(4096, 4096, 4096)]
250+
mkns = [(4096 * 8, 4096, 4096)]
251251
# in profile mode, only run one recipe specified in args.recipe
252252
assert args.recipe != "all", (
253253
"In profile mode, only one recipe can be specified, please specify the recipe as"

qa/L0_pytorch_debug_unittest/test.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,19 @@
1414

1515
FAIL=0
1616

17+
# It is not installed as a requirement,
18+
# because it is not available on PyPI.
19+
pip uninstall -y nvdlfw-inspect
20+
pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git
21+
1722
pip install pytest==8.2.1
1823
pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
1924
pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
2025
pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
2126
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
27+
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
28+
pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
29+
2230

2331
# standard sanity and numerics tests with initialized debug
2432
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1

qa/L1_pytorch_distributed_unittest/test.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ FAILED_CASES=""
2121
mkdir -p "$XML_LOG_DIR"
2222

2323

24+
# It is not installed as a requirement,
25+
# because it is not available on PyPI.
26+
pip uninstall -y nvdlfw-inspect
27+
pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git
28+
2429
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
2530

2631
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py"

tests/jax/conftest.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import os
66
import jax
77
import pytest
8+
from collections import defaultdict
9+
import time
810

911

1012
import transformer_engine.jax
@@ -32,3 +34,54 @@ def enable_fused_attn_after_hopper():
3234
yield
3335
if "NVTE_FUSED_ATTN" in os.environ:
3436
del os.environ["NVTE_FUSED_ATTN"]
37+
38+
39+
class TestTimingPlugin:
40+
"""
41+
Plugin to measure test execution time. Enable test timing by setting NVTE_JAX_TEST_TIMING=1
42+
in the environment.
43+
"""
44+
45+
def __init__(self):
46+
self.test_timings = defaultdict(list)
47+
48+
@pytest.hookimpl(tryfirst=True)
49+
def pytest_runtest_setup(self, item):
50+
item._timing_start = time.time()
51+
52+
@pytest.hookimpl(trylast=True)
53+
def pytest_runtest_teardown(self, item, nextitem):
54+
if hasattr(item, "_timing_start"):
55+
duration = time.time() - item._timing_start
56+
57+
# Extract base function name without parameters
58+
test_name = item.name
59+
if "[" in test_name:
60+
base_name = test_name.split("[")[0]
61+
else:
62+
base_name = test_name
63+
64+
self.test_timings[base_name].append(duration)
65+
66+
def pytest_sessionfinish(self, session, exitstatus):
67+
print("\n" + "=" * 80)
68+
print("TEST RUNTIME SUMMARY (grouped by function)")
69+
print("=" * 80)
70+
71+
total_overall = 0
72+
for test_name, durations in sorted(self.test_timings.items()):
73+
total_time = sum(durations)
74+
count = len(durations)
75+
avg_time = total_time / count if count > 0 else 0
76+
total_overall += total_time
77+
78+
print(f"{test_name:<60} | {count:3}x | {total_time:7.2f}s | avg: {avg_time:6.2f}s")
79+
80+
print("=" * 80)
81+
print(f"{'TOTAL RUNTIME':<60} | {'':>3} | {total_overall:7.2f}s |")
82+
print("=" * 80)
83+
84+
85+
def pytest_configure(config):
86+
if os.getenv("NVTE_JAX_TEST_TIMING", "0") == "1":
87+
config.pluginmanager.register(TestTimingPlugin(), "test_timing")

tests/jax/test_custom_call_compute.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -673,10 +673,6 @@ def test_grouped_qdq(
673673
n_groups=n_groups,
674674
)
675675

676-
# grouped_quantize does not work with cudaGraph yet, so the jitting will breaks
677-
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
678-
# disable cudaGraph, then use the following jitted function
679-
680676
scaled_tensor = tex.grouped_quantize(
681677
x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
682678
)
@@ -1312,16 +1308,14 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
13121308
)
13131309
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
13141310

1315-
# grouped_gemm does not work with cudaGraph yet, so the jitting will breaks
1316-
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
1317-
# disable cudaGraph, then use the following jitted function
1318-
13191311
# jitting grouped_gemm
1320-
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
1321-
# lhs, rhs, group_sizes, contracting_dims,
1322-
# )
1312+
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
1313+
lhs,
1314+
rhs,
1315+
group_sizes,
1316+
contracting_dims,
1317+
)
13231318

1324-
prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
13251319
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
13261320

13271321
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@@ -1350,12 +1344,7 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout
13501344
)
13511345
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
13521346

1353-
# jitting grouped_gemm
1354-
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))(
1355-
# lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
1356-
# )
1357-
1358-
prim_out = tex.grouped_gemm(
1347+
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
13591348
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
13601349
)
13611350

@@ -1391,9 +1380,9 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape):
13911380

13921381
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
13931382
# jitting the grouped_dense
1394-
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
1395-
# static_argnums=(4,))
1396-
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
1383+
value_n_grad_prim_func = jit(
1384+
value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
1385+
)
13971386

13981387
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
13991388
x, kernel, bias, group_sizes, contracting_dims
@@ -1432,9 +1421,9 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
14321421
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
14331422

14341423
# jitting the grouped_dense
1435-
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
1436-
# static_argnums=(4,))
1437-
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
1424+
value_n_grad_prim_func = jit(
1425+
value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
1426+
)
14381427

14391428
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
14401429
x,

tests/jax/test_distributed_layernorm_mlp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,6 @@ def _test_layernorm_mlp(
333333
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
334334
ln_mlp_single = LayerNormMLP(
335335
layernorm_type=layernorm_type,
336-
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
337336
intermediate_dim=INTERMEDIATE,
338337
activations=activation_type,
339338
use_bias=use_bias,
@@ -352,7 +351,6 @@ def _test_layernorm_mlp(
352351
):
353352
ln_mlp_sharded = LayerNormMLP(
354353
layernorm_type=layernorm_type,
355-
transpose_batch_sequence=False,
356354
intermediate_dim=INTERMEDIATE,
357355
activations=activation_type,
358356
scale_axes=LN_SCALE_AXES,

tests/jax/test_distributed_softmax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def impl_test_softmax(
135135
)
136136

137137
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
138-
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
138+
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]])
139139
@pytest.mark.parametrize(
140140
"softmax_type",
141141
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
@@ -168,14 +168,14 @@ def test_softmax(
168168
dtype,
169169
bad_sharding,
170170
broadcast_batch_mask,
171-
use_shardy=False,
171+
use_shardy=True,
172172
)
173173

174174
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
175175
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED])
176176
@pytest.mark.parametrize("bad_sharding", [False, True])
177177
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
178-
def test_softmax_shardy(
178+
def test_softmax_gspmd(
179179
self,
180180
device_count,
181181
mesh_shape,
@@ -196,5 +196,5 @@ def test_softmax_shardy(
196196
dtype=DTYPES[0],
197197
bad_sharding=bad_sharding,
198198
broadcast_batch_mask=broadcast_batch_mask,
199-
use_shardy=True,
199+
use_shardy=False,
200200
)

0 commit comments

Comments
 (0)