2020from transformer_engine .jax .layernorm import layernorm
2121from transformer_engine .jax .quantize import QuantizerFactory , ScalingMode , is_fp8_available
2222from transformer_engine .jax .cpp_extensions .base import primitive_context
23- from transformer_engine .jax .cpp_extensions .normalization import is_norm_zero_centered_gamma_in_weight_dtype
23+ from transformer_engine .jax .cpp_extensions .normalization import (
24+ is_norm_zero_centered_gamma_in_weight_dtype ,
25+ )
2426
2527
2628DTYPES = [jnp .bfloat16 , jnp .float32 ]
4042if is_mxfp8_supported :
4143 SUPPORTED_RECIPES .append (pytest .param (recipe .MXFP8BlockScaling (), id = "MXFP8BlockScaling" ))
4244
45+
4346class TestDistributedLayernorm :
4447
4548 def generate_inputs (self , shape , mesh_resource , dtype ):
@@ -73,17 +76,13 @@ def generate_collectives_count_ref(
7376
7477 # JAX is able to optimize away the computation for dbeta because our
7578 # loss function is jnp.mean, it can determine that dbeta is always 1.0/beta.shape[-1]
76- dbeta_needs_allreduce = ( ln_type == "layernorm" and use_te_norm )
79+ dbeta_needs_allreduce = ln_type == "layernorm" and use_te_norm
7780 # allreduce for dgamma and if required also dbeta
7881 weight_count = 2 if dbeta_needs_allreduce else 1
79- allreduce_total_bytes = (
80- all_reduce_loss_bytes + weight_count * shape [- 1 ] * dtype .itemsize
81- )
82+ allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape [- 1 ] * dtype .itemsize
8283 if fp8_recipe == recipe .Float8CurrentScaling ():
8384 allreduce_total_bytes += dtype .itemsize # 1 * dtype for the amax reduction
84- return generate_collectives_count (
85- allreduce = allreduce_total_bytes , allgather = 0 , other = 0
86- )
85+ return generate_collectives_count (allreduce = allreduce_total_bytes , allgather = 0 , other = 0 )
8786
8887 @pytest .mark .parametrize ("device_count,mesh_shape,mesh_axes,mesh_resource" , generate_configs ())
8988 @pytest_parametrize_wrapper ("data_shape" , NORM_INPUT_SHAPES )
@@ -133,12 +132,22 @@ def ref_func(x, gamma, beta):
133132 data_shape , mesh_resource , dtype
134133 )
135134 collective_count_ref = self .generate_collectives_count_ref (
136- mesh_resource , ln_type , data_shape , dtype , mesh_axes , fp8_recipe , use_te_norm = use_te_norm
135+ mesh_resource ,
136+ ln_type ,
137+ data_shape ,
138+ dtype ,
139+ mesh_axes ,
140+ fp8_recipe ,
141+ use_te_norm = use_te_norm ,
137142 )
138143 devices = np .asarray (jax .devices ()[:device_count ]).reshape (* mesh_shape )
139144 mesh = Mesh (devices , mesh_axes )
140- prim_ctx = primitive_context (f"NormFwdPrimitive={ use_te_norm } ,NormBwdPrimitive={ use_te_norm } " )
141- with mesh , fp8_autocast (enabled = True , fp8_recipe = fp8_recipe , mesh_resource = mesh_resource ), prim_ctx :
145+ prim_ctx = primitive_context (
146+ f"NormFwdPrimitive={ use_te_norm } ,NormBwdPrimitive={ use_te_norm } "
147+ )
148+ with mesh , fp8_autocast (
149+ enabled = True , fp8_recipe = fp8_recipe , mesh_resource = mesh_resource
150+ ), prim_ctx :
142151 x_ = jax .device_put (x , NamedSharding (mesh , x_pspec ))
143152 gamma_ = jax .device_put (gamma , NamedSharding (mesh , g_pspec ))
144153 beta_ = jax .device_put (beta , NamedSharding (mesh , b_pspec ))
@@ -180,9 +189,13 @@ def test_rmsnorm(
180189 q_dtype = jnp .float8_e4m3fn
181190
182191 def target_func (x , gamma ):
183- with primitive_context (f"NormFwdPrimitive={ use_te_norm } ,NormBwdPrimitive={ use_te_norm } " ):
192+ with primitive_context (
193+ f"NormFwdPrimitive={ use_te_norm } ,NormBwdPrimitive={ use_te_norm } "
194+ ):
184195 quantizer = QuantizerFactory .create_set ().x
185- return jnp .mean (layernorm (x , gamma , None , ln_type , False , epsilon , quantizer = quantizer ))
196+ return jnp .mean (
197+ layernorm (x , gamma , None , ln_type , False , epsilon , quantizer = quantizer )
198+ )
186199
187200 def ref_func (x , gamma ):
188201 x = jnp .asarray (x , jnp .float32 )
@@ -195,12 +208,22 @@ def ref_func(x, gamma):
195208 data_shape , mesh_resource , dtype
196209 )
197210 collective_count_ref = self .generate_collectives_count_ref (
198- mesh_resource , ln_type , data_shape , dtype , mesh_axes , fp8_recipe , use_te_norm = use_te_norm
211+ mesh_resource ,
212+ ln_type ,
213+ data_shape ,
214+ dtype ,
215+ mesh_axes ,
216+ fp8_recipe ,
217+ use_te_norm = use_te_norm ,
199218 )
200219 devices = np .asarray (jax .devices ()[:device_count ]).reshape (* mesh_shape )
201220 mesh = Mesh (devices , mesh_axes )
202- prim_ctx = primitive_context (f"NormFwdPrimitive={ use_te_norm } ,NormBwdPrimitive={ use_te_norm } " )
203- with mesh , fp8_autocast (enabled = True , fp8_recipe = fp8_recipe , mesh_resource = mesh_resource ), prim_ctx :
221+ prim_ctx = primitive_context (
222+ f"NormFwdPrimitive={ use_te_norm } ,NormBwdPrimitive={ use_te_norm } "
223+ )
224+ with mesh , fp8_autocast (
225+ enabled = True , fp8_recipe = fp8_recipe , mesh_resource = mesh_resource
226+ ), prim_ctx :
204227 x_ = jax .device_put (x , NamedSharding (mesh , x_pspec ))
205228 gamma_ = jax .device_put (gamma , NamedSharding (mesh , g_pspec ))
206229
0 commit comments