Skip to content

Conversation

jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Aug 7, 2025

Description

Disables TE norm primitive. Updates distributed layernorm collection communication byte count logic and weight sharding. Also extend tests to run both TE norm and JAX norm

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Set TE norm primitives to disabled by default
  • tests/jax/test_distributed_layernorm.mlp
    • Remove weight sharding for gamma and beta that incorrectly sharding in DP
    • Fix collection communication bytes assertions
    • Tests now run both JAX norm and TE norm
  • Updated JAX norm impl casting to avoid performing an AllReduce in fp32 when dtype is in bfloat16. From running test_custom_call_compute.py, this seems to still give the same precision of output as TE norm.
  • Added a context utility primitive_context for modifying which TE primitives are enabled within a block
    • Updated use_jax_gemm to use this utility instead. Slight improvement on existing behavior as previously use_jax_gemm would ignore NVTE_JAX_CUSTOM_CALLS when inside the block, but now the behavior is (all updates specified in NVTE_JAX_CUSTOM_CALLS) + (override GemmPrimitive=true/false)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@jberchtold-nvidia jberchtold-nvidia force-pushed the dev/jberchtold/fix-layernorm-distributed-tests branch from cc8be88 to 3fef659 Compare August 7, 2025 22:41
@jberchtold-nvidia jberchtold-nvidia changed the title [Draft][JAX] Fix layernorm distributed test sharding and collective assertions [JAX] Fix layernorm distributed test sharding and collective assertions Aug 8, 2025
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

Signed-off-by: Jeremy Berchtold <[email protected]>
@jberchtold-nvidia jberchtold-nvidia force-pushed the dev/jberchtold/fix-layernorm-distributed-tests branch from a5b8244 to 025db08 Compare August 8, 2025 18:17
Signed-off-by: Jeremy Berchtold <[email protected]>
@jberchtold-nvidia jberchtold-nvidia force-pushed the dev/jberchtold/fix-layernorm-distributed-tests branch from 2006c1f to e3d2a48 Compare August 8, 2025 18:20
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

Signed-off-by: Jeremy Berchtold <[email protected]>
@jberchtold-nvidia jberchtold-nvidia force-pushed the dev/jberchtold/fix-layernorm-distributed-tests branch from 17ae015 to 629dba5 Compare August 8, 2025 20:15
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

@jberchtold-nvidia jberchtold-nvidia force-pushed the dev/jberchtold/fix-layernorm-distributed-tests branch from a3a9513 to 629dba5 Compare August 8, 2025 22:04
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

1 similar comment
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

@@ -828,18 +828,23 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None)
"""
JAX native layernorm implementation
"""
x_ = jnp.asarray(x, jnp.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this change, as without this casting, the ctype of normalization will be bf16.

f"{name}={'true' if enabled else 'false'}" for name, enabled in primitives.items()
)

os.environ["NVTE_JAX_CUSTOM_CALLS"] = str(updated_env_var)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I think we could use manage_primitives here and don't need to change the env var.
If you prefer to pass key-value pairs, we could update the manage_primitives to do so too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just read the full docstring on manage_primitives and you're right it will work. I'll update to that instead for this context

I had missed the disable_all_first argument and thought manage_primitives was always disabling any unspecified primitives

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants