diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py
index 8917e92465..8533824169 100644
--- a/tests/jax/test_custom_call_compute.py
+++ b/tests/jax/test_custom_call_compute.py
@@ -4,6 +4,7 @@
 
 import jax
 import jax.numpy as jnp
+import numpy as np
 import pytest
 from jax import jit, value_and_grad
 from functools import reduce
@@ -54,6 +55,7 @@
 """ Find supported scaling modes"""
 if is_fp8_supported:
     supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
+    supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING)
 if is_mxfp8_supported:
     supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
 
@@ -71,8 +73,9 @@ def is_shape_supported_by_mxfp8(input_shape):
 
 def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
     if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
+        assert_allclose(a.scale_inv, b.scale_inv)
         assert_allclose(a.data, b.data)
-        assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
+
     elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
         assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor)
         assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor)
@@ -159,7 +162,12 @@ def test_act_grad(self, shape, activation_type):
     @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
     @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
     @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
-    def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type):
+    @pytest_parametrize_wrapper(
+        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
+    )
+    def test_act_grad_with_tensor_scaling_fp8(
+        self, random_inputs, activation_type, output_type, scaling_mode
+    ):
         x = random_inputs
         x = jnp.expand_dims(x, axis=-2)
         x = jnp.repeat(x, len(activation_type), axis=-2)
@@ -170,7 +178,7 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type,
         )
 
         quantizer = QuantizerFactory.create(
-            scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
+            scaling_mode=scaling_mode,
             q_dtype=output_type,
             q_layout=QuantizeLayout.ROWWISE,
         )
@@ -188,8 +196,11 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type,
     @pytest_parametrize_wrapper(
         "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
     )
