Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def test_te_delayed_scaling_fp8(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.361 and actual[1] > 0.84
assert actual[0] < 0.362 and actual[1] > 0.84

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
Expand Down Expand Up @@ -535,7 +535,7 @@ def test_te_delayed_scaling_fp8_with_sp(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.361 and actual[1] > 0.84
assert actual[0] < 0.362 and actual[1] > 0.84

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
Expand Down Expand Up @@ -569,7 +569,7 @@ def test_te_delayed_scaling_fp8_shardy(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.361 and actual[1] > 0.84
assert actual[0] < 0.362 and actual[1] > 0.84

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
Expand All @@ -579,7 +579,7 @@ def test_te_delayed_scaling_fp8_with_sp_shardy(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.361 and actual[1] > 0.84
assert actual[0] < 0.362 and actual[1] > 0.84

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_shardy(self):
Expand Down
9 changes: 9 additions & 0 deletions tests/jax/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ class EncoderRunner(BaseRunner):
"attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"attention/DotProductAttention_0/softmax_offset"
),
"attention/DotProductAttention_0/_FusedDotProductAttention_0/softmax_offset": (
"attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel",
Expand Down Expand Up @@ -478,13 +481,19 @@ class DecoderRunner(BaseRunner):
"encoder_decoder_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"encoder_decoder_attention/DotProductAttention_0/softmax_offset"
),
"encoder_decoder_attention/DotProductAttention_0/_FusedDotProductAttention_0/softmax_offset": (
"encoder_decoder_attention/DotProductAttention_0/softmax_offset"
),
"self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
"self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/query/scale": "pre_self_attention_layer_norm/scale",
"self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"self_attention/DotProductAttention_0/softmax_offset"
),
"self_attention/DotProductAttention_0/_FusedDotProductAttention_0/softmax_offset": (
"self_attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel",
Expand Down
Loading