@@ -673,10 +673,6 @@ def test_grouped_qdq(
673673 n_groups = n_groups ,
674674 )
675675
676- # grouped_quantize does not work with cudaGraph yet, so the jitting will breaks
677- # To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
678- # disable cudaGraph, then use the following jitted function
679-
680676 scaled_tensor = tex .grouped_quantize (
681677 x , group_sizes = group_sizes , flatten_axis = flatten_axis , quantizer = grouped_quantizer
682678 )
@@ -1312,16 +1308,14 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
13121308 )
13131309 ref_out = self ._ref_grouped_dense (lhs , rhs , None , group_sizes , contracting_dims )
13141310
1315- # grouped_gemm does not work with cudaGraph yet, so the jitting will breaks
1316- # To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
1317- # disable cudaGraph, then use the following jitted function
1318-
13191311 # jitting grouped_gemm
1320- # prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
1321- # lhs, rhs, group_sizes, contracting_dims,
1322- # )
1312+ prim_out = jax .jit (tex .grouped_gemm , static_argnames = ("contracting_dims" ,))(
1313+ lhs ,
1314+ rhs ,
1315+ group_sizes ,
1316+ contracting_dims ,
1317+ )
13231318
1324- prim_out = tex .grouped_gemm (lhs , rhs , group_sizes , contracting_dims )
13251319 self ._assert_grouped_gemm_output (prim_out , group_sizes , ref_out , dtype )
13261320
13271321 @pytest .mark .skipif (not is_fp8_supported , reason = fp8_unsupported_reason )
@@ -1350,12 +1344,7 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout
13501344 )
13511345 ref_out = self ._ref_grouped_dense (lhs , rhs , None , group_sizes , contracting_dims )
13521346
1353- # jitting grouped_gemm
1354- # prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))(
1355- # lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
1356- # )
1357-
1358- prim_out = tex .grouped_gemm (
1347+ prim_out = jax .jit (tex .grouped_gemm , static_argnames = ("contracting_dims" ,))(
13591348 lhs , rhs , group_sizes , contracting_dims , quantizer_set = quantizer_set
13601349 )
13611350
@@ -1391,9 +1380,9 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape):
13911380
13921381 value_n_grad_ref_func = value_and_grad (self ._ref_sum_grouped_dense , (0 , 1 , 2 ))
13931382 # jitting the grouped_dense
1394- # value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
1395- # static_argnums=(4,) )
1396- value_n_grad_prim_func = value_and_grad ( self . _primitive_sum_grouped_dense , ( 0 , 1 , 2 ) )
1383+ value_n_grad_prim_func = jit (
1384+ value_and_grad ( self . _primitive_sum_grouped_dense , ( 0 , 1 , 2 )), static_argnums = (4 ,)
1385+ )
13971386
13981387 ref_out_sum , (ref_dgrad , ref_wgrad , ref_dbias ) = value_n_grad_ref_func (
13991388 x , kernel , bias , group_sizes , contracting_dims
@@ -1432,9 +1421,9 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
14321421 value_n_grad_ref_func = value_and_grad (self ._ref_sum_grouped_dense , (0 , 1 , 2 ))
14331422
14341423 # jitting the grouped_dense
1435- # value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
1436- # static_argnums=(4,) )
1437- value_n_grad_prim_func = value_and_grad ( self . _primitive_sum_grouped_dense , ( 0 , 1 , 2 ) )
1424+ value_n_grad_prim_func = jit (
1425+ value_and_grad ( self . _primitive_sum_grouped_dense , ( 0 , 1 , 2 )), static_argnums = (4 ,)
1426+ )
14381427
14391428 ref_out_sum , (ref_dgrad , ref_wgrad , ref_dbias ) = value_n_grad_ref_func (
14401429 x ,
0 commit comments