Skip to content

Commit e3d2a48

Browse files
Format
Signed-off-by: Jeremy Berchtold <[email protected]>
1 parent 68180bc commit e3d2a48

File tree

2 files changed

+53
-22
lines changed

2 files changed

+53
-22
lines changed

tests/jax/test_distributed_layernorm.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from transformer_engine.jax.layernorm import layernorm
2121
from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available
2222
from transformer_engine.jax.cpp_extensions.base import primitive_context
23-
from transformer_engine.jax.cpp_extensions.normalization import is_norm_zero_centered_gamma_in_weight_dtype
23+
from transformer_engine.jax.cpp_extensions.normalization import (
24+
is_norm_zero_centered_gamma_in_weight_dtype,
25+
)
2426

2527

2628
DTYPES = [jnp.bfloat16, jnp.float32]
@@ -40,6 +42,7 @@
4042
if is_mxfp8_supported:
4143
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
4244

45+
4346
class TestDistributedLayernorm:
4447

4548
def generate_inputs(self, shape, mesh_resource, dtype):
@@ -73,17 +76,13 @@ def generate_collectives_count_ref(
7376

7477
# JAX is able to optimize away the computation for dbeta because our
7578
# loss function is jnp.mean, it can determine that dbeta is always 1.0/beta.shape[-1]
76-
dbeta_needs_allreduce = (ln_type == "layernorm" and use_te_norm)
79+
dbeta_needs_allreduce = ln_type == "layernorm" and use_te_norm
7780
# allreduce for dgamma and if required also dbeta
7881
weight_count = 2 if dbeta_needs_allreduce else 1
79-
allreduce_total_bytes = (
80-
all_reduce_loss_bytes + weight_count * shape[-1] * dtype.itemsize
81-
)
82+
allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * dtype.itemsize
8283
if fp8_recipe == recipe.Float8CurrentScaling():
8384
allreduce_total_bytes += dtype.itemsize # 1 * dtype for the amax reduction
84-
return generate_collectives_count(
85-
allreduce=allreduce_total_bytes, allgather=0, other=0
86-
)
85+
return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)
8786

8887
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
8988
@pytest_parametrize_wrapper("data_shape", NORM_INPUT_SHAPES)
@@ -133,12 +132,22 @@ def ref_func(x, gamma, beta):
133132
data_shape, mesh_resource, dtype
134133
)
135134
collective_count_ref = self.generate_collectives_count_ref(
136-
mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe, use_te_norm=use_te_norm
135+
mesh_resource,
136+
ln_type,
137+
data_shape,
138+
dtype,
139+
mesh_axes,
140+
fp8_recipe,
141+
use_te_norm=use_te_norm,
137142
)
138143
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
139144
mesh = Mesh(devices, mesh_axes)
140-
prim_ctx = primitive_context(f"NormFwdPrimitive={use_te_norm},NormBwdPrimitive={use_te_norm}")
141-
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource), prim_ctx:
145+
prim_ctx = primitive_context(
146+
f"NormFwdPrimitive={use_te_norm},NormBwdPrimitive={use_te_norm}"
147+
)
148+
with mesh, fp8_autocast(
149+
enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
150+
), prim_ctx:
142151
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
143152
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
144153
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))
@@ -180,9 +189,13 @@ def test_rmsnorm(
180189
q_dtype = jnp.float8_e4m3fn
181190

182191
def target_func(x, gamma):
183-
with primitive_context(f"NormFwdPrimitive={use_te_norm},NormBwdPrimitive={use_te_norm}"):
192+
with primitive_context(
193+
f"NormFwdPrimitive={use_te_norm},NormBwdPrimitive={use_te_norm}"
194+
):
184195
quantizer = QuantizerFactory.create_set().x
185-
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon, quantizer=quantizer))
196+
return jnp.mean(
197+
layernorm(x, gamma, None, ln_type, False, epsilon, quantizer=quantizer)
198+
)
186199

187200
def ref_func(x, gamma):
188201
x = jnp.asarray(x, jnp.float32)
@@ -195,12 +208,22 @@ def ref_func(x, gamma):
195208
data_shape, mesh_resource, dtype
196209
)
197210
collective_count_ref = self.generate_collectives_count_ref(
198-
mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe, use_te_norm=use_te_norm
211+
mesh_resource,
212+
ln_type,
213+
data_shape,
214+
dtype,
215+
mesh_axes,
216+
fp8_recipe,
217+
use_te_norm=use_te_norm,
199218
)
200219
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
201220
mesh = Mesh(devices, mesh_axes)
202-
prim_ctx = primitive_context(f"NormFwdPrimitive={use_te_norm},NormBwdPrimitive={use_te_norm}")
203-
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource), prim_ctx:
221+
prim_ctx = primitive_context(
222+
f"NormFwdPrimitive={use_te_norm},NormBwdPrimitive={use_te_norm}"
223+
)
224+
with mesh, fp8_autocast(
225+
enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
226+
), prim_ctx:
204227
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
205228
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
206229

transformer_engine/jax/cpp_extensions/normalization.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -835,9 +835,13 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None)
835835

836836
gamma_ = gamma
837837
if zero_centered_gamma:
838-
zero_centered_gamma_dtype = gamma.dtype if is_norm_zero_centered_gamma_in_weight_dtype(
839-
quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING
840-
) else jnp.float32
838+
zero_centered_gamma_dtype = (
839+
gamma.dtype
840+
if is_norm_zero_centered_gamma_in_weight_dtype(
841+
quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING
842+
)
843+
else jnp.float32
844+
)
841845
gamma_ = (gamma + jnp.asarray(1.0, dtype=zero_centered_gamma_dtype)).astype(gamma.dtype)
842846

843847
output = normed_input * gamma_ + beta
@@ -860,9 +864,13 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
860864

861865
gamma_ = gamma
862866
if zero_centered_gamma:
863-
zero_centered_gamma_dtype = gamma.dtype if is_norm_zero_centered_gamma_in_weight_dtype(
864-
quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING
865-
) else jnp.float32
867+
zero_centered_gamma_dtype = (
868+
gamma.dtype
869+
if is_norm_zero_centered_gamma_in_weight_dtype(
870+
quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING
871+
)
872+
else jnp.float32
873+
)
866874
gamma_ = gamma + jnp.asarray(1.0, dtype=zero_centered_gamma_dtype)
867875

868876
output = normed_input * gamma_

0 commit comments

Comments
 (0)