-
Notifications
You must be signed in to change notification settings - Fork 499
[JAX] Fix layernorm distributed test sharding and collective assertions #2041
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Fix layernorm distributed test sharding and collective assertions #2041
Conversation
Signed-off-by: Jeremy Berchtold <[email protected]>
Signed-off-by: Jeremy Berchtold <[email protected]>
cc8be88
to
3fef659
Compare
/te-ci L1 jax |
Signed-off-by: Jeremy Berchtold <[email protected]>
a5b8244
to
025db08
Compare
Signed-off-by: Jeremy Berchtold <[email protected]>
2006c1f
to
e3d2a48
Compare
/te-ci L1 jax |
Signed-off-by: Jeremy Berchtold <[email protected]>
17ae015
to
629dba5
Compare
/te-ci L1 jax |
a3a9513
to
629dba5
Compare
/te-ci L1 jax |
1 similar comment
/te-ci L1 jax |
@@ -828,18 +828,23 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) | |||
""" | |||
JAX native layernorm implementation | |||
""" | |||
x_ = jnp.asarray(x, jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this change, as without this casting, the ctype
of normalization will be bf16
.
f"{name}={'true' if enabled else 'false'}" for name, enabled in primitives.items() | ||
) | ||
|
||
os.environ["NVTE_JAX_CUSTOM_CALLS"] = str(updated_env_var) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
I think we could use manage_primitives
here and don't need to change the env var.
If you prefer to pass key-value pairs, we could update the manage_primitives
to do so too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just read the full docstring on manage_primitives
and you're right it will work. I'll update to that instead for this context
I had missed the disable_all_first
argument and thought manage_primitives
was always disabling any unspecified primitives
Description
Disables TE norm primitive. Updates distributed layernorm collection communication byte count logic and weight sharding. Also extend tests to run both TE norm and JAX norm
Type of change
Changes
tests/jax/test_distributed_layernorm.mlp
primitive_context
for modifying which TE primitives are enabled within a blockuse_jax_gemm
to use this utility instead. Slight improvement on existing behavior as previouslyuse_jax_gemm
would ignoreNVTE_JAX_CUSTOM_CALLS
when inside the block, but now the behavior is(all updates specified in NVTE_JAX_CUSTOM_CALLS) + (override GemmPrimitive=true/false)
Checklist: