Skip to content

[JAX] Update helper tests #1664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from functools import reduce
Expand Down Expand Up @@ -54,6 +55,7 @@
""" Find supported scaling modes"""
if is_fp8_supported:
supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING)
if is_mxfp8_supported:
supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)

Expand All @@ -71,8 +73,9 @@ def is_shape_supported_by_mxfp8(input_shape):

def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
assert_allclose(a.scale_inv, b.scale_inv)
assert_allclose(a.data, b.data)
assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))

elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor)
assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor)
Expand Down Expand Up @@ -159,7 +162,12 @@ def test_act_grad(self, shape, activation_type):
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type):
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_act_grad_with_tensor_scaling_fp8(
self, random_inputs, activation_type, output_type, scaling_mode
):
x = random_inputs
x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
Expand All @@ -170,7 +178,7 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type,
)

quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
scaling_mode=scaling_mode,
q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE,
)
Expand All @@ -188,8 +196,11 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type,
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_act_forward_with_delayed_scaling_fp8(
self, random_inputs, activation_type, output_type, q_layout
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_act_forward_with_tensor_scaling_fp8(
self, random_inputs, activation_type, output_type, q_layout, scaling_mode
):
x = random_inputs
x = jnp.expand_dims(x, axis=-2)
Expand All @@ -198,7 +209,7 @@ def test_act_forward_with_delayed_scaling_fp8(

te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
scaling_mode=scaling_mode,
q_dtype=output_type,
q_layout=q_layout,
)
Expand Down Expand Up @@ -335,8 +346,20 @@ def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_norm_grad_with_delayed_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_norm_grad_with_tensor_scaling_fp8(
self,
n,
hidden,
norm_type,
zero_centered_gamma,
epsilon,
inp_dtype,
out_dtype,
q_layout,
scaling_mode,
):
"""
Test transformer_engine.jax.layernorm.layernorm
Expand All @@ -345,9 +368,7 @@ def test_norm_grad_with_delayed_scaling_fp8(
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")

quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
q_dtype=out_dtype,
q_layout=q_layout,
scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
)
self._test_norm_grad(
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
Expand Down Expand Up @@ -406,8 +427,20 @@ def _test_norm_forward(
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_norm_forward_with_delayed_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_norm_forward_with_tensor_scaling_fp8(
self,
n,
hidden,
norm_type,
zero_centered_gamma,
epsilon,
inp_dtype,
out_dtype,
q_layout,
scaling_mode,
):
if norm_type == "rmsnorm" and zero_centered_gamma is True:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
Expand All @@ -420,7 +453,7 @@ def test_norm_forward_with_delayed_scaling_fp8(
epsilon=epsilon,
inp_dtype=inp_dtype,
out_dtype=out_dtype,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
scaling_mode=scaling_mode,
q_layout=q_layout,
)

Expand Down Expand Up @@ -448,8 +481,8 @@ def test_norm_forward_with_block_scaling_fp8(
}

ALL_QUANTIZE_TEST_SHAPES = [
(32, 64),
(2, 64, 32),
(32, 256, 128),
(64, 32, 32, 256),
]

QUANTIZE_TEST_SHAPES = {
Expand Down Expand Up @@ -630,16 +663,19 @@ def test_quantize_dact_dbias_no_quantization(
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_quantize_dact_dbias_delayed_scaling(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_quantize_dact_dbias_tensor_scaling(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode
):
self._test_quantize_dact_dbias(
in_dtype=in_dtype,
input_shape=input_shape,
out_dtype=out_dtype,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
scaling_mode=scaling_mode,
activation_type=activation_type,
is_dbias=is_dbias,
q_layout=q_layout,
Expand Down Expand Up @@ -830,7 +866,10 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type):
Test layernorm_dense VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")

# zero_centered_gamma is already tested in TestNorm
Expand Down Expand Up @@ -916,7 +955,10 @@ def test_layernorm_mlp_grad(
Test layernorm_mlp VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")

# zero_centered_gamma is already tested in TestNorm
Expand Down
109 changes: 72 additions & 37 deletions tests/jax/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,22 @@
import numpy as np

from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.quantize import QuantizeConfig, is_fp8_available, AmaxComputeAlgo
from transformer_engine.jax.quantize import (
QuantizeConfig,
is_fp8_available,
ScalingMode,
update_collections,
)
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource

is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)


class TestQuantizeConfig(unittest.TestCase):

@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_initialize(self):
margin = 5.0
fp8_format = FP8Format.E4M3
amax_history_len = 10

QuantizeConfig.initialize(
margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
)

self.assertEqual(
QuantizeConfig.MARGIN,
margin,
f"QuantizeConfig.MARGIN initialization failed, should be {margin}"
f" but got {QuantizeConfig.MARGIN}.",
)
self.assertEqual(
QuantizeConfig.FP8_FORMAT,
fp8_format,
f"QuantizeConfig.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {QuantizeConfig.FP8_FORMAT}.",
)
self.assertEqual(
QuantizeConfig.AMAX_HISTORY_LEN,
amax_history_len,
f"QuantizeConfig.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {QuantizeConfig.AMAX_HISTORY_LEN}.",
)

QuantizeConfig.finalize()
class TestHelper(unittest.TestCase):

@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_collections(self):
Expand All @@ -61,12 +36,12 @@ def test_update_collections(self):
"test1": original_val,
"test2": original_val,
}
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)

original_state = flax.core.frozen_dict.FrozenDict(original_state)
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)

Expand All @@ -82,8 +57,18 @@ def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)

def _compare_current_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)

def _compare_mxfp8_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)

@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast(self):
def test_fp8_autocast_delayed_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()

Expand All @@ -107,6 +92,56 @@ def test_fp8_autocast(self):

self._check_defult_state()

@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()

with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(Float8CurrentScaling())

self._check_defult_state()

cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)

self._check_defult_state()

cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)

self._check_defult_state()

@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()

with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(MXFP8BlockScaling())

self._check_defult_state()

bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)

self._check_defult_state()

bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)

self._check_defult_state()

@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
Expand Down
22 changes: 13 additions & 9 deletions transformer_engine/common/normalization/layernorm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,23 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
}

NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(gamma.data.shape == beta.data.shape);
NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]);
NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor.");
NVTE_CHECK(gamma.data.shape == beta.data.shape, "Gamma and Beta must have the same shape.");
NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0], "Gamma must have the same hidden size.");

NVTE_CHECK(epsilon >= 0.f, "Epsilon must be non-negative.");

NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(z->data.shape == x.data.shape, "Output tensor must have the same shape as x.");

NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(mu->data.shape == std::vector<size_t>{x.data.shape[0]},
"Mu must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(mu->data.dtype == DType::kFloat32, "Mu must be a float32 tensor.");

NVTE_CHECK(mu->data.shape == std::vector<size_t>{x.data.shape[0]});
NVTE_CHECK(mu->data.dtype == DType::kFloat32);
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]},
"RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");

NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]});
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32);
NVTE_CHECK(gamma.data.dtype == beta.data.dtype, "Gamma and Beta must have the same dtype.");

if (!workspace->data.shape.empty()) {
CheckInputTensor(x, "x");
Expand Down
Loading