Skip to content

Commit 17ae015

Browse files
Merge branch 'main' into dev/jberchtold/fix-layernorm-distributed-tests
2 parents f2af0f1 + 9f9b481 commit 17ae015

File tree

3 files changed

+15
-28
lines changed

3 files changed

+15
-28
lines changed

tests/jax/test_custom_call_compute.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -673,10 +673,6 @@ def test_grouped_qdq(
673673
n_groups=n_groups,
674674
)
675675

676-
# grouped_quantize does not work with cudaGraph yet, so the jitting will breaks
677-
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
678-
# disable cudaGraph, then use the following jitted function
679-
680676
scaled_tensor = tex.grouped_quantize(
681677
x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
682678
)
@@ -1312,16 +1308,14 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
13121308
)
13131309
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
13141310

1315-
# grouped_gemm does not work with cudaGraph yet, so the jitting will breaks
1316-
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
1317-
# disable cudaGraph, then use the following jitted function
1318-
13191311
# jitting grouped_gemm
1320-
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
1321-
# lhs, rhs, group_sizes, contracting_dims,
1322-
# )
1312+
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
1313+
lhs,
1314+
rhs,
1315+
group_sizes,
1316+
contracting_dims,
1317+
)
13231318

1324-
prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
13251319
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
13261320

13271321
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@@ -1350,12 +1344,7 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout
13501344
)
13511345
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
13521346

1353-
# jitting grouped_gemm
1354-
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))(
1355-
# lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
1356-
# )
1357-
1358-
prim_out = tex.grouped_gemm(
1347+
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
13591348
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
13601349
)
13611350

@@ -1391,9 +1380,9 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape):
13911380

13921381
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
13931382
# jitting the grouped_dense
1394-
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
1395-
# static_argnums=(4,))
1396-
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
1383+
value_n_grad_prim_func = jit(
1384+
value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
1385+
)
13971386

13981387
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
13991388
x, kernel, bias, group_sizes, contracting_dims
@@ -1432,9 +1421,9 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
14321421
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
14331422

14341423
# jitting the grouped_dense
1435-
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
1436-
# static_argnums=(4,))
1437-
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
1424+
value_n_grad_prim_func = jit(
1425+
value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,)
1426+
)
14381427

14391428
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
14401429
x,

transformer_engine/jax/csrc/extensions/gemm.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,8 +592,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
592592
.Attr<bool>("rhs_is_trans")
593593
.Attr<JAXX_Scaling_Mode>("scaling_mode")
594594
.Attr<bool>("has_bias")
595-
.Attr<bool>("is_grouped_dense_wgrad"),
596-
FFI_CudaGraph_Traits);
595+
.Attr<bool>("is_grouped_dense_wgrad"));
597596

598597
} // namespace jax
599598
} // namespace transformer_engine

transformer_engine/jax/csrc/extensions/quantization.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI,
410410
.Ret<Buffer_Type>() // amax
411411
.Attr<JAXX_Scaling_Mode>("scaling_mode")
412412
.Attr<int64_t>("q_layout")
413-
.Attr<int64_t>("flatten_axis"),
414-
FFI_CudaGraph_Traits);
413+
.Attr<int64_t>("flatten_axis"));
415414

416415
} // namespace jax
417416
} // namespace transformer_engine

0 commit comments

Comments
 (0)