Skip to content

Commit a5b8244

Browse files
Merge branch 'main' into dev/jberchtold/fix-layernorm-distributed-tests
2 parents 5e7e3b7 + cae1c43 commit a5b8244

File tree

26 files changed

+797
-559
lines changed

26 files changed

+797
-559
lines changed

benchmarks/linear/benchmark_grouped_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
247247
num_gemms_list = [8]
248248

249249
if args.profile:
250-
mkns = [(4096, 4096, 4096)]
250+
mkns = [(4096 * 8, 4096, 4096)]
251251
# in profile mode, only run one recipe specified in args.recipe
252252
assert args.recipe != "all", (
253253
"In profile mode, only one recipe can be specified, please specify the recipe as"

tests/jax/conftest.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import os
66
import jax
77
import pytest
8+
from collections import defaultdict
9+
import time
810

911

1012
import transformer_engine.jax
@@ -32,3 +34,54 @@ def enable_fused_attn_after_hopper():
3234
yield
3335
if "NVTE_FUSED_ATTN" in os.environ:
3436
del os.environ["NVTE_FUSED_ATTN"]
37+
38+
39+
class TestTimingPlugin:
40+
"""
41+
Plugin to measure test execution time. Enable test timing by setting NVTE_JAX_TEST_TIMING=1
42+
in the environment.
43+
"""
44+
45+
def __init__(self):
46+
self.test_timings = defaultdict(list)
47+
48+
@pytest.hookimpl(tryfirst=True)
49+
def pytest_runtest_setup(self, item):
50+
item._timing_start = time.time()
51+
52+
@pytest.hookimpl(trylast=True)
53+
def pytest_runtest_teardown(self, item, nextitem):
54+
if hasattr(item, "_timing_start"):
55+
duration = time.time() - item._timing_start
56+
57+
# Extract base function name without parameters
58+
test_name = item.name
59+
if "[" in test_name:
60+
base_name = test_name.split("[")[0]
61+
else:
62+
base_name = test_name
63+
64+
self.test_timings[base_name].append(duration)
65+
66+
def pytest_sessionfinish(self, session, exitstatus):
67+
print("\n" + "=" * 80)
68+
print("TEST RUNTIME SUMMARY (grouped by function)")
69+
print("=" * 80)
70+
71+
total_overall = 0
72+
for test_name, durations in sorted(self.test_timings.items()):
73+
total_time = sum(durations)
74+
count = len(durations)
75+
avg_time = total_time / count if count > 0 else 0
76+
total_overall += total_time
77+
78+
print(f"{test_name:<60} | {count:3}x | {total_time:7.2f}s | avg: {avg_time:6.2f}s")
79+
80+
print("=" * 80)
81+
print(f"{'TOTAL RUNTIME':<60} | {'':>3} | {total_overall:7.2f}s |")
82+
print("=" * 80)
83+
84+
85+
def pytest_configure(config):
86+
if os.getenv("NVTE_JAX_TEST_TIMING", "0") == "1":
87+
config.pluginmanager.register(TestTimingPlugin(), "test_timing")

tests/jax/test_distributed_layernorm_mlp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,6 @@ def _test_layernorm_mlp(
333333
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
334334
ln_mlp_single = LayerNormMLP(
335335
layernorm_type=layernorm_type,
336-
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
337336
intermediate_dim=INTERMEDIATE,
338337
activations=activation_type,
339338
use_bias=use_bias,
@@ -352,7 +351,6 @@ def _test_layernorm_mlp(
352351
):
353352
ln_mlp_sharded = LayerNormMLP(
354353
layernorm_type=layernorm_type,
355-
transpose_batch_sequence=False,
356354
intermediate_dim=INTERMEDIATE,
357355
activations=activation_type,
358356
scale_axes=LN_SCALE_AXES,

tests/jax/test_distributed_softmax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def impl_test_softmax(
135135
)
136136

137137
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
138-
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
138+
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]])
139139
@pytest.mark.parametrize(
140140
"softmax_type",
141141
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
@@ -168,14 +168,14 @@ def test_softmax(
168168
dtype,
169169
bad_sharding,
170170
broadcast_batch_mask,
171-
use_shardy=False,
171+
use_shardy=True,
172172
)
173173

174174
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
175175
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED])
176176
@pytest.mark.parametrize("bad_sharding", [False, True])
177177
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
178-
def test_softmax_shardy(
178+
def test_softmax_gspmd(
179179
self,
180180
device_count,
181181
mesh_shape,
@@ -196,5 +196,5 @@ def test_softmax_shardy(
196196
dtype=DTYPES[0],
197197
bad_sharding=bad_sharding,
198198
broadcast_batch_mask=broadcast_batch_mask,
199-
use_shardy=True,
199+
use_shardy=False,
200200
)

tests/pytorch/test_float8blockwisetensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def test_quantize_dequantize_compact_format(
219219
rowwise=True,
220220
columnwise=dq_columnwise,
221221
block_scaling_dim=block_scaling_dim,
222-
all_gather_usage=True,
222+
all_gather_usage=(block_scaling_dim == 1),
223223
)
224224
self._test_quantize_dequantize(
225225
quantizer=quantizer,

transformer_engine/common/common.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
138138
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
139139
const uint32_t shmemX, const uint32_t stride_elems,
140140
const uint32_t offset_elems, const size_t type_num_bits) {
141+
cuda_driver::ensure_context_exists();
141142
// Get a function pointer to the cuTensorMapEncodeTiled driver API
142143
// Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
143144
static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {

transformer_engine/common/include/transformer_engine/swizzle.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@ extern "C" {
3030
*/
3131
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream);
3232

33+
/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM
34+
*
35+
* \param[in] inputs Input tensors with non-swizzled scale_inv.
36+
* \param[in,out] outputs Output tensors which hosts swizzled scale_inv.
37+
* \param[in] stream CUDA stream used for the operation.
38+
*
39+
* Requirements:
40+
* - scale_inv is stored in row-major.
41+
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
42+
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
43+
*/
44+
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
45+
const size_t num_tensors, cudaStream_t stream);
46+
3347
#ifdef __cplusplus
3448
} // extern "C"
3549
#endif

0 commit comments

Comments
 (0)