Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -866,11 +866,12 @@ void FusedActDequantInferMeta(const MetaTensor& x,
x.dtype()));

PADDLE_ENFORCE_EQ(
x_scale.dtype(),
DataType::FLOAT32,
common::errors::InvalidArgument(
"The data type of X_scale should be FLOAT32, but received %s.",
x_scale.dtype()));
x_scale.dtype() == DataType::FLOAT32 ||
x_scale.dtype() == DataType::INT32,
true,
common::errors::InvalidArgument("The data type of X_scale should be "
"FLOAT32 or INT32, but received %s.",
x_scale.dtype()));

PADDLE_ENFORCE_EQ(x_dims.size(),
2,
Expand All @@ -893,6 +894,31 @@ void FusedActDequantInferMeta(const MetaTensor& x,
common::errors::InvalidArgument(
"The cols of X should be positive, but received %d.", cols));

auto scale_dims = x_scale.dims();
int64_t scale_cols_expected = (cols + 127) / 128;
if (x_scale.dtype() == DataType::INT32) {
scale_cols_expected = (scale_cols_expected + 3) / 4;
}

// Check scale shape assuming it is [rows, scale_cols] or flattened
if (scale_dims.size() == 2) {
PADDLE_ENFORCE_EQ(scale_dims[0],
rows,
common::errors::InvalidArgument(
"The rows of X_scale should be equal to rows of X"));
PADDLE_ENFORCE_EQ(
scale_dims[1],
scale_cols_expected,
common::errors::InvalidArgument("The cols of X_scale should be %d",
scale_cols_expected));
} else if (scale_dims.size() == 1) {
PADDLE_ENFORCE_EQ(
scale_dims[0],
rows * scale_cols_expected,
common::errors::InvalidArgument("The numel of X_scale should be %d",
rows * scale_cols_expected));
}

out->set_dims(x_dims);
out->set_dtype(DataType::BFLOAT16);
out->set_layout(x.layout());
Expand Down
47 changes: 37 additions & 10 deletions paddle/phi/kernels/fusion/gpu/fused_act_dequant_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,22 @@ struct alignas(16) VectorType {
T data[N];
};

template <typename ScaleT, bool using_ue8m0_scale>
__device__ __forceinline__ float LoadScale(const ScaleT* ptr, int64_t idx) {
if constexpr (using_ue8m0_scale) {
int packed_scale = reinterpret_cast<const int*>(ptr)[idx / 4];
int scale_offset = idx % 4;
uint8_t scale_u8 = (packed_scale >> (scale_offset * 8)) & 0xFF;
int val_as_int = static_cast<int>(scale_u8) << 23;
return __int_as_float(val_as_int);
} else {
return ptr[idx];
}
}

template <typename ScaleT, bool using_ue8m0_scale>
__global__ void FusedActDequant(const phi::float8_e4m3fn* __restrict__ Xin,
const float* __restrict__ Xscale,
const ScaleT* __restrict__ Xscale,
phi::bfloat16* __restrict__ out,
const int rows,
const int cols) {
Expand All @@ -51,7 +65,7 @@ __global__ void FusedActDequant(const phi::float8_e4m3fn* __restrict__ Xin,

int64_t scale_idx =
(int64_t)this_row_idx * (int64_t)Xscale_stride + (x_offset / 128);
float this_scale = Xscale[scale_idx];
float this_scale = LoadScale<ScaleT, using_ue8m0_scale>(Xscale, scale_idx);

VectorType<__nv_bfloat16, vector_size> out_vec;

Expand All @@ -73,8 +87,12 @@ __global__ void FusedActDequant(const phi::float8_e4m3fn* __restrict__ Xin,
int64_t idx = X_idx + tid;
if (tid < remaining_elements) {
float X_value = static_cast<float>(Xin[idx]);
X_value *= Xscale[(int64_t)this_row_idx * (int64_t)Xscale_stride +
(x_offset / 128)];

int64_t scale_idx =
(int64_t)this_row_idx * (int64_t)Xscale_stride + (x_offset / 128);
float this_scale =
LoadScale<ScaleT, using_ue8m0_scale>(Xscale, scale_idx);
X_value *= this_scale;
out[idx] = __float2bfloat16(X_value);
}
}
Expand All @@ -97,12 +115,21 @@ void FusedActDequantKernel(const Context& dev_ctx,
dim3 grid(rows);
dim3 block(256);

FusedActDequant<<<grid, block, 0, dev_ctx.stream()>>>(
x.data<phi::float8_e4m3fn>(),
x_scale.data<float>(),
out->data<phi::bfloat16>(),
rows,
cols);
if (x_scale.dtype() == phi::DataType::FLOAT32) {
FusedActDequant<float, false>
<<<grid, block, 0, dev_ctx.stream()>>>(x.data<phi::float8_e4m3fn>(),
x_scale.data<float>(),
out->data<phi::bfloat16>(),
rows,
cols);
} else if (x_scale.dtype() == phi::DataType::INT32) {
FusedActDequant<int, true>
<<<grid, block, 0, dev_ctx.stream()>>>(x.data<phi::float8_e4m3fn>(),
x_scale.data<int>(),
out->data<phi::bfloat16>(),
rows,
cols);
}

#ifdef PADDLE_WITH_CUDA_CHECK
auto cuda_error = cudaGetLastError();
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/incubate/nn/functional/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def fused_act_dequant(
Args:
x (Tensor): Input quantized tensor with dtype float8_e4m3fn and shape [M, N]. This tensor contains the quantized
activations from previous layers.
x_scale (Tensor): Dequantization scale tensor with dtype float32 and shape [M, (N + 127) // 128].
Each scale value corresponds to a 128-column block in the input tensor.
x_scale (Tensor): Dequantization scale tensor with dtype float32 and shape [M, (N + 127) // 128] or int32 and shape [M, (N + 511) // 512].
Each scale value corresponds to a 128-column in the input tensor.

Returns:
Tensor. Dequantized output tensor with dtype bfloat16 and shape [M, N]. The values are
Expand Down
131 changes: 131 additions & 0 deletions test/legacy_test/test_fused_act_dequant_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,98 @@
import paddle


def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y


def align(x: int, y: int) -> int:
return ceil_div(x, y) * y


def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Align x to TMA-required size.
Args:
x: size in elements
element_size: size of each element in bytes
Returns:
Aligned size in elements
"""
kNumTMAAlignmentBytes = 16
assert kNumTMAAlignmentBytes % element_size == 0
return align(x, kNumTMAAlignmentBytes // element_size)


def ceil_to_ue8m0_paddle(x: paddle.Tensor):
"""
x > 0
return 2 ^ ceil(log2(x))
"""
# log2(x)
log2_x = paddle.log(x) / paddle.log(paddle.to_tensor(2.0, dtype=x.dtype))
# ceil
ceil_log2_x = paddle.ceil(log2_x)
# 2^k
return paddle.pow(paddle.to_tensor(2.0, dtype=x.dtype), ceil_log2_x)


def _get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(
x: paddle.Tensor,
):
assert x.dtype == paddle.float and x.dim() in (2, 3)

ue8m0_tensor = (x.view(paddle.int) >> 23).to(paddle.uint8)

mn, k = x.shape[-2], x.shape[-1]
remove_dim = False

if x.dim() == 2:
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]

aligned_mn = get_tma_aligned_size(mn, 4)
aligned_k = align(k, 4)

padded = paddle.zeros(
(b, aligned_mn, aligned_k), device=x.device, dtype=paddle.uint8
)
padded[:, :mn, :k] = ue8m0_tensor

padded = (
padded.view(-1)
.view(dtype=paddle.int)
.view(b, aligned_mn, aligned_k // 4)
)

transposed = paddle.zeros(
(b, aligned_k // 4, aligned_mn), device=x.device, dtype=paddle.int
).mT
transposed[:, :, :] = padded

aligned_x = transposed[:, :mn, :]

return aligned_x.squeeze(0) if remove_dim else aligned_x


def transform_scale_ue8m0(sf, mn, weight_block_size=None):
get_mn_major_tma_aligned_packed_ue8m0_tensor = (
_get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl
)
if weight_block_size:
assert weight_block_size == [128, 128]
sf = sf.index_select(-2, paddle.arange(mn, device=sf.device) // 128)
sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
return sf


def quant_ref(x_scale_fp32, mn, weight_block_size=None):
x_scale_fp32_ = ceil_to_ue8m0_paddle(x_scale_fp32)
ref_e8m0_scale = transform_scale_ue8m0(
x_scale_fp32_, mn=mn, weight_block_size=weight_block_size
)
return ref_e8m0_scale


class TestActQuantDequant(unittest.TestCase):
"""Test cases for activation quantization and dequantization functions."""

Expand Down Expand Up @@ -227,6 +319,45 @@ def _test_single_case(
f"Results don't match for shape [{height}, {width}]: {e!s}"
)

def test_ue8m0_support(self):
"""Test ue8m0 support in fused_act_dequant."""
if not hasattr(paddle.incubate.nn.functional, 'fused_act_dequant'):
self.skipTest(
"fused_act_dequant not available in this Paddle version"
)

height, width = 4096, 7168
x = paddle.clip(
paddle.randn([height, width]).astype("bfloat16"), min=-50, max=50
)
x_fp8, scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
x, quant_method="1x128", output_scale_transpose=False
)

# 1. Align scale to ue8m0 values (2^k) in float32
scale_aligned_fp32 = ceil_to_ue8m0_paddle(scale)

# 2. Pack aligned scale to ue8m0 format (int32)
scale_packed_int32 = transform_scale_ue8m0(
scale_aligned_fp32, mn=height
)

# 3. Run fused_act_dequant with aligned float32 scale
out_fp32 = paddle.incubate.nn.functional.fused_act_dequant(
x_fp8, scale_aligned_fp32
)

# 4. Run fused_act_dequant with packed int32 scale
out_ue8m0 = paddle.incubate.nn.functional.fused_act_dequant(
x_fp8, scale_packed_int32
)

# 5. Compare
out_fp32_np = out_fp32.numpy()
out_ue8m0_np = out_ue8m0.numpy()

np.testing.assert_allclose(out_fp32_np, out_ue8m0_np, rtol=0, atol=0)

def test_invalid_inputs(self):
"""Test error handling for invalid inputs."""
# Test non-divisible block size
Expand Down
Loading