2222#include " ../transpose/cast_transpose.h"
2323#include " ../util/vectorized_pointwise.h"
2424#include " ../utils.cuh"
25+ #include " cast_kernels_spec.cuh"
2526#include " math.h"
2627#include " ptx.cuh"
2728#include " transformer_engine/transformer_engine.h"
28- #include " cast_kernels_spec.cuh"
2929
3030namespace transformer_engine {
3131
@@ -1083,31 +1083,21 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
10831083 output->dtype (), OType,
10841084
10851085 if (spec::hasSpec<IS_DBIAS, IS_DACT, IS_ACT, IType, OType>()) {
1086-
10871086 switch (scaling_type) {
10881087 case ScalingType::ROWWISE: {
10891088 using traits = spec::CastTraits<IType, OType, true , false >;
10901089 auto kernel = spec::cast_mxfp8_kernel<traits>;
10911090
1092- cudaFuncSetAttribute (
1093- kernel,
1094- cudaFuncAttributeMaxDynamicSharedMemorySize,
1095- traits::smem);
1091+ cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
1092+ traits::smem);
10961093
1097- dim3 block (traits::threadLayout::num,
1098- traits::warpLayout::N,
1099- traits::warpLayout::M);
1094+ dim3 block (traits::threadLayout::num, traits::warpLayout::N, traits::warpLayout::M);
11001095 dim3 grid ((cols + traits::blockDimN - 1 ) / traits::blockDimN,
11011096 (rows + traits::blockDimM - 1 ) / traits::blockDimM);
11021097 kernel<<<grid, block, traits::smem, stream>>> (
1103- reinterpret_cast <typename traits::IType *>(input.data .dptr ),
1104- reinterpret_cast <typename traits::OType *>(output->data .dptr ),
1105- scales_rowwise_ptr,
1106- rows,
1107- cols,
1108- scale_stride_rowwise,
1109- scale_stride_colwise
1110- );
1098+ reinterpret_cast <typename traits::IType *>(input.data .dptr ),
1099+ reinterpret_cast <typename traits::OType *>(output->data .dptr ),
1100+ scales_rowwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
11111101
11121102 break ;
11131103 }
@@ -1119,55 +1109,35 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
11191109 using traits = spec::CastTraits<IType, OType, true , true >;
11201110 auto kernel = spec::cast_mxfp8_kernel<traits>;
11211111
1122- cudaFuncSetAttribute (
1123- kernel,
1124- cudaFuncAttributeMaxDynamicSharedMemorySize,
1125- traits::smem
1126- );
1112+ cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
1113+ traits::smem);
11271114 // TMA for loading, so that we don't need STS for transposing
11281115 alignas (64 ) CUtensorMap tensor_map_input{};
11291116 constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
1130- create_2D_tensor_map (tensor_map_input,
1131- input.data ,
1132- traits::input_swizzle_pattern,
1133- rows, cols,
1134- traits::blockIterDim::M, traits::blockIterDim::N,
1135- /* stride_elems=*/ cols,
1136- /* offset_elems=*/ 0 ,
1137- input_type_bit_size);
1117+ create_2D_tensor_map (tensor_map_input, input.data , traits::input_swizzle_pattern,
1118+ rows, cols, traits::blockIterDim::M, traits::blockIterDim::N,
1119+ /* stride_elems=*/ cols,
1120+ /* offset_elems=*/ 0 , input_type_bit_size);
11381121 alignas (64 ) CUtensorMap tensor_map_rowwise_output{};
11391122 alignas (64 ) CUtensorMap tensor_map_colwise_output{};
11401123 constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
1141- create_2D_tensor_map (tensor_map_rowwise_output,
1142- output->data ,
1143- traits::output_swizzle_pattern,
1144- rows, cols,
1124+ create_2D_tensor_map (tensor_map_rowwise_output, output->data ,
1125+ traits::output_swizzle_pattern, rows, cols,
11451126 traits::blockIterDim::M, traits::blockIterDim::N,
1146- /* stride_elems=*/ cols,
1147- /* offset_elems=*/ 0 ,
1127+ /* stride_elems=*/ cols,
1128+ /* offset_elems=*/ 0 , output_type_bit_size);
1129+ create_2D_tensor_map (tensor_map_colwise_output, output->columnwise_data ,
1130+ traits::output_swizzle_pattern, rows, cols,
1131+ traits::blockIterDim::M, traits::blockIterDim::N, cols, 0 ,
11481132 output_type_bit_size);
1149- create_2D_tensor_map (tensor_map_colwise_output,
1150- output->columnwise_data ,
1151- traits::output_swizzle_pattern,
1152- rows, cols,
1153- traits::blockIterDim::M, traits::blockIterDim::N,
1154- cols, 0 , output_type_bit_size);
11551133
1156- dim3 block (traits::rowThreadLayout::num,
1157- traits::numWarps);
1134+ dim3 block (traits::rowThreadLayout::num, traits::numWarps);
11581135 dim3 grid ((cols + traits::blockDIM::N - 1 ) / traits::blockDIM::N,
11591136 (rows + traits::blockDIM::M - 1 ) / traits::blockDIM::M);
11601137 kernel<<<grid, block, traits::smem, stream>>> (
1161- tensor_map_input,
1162- tensor_map_rowwise_output,
1163- tensor_map_colwise_output,
1164- scales_rowwise_ptr,
1165- scales_colwise_ptr,
1166- rows,
1167- cols,
1168- scale_stride_rowwise,
1169- scale_stride_colwise
1170- );
1138+ tensor_map_input, tensor_map_rowwise_output, tensor_map_colwise_output,
1139+ scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
1140+ scale_stride_colwise);
11711141
11721142 break ;
11731143 }
0 commit comments