diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index be5c8ef987..84095ed51c 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -19,6 +19,10 @@ from transformer_engine.common import recipe from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available +from transformer_engine.jax.cpp_extensions.base import primitive_context +from transformer_engine.jax.cpp_extensions.normalization import ( + is_norm_zero_centered_gamma_in_weight_dtype, +) DTYPES = [jnp.bfloat16, jnp.float32] @@ -41,7 +45,7 @@ class TestDistributedLayernorm: - def generate_inputs(self, shape, mesh_resource, dtype, shard_weights): + def generate_inputs(self, shape, mesh_resource, dtype): weight_shape = (shape[-1],) x = random.normal(random.PRNGKey(1124), shape, dtype=dtype) @@ -55,37 +59,36 @@ def generate_inputs(self, shape, mesh_resource, dtype, shard_weights): else: raise NotImplementedError - g_pspec = b_pspec = ( - PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None) - ) + g_pspec = b_pspec = PartitionSpec(None) return (x, gamma, beta), (x_pspec, g_pspec, b_pspec) def generate_collectives_count_ref( - self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe + self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe, use_te_norm ): - jax_dtype = jax.dtypes.canonicalize_dtype(dtype) - is_dp_enabled = mesh_resource.dp_resource is not None + if mesh_axes == ("tp",): + # No collectives for tensor parallelism only as we do not shard the hidden dim for LN with TP. + return generate_collectives_count(allreduce=0, allgather=0, other=0) + + dtype = jax.dtypes.canonicalize_dtype(dtype) assert ln_type in ["layernorm", "rmsnorm"] all_reduce_loss_bytes = 4 # 1 * FP32 - # for loss, dgamma and dbeta - # TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp - weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1 - allreduce_total_bytes = ( - all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize - ) - other_bytes = 0 + + # JAX is able to optimize away the computation for dbeta because our + # loss function is jnp.mean, it can determine that dbeta is always 1.0/beta.shape[-1] + dbeta_needs_allreduce = ln_type == "layernorm" and use_te_norm + # allreduce for dgamma and if required also dbeta + weight_count = 2 if dbeta_needs_allreduce else 1 + allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * dtype.itemsize if fp8_recipe == recipe.Float8CurrentScaling(): - allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction - return generate_collectives_count( - allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes - ) + allreduce_total_bytes += dtype.itemsize # 1 * dtype for the amax reduction + return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest_parametrize_wrapper("data_shape", NORM_INPUT_SHAPES) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("zero_centered_gamma", [False, True]) - @pytest_parametrize_wrapper("shard_weights", [False, True]) + @pytest_parametrize_wrapper("use_te_norm", [False, True]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_layernorm( @@ -97,7 +100,7 @@ def test_layernorm( data_shape, dtype, zero_centered_gamma, - shard_weights, + use_te_norm, fp8_recipe, use_shardy, ): @@ -126,51 +129,46 @@ def ref_func(x, gamma, beta): return jnp.mean(output) (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs( - data_shape, mesh_resource, dtype, shard_weights + data_shape, mesh_resource, dtype ) collective_count_ref = self.generate_collectives_count_ref( - mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe + mesh_resource, + ln_type, + data_shape, + dtype, + mesh_axes, + fp8_recipe, + use_te_norm=use_te_norm, ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): + prim_ctx = primitive_context( + f"NormFwdPrimitive={use_te_norm},NormBwdPrimitive={use_te_norm}" + ) + with mesh, fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + ), prim_ctx: x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec)) with warnings.catch_warnings(record=True) as warns: - try: - compare_ops( - target_func, - ref_func, - [x_, gamma_, beta_], - collective_count_ref, - grad_args=(0, 1, 2), - metric_fwd_dtype=q_dtype, - metric_bwd_dtype=q_dtype, - in_shardings=(x_pspec, g_pspec, b_pspec), - out_shardings=(None, (x_pspec, g_pspec, b_pspec)), - ) - except AssertionError as err: - # Layernorm should still produce the correct numerical result with - # gamma/beta sharded. However, the collective count may not be the same - # when XLA is forced to unshard gamma and/or beta. We can catch - # and ignore that specific error here. - if ( - g_pspec[-1] is None and b_pspec[-1] is None - ) or "Expected collective count" not in str(err): - raise err - finally: - for w in warns: - assert "Enforcing no sharding of parameters hidden dim!" in str(w), ( - "Layernorm primitive did not raise the correct warning for " - "unsupported sharding of gamma and/or beta" - ) + compare_ops( + target_func, + ref_func, + [x_, gamma_, beta_], + collective_count_ref, + grad_args=(0, 1, 2), + metric_fwd_dtype=q_dtype, + metric_bwd_dtype=q_dtype, + in_shardings=(x_pspec, g_pspec, b_pspec), + out_shardings=(None, (x_pspec, g_pspec, b_pspec)), + ) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest_parametrize_wrapper("data_shape", NORM_INPUT_SHAPES) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("shard_weights", [False, True]) + @pytest_parametrize_wrapper("use_te_norm", [False, True]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_rmsnorm( @@ -181,7 +179,7 @@ def test_rmsnorm( mesh_resource, data_shape, dtype, - shard_weights, + use_te_norm, fp8_recipe, use_shardy, ): @@ -191,8 +189,13 @@ def test_rmsnorm( q_dtype = jnp.float8_e4m3fn def target_func(x, gamma): - quantizer = QuantizerFactory.create_set().x - return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon, quantizer=quantizer)) + with primitive_context( + f"NormFwdPrimitive={use_te_norm},NormBwdPrimitive={use_te_norm}" + ): + quantizer = QuantizerFactory.create_set().x + return jnp.mean( + layernorm(x, gamma, None, ln_type, False, epsilon, quantizer=quantizer) + ) def ref_func(x, gamma): x = jnp.asarray(x, jnp.float32) @@ -202,40 +205,37 @@ def ref_func(x, gamma): return jnp.mean(output) (x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs( - data_shape, mesh_resource, dtype, shard_weights + data_shape, mesh_resource, dtype ) collective_count_ref = self.generate_collectives_count_ref( - mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe + mesh_resource, + ln_type, + data_shape, + dtype, + mesh_axes, + fp8_recipe, + use_te_norm=use_te_norm, ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): + prim_ctx = primitive_context( + f"NormFwdPrimitive={use_te_norm},NormBwdPrimitive={use_te_norm}" + ) + with mesh, fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + ), prim_ctx: x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) with warnings.catch_warnings(record=True) as warns: - try: - compare_ops( - target_func, - ref_func, - [x_, gamma_], - collective_count_ref, - grad_args=(0, 1), - metric_fwd_dtype=q_dtype, - metric_bwd_dtype=q_dtype, - in_shardings=(x_pspec, g_pspec), - out_shardings=(None, (x_pspec, g_pspec)), - ) - except AssertionError as err: - # RmsNorm should still produce the correct numerical result with - # gamma/beta sharded. However, the collective count may not be the same - # when XLA is forced to unshard gamma. We can catch - # and ignore that specific error here. - if g_pspec[-1] is None or "Expected collective count" not in str(err): - raise err - finally: - for w in warns: - assert "Enforcing no sharding of parameters hidden dim!" in str(w), ( - "RmsNorm primitive did not raise the correct warning for " - "unsupported sharding of gamma and/or beta" - ) + compare_ops( + target_func, + ref_func, + [x_, gamma_], + collective_count_ref, + grad_args=(0, 1), + metric_fwd_dtype=q_dtype, + metric_bwd_dtype=q_dtype, + in_shardings=(x_pspec, g_pspec), + out_shardings=(None, (x_pspec, g_pspec)), + ) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 8ad6dccfec..baab7229b6 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -25,6 +25,7 @@ make_swa_mask, ) from transformer_engine.jax.quantize.helper import DType as TEDType +from transformer_engine.jax.cpp_extensions.base import primitive_context PRNGKey = Any Shape = Tuple[int, ...] @@ -1604,18 +1605,5 @@ def print_debug_tensor_stats(prefix, tensor, hist=False): @contextmanager def use_jax_gemm(enabled=False): - orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS", None) - - try: - if enabled: - os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=false" - else: - os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=true" + with primitive_context(f"GemmPrimitive={enabled}"): yield - - finally: - if enabled: - if orig_custom_calls_filter is None: - os.environ.pop("NVTE_JAX_CUSTOM_CALLS") - else: - os.environ["NVTE_JAX_CUSTOM_CALLS"] = orig_custom_calls_filter diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index fcc2108cca..41e2cdfcfa 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -7,6 +7,7 @@ import warnings from abc import ABCMeta, abstractmethod from functools import partial +from contextlib import contextmanager from packaging import version from jax.extend import core @@ -34,7 +35,7 @@ class BasePrimitive(metaclass=ABCMeta): _is_enabled = True # Default list of primitives to disable for all recipes - _default_disable_names = ["GemmPrimitive"] + _default_disable_names = ["GemmPrimitive", "NormFwdPrimitive", "NormBwdPrimitive"] @classmethod def enabled(cls): @@ -258,3 +259,66 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F cls.set_enabled(False) else: raise ValueError(f"Primitive not found in registry: {name}") + + +def _parse_custom_call_string_to_dict(custom_calls_str: str): + """ + Parses a custom call string into a dictionary of primitive names and their enabled states. + The input string can be a single value 'true' or 'false' to enable/disable all primitives, + or a comma-separated list of key=value pairs. + + Args: + custom_calls_str: A string representing the custom call settings. + + Returns: + A dictionary where keys are primitive names and values are booleans indicating enabled state. + """ + custom_calls_str = custom_calls_str.strip() + if custom_calls_str.lower() == "true": + return {primitive_name: True for primitive_name in _primitive_registry} + if custom_calls_str.lower() == "false": + return {primitive_name: False for primitive_name in _primitive_registry} + + settings = {} + for pair in custom_calls_str.split(","): + pair = pair.strip() + if "=" in pair: + key, value = pair.split("=", 1) + key = key.strip() + value = value.strip().lower() + settings[key] = value == "true" + return settings + + +@contextmanager +def primitive_context(primitive_enabling_changes: str): + """Context manager to temporarily change the enabled state of primitives. + + This context manager allows for temporary changes to the enabled state of + primitives within its scope. Any changes made will be reverted once the + context is exited. + + Args: + primitive_enabling_changes: A string representing the changes to be made to the enabled state of primitives. This input string uses the same pattern as the `NVTE_JAX_CUSTOM_CALLS` environment variable, `Prim1=true,Prim2=false` or `false`/`true` to disable or enable all primitives, respectively. + """ + orig_env_var = os.getenv("NVTE_JAX_CUSTOM_CALLS") + + primitives = {} + if orig_env_var is not None: + primitives = _parse_custom_call_string_to_dict(orig_env_var) + + changes = _parse_custom_call_string_to_dict(primitive_enabling_changes) + primitives.update(changes) + + updated_env_var = ",".join( + f"{name}={'true' if enabled else 'false'}" for name, enabled in primitives.items() + ) + + os.environ["NVTE_JAX_CUSTOM_CALLS"] = str(updated_env_var) + try: + yield + finally: + if orig_env_var is not None: + os.environ["NVTE_JAX_CUSTOM_CALLS"] = orig_env_var + else: + del os.environ["NVTE_JAX_CUSTOM_CALLS"] diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 3b563efbd0..86f6f48694 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -828,18 +828,23 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) """ JAX native layernorm implementation """ - x_ = jnp.asarray(x, jnp.float32) - if not is_norm_zero_centered_gamma_in_weight_dtype( - quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING - ): - gamma = gamma.astype(jnp.float32) - mean = jnp.mean(x_, axis=-1, keepdims=True) - var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) + mean = jnp.mean(x, axis=-1, keepdims=True) + var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True) rsigma = jax.lax.rsqrt(var + epsilon) - normed_input = (x_ - mean) * rsigma + normed_input = (x - mean) * rsigma + + gamma_ = gamma if zero_centered_gamma: - gamma += 1.0 - output = normed_input * gamma + beta + zero_centered_gamma_dtype = ( + gamma.dtype + if is_norm_zero_centered_gamma_in_weight_dtype( + quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING + ) + else jnp.float32 + ) + gamma_ = (gamma + jnp.asarray(1.0, dtype=zero_centered_gamma_dtype)).astype(gamma.dtype) + + output = normed_input * gamma_ + beta if quantizer: ln_out = quantizer.quantize(output, dq_dtype=x.dtype) @@ -853,17 +858,22 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): """ JAX native rmsnorm implementation """ - x_ = jnp.asarray(x, jnp.float32) - if not is_norm_zero_centered_gamma_in_weight_dtype( - quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING - ): - gamma = gamma.astype(jnp.float32) - var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) rsigma = jax.lax.rsqrt(var + epsilon) - normed_input = x_ * rsigma + normed_input = x * rsigma + + gamma_ = gamma if zero_centered_gamma: - gamma += 1.0 - output = normed_input * gamma + zero_centered_gamma_dtype = ( + gamma.dtype + if is_norm_zero_centered_gamma_in_weight_dtype( + quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING + ) + else jnp.float32 + ) + gamma_ = gamma + jnp.asarray(1.0, dtype=zero_centered_gamma_dtype) + + output = normed_input * gamma_ if quantizer: ln_out = quantizer.quantize(output, dq_dtype=x.dtype)