Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 82 additions & 82 deletions tests/jax/test_distributed_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from transformer_engine.common import recipe
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available
from transformer_engine.jax.cpp_extensions.base import primitive_context
from transformer_engine.jax.cpp_extensions.normalization import (
is_norm_zero_centered_gamma_in_weight_dtype,
)


DTYPES = [jnp.bfloat16, jnp.float32]
Expand All @@ -41,7 +45,7 @@

class TestDistributedLayernorm:

def generate_inputs(self, shape, mesh_resource, dtype, shard_weights):
def generate_inputs(self, shape, mesh_resource, dtype):
weight_shape = (shape[-1],)

x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
Expand All @@ -55,37 +59,36 @@ def generate_inputs(self, shape, mesh_resource, dtype, shard_weights):
else:
raise NotImplementedError

g_pspec = b_pspec = (
PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
)
g_pspec = b_pspec = PartitionSpec(None)

return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)

def generate_collectives_count_ref(
self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe
self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe, use_te_norm
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
is_dp_enabled = mesh_resource.dp_resource is not None
if mesh_axes == ("tp",):
# No collectives for tensor parallelism only as we do not shard the hidden dim for LN with TP.
return generate_collectives_count(allreduce=0, allgather=0, other=0)

dtype = jax.dtypes.canonicalize_dtype(dtype)
assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta
# TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp
weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1
allreduce_total_bytes = (
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
other_bytes = 0

# JAX is able to optimize away the computation for dbeta because our
# loss function is jnp.mean, it can determine that dbeta is always 1.0/beta.shape[-1]
dbeta_needs_allreduce = ln_type == "layernorm" and use_te_norm
# allreduce for dgamma and if required also dbeta
weight_count = 2 if dbeta_needs_allreduce else 1
allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * dtype.itemsize
if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
)
allreduce_total_bytes += dtype.itemsize # 1 * dtype for the amax reduction
return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)

@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper("data_shape", NORM_INPUT_SHAPES)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("zero_centered_gamma", [False, True])
@pytest_parametrize_wrapper("shard_weights", [False, True])
@pytest_parametrize_wrapper("use_te_norm", [False, True])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_layernorm(
Expand All @@ -97,7 +100,7 @@ def test_layernorm(
data_shape,
dtype,
zero_centered_gamma,
shard_weights,
use_te_norm,
fp8_recipe,
use_shardy,
):
Expand Down Expand Up @@ -126,51 +129,46 @@ def ref_func(x, gamma, beta):
return jnp.mean(output)

(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs(
data_shape, mesh_resource, dtype, shard_weights
data_shape, mesh_resource, dtype
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe
mesh_resource,
ln_type,
data_shape,
dtype,
mesh_axes,
fp8_recipe,
use_te_norm=use_te_norm,
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
prim_ctx = primitive_context(
f"NormFwdPrimitive={use_te_norm},NormBwdPrimitive={use_te_norm}"
)
with mesh, fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
), prim_ctx:
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))

with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(
target_func,
ref_func,
[x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
)
except AssertionError as err:
# Layernorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma and/or beta. We can catch
# and ignore that specific error here.
if (
g_pspec[-1] is None and b_pspec[-1] is None
) or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
"Layernorm primitive did not raise the correct warning for "
"unsupported sharding of gamma and/or beta"
)
compare_ops(
target_func,
ref_func,
[x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
)

@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper("data_shape", NORM_INPUT_SHAPES)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("shard_weights", [False, True])
@pytest_parametrize_wrapper("use_te_norm", [False, True])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_rmsnorm(
Expand All @@ -181,7 +179,7 @@ def test_rmsnorm(
mesh_resource,
data_shape,
dtype,
shard_weights,
use_te_norm,
fp8_recipe,
use_shardy,
):
Expand All @@ -191,8 +189,13 @@ def test_rmsnorm(
q_dtype = jnp.float8_e4m3fn

def target_func(x, gamma):
quantizer = QuantizerFactory.create_set().x
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon, quantizer=quantizer))
with primitive_context(
f"NormFwdPrimitive={use_te_norm},NormBwdPrimitive={use_te_norm}"
):
quantizer = QuantizerFactory.create_set().x
return jnp.mean(
layernorm(x, gamma, None, ln_type, False, epsilon, quantizer=quantizer)
)

def ref_func(x, gamma):
x = jnp.asarray(x, jnp.float32)
Expand All @@ -202,40 +205,37 @@ def ref_func(x, gamma):
return jnp.mean(output)

(x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs(
data_shape, mesh_resource, dtype, shard_weights
data_shape, mesh_resource, dtype
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe
mesh_resource,
ln_type,
data_shape,
dtype,
mesh_axes,
fp8_recipe,
use_te_norm=use_te_norm,
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
prim_ctx = primitive_context(
f"NormFwdPrimitive={use_te_norm},NormBwdPrimitive={use_te_norm}"
)
with mesh, fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
), prim_ctx:
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))

with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(
target_func,
ref_func,
[x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)),
)
except AssertionError as err:
# RmsNorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma. We can catch
# and ignore that specific error here.
if g_pspec[-1] is None or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
"RmsNorm primitive did not raise the correct warning for "
"unsupported sharding of gamma and/or beta"
)
compare_ops(
target_func,
ref_func,
[x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)),
)
16 changes: 2 additions & 14 deletions tests/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
make_swa_mask,
)
from transformer_engine.jax.quantize.helper import DType as TEDType
from transformer_engine.jax.cpp_extensions.base import primitive_context

PRNGKey = Any
Shape = Tuple[int, ...]
Expand Down Expand Up @@ -1604,18 +1605,5 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):

@contextmanager
def use_jax_gemm(enabled=False):
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS", None)

try:
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=false"
else:
os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=true"
with primitive_context(f"GemmPrimitive={enabled}"):
yield

finally:
if enabled:
if orig_custom_calls_filter is None:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS")
else:
os.environ["NVTE_JAX_CUSTOM_CALLS"] = orig_custom_calls_filter
66 changes: 65 additions & 1 deletion transformer_engine/jax/cpp_extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
from abc import ABCMeta, abstractmethod
from functools import partial
from contextlib import contextmanager
from packaging import version

from jax.extend import core
Expand Down Expand Up @@ -34,7 +35,7 @@ class BasePrimitive(metaclass=ABCMeta):
_is_enabled = True

# Default list of primitives to disable for all recipes
_default_disable_names = ["GemmPrimitive"]
_default_disable_names = ["GemmPrimitive", "NormFwdPrimitive", "NormBwdPrimitive"]

@classmethod
def enabled(cls):
Expand Down Expand Up @@ -258,3 +259,66 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F
cls.set_enabled(False)
else:
raise ValueError(f"Primitive not found in registry: {name}")


def _parse_custom_call_string_to_dict(custom_calls_str: str):
"""
Parses a custom call string into a dictionary of primitive names and their enabled states.
The input string can be a single value 'true' or 'false' to enable/disable all primitives,
or a comma-separated list of key=value pairs.

Args:
custom_calls_str: A string representing the custom call settings.

Returns:
A dictionary where keys are primitive names and values are booleans indicating enabled state.
"""
custom_calls_str = custom_calls_str.strip()
if custom_calls_str.lower() == "true":
return {primitive_name: True for primitive_name in _primitive_registry}
if custom_calls_str.lower() == "false":
return {primitive_name: False for primitive_name in _primitive_registry}

settings = {}
for pair in custom_calls_str.split(","):
pair = pair.strip()
if "=" in pair:
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip().lower()
settings[key] = value == "true"
return settings


@contextmanager
def primitive_context(primitive_enabling_changes: str):
"""Context manager to temporarily change the enabled state of primitives.

This context manager allows for temporary changes to the enabled state of
primitives within its scope. Any changes made will be reverted once the
context is exited.

Args:
primitive_enabling_changes: A string representing the changes to be made to the enabled state of primitives. This input string uses the same pattern as the `NVTE_JAX_CUSTOM_CALLS` environment variable, `Prim1=true,Prim2=false` or `false`/`true` to disable or enable all primitives, respectively.
"""
orig_env_var = os.getenv("NVTE_JAX_CUSTOM_CALLS")

primitives = {}
if orig_env_var is not None:
primitives = _parse_custom_call_string_to_dict(orig_env_var)

changes = _parse_custom_call_string_to_dict(primitive_enabling_changes)
primitives.update(changes)

updated_env_var = ",".join(
f"{name}={'true' if enabled else 'false'}" for name, enabled in primitives.items()
)

os.environ["NVTE_JAX_CUSTOM_CALLS"] = str(updated_env_var)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I think we could use manage_primitives here and don't need to change the env var.
If you prefer to pass key-value pairs, we could update the manage_primitives to do so too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just read the full docstring on manage_primitives and you're right it will work. I'll update to that instead for this context

I had missed the disable_all_first argument and thought manage_primitives was always disabling any unspecified primitives

try:
yield
finally:
if orig_env_var is not None:
os.environ["NVTE_JAX_CUSTOM_CALLS"] = orig_env_var
else:
del os.environ["NVTE_JAX_CUSTOM_CALLS"]
Loading