diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 8917e92465..8533824169 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp +import numpy as np import pytest from jax import jit, value_and_grad from functools import reduce @@ -54,6 +55,7 @@ """ Find supported scaling modes""" if is_fp8_supported: supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) + supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING) if is_mxfp8_supported: supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) @@ -71,8 +73,9 @@ def is_shape_supported_by_mxfp8(input_shape): def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x): + assert_allclose(a.scale_inv, b.scale_inv) assert_allclose(a.data, b.data) - assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8)) + elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x): assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor) assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor) @@ -159,7 +162,12 @@ def test_act_grad(self, shape, activation_type): @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) - def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type): + @pytest_parametrize_wrapper( + "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING] + ) + def test_act_grad_with_tensor_scaling_fp8( + self, random_inputs, activation_type, output_type, scaling_mode + ): x = random_inputs x = jnp.expand_dims(x, axis=-2) x = jnp.repeat(x, len(activation_type), axis=-2) @@ -170,7 +178,7 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, ) quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, + scaling_mode=scaling_mode, q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) @@ -188,8 +196,11 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] ) - def test_act_forward_with_delayed_scaling_fp8( - self, random_inputs, activation_type, output_type, q_layout + @pytest_parametrize_wrapper( + "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING] + ) + def test_act_forward_with_tensor_scaling_fp8( + self, random_inputs, activation_type, output_type, q_layout, scaling_mode ): x = random_inputs x = jnp.expand_dims(x, axis=-2) @@ -198,7 +209,7 @@ def test_act_forward_with_delayed_scaling_fp8( te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, - scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, + scaling_mode=scaling_mode, q_dtype=output_type, q_layout=q_layout, ) @@ -335,8 +346,20 @@ def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] ) - def test_norm_grad_with_delayed_scaling_fp8( - self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout + @pytest_parametrize_wrapper( + "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING] + ) + def test_norm_grad_with_tensor_scaling_fp8( + self, + n, + hidden, + norm_type, + zero_centered_gamma, + epsilon, + inp_dtype, + out_dtype, + q_layout, + scaling_mode, ): """ Test transformer_engine.jax.layernorm.layernorm @@ -345,9 +368,7 @@ def test_norm_grad_with_delayed_scaling_fp8( pytest.skip("RMSNorm and zero_centered_gamma is not supported!") quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, - q_dtype=out_dtype, - q_layout=q_layout, + scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout ) self._test_norm_grad( n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer @@ -406,8 +427,20 @@ def _test_norm_forward( @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] ) - def test_norm_forward_with_delayed_scaling_fp8( - self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout + @pytest_parametrize_wrapper( + "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING] + ) + def test_norm_forward_with_tensor_scaling_fp8( + self, + n, + hidden, + norm_type, + zero_centered_gamma, + epsilon, + inp_dtype, + out_dtype, + q_layout, + scaling_mode, ): if norm_type == "rmsnorm" and zero_centered_gamma is True: pytest.skip("RMSNorm and zero_centered_gamma is not supported!") @@ -420,7 +453,7 @@ def test_norm_forward_with_delayed_scaling_fp8( epsilon=epsilon, inp_dtype=inp_dtype, out_dtype=out_dtype, - scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, + scaling_mode=scaling_mode, q_layout=q_layout, ) @@ -448,8 +481,8 @@ def test_norm_forward_with_block_scaling_fp8( } ALL_QUANTIZE_TEST_SHAPES = [ - (32, 64), - (2, 64, 32), + (32, 256, 128), + (64, 32, 32, 256), ] QUANTIZE_TEST_SHAPES = { @@ -630,16 +663,19 @@ def test_quantize_dact_dbias_no_quantization( @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper( - "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] ) - def test_quantize_dact_dbias_delayed_scaling( - self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout + @pytest_parametrize_wrapper( + "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING] + ) + def test_quantize_dact_dbias_tensor_scaling( + self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode ): self._test_quantize_dact_dbias( in_dtype=in_dtype, input_shape=input_shape, out_dtype=out_dtype, - scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING, + scaling_mode=scaling_mode, activation_type=activation_type, is_dbias=is_dbias, q_layout=q_layout, @@ -830,7 +866,10 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): Test layernorm_dense VJP Rule """ # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: + if q_dtype == jnp.float8_e5m2 and scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, + ): pytest.skip("E5M2 is not supported in normalization with TE Backend!") # zero_centered_gamma is already tested in TestNorm @@ -916,7 +955,10 @@ def test_layernorm_mlp_grad( Test layernorm_mlp VJP Rule """ # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: + if q_dtype == jnp.float8_e5m2 and scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, + ): pytest.skip("E5M2 is not supported in normalization with TE Backend!") # zero_centered_gamma is already tested in TestNorm diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index 175de417bc..b608e57a35 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -10,47 +10,22 @@ import numpy as np from utils import assert_allclose -from transformer_engine.common.recipe import DelayedScaling +from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.jax import fp8_autocast, get_delayed_scaling -from transformer_engine.jax.quantize import QuantizeConfig, is_fp8_available, AmaxComputeAlgo +from transformer_engine.jax.quantize import ( + QuantizeConfig, + is_fp8_available, + ScalingMode, + update_collections, +) from transformer_engine.jax.sharding import MeshResource, global_mesh_resource is_fp8_supported, reason = is_fp8_available() +is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) -class TestQuantizeConfig(unittest.TestCase): - - @unittest.skipIf(not is_fp8_supported, reason=reason) - def test_initialize(self): - margin = 5.0 - fp8_format = FP8Format.E4M3 - amax_history_len = 10 - - QuantizeConfig.initialize( - margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len - ) - - self.assertEqual( - QuantizeConfig.MARGIN, - margin, - f"QuantizeConfig.MARGIN initialization failed, should be {margin}" - f" but got {QuantizeConfig.MARGIN}.", - ) - self.assertEqual( - QuantizeConfig.FP8_FORMAT, - fp8_format, - f"QuantizeConfig.FP8_FORMAT initialization failed, should be {fp8_format}" - f" but got {QuantizeConfig.FP8_FORMAT}.", - ) - self.assertEqual( - QuantizeConfig.AMAX_HISTORY_LEN, - amax_history_len, - f"QuantizeConfig.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}" - f" but got {QuantizeConfig.AMAX_HISTORY_LEN}.", - ) - - QuantizeConfig.finalize() +class TestHelper(unittest.TestCase): @unittest.skipIf(not is_fp8_supported, reason=reason) def test_update_collections(self): @@ -61,12 +36,12 @@ def test_update_collections(self): "test1": original_val, "test2": original_val, } - updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state) + updated_state = update_collections({"test1": updated_val}, original_state) self.assertEqual(updated_state["test1"], updated_val) self.assertEqual(updated_state["test2"], original_val) original_state = flax.core.frozen_dict.FrozenDict(original_state) - updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state) + updated_state = update_collections({"test1": updated_val}, original_state) self.assertEqual(updated_state["test1"], updated_val) self.assertEqual(updated_state["test2"], original_val) @@ -82,8 +57,18 @@ def _compare_delay_scaling(self, ref, test): self.assertTrue(ref.amax_history_len == test.amax_history_len) self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) + def _compare_current_scaling(self, test): + self.assertEqual(QuantizeConfig.MARGIN, test.margin) + self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) + self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING) + + def _compare_mxfp8_scaling(self, test): + self.assertEqual(QuantizeConfig.MARGIN, test.margin) + self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) + self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING) + @unittest.skipIf(not is_fp8_supported, reason=reason) - def test_fp8_autocast(self): + def test_fp8_autocast_delayed_scaling(self): QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_defult_state() @@ -107,6 +92,56 @@ def test_fp8_autocast(self): self._check_defult_state() + @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) + def test_fp8_autocast_mxfp8_scaling(self): + QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. + self._check_defult_state() + + with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()): + self.assertFalse(QuantizeConfig.is_fp8_enabled()) + self._compare_current_scaling(Float8CurrentScaling()) + + self._check_defult_state() + + cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3) + with fp8_autocast(enabled=True, fp8_recipe=cs): + self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self._compare_current_scaling(cs) + + self._check_defult_state() + + cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID) + with fp8_autocast(enabled=True, fp8_recipe=cs): + self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self._compare_current_scaling(cs) + + self._check_defult_state() + + @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) + def test_fp8_autocast_mxfp8_scaling(self): + QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. + self._check_defult_state() + + with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()): + self.assertFalse(QuantizeConfig.is_fp8_enabled()) + self._compare_mxfp8_scaling(MXFP8BlockScaling()) + + self._check_defult_state() + + bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3) + with fp8_autocast(enabled=True, fp8_recipe=bs): + self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self._compare_mxfp8_scaling(bs) + + self._check_defult_state() + + bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID) + with fp8_autocast(enabled=True, fp8_recipe=bs): + self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self._compare_mxfp8_scaling(bs) + + self._check_defult_state() + @unittest.skipIf(not is_fp8_supported, reason=reason) def test_fp8_autocast_with_sharding_resource(self): QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index f6b6ae22c2..7723a59035 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -31,19 +31,23 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } - NVTE_CHECK(x.data.shape.size() == 2); - NVTE_CHECK(gamma.data.shape == beta.data.shape); - NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]); + NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor."); + NVTE_CHECK(gamma.data.shape == beta.data.shape, "Gamma and Beta must have the same shape."); + NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0], "Gamma must have the same hidden size."); + + NVTE_CHECK(epsilon >= 0.f, "Epsilon must be non-negative."); - NVTE_CHECK(epsilon >= 0.f); + NVTE_CHECK(z->data.shape == x.data.shape, "Output tensor must have the same shape as x."); - NVTE_CHECK(z->data.shape == x.data.shape); + NVTE_CHECK(mu->data.shape == std::vector<size_t>{x.data.shape[0]}, + "Mu must be 1D tensor with shape (x.shape[0],)."); + NVTE_CHECK(mu->data.dtype == DType::kFloat32, "Mu must be a float32 tensor."); - NVTE_CHECK(mu->data.shape == std::vector<size_t>{x.data.shape[0]}); - NVTE_CHECK(mu->data.dtype == DType::kFloat32); + NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]}, + "RSigma must be 1D tensor with shape (x.shape[0],)."); + NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor."); - NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]}); - NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); + NVTE_CHECK(gamma.data.dtype == beta.data.dtype, "Gamma and Beta must have the same dtype."); if (!workspace->data.shape.empty()) { CheckInputTensor(x, "x"); diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index c27f6f50f7..8f01a08798 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -119,6 +119,11 @@ def abstract( f" {x_aval.shape} and act_len {act_len}" ) + assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, ( + "Current tensor scaling is not supported for fused activation and quantization. Please" + " do activation in higher-precision then quantize with current tensor scaling." + ) + out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) @@ -451,6 +456,12 @@ def abstract( f" {x_aval.shape} and act_len {act_len}" ) assert scale_aval.dtype == jnp.float32 + + assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, ( + "Current tensor scaling is not supported for fused dact and quantization. Please do" + " dact in higher-precision then quantize with current tensor scaling." + ) + ir_hidden_size = dz_aval.shape[-1] gi_hidden_size = act_len * x_aval.shape[-1] assert act_len * ir_hidden_size == gi_hidden_size @@ -463,7 +474,10 @@ def abstract( scaling_mode ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) if is_2x: - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + if scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING.value, + ScalingMode.CURRENT_TENSOR_SCALING.value, + ): colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) else: colwise_out_shape = out_shape @@ -669,6 +683,10 @@ def infer_sharding_from_operands( x_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[2]) + assert ( + scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value + ), "Partitioned current tensor scaling is not yet supported." + out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" ) @@ -937,6 +955,16 @@ def act_lu( out = out.reshape(output_shape) return out + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. + out = act_lu( + x=x.astype(jnp.float32), + activation_type=activation_type, + quantizer=None, + ) + out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) + return out + if isinstance(quantizer, DelayedScaleQuantizer): scale = quantizer.scale @@ -1012,8 +1040,12 @@ def quantize_dact_dbias( # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): - out = dact_lu(dz, x, activation_type, quantizer=None) - return _quantize_dbias_impl(out, quantizer, is_dbias=True, flatten_axis=-2) + out = dact_lu( + dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None + ) + return _quantize_dbias_impl( + out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + ) is_gated = act_len == 2 # TE/common does not support DelayedScaling2x for gated-act yet @@ -1056,6 +1088,20 @@ def quantize_dact_dbias( dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) return output.astype(x.dtype), dbias + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. + out, _ = quantize_dact_dbias( + dz=dz.astype(jnp.float32), + x=x.astype(jnp.float32), + activation_type=activation_type, + is_dbias=False, + quantizer=None, + ) + out, dbias = _quantize_dbias_impl( + out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 + ) + return out, dbias + if isinstance(quantizer, DelayedScaleQuantizer): scale = quantizer.scale @@ -1095,7 +1141,11 @@ def quantize_dact_dbias( ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise - if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): + if ( + quantizer.scaling_mode + in (ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING) + and quantizer.is_2x2x() + ): colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0327542c2f..175f1e61ce 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -175,7 +175,7 @@ def _dequantize(x, scale_inv, dq_dtype): 4, ), ) -def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision): +def __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision): # Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype) rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype) @@ -193,12 +193,13 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision): return out_3d -def _jax_gemm_delayed_scaling_fp8( +def _jax_gemm_tensor_scaling_fp8( lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]] ): """FP8 GEMM for XLA pattern match""" - assert ( - rhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING + assert rhs.scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, ), "rhs does not have delayed tensor scaling mode" (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums @@ -216,7 +217,7 @@ def _jax_gemm_delayed_scaling_fp8( precision = ( jax.lax.Precision.HIGHEST if QuantizeConfig.FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT ) - out_3d = __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision) + out_3d = __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision) # Reshape [B, M, N] -> [..., M, N] out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape) @@ -291,8 +292,11 @@ def _jax_gemm( def _jax_gemm_fp8_impl(lhs, rhs): - if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: - return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums) + if lhs.scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, + ): + return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums) if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) @@ -403,7 +407,10 @@ def grouped_gemm( rhs_shape = rhs.data.shape out_dtype = lhs.dq_dtype # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout - if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: + if lhs.scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, + ): assert not ( lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 ), "FP8 GEMM does not support E5M2 * E5M2" @@ -430,7 +437,10 @@ def grouped_gemm( if scaling_mode == ScalingMode.NO_SCALING: lhs_3d = _shape_normalization(lhs, lhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn) - elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: + elif scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, + ): lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") elif scaling_mode == ScalingMode.MXFP8_1D_SCALING: @@ -473,7 +483,7 @@ def grouped_gemm( if scaling_mode == ScalingMode.NO_SCALING: lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: + if scaling_mode in (ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING): lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1)) rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1)) if scaling_mode == ScalingMode.MXFP8_1D_SCALING: diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index d64104ac27..a74f0f9088 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -216,7 +216,8 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, """ should_apply_war = ( quantizer is not None - and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING + and quantizer.scaling_mode + in (ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING) and quantizer.is_2x2x() ) if not should_apply_war: diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 388d4f17ee..c079edbee1 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -26,6 +26,7 @@ te_dtype_to_jax_dtype, NamedSharding, ) +from .quantization import _quantize_dbias_impl from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ( @@ -100,6 +101,11 @@ def abstract( assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, ( + "Current tensor scaling is not supported for fused norm and quantization. Please do" + " norm in higher-precision then quantize with current tensor scaling." + ) + mu_rsigama_dtype = jnp.float32 if norm_type == NVTE_Norm_Type.LayerNorm: @@ -731,6 +737,8 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) JAX native layernorm implementation """ x_ = jnp.asarray(x, jnp.float32) + gamma = gamma.astype(jnp.float32) + beta = beta.astype(jnp.float32) mean = jnp.mean(x_, axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) rsigma = jax.lax.rsqrt(var + epsilon) @@ -752,6 +760,7 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): JAX native rmsnorm implementation """ x_ = jnp.asarray(x, jnp.float32) + gamma = gamma.astype(jnp.float32) var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) rsigma = jax.lax.rsqrt(var + epsilon) normed_input = x_ * rsigma @@ -827,9 +836,26 @@ def layernorm_fwd( ) return output, mu, rsigma + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + # Current scaling does not support fused operations. Perform norm in higher precision then quantize after. + out, mu, rsigma = layernorm_fwd( + x=x.astype(jnp.float32), + gamma=gamma.astype(jnp.float32), + beta=beta.astype(jnp.float32), + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon, + quantizer=None, + ) + out = quantizer.quantize(out, dq_dtype=x.dtype) + # out,_ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) + return out, mu, rsigma + is_2x2x = quantizer.is_2x2x() # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, + ): is_2x2x = False ( rowwise_casted_output, @@ -857,7 +883,10 @@ def layernorm_fwd( quantizer.update(updated_amax) # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, + ): colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -1009,9 +1038,24 @@ def rmsnorm_fwd( ) return output, rsigma + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + # Current scaling does not support fused operations. Perform norm in higher precision then quantize after. + out, rsigma = rmsnorm_fwd( + x=x, + gamma=gamma, + zero_centered_gamma=zero_centered_gamma, + epsilon=epsilon, + quantizer=None, + ) + out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) + return out, rsigma + is_2x2x = quantizer.is_2x2x() # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, + ): is_2x2x = False ( rowwise_casted_output, @@ -1039,7 +1083,10 @@ def rmsnorm_fwd( quantizer.update(updated_amax) # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: + if quantizer.is_2x2x() and quantizer.scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, + ): colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 2911b5a420..93333d358f 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -93,7 +93,10 @@ def abstract( ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: + if scaling_mode in ( + ScalingMode.DELAYED_TENSOR_SCALING.value, + ScalingMode.CURRENT_TENSOR_SCALING.value, + ): colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) else: colwise_out_shape = out_shape @@ -298,6 +301,11 @@ def infer_sharding_from_operands( result_infos, ): del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer) # Unused. + + assert ( + scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value + ), "Current tensor scaling is not yet supported for multi-GPU partitioning." + x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) out_sharding = NamedSharding( @@ -370,6 +378,11 @@ def partition( result_infos, ): del result_infos, is_outer + + assert ( + scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value + ), "Current tensor scaling is not yet supported for multi-GPU partitioning." + x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) out_sharding = NamedSharding( @@ -592,7 +605,11 @@ def _quantize_dbias_impl( is_outer=True, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise - if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): + if ( + quantizer.scaling_mode + in (ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING) + and quantizer.is_2x2x() + ): colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index fc7f231f34..9487c2bd86 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -44,6 +44,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape); + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, + "Current tensor scaling does not support fused operations. Please call this primitive " + "in higher-precision then quantize with current scaling."); + if (is_fp8_dtype(out_dtype)) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); @@ -152,6 +156,10 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size}; auto dbias_shape = std::vector<size_t>{hidden_size}; + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, + "Current tensor scaling does not support fused operations. Please call this primitive " + "in higher-precision then quantize with current scaling."); + // Evil hack to specify TE impl // Note: nvte_quantize_dbias_dgelu chooses its internal impl based // on what pointers are allocated, e.g. whether to output with @@ -219,6 +227,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto act_type = static_cast<NVTE_Activation_Type>(act_enum); auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, + "Current tensor scaling does not support fused operations. Please call this primitive " + "in higher-precision then quantize with current scaling."); + auto *output = output_buf->untyped_data(); auto *colwise_output = colwise_output_buf->untyped_data(); auto *dbias = dbias_buf->untyped_data(); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index d4b9bf720e..1653adb8ec 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -95,7 +95,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh lhs_i.set_rowwise_data(static_cast<void *>(lhs_ptr), lhs_dtype, lhs_shape); rhs_i.set_rowwise_data(static_cast<void *>(rhs_ptr), rhs_dtype, rhs_shape); - if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) { lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat32, std::vector<size_t>{1}); rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat32, @@ -190,8 +190,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten, return GroupedGemmImpl(lhs_ptr, lhs_dtype, lhs_sinv_ptr, lhs_sinv_dtype, rhs_ptr, rhs_dtype, rhs_sinv_ptr, rhs_sinv_dtype, bias_ptr, bias_dtype, out_ptr, out_dtype, - workspace_ptr, workspace_size, num_gemms, dim_list_ptr, scaling_mode, - stream); + workspace_ptr, workspace_size, num_gemms, dim_list_ptr, + scaling_mode, stream); } XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index f7577c24f3..4c3d29ef0a 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -44,6 +44,7 @@ enum class JAXX_Scaling_Mode : int64_t { NO_SCALING = 0, DELAYED_TENSOR_SCALING = 1, MXFP8_1D_SCALING = 2, + CURRENT_TENSOR_SCALING = 3, }; static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { @@ -57,6 +58,9 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { case JAXX_Scaling_Mode::MXFP8_1D_SCALING: return NVTEScalingMode::NVTE_MXFP8_1D_SCALING; break; + case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING: + return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; + break; default: NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode)); break; diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index e23e42f528..f3aa237996 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -98,7 +98,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto workspace_shape = std::vector<size_t>{workspace_size}; auto input_tensor = TensorWrapper(input, input_shape, in_dtype); - auto gamma_tensor = TensorWrapper(gamma, gamma_shape, in_dtype); + auto gamma_tensor = TensorWrapper(gamma, gamma_shape, w_dtype); auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin; @@ -107,6 +107,10 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape); + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, + "Current tensor scaling does not support fused operations. Please call this primitive " + "in higher-precision then quantize with current scaling."); + if (is_fp8_dtype(out_dtype)) { output_tensor.set_rowwise_scale_inv( scale_inv_buf->untyped_data(), @@ -134,6 +138,8 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc } if (_norm_type == NVTE_Norm_Type::LayerNorm) { + NVTE_CHECK(w_dtype == convert_ffi_datatype_to_te_dtype(beta_buf.element_type()), + "gamma and beta must have the same data type."); auto beta_tensor = TensorWrapper(beta, gamma_shape, w_dtype); auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 5c165cccb6..3866aca62f 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -142,6 +142,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING) .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) + .value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) .export_values(); pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout", diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 481dbd7cdf..d544972c97 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -7,6 +7,7 @@ #include "extensions.h" #include "transformer_engine/cast.h" +#include "transformer_engine/recipe.h" #include "xla/ffi/api/c_api.h" namespace transformer_engine { @@ -107,12 +108,15 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; + if (quantize_layout == QuantizeLayout::ROWWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_fp8_dtype(out_dtype)) { - if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + if (is_tensor_scaling) { float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); float *amax = reinterpret_cast<float *>(amax_buf->untyped_data()); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); @@ -142,11 +146,9 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T : output_shape; output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) - ? scale_inv_buf - : colwise_scale_inv_buf; + auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf; - if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + if (is_tensor_scaling) { output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector<size_t>{1}); @@ -159,6 +161,21 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T } } + if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) { + nvte_compute_amax(input_tensor.data(), // input data + output_tensor.data(), // output data (for amax) + stream); + + QuantizationConfigWrapper quant_config; + /** defaults for now, TODO(Jeremy) move to parameter */ + bool force_pow_2_scales = false; + float amax_epsilon = 0.0; + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(amax_epsilon); + nvte_compute_scale_from_amax(output_tensor.data(), quant_config, stream); + output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1}); + } + auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index d68eb3c6c2..d43c61c9fc 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -85,6 +85,7 @@ def _dq_func_block_scaling(scaled_tensor): funcs = { ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, + ScalingMode.CURRENT_TENSOR_SCALING: _dq_func_tensor_scaling, ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling, } diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 98f280b9a9..3a2d6a57d4 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -94,7 +94,7 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: A tuple of (bool, str) indicating support and any error message """ gpu_arch = get_device_compute_capability(gpu_id) - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: + if scaling_mode in (ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING): return _check_delayed_scaling_fp8_support(gpu_arch) if scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _check_block_scaling_fp8_support(gpu_arch) @@ -182,6 +182,8 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: return ScalingMode.DELAYED_TENSOR_SCALING if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): return ScalingMode.MXFP8_1D_SCALING + if isinstance(fp8_recipe, recipe.Float8CurrentScaling): + return ScalingMode.CURRENT_TENSOR_SCALING raise ValueError("Invalid fp8_recipe!") @@ -309,6 +311,30 @@ def finalize() -> None: QuantizeConfig.finalize() +class CurrentScalingQuantizeConfig: + """Configuration class for current scaling FP8 recipe. + + This class provides specific initialization and finalization for current scaling + FP8 quantization mode. + """ + + @staticmethod + def initialize(fp8_recipe: recipe.Recipe) -> None: + """Initialize current scaling FP8 configuration. + + Args: + fp8_recipe: The FP8 recipe to use for initialization + """ + cls = QuantizeConfig + cls.initialize(fp8_recipe) + cls.AMAX_HISTORY_LEN = 0 + + @staticmethod + def finalize() -> None: + """Reset the current scaling configuration.""" + QuantizeConfig.finalize() + + class BlockScalingQuantizeConfig: """Configuration class for block scaling FP8 recipe. @@ -385,6 +411,8 @@ def fp8_autocast( Config = DelayedScalingQuantizeConfig if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): Config = BlockScalingQuantizeConfig + if isinstance(fp8_recipe, recipe.Float8CurrentScaling): + Config = CurrentScalingQuantizeConfig try: with global_shard_guard(mesh_resource): diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index b57043a034..becacde65d 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -27,6 +27,7 @@ "QuantizeLayout", "Quantizer", "QuantizerSet", + "CurrentScaleQuantizer", "DelayedScaleQuantizer", "BlockScaleQuantizer", "QuantizerFactory", @@ -159,34 +160,26 @@ def get_scale_dtype(self): @register_pytree_node_class @dataclass -class DelayedScaleQuantizer(Quantizer): - """Quantizer implementation using delayed scaling. +class CurrentScaleQuantizer(Quantizer): + """Quantizer implementation using current scaling. - This quantizer uses delayed scaling mode with float32 scales and maintains - a history of maximum absolute values for dynamic scaling. + This quantizer uses current scaling mode with float32 scales Attributes: scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING q_layout: Quantization axis (default: ROWWISE_COLWISE) - scale: Current scaling factor - amax_history: History of maximum absolute values """ - scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING + scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE - scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) - amax_history: jnp.ndarray = field( - default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32) - ) - def tree_flatten(self): """Flatten the quantizer for JAX tree operations. Returns: Tuple of (children, aux_data) for tree operations """ - children = (self.scale, self.amax_history) + children = () aux_data = (self.q_dtype, self.scaling_mode, self.q_layout) return (children, aux_data) @@ -217,15 +210,18 @@ def _quantize_func( x: Input tensor to quantize is_colwise: Whether to use column-wise quantization dq_dtype: Data type for dequantized values - flatten_axis: The quantization axis for the tensor + Returns: A ScaledTensor1x containing the quantized data """ dq_dtype = dq_dtype if dq_dtype is not None else x.dtype - compute_dtype = self.scale.dtype + compute_dtype = self.scaling_mode.get_scale_dtype() dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) - scaled_x = x.astype(compute_dtype) * self.scale + amax = jnp.max(jnp.abs(x)).reshape((1,)) + fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) + scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN) + scaled_x = x.astype(compute_dtype) * scale # quantize() in the old dot.py do this way, leave this code block here for future debugging # compute_dtype = x.dtype @@ -233,8 +229,7 @@ def _quantize_func( # scaled_x = x * self.scale.astype(compute_dtype) clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) - scale_inv = 1.0 / self.scale - self.update(jnp.max(jnp.abs(x)).reshape((1,))) + scale_inv = 1.0 / scale return ScaledTensorFactory.create_1x( data=clipped_scaled_x, scale_inv=scale_inv, @@ -294,6 +289,75 @@ def quantize( return colwise_tensor return rowwise_tensor + +@register_pytree_node_class +@dataclass +class DelayedScaleQuantizer(CurrentScaleQuantizer): + """Quantizer implementation using delayed scaling. + + This quantizer uses delayed scaling mode with float32 scales and maintains + a history of maximum absolute values for dynamic scaling. + + Attributes: + scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING + q_layout: Quantization axis (default: ROWWISE_COLWISE) + scale: Current scaling factor + amax_history: History of maximum absolute values + """ + + scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING + q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE + + scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) + amax_history: jnp.ndarray = field( + default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32) + ) + + def tree_flatten(self): + """Flatten the quantizer for JAX tree operations. + + Returns: + Tuple of (children, aux_data) for tree operations + """ + children = (self.scale, self.amax_history) + aux_data = (self.q_dtype, self.scaling_mode, self.q_layout) + return (children, aux_data) + + def _quantize_func( + self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1 + ) -> ScaledTensor1x: + """Quantize function helper for delayed scaling FP8. + + Args: + x: Input tensor to quantize + is_colwise: Whether to use column-wise quantization + dq_dtype: Data type for dequantized values + flatten_axis: The quantization axis for the tensor + Returns: + A ScaledTensor1x containing the quantized data + """ + dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + + compute_dtype = self.scale.dtype + dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) + scaled_x = x.astype(compute_dtype) * self.scale + + # quantize() in the old dot.py do this way, leave this code block here for future debugging + # compute_dtype = x.dtype + # dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) + # scaled_x = x * self.scale.astype(compute_dtype) + + clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) + scale_inv = 1.0 / self.scale + self.update(jnp.max(jnp.abs(x)).reshape((1,))) + return ScaledTensorFactory.create_1x( + data=clipped_scaled_x, + scale_inv=scale_inv, + scaling_mode=self.scaling_mode, + dq_dtype=dq_dtype, + flatten_axis=flatten_axis, + ) + @staticmethod @jax.jit def _update_amax_history(amax_history, new_amax): @@ -531,6 +595,7 @@ class QuantizerFactory: quantizer_type_map = { ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, + ScalingMode.CURRENT_TENSOR_SCALING: CurrentScaleQuantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer, } diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 34f63a994c..edbf38e99d 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -60,10 +60,10 @@ def get_scale_shape( """ -class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl): - """Implementation for delayed scaling mode. +class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): + """Implementation for current scaling mode. - This implementation provides metadata for delayed scaling mode, including scale data type and shape. + This implementation provides metadata for current scaling mode, including scale data type and shape. """ def get_scale_dtype(self) -> jnp.dtype: @@ -96,6 +96,13 @@ def get_scale_shape( return (1,) +class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl): + """Implementation for delayed scaling mode. + + This implementation provides metadata for delayed scaling mode, including scale data type and shape. + """ + + class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): """Implementation for block scaling mode. @@ -226,12 +233,14 @@ class ScalingMode(Enum): This class defines the available scaling modes for tensor quantization: - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales + - CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales - NO_SCALING: No scaling applied """ NO_SCALING = JAXX_Scaling_Mode.NO_SCALING DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING + CURRENT_TENSOR_SCALING = JAXX_Scaling_Mode.CURRENT_TENSOR_SCALING def _get_impl(self) -> ScalingModeMetadataImpl: """Get the implementation for this scaling mode. @@ -329,5 +338,6 @@ def tree_unflatten(cls, aux_data, _children): ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), # WAR + ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(), ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(), }