-    def test_act_forward_with_delayed_scaling_fp8(
-        self, random_inputs, activation_type, output_type, q_layout
+    @pytest_parametrize_wrapper(
+        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
+    )
+    def test_act_forward_with_tensor_scaling_fp8(
+        self, random_inputs, activation_type, output_type, q_layout, scaling_mode
     ):
         x = random_inputs
         x = jnp.expand_dims(x, axis=-2)
@@ -198,7 +209,7 @@ def test_act_forward_with_delayed_scaling_fp8(
 
         te_quantizer, jax_quantizer = QuantizerFactory.create(
             n_quantizers=2,
-            scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
+            scaling_mode=scaling_mode,
             q_dtype=output_type,
             q_layout=q_layout,
         )
@@ -335,8 +346,20 @@ def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp
     @pytest_parametrize_wrapper(
         "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
     )
-    def test_norm_grad_with_delayed_scaling_fp8(
-        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
+    @pytest_parametrize_wrapper(
+        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
+    )
+    def test_norm_grad_with_tensor_scaling_fp8(
+        self,
+        n,
+        hidden,
+        norm_type,
+        zero_centered_gamma,
+        epsilon,
+        inp_dtype,
+        out_dtype,
+        q_layout,
+        scaling_mode,
     ):
         """
         Test transformer_engine.jax.layernorm.layernorm
@@ -345,9 +368,7 @@ def test_norm_grad_with_delayed_scaling_fp8(
             pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
 
         quantizer = QuantizerFactory.create(
-            scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
-            q_dtype=out_dtype,
-            q_layout=q_layout,
+            scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
         )
         self._test_norm_grad(
             n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
@@ -406,8 +427,20 @@ def _test_norm_forward(
     @pytest_parametrize_wrapper(
         "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
     )
-    def test_norm_forward_with_delayed_scaling_fp8(
-        self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
+    @pytest_parametrize_wrapper(
+        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
+    )
+    def test_norm_forward_with_tensor_scaling_fp8(
+        self,
+        n,
+        hidden,
+        norm_type,
+        zero_centered_gamma,
+        epsilon,
+        inp_dtype,
+        out_dtype,
+        q_layout,
+        scaling_mode,
     ):
         if norm_type == "rmsnorm" and zero_centered_gamma is True:
             pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
@@ -420,7 +453,7 @@ def test_norm_forward_with_delayed_scaling_fp8(
             epsilon=epsilon,
             inp_dtype=inp_dtype,
             out_dtype=out_dtype,
-            scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
+            scaling_mode=scaling_mode,
             q_layout=q_layout,
         )
 
@@ -448,8 +481,8 @@ def test_norm_forward_with_block_scaling_fp8(
 }
 
 ALL_QUANTIZE_TEST_SHAPES = [
-    (32, 64),
-    (2, 64, 32),
+    (32, 256, 128),
+    (64, 32, 32, 256),
 ]
 
 QUANTIZE_TEST_SHAPES = {
@@ -630,16 +663,19 @@ def test_quantize_dact_dbias_no_quantization(
     @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
     @pytest_parametrize_wrapper("is_dbias", [True, False])
     @pytest_parametrize_wrapper(
-        "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
+        "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
     )
-    def test_quantize_dact_dbias_delayed_scaling(
-        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
+    @pytest_parametrize_wrapper(
+        "scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
+    )
+    def test_quantize_dact_dbias_tensor_scaling(
+        self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout, scaling_mode
     ):
         self._test_quantize_dact_dbias(
             in_dtype=in_dtype,
             input_shape=input_shape,
             out_dtype=out_dtype,
-            scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
+            scaling_mode=scaling_mode,
             activation_type=activation_type,
             is_dbias=is_dbias,
             q_layout=q_layout,
@@ -830,7 +866,10 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type):
         Test layernorm_dense VJP Rule
         """
         # No Norm FWD E5M2 in TE backend
-        if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
+        if q_dtype == jnp.float8_e5m2 and scaling_mode in (
+            ScalingMode.DELAYED_TENSOR_SCALING,
+            ScalingMode.CURRENT_TENSOR_SCALING,
+        ):
             pytest.skip("E5M2 is not supported in normalization with TE Backend!")
 
         # zero_centered_gamma is already tested in TestNorm
@@ -916,7 +955,10 @@ def test_layernorm_mlp_grad(
         Test layernorm_mlp VJP Rule
         """
         # No Norm FWD E5M2 in TE backend
-        if q_dtype == jnp.float8_e5m2 and scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
+        if q_dtype == jnp.float8_e5m2 and scaling_mode in (
+            ScalingMode.DELAYED_TENSOR_SCALING,
+            ScalingMode.CURRENT_TENSOR_SCALING,
+        ):
             pytest.skip("E5M2 is not supported in normalization with TE Backend!")
 
         # zero_centered_gamma is already tested in TestNorm
diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py
index 175de417bc..b608e57a35 100644
--- a/tests/jax/test_helper.py
+++ b/tests/jax/test_helper.py
@@ -10,47 +10,22 @@
 import numpy as np
 
 from utils import assert_allclose
-from transformer_engine.common.recipe import DelayedScaling
+from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
 from transformer_engine.common.recipe import Format as FP8Format
 from transformer_engine.jax import fp8_autocast, get_delayed_scaling
-from transformer_engine.jax.quantize import QuantizeConfig, is_fp8_available, AmaxComputeAlgo
+from transformer_engine.jax.quantize import (
+    QuantizeConfig,
+    is_fp8_available,
+    ScalingMode,
+    update_collections,
+)
 from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
 
 is_fp8_supported, reason = is_fp8_available()
+is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
 
 
-class TestQuantizeConfig(unittest.TestCase):
-
-    @unittest.skipIf(not is_fp8_supported, reason=reason)
-    def test_initialize(self):
-        margin = 5.0
-        fp8_format = FP8Format.E4M3
-        amax_history_len = 10
-
-        QuantizeConfig.initialize(
-            margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
-        )
-
-        self.assertEqual(
-            QuantizeConfig.MARGIN,
-            margin,
-            f"QuantizeConfig.MARGIN initialization failed, should be {margin}"
-            f" but got {QuantizeConfig.MARGIN}.",
-        )
-        self.assertEqual(
-            QuantizeConfig.FP8_FORMAT,
-            fp8_format,
-            f"QuantizeConfig.FP8_FORMAT initialization failed, should be {fp8_format}"
-            f" but got {QuantizeConfig.FP8_FORMAT}.",
-        )
-        self.assertEqual(
-            QuantizeConfig.AMAX_HISTORY_LEN,
-            amax_history_len,
-            f"QuantizeConfig.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
-            f" but got {QuantizeConfig.AMAX_HISTORY_LEN}.",
-        )
-
-        QuantizeConfig.finalize()
+class TestHelper(unittest.TestCase):
 
     @unittest.skipIf(not is_fp8_supported, reason=reason)
     def test_update_collections(self):
@@ -61,12 +36,12 @@ def test_update_collections(self):
             "test1": original_val,
             "test2": original_val,
         }
-        updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
+        updated_state = update_collections({"test1": updated_val}, original_state)
         self.assertEqual(updated_state["test1"], updated_val)
         self.assertEqual(updated_state["test2"], original_val)
 
         original_state = flax.core.frozen_dict.FrozenDict(original_state)
-        updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
+        updated_state = update_collections({"test1": updated_val}, original_state)
         self.assertEqual(updated_state["test1"], updated_val)
         self.assertEqual(updated_state["test2"], original_val)
 
@@ -82,8 +57,18 @@ def _compare_delay_scaling(self, ref, test):
         self.assertTrue(ref.amax_history_len == test.amax_history_len)
         self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
 
+    def _compare_current_scaling(self, test):
+        self.assertEqual(QuantizeConfig.MARGIN, test.margin)
+        self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
+        self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)
+
+    def _compare_mxfp8_scaling(self, test):
+        self.assertEqual(QuantizeConfig.MARGIN, test.margin)
+        self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
+        self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)
+
     @unittest.skipIf(not is_fp8_supported, reason=reason)
-    def test_fp8_autocast(self):
+    def test_fp8_autocast_delayed_scaling(self):
         QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
         self._check_defult_state()
 
@@ -107,6 +92,56 @@ def test_fp8_autocast(self):
 
         self._check_defult_state()
 
+    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
+    def test_fp8_autocast_mxfp8_scaling(self):
+        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
+        self._check_defult_state()
+
+        with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()):
+            self.assertFalse(QuantizeConfig.is_fp8_enabled())
+            self._compare_current_scaling(Float8CurrentScaling())
+
+        self._check_defult_state()
+
+        cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3)
+        with fp8_autocast(enabled=True, fp8_recipe=cs):
+            self.assertTrue(QuantizeConfig.is_fp8_enabled())
+            self._compare_current_scaling(cs)
+
+        self._check_defult_state()
+
+        cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
+        with fp8_autocast(enabled=True, fp8_recipe=cs):
+            self.assertTrue(QuantizeConfig.is_fp8_enabled())
+            self._compare_current_scaling(cs)
+
+        self._check_defult_state()
+
+    @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
+    def test_fp8_autocast_mxfp8_scaling(self):
+        QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
+        self._check_defult_state()
+
+        with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()):
+            self.assertFalse(QuantizeConfig.is_fp8_enabled())
+            self._compare_mxfp8_scaling(MXFP8BlockScaling())
+
+        self._check_defult_state()
+
+        bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
+        with fp8_autocast(enabled=True, fp8_recipe=bs):
+            self.assertTrue(QuantizeConfig.is_fp8_enabled())
+            self._compare_mxfp8_scaling(bs)
+
+        self._check_defult_state()
+
+        bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
+        with fp8_autocast(enabled=True, fp8_recipe=bs):
+            self.assertTrue(QuantizeConfig.is_fp8_enabled())
+            self._compare_mxfp8_scaling(bs)
+
+        self._check_defult_state()
+
     @unittest.skipIf(not is_fp8_supported, reason=reason)
     def test_fp8_autocast_with_sharding_resource(self):
         QuantizeConfig.finalize()  # Ensure the testing not affect by previous tests.
diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp
index f6b6ae22c2..7723a59035 100644
--- a/transformer_engine/common/normalization/layernorm/ln_api.cpp
+++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp
@@ -31,19 +31,23 @@ void layernorm_fwd(const Tensor& x,      // BxSxhidden_size
     NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
   }
 
-  NVTE_CHECK(x.data.shape.size() == 2);
-  NVTE_CHECK(gamma.data.shape == beta.data.shape);
-  NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]);
+  NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor.");
+  NVTE_CHECK(gamma.data.shape == beta.data.shape, "Gamma and Beta must have the same shape.");
+  NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0], "Gamma must have the same hidden size.");
+
+  NVTE_CHECK(epsilon >= 0.f, "Epsilon must be non-negative.");
 
-  NVTE_CHECK(epsilon >= 0.f);
+  NVTE_CHECK(z->data.shape == x.data.shape, "Output tensor must have the same shape as x.");
 
-  NVTE_CHECK(z->data.shape == x.data.shape);
+  NVTE_CHECK(mu->data.shape == std::vector<size_t>{x.data.shape[0]},
+             "Mu must be 1D tensor with shape (x.shape[0],).");
+  NVTE_CHECK(mu->data.dtype == DType::kFloat32, "Mu must be a float32 tensor.");
 
-  NVTE_CHECK(mu->data.shape == std::vector<size_t>{x.data.shape[0]});
-  NVTE_CHECK(mu->data.dtype == DType::kFloat32);
+  NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]},
+             "RSigma must be 1D tensor with shape (x.shape[0],).");
+  NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");
 
-  NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{x.data.shape[0]});
-  NVTE_CHECK(rsigma->data.dtype == DType::kFloat32);
+  NVTE_CHECK(gamma.data.dtype == beta.data.dtype, "Gamma and Beta must have the same dtype.");
 
   if (!workspace->data.shape.empty()) {
     CheckInputTensor(x, "x");
diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py
index c27f6f50f7..8f01a08798 100644
--- a/transformer_engine/jax/cpp_extensions/activation.py
+++ b/transformer_engine/jax/cpp_extensions/activation.py
@@ -119,6 +119,11 @@ def abstract(
             f" {x_aval.shape} and act_len {act_len}"
         )
 
+        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
+            "Current tensor scaling is not supported for fused activation and quantization. Please"
+            " do activation in higher-precision then quantize with current tensor scaling."
+        )
+
         out_shape = (*x_aval.shape[:-2], x_aval.shape[-1])  # Exclude act dim
         out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
 
@@ -451,6 +456,12 @@ def abstract(
             f" {x_aval.shape} and act_len {act_len}"
         )
         assert scale_aval.dtype == jnp.float32
+
+        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
+            "Current tensor scaling is not supported for fused dact and quantization. Please do"
+            " dact in higher-precision then quantize with current tensor scaling."
+        )
+
         ir_hidden_size = dz_aval.shape[-1]
         gi_hidden_size = act_len * x_aval.shape[-1]
         assert act_len * ir_hidden_size == gi_hidden_size
@@ -463,7 +474,10 @@ def abstract(
             scaling_mode
         ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
         if is_2x:
-            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
+            if scaling_mode in (
+                ScalingMode.DELAYED_TENSOR_SCALING.value,
+                ScalingMode.CURRENT_TENSOR_SCALING.value,
+            ):
                 colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
             else:
                 colwise_out_shape = out_shape
@@ -669,6 +683,10 @@ def infer_sharding_from_operands(
         x_spec = get_padded_spec(arg_infos[1])
         scale_spec = get_padded_spec(arg_infos[2])
 
+        assert (
+            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
+        ), "Partitioned current tensor scaling is not yet supported."
+
         out_sharding = NamedSharding(
             mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
         )
@@ -937,6 +955,16 @@ def act_lu(
         out = out.reshape(output_shape)
         return out
 
+    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
+        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
+        out = act_lu(
+            x=x.astype(jnp.float32),
+            activation_type=activation_type,
+            quantizer=None,
+        )
+        out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
+        return out
+
     if isinstance(quantizer, DelayedScaleQuantizer):
         scale = quantizer.scale
 
@@ -1012,8 +1040,12 @@ def quantize_dact_dbias(
 
     # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
     if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
-        out = dact_lu(dz, x, activation_type, quantizer=None)
-        return _quantize_dbias_impl(out, quantizer, is_dbias=True, flatten_axis=-2)
+        out = dact_lu(
+            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None
+        )
+        return _quantize_dbias_impl(
+            out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
+        )
 
     is_gated = act_len == 2
     # TE/common does not support DelayedScaling2x for gated-act yet
@@ -1056,6 +1088,20 @@ def quantize_dact_dbias(
             dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
         return output.astype(x.dtype), dbias
 
+    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
+        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
+        out, _ = quantize_dact_dbias(
+            dz=dz.astype(jnp.float32),
+            x=x.astype(jnp.float32),
+            activation_type=activation_type,
+            is_dbias=False,
+            quantizer=None,
+        )
+        out, dbias = _quantize_dbias_impl(
+            out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
+        )
+        return out, dbias
+
     if isinstance(quantizer, DelayedScaleQuantizer):
         scale = quantizer.scale
 
@@ -1095,7 +1141,11 @@ def quantize_dact_dbias(
     )
 
     # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
-    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
+    if (
+        quantizer.scaling_mode
+        in (ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING)
+        and quantizer.is_2x2x()
+    ):
         colwise_scale_inv = rowwise_scale_inv
 
     quantizer.update(updated_amax)
diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py
index 0327542c2f..175f1e61ce 100644
--- a/transformer_engine/jax/cpp_extensions/gemm.py
+++ b/transformer_engine/jax/cpp_extensions/gemm.py
@@ -175,7 +175,7 @@ def _dequantize(x, scale_inv, dq_dtype):
         4,
     ),
 )
-def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
+def __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
     # Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching
     lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype)
     rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype)
@@ -193,12 +193,13 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
     return out_3d
 
 
-def _jax_gemm_delayed_scaling_fp8(
+def _jax_gemm_tensor_scaling_fp8(
     lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
 ):
     """FP8 GEMM for XLA pattern match"""
-    assert (
-        rhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
+    assert rhs.scaling_mode in (
+        ScalingMode.DELAYED_TENSOR_SCALING,
+        ScalingMode.CURRENT_TENSOR_SCALING,
     ), "rhs does not have delayed tensor scaling mode"
 
     (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
@@ -216,7 +217,7 @@ def _jax_gemm_delayed_scaling_fp8(
     precision = (
         jax.lax.Precision.HIGHEST if QuantizeConfig.FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT
     )
-    out_3d = __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision)
+    out_3d = __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision)
 
     # Reshape [B, M, N] -> [..., M, N]
     out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
@@ -291,8 +292,11 @@ def _jax_gemm(
 
     def _jax_gemm_fp8_impl(lhs, rhs):
 
-        if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
-            return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums)
+        if lhs.scaling_mode in (
+            ScalingMode.DELAYED_TENSOR_SCALING,
+            ScalingMode.CURRENT_TENSOR_SCALING,
+        ):
+            return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums)
 
         if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
             return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
@@ -403,7 +407,10 @@ def grouped_gemm(
             rhs_shape = rhs.data.shape
             out_dtype = lhs.dq_dtype
             # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout
-            if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
+            if lhs.scaling_mode in (
+                ScalingMode.DELAYED_TENSOR_SCALING,
+                ScalingMode.CURRENT_TENSOR_SCALING,
+            ):
                 assert not (
                     lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
                 ), "FP8 GEMM does not support E5M2 * E5M2"
@@ -430,7 +437,10 @@ def grouped_gemm(
         if scaling_mode == ScalingMode.NO_SCALING:
             lhs_3d = _shape_normalization(lhs, lhs_dn)
             rhs_3d = _shape_normalization(rhs, rhs_dn)
-        elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
+        elif scaling_mode in (
+            ScalingMode.DELAYED_TENSOR_SCALING,
+            ScalingMode.CURRENT_TENSOR_SCALING,
+        ):
             lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
             rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
         elif scaling_mode == ScalingMode.MXFP8_1D_SCALING:
@@ -473,7 +483,7 @@ def grouped_gemm(
         if scaling_mode == ScalingMode.NO_SCALING:
             lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
             rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
-        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
+        if scaling_mode in (ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING):
             lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1))
             rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1))
         if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py
index d64104ac27..a74f0f9088 100644
--- a/transformer_engine/jax/cpp_extensions/misc.py
+++ b/transformer_engine/jax/cpp_extensions/misc.py
@@ -216,7 +216,8 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1,
     """
     should_apply_war = (
         quantizer is not None
-        and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
+        and quantizer.scaling_mode
+        in (ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING)
         and quantizer.is_2x2x()
     )
     if not should_apply_war:
diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py
index 388d4f17ee..c079edbee1 100644
--- a/transformer_engine/jax/cpp_extensions/normalization.py
+++ b/transformer_engine/jax/cpp_extensions/normalization.py
@@ -26,6 +26,7 @@
     te_dtype_to_jax_dtype,
     NamedSharding,
 )
+from .quantization import _quantize_dbias_impl
 from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
 from ..quantize import ScaledTensor, ScaledTensorFactory
 from ..quantize import (
@@ -100,6 +101,11 @@ def abstract(
         assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
         assert scale_aval is None or scale_aval.dtype == jnp.float32
 
+        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
+            "Current tensor scaling is not supported for fused norm and quantization. Please do"
+            " norm in higher-precision then quantize with current tensor scaling."
+        )
+
         mu_rsigama_dtype = jnp.float32
 
         if norm_type == NVTE_Norm_Type.LayerNorm:
@@ -731,6 +737,8 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None)
     JAX native layernorm implementation
     """
     x_ = jnp.asarray(x, jnp.float32)
+    gamma = gamma.astype(jnp.float32)
+    beta = beta.astype(jnp.float32)
     mean = jnp.mean(x_, axis=-1, keepdims=True)
     var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
     rsigma = jax.lax.rsqrt(var + epsilon)
@@ -752,6 +760,7 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
     JAX native rmsnorm implementation
     """
     x_ = jnp.asarray(x, jnp.float32)
+    gamma = gamma.astype(jnp.float32)
     var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
     rsigma = jax.lax.rsqrt(var + epsilon)
     normed_input = x_ * rsigma
@@ -827,9 +836,26 @@ def layernorm_fwd(
         )
         return output, mu, rsigma
 
+    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
+        # Current scaling does not support fused operations. Perform norm in higher precision then quantize after.
+        out, mu, rsigma = layernorm_fwd(
+            x=x.astype(jnp.float32),
+            gamma=gamma.astype(jnp.float32),
+            beta=beta.astype(jnp.float32),
+            zero_centered_gamma=zero_centered_gamma,
+            epsilon=epsilon,
+            quantizer=None,
+        )
+        out = quantizer.quantize(out, dq_dtype=x.dtype)
+        # out,_ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
+        return out, mu, rsigma
+
     is_2x2x = quantizer.is_2x2x()
     # TE/common normalization doesn't support 2x delayed scaling
-    if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
+    if quantizer.is_2x2x() and quantizer.scaling_mode in (
+        ScalingMode.DELAYED_TENSOR_SCALING,
+        ScalingMode.CURRENT_TENSOR_SCALING,
+    ):
         is_2x2x = False
     (
         rowwise_casted_output,
@@ -857,7 +883,10 @@ def layernorm_fwd(
     quantizer.update(updated_amax)
 
     # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
-    if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
+    if quantizer.is_2x2x() and quantizer.scaling_mode in (
+        ScalingMode.DELAYED_TENSOR_SCALING,
+        ScalingMode.CURRENT_TENSOR_SCALING,
+    ):
         colwise_casted_output = jnp.transpose(
             rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
         )
@@ -1009,9 +1038,24 @@ def rmsnorm_fwd(
         )
         return output, rsigma
 
+    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
+        # Current scaling does not support fused operations. Perform norm in higher precision then quantize after.
+        out, rsigma = rmsnorm_fwd(
+            x=x,
+            gamma=gamma,
+            zero_centered_gamma=zero_centered_gamma,
+            epsilon=epsilon,
+            quantizer=None,
+        )
+        out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
+        return out, rsigma
+
     is_2x2x = quantizer.is_2x2x()
     # TE/common normalization doesn't support 2x delayed scaling
-    if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
+    if quantizer.is_2x2x() and quantizer.scaling_mode in (
+        ScalingMode.DELAYED_TENSOR_SCALING,
+        ScalingMode.CURRENT_TENSOR_SCALING,
+    ):
         is_2x2x = False
     (
         rowwise_casted_output,
@@ -1039,7 +1083,10 @@ def rmsnorm_fwd(
     quantizer.update(updated_amax)
 
     # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
-    if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
+    if quantizer.is_2x2x() and quantizer.scaling_mode in (
+        ScalingMode.DELAYED_TENSOR_SCALING,
+        ScalingMode.CURRENT_TENSOR_SCALING,
+    ):
         colwise_casted_output = jnp.transpose(
             rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
         )
diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py
index 2911b5a420..93333d358f 100644
--- a/transformer_engine/jax/cpp_extensions/quantization.py
+++ b/transformer_engine/jax/cpp_extensions/quantization.py
@@ -93,7 +93,10 @@ def abstract(
         ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
 
         if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
-            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
+            if scaling_mode in (
+                ScalingMode.DELAYED_TENSOR_SCALING.value,
+                ScalingMode.CURRENT_TENSOR_SCALING.value,
+            ):
                 colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
             else:
                 colwise_out_shape = out_shape
@@ -298,6 +301,11 @@ def infer_sharding_from_operands(
         result_infos,
     ):
         del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer)  # Unused.
+
+        assert (
+            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
+        ), "Current tensor scaling is not yet supported for multi-GPU partitioning."
+
         x_spec = get_padded_spec(arg_infos[0])
         scale_spec = get_padded_spec(arg_infos[1])
         out_sharding = NamedSharding(
@@ -370,6 +378,11 @@ def partition(
         result_infos,
     ):
         del result_infos, is_outer
+
+        assert (
+            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
+        ), "Current tensor scaling is not yet supported for multi-GPU partitioning."
+
         x_spec = get_padded_spec(arg_infos[0])
         scale_spec = get_padded_spec(arg_infos[1])
         out_sharding = NamedSharding(
@@ -592,7 +605,11 @@ def _quantize_dbias_impl(
         is_outer=True,
     )
     # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
-    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
+    if (
+        quantizer.scaling_mode
+        in (ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING)
+        and quantizer.is_2x2x()
+    ):
         colwise_scale_inv = rowwise_scale_inv
 
     quantizer.update(updated_amax)
diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp
index fc7f231f34..9487c2bd86 100644
--- a/transformer_engine/jax/csrc/extensions/activation.cpp
+++ b/transformer_engine/jax/csrc/extensions/activation.cpp
@@ -44,6 +44,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
   auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
   output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
 
+  NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
+             "Current tensor scaling does not support fused operations. Please call this primitive "
+             "in higher-precision then quantize with current scaling.");
+
   if (is_fp8_dtype(out_dtype)) {
     if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
       NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
@@ -152,6 +156,10 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
   auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
   auto dbias_shape = std::vector<size_t>{hidden_size};
 
+  NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
+             "Current tensor scaling does not support fused operations. Please call this primitive "
+             "in higher-precision then quantize with current scaling.");
+
   // Evil hack to specify TE impl
   // Note: nvte_quantize_dbias_dgelu chooses its internal impl based
   // on what pointers are allocated, e.g. whether to output with
@@ -219,6 +227,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
   auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
   auto flatten_axis = output_buf->dimensions().size() - 2;  // output has act axis
 
+  NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
+             "Current tensor scaling does not support fused operations. Please call this primitive "
+             "in higher-precision then quantize with current scaling.");
+
   auto *output = output_buf->untyped_data();
   auto *colwise_output = colwise_output_buf->untyped_data();
   auto *dbias = dbias_buf->untyped_data();
diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp
index d4b9bf720e..1653adb8ec 100644
--- a/transformer_engine/jax/csrc/extensions/gemm.cpp
+++ b/transformer_engine/jax/csrc/extensions/gemm.cpp
@@ -95,7 +95,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
     lhs_i.set_rowwise_data(static_cast<void *>(lhs_ptr), lhs_dtype, lhs_shape);
     rhs_i.set_rowwise_data(static_cast<void *>(rhs_ptr), rhs_dtype, rhs_shape);
 
-    if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
+    if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
       lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat32,
                                   std::vector<size_t>{1});
       rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat32,
@@ -190,8 +190,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten,
 
   return GroupedGemmImpl(lhs_ptr, lhs_dtype, lhs_sinv_ptr, lhs_sinv_dtype, rhs_ptr, rhs_dtype,
                          rhs_sinv_ptr, rhs_sinv_dtype, bias_ptr, bias_dtype, out_ptr, out_dtype,
-                         workspace_ptr, workspace_size, num_gemms, dim_list_ptr, scaling_mode,
-                         stream);
+                         workspace_ptr, workspace_size, num_gemms, dim_list_ptr,
+                         scaling_mode, stream);
 }
 
 XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h
index f7577c24f3..4c3d29ef0a 100644
--- a/transformer_engine/jax/csrc/extensions/misc.h
+++ b/transformer_engine/jax/csrc/extensions/misc.h
@@ -44,6 +44,7 @@ enum class JAXX_Scaling_Mode : int64_t {
   NO_SCALING = 0,
   DELAYED_TENSOR_SCALING = 1,
   MXFP8_1D_SCALING = 2,
+  CURRENT_TENSOR_SCALING = 3,
 };
 
 static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
@@ -57,6 +58,9 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
     case JAXX_Scaling_Mode::MXFP8_1D_SCALING:
       return NVTEScalingMode::NVTE_MXFP8_1D_SCALING;
       break;
+    case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING:
+      return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
+      break;
     default:
       NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
       break;
diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp
index e23e42f528..f3aa237996 100644
--- a/transformer_engine/jax/csrc/extensions/normalization.cpp
+++ b/transformer_engine/jax/csrc/extensions/normalization.cpp
@@ -98,7 +98,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
   auto workspace_shape = std::vector<size_t>{workspace_size};
 
   auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
-  auto gamma_tensor = TensorWrapper(gamma, gamma_shape, in_dtype);
+  auto gamma_tensor = TensorWrapper(gamma, gamma_shape, w_dtype);
 
   auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);
   auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin;
@@ -107,6 +107,10 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
   auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
   output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape);
 
+  NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
+             "Current tensor scaling does not support fused operations. Please call this primitive "
+             "in higher-precision then quantize with current scaling.");
+
   if (is_fp8_dtype(out_dtype)) {
     output_tensor.set_rowwise_scale_inv(
         scale_inv_buf->untyped_data(),
@@ -134,6 +138,8 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
   }
 
   if (_norm_type == NVTE_Norm_Type::LayerNorm) {
+    NVTE_CHECK(w_dtype == convert_ffi_datatype_to_te_dtype(beta_buf.element_type()),
+               "gamma and beta must have the same data type.");
     auto beta_tensor = TensorWrapper(beta, gamma_shape, w_dtype);
     auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
 
diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp
index 5c165cccb6..3866aca62f 100644
--- a/transformer_engine/jax/csrc/extensions/pybind.cpp
+++ b/transformer_engine/jax/csrc/extensions/pybind.cpp
@@ -142,6 +142,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
       .value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING)
       .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
       .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING)
+      .value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING)
       .export_values();
 
   pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp
index 481dbd7cdf..d544972c97 100644
--- a/transformer_engine/jax/csrc/extensions/quantization.cpp
+++ b/transformer_engine/jax/csrc/extensions/quantization.cpp
@@ -7,6 +7,7 @@
 
 #include "extensions.h"
 #include "transformer_engine/cast.h"
+#include "transformer_engine/recipe.h"
 #include "xla/ffi/api/c_api.h"
 
 namespace transformer_engine {
@@ -107,12 +108,15 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
   auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
   auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
 
+  bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
+                                 scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
+
   if (quantize_layout == QuantizeLayout::ROWWISE ||
       quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
     output_tensor.set_rowwise_data(output, out_dtype, output_shape);
 
     if (is_fp8_dtype(out_dtype)) {
-      if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
+      if (is_tensor_scaling) {
         float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
         float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
         NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
@@ -142,11 +146,9 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
                           : output_shape;
     output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
     // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
-    auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
-                        ? scale_inv_buf
-                        : colwise_scale_inv_buf;
+    auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf;
 
-    if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
+    if (is_tensor_scaling) {
       output_tensor.set_columnwise_scale_inv(
           tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
           std::vector<size_t>{1});
@@ -159,6 +161,21 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
     }
   }
 
+  if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
+    nvte_compute_amax(input_tensor.data(),   // input data
+                      output_tensor.data(),  // output data (for amax)
+                      stream);
+
+    QuantizationConfigWrapper quant_config;
+    /** defaults for now, TODO(Jeremy) move to parameter */
+    bool force_pow_2_scales = false;
+    float amax_epsilon = 0.0;
+    quant_config.set_force_pow_2_scales(force_pow_2_scales);
+    quant_config.set_amax_epsilon(amax_epsilon);
+    nvte_compute_scale_from_amax(output_tensor.data(), quant_config, stream);
+    output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
+  }
+
   auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
   auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
 
diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py
index d68eb3c6c2..d43c61c9fc 100644
--- a/transformer_engine/jax/quantize/dequantizer.py
+++ b/transformer_engine/jax/quantize/dequantizer.py
@@ -85,6 +85,7 @@ def _dq_func_block_scaling(scaled_tensor):
 
     funcs = {
         ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling,
+        ScalingMode.CURRENT_TENSOR_SCALING: _dq_func_tensor_scaling,
         ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling,
     }
 
diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py
index 98f280b9a9..3a2d6a57d4 100644
--- a/transformer_engine/jax/quantize/helper.py
+++ b/transformer_engine/jax/quantize/helper.py
@@ -94,7 +94,7 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
         A tuple of (bool, str) indicating support and any error message
     """
     gpu_arch = get_device_compute_capability(gpu_id)
-    if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
+    if scaling_mode in (ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING):
         return _check_delayed_scaling_fp8_support(gpu_arch)
     if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
         return _check_block_scaling_fp8_support(gpu_arch)
@@ -182,6 +182,8 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
         return ScalingMode.DELAYED_TENSOR_SCALING
     if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
         return ScalingMode.MXFP8_1D_SCALING
+    if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
+        return ScalingMode.CURRENT_TENSOR_SCALING
     raise ValueError("Invalid fp8_recipe!")
 
 
@@ -309,6 +311,30 @@ def finalize() -> None:
         QuantizeConfig.finalize()
 
 
+class CurrentScalingQuantizeConfig:
+    """Configuration class for current scaling FP8 recipe.
+
+    This class provides specific initialization and finalization for current scaling
+    FP8 quantization mode.
+    """
+
+    @staticmethod
+    def initialize(fp8_recipe: recipe.Recipe) -> None:
+        """Initialize current scaling FP8 configuration.
+
+        Args:
+            fp8_recipe: The FP8 recipe to use for initialization
+        """
+        cls = QuantizeConfig
+        cls.initialize(fp8_recipe)
+        cls.AMAX_HISTORY_LEN = 0
+
+    @staticmethod
+    def finalize() -> None:
+        """Reset the current scaling configuration."""
+        QuantizeConfig.finalize()
+
+
 class BlockScalingQuantizeConfig:
     """Configuration class for block scaling FP8 recipe.
 
@@ -385,6 +411,8 @@ def fp8_autocast(
     Config = DelayedScalingQuantizeConfig
     if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
         Config = BlockScalingQuantizeConfig
+    if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
+        Config = CurrentScalingQuantizeConfig
 
     try:
         with global_shard_guard(mesh_resource):
diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py
index b57043a034..becacde65d 100644
--- a/transformer_engine/jax/quantize/quantizer.py
+++ b/transformer_engine/jax/quantize/quantizer.py
@@ -27,6 +27,7 @@
     "QuantizeLayout",
     "Quantizer",
     "QuantizerSet",
+    "CurrentScaleQuantizer",
     "DelayedScaleQuantizer",
     "BlockScaleQuantizer",
     "QuantizerFactory",
@@ -159,34 +160,26 @@ def get_scale_dtype(self):
 
 @register_pytree_node_class
 @dataclass
-class DelayedScaleQuantizer(Quantizer):
-    """Quantizer implementation using delayed scaling.
+class CurrentScaleQuantizer(Quantizer):
+    """Quantizer implementation using current scaling.
 
-    This quantizer uses delayed scaling mode with float32 scales and maintains
-    a history of maximum absolute values for dynamic scaling.
+    This quantizer uses current scaling mode with float32 scales
 
     Attributes:
         scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
         q_layout: Quantization axis (default: ROWWISE_COLWISE)
-        scale: Current scaling factor
-        amax_history: History of maximum absolute values
     """
 
-    scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
+    scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING
     q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
 
-    scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
-    amax_history: jnp.ndarray = field(
-        default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32)
-    )
-
     def tree_flatten(self):
         """Flatten the quantizer for JAX tree operations.
 
         Returns:
             Tuple of (children, aux_data) for tree operations
         """
-        children = (self.scale, self.amax_history)
+        children = ()
         aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
         return (children, aux_data)
 
@@ -217,15 +210,18 @@ def _quantize_func(
             x: Input tensor to quantize
             is_colwise: Whether to use column-wise quantization
             dq_dtype: Data type for dequantized values
-            flatten_axis: The quantization axis for the tensor
+
         Returns:
             A ScaledTensor1x containing the quantized data
         """
         dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
 
-        compute_dtype = self.scale.dtype
+        compute_dtype = self.scaling_mode.get_scale_dtype()
         dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
-        scaled_x = x.astype(compute_dtype) * self.scale
+        amax = jnp.max(jnp.abs(x)).reshape((1,))
+        fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
+        scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
+        scaled_x = x.astype(compute_dtype) * scale
 
         # quantize() in the old dot.py do this way, leave this code block here for future debugging
         # compute_dtype = x.dtype
@@ -233,8 +229,7 @@ def _quantize_func(
         # scaled_x = x * self.scale.astype(compute_dtype)
 
         clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
-        scale_inv = 1.0 / self.scale
-        self.update(jnp.max(jnp.abs(x)).reshape((1,)))
+        scale_inv = 1.0 / scale
         return ScaledTensorFactory.create_1x(
             data=clipped_scaled_x,
             scale_inv=scale_inv,
@@ -294,6 +289,75 @@ def quantize(
             return colwise_tensor
         return rowwise_tensor
 
+
+@register_pytree_node_class
+@dataclass
+class DelayedScaleQuantizer(CurrentScaleQuantizer):
+    """Quantizer implementation using delayed scaling.
+
+    This quantizer uses delayed scaling mode with float32 scales and maintains
+    a history of maximum absolute values for dynamic scaling.
+
+    Attributes:
+        scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
+        q_layout: Quantization axis (default: ROWWISE_COLWISE)
+        scale: Current scaling factor
+        amax_history: History of maximum absolute values
+    """
+
+    scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
+    q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
+
+    scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
+    amax_history: jnp.ndarray = field(
+        default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32)
+    )
+
+    def tree_flatten(self):
+        """Flatten the quantizer for JAX tree operations.
+
+        Returns:
+            Tuple of (children, aux_data) for tree operations
+        """
+        children = (self.scale, self.amax_history)
+        aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
+        return (children, aux_data)
+
+    def _quantize_func(
+        self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
+    ) -> ScaledTensor1x:
+        """Quantize function helper for delayed scaling FP8.
+
+        Args:
+            x: Input tensor to quantize
+            is_colwise: Whether to use column-wise quantization
+            dq_dtype: Data type for dequantized values
+            flatten_axis: The quantization axis for the tensor
+        Returns:
+            A ScaledTensor1x containing the quantized data
+        """
+        dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
+
+        compute_dtype = self.scale.dtype
+        dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
+        scaled_x = x.astype(compute_dtype) * self.scale
+
+        # quantize() in the old dot.py do this way, leave this code block here for future debugging
+        # compute_dtype = x.dtype
+        # dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
+        # scaled_x = x * self.scale.astype(compute_dtype)
+
+        clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
+        scale_inv = 1.0 / self.scale
+        self.update(jnp.max(jnp.abs(x)).reshape((1,)))
+        return ScaledTensorFactory.create_1x(
+            data=clipped_scaled_x,
+            scale_inv=scale_inv,
+            scaling_mode=self.scaling_mode,
+            dq_dtype=dq_dtype,
+            flatten_axis=flatten_axis,
+        )
+
     @staticmethod
     @jax.jit
     def _update_amax_history(amax_history, new_amax):
@@ -531,6 +595,7 @@ class QuantizerFactory:
 
     quantizer_type_map = {
         ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
+        ScalingMode.CURRENT_TENSOR_SCALING: CurrentScaleQuantizer,
         ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer,
     }
 
diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py
index 34f63a994c..edbf38e99d 100644
--- a/transformer_engine/jax/quantize/scaling_modes.py
+++ b/transformer_engine/jax/quantize/scaling_modes.py
@@ -60,10 +60,10 @@ def get_scale_shape(
         """
 
 
-class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
-    """Implementation for delayed scaling mode.
+class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
+    """Implementation for current scaling mode.
 
-    This implementation provides metadata for delayed scaling mode, including scale data type and shape.
+    This implementation provides metadata for current scaling mode, including scale data type and shape.
     """
 
     def get_scale_dtype(self) -> jnp.dtype:
@@ -96,6 +96,13 @@ def get_scale_shape(
         return (1,)
 
 
+class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl):
+    """Implementation for delayed scaling mode.
+
+    This implementation provides metadata for delayed scaling mode, including scale data type and shape.
+    """
+
+
 class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
     """Implementation for block scaling mode.
 
@@ -226,12 +233,14 @@ class ScalingMode(Enum):
     This class defines the available scaling modes for tensor quantization:
     - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
     - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
+    - CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales
     - NO_SCALING: No scaling applied
     """
 
     NO_SCALING = JAXX_Scaling_Mode.NO_SCALING
     DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING
     MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING
+    CURRENT_TENSOR_SCALING = JAXX_Scaling_Mode.CURRENT_TENSOR_SCALING
 
     def _get_impl(self) -> ScalingModeMetadataImpl:
         """Get the implementation for this scaling mode.
@@ -329,5 +338,6 @@ def tree_unflatten(cls, aux_data, _children):
     ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
     ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
     # WAR
+    ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(),
     ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(),
 }