Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,9 @@ std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape
return ret;
}

std::vector<size_t> getTensorShape(const at::Tensor& t) {
std::vector<size_t> shape;
for (auto s : t.sizes()) {
shape.push_back(s);
}
return shape;
}
NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); }

NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) {
NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) {
NVTEShape ret;
ret.ndim = torch_shape.size();
constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t);
Expand Down Expand Up @@ -178,6 +172,38 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode) {
TensorWrapper ret(scaling_mode);
ret.set_rowwise_data(data_ptr, type, shape);
const size_t meta_shape_data[1] = {1};
const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1);
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
auto scale_inv_dtype =
(scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32;
ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode) {
TensorWrapper ret(scaling_mode);
ret.set_rowwise_data(data_ptr, type, shape);
const size_t meta_shape_data[1] = {1};
const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1);
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
auto scale_inv_dtype =
(scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32;
ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
Expand All @@ -199,6 +225,28 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape,
const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape,
NVTEScalingMode scaling_mode) {
TensorWrapper ret(scaling_mode);
ret.set_rowwise_data(data_ptr, type, shape);
ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape);
const size_t meta_shape_data[1] = {1};
const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1);
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0
: (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3
: DType::kFloat32;
ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype,
columnwise_scale_inv_shape);
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax,
const at::Tensor scale,
at::Tensor scale_inv,
Expand Down
21 changes: 19 additions & 2 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ class NVFP4Quantizer : public Quantizer {

std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);

std::vector<size_t> getTensorShape(const at::Tensor& t);
NVTEShape getTensorShape(const at::Tensor& t);

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe);
Expand Down Expand Up @@ -432,6 +432,16 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape = {1},
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
Expand All @@ -440,6 +450,13 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
const std::vector<size_t>& columnwise_scale_inv_shape = {1},
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape,
const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape,
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
const NVTEShape& shape,
const transformer_engine::DType type);
Expand Down Expand Up @@ -479,7 +496,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape);

size_t roundup(const size_t value, const size_t multiple);

NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape);

std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose);

Expand Down
9 changes: 6 additions & 3 deletions transformer_engine/pytorch/csrc/extensions/bias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ std::vector<py::object> bgrad_quantize(const at::Tensor &grad_output, py::handle
// Grad output tensor
auto grad_output_torch = grad_output.contiguous();
const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch);
const auto shape = getTensorShape(grad_output_torch);
const auto shape_nvte = getTensorShape(grad_output_torch);
const auto shape = convertShape(shape_nvte);
auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type());

// Construct grad bias tensor
Expand Down Expand Up @@ -116,11 +117,13 @@ std::vector<py::object> dact_dbias(
// Grad output and activation input tensors
grad_output_torch = grad_output_torch.contiguous();
const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch);
const auto output_shape = getTensorShape(grad_output_torch);
const auto output_shape_nvte = getTensorShape(grad_output_torch);
const auto output_shape = convertShape(output_shape_nvte);
auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type());
act_input_torch = act_input_torch.contiguous();
const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch);
const auto input_shape = getTensorShape(act_input_torch);
const auto input_shape_nvte = getTensorShape(act_input_torch);
const auto input_shape = convertShape(input_shape_nvte);

// Construct tensors
auto quantizer_cpp = convert_quantizer(quantizer_py);
Expand Down
18 changes: 10 additions & 8 deletions transformer_engine/pytorch/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,16 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING;
NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING;

auto te_A = makeTransformerEngineTensor(
A.data_ptr(), {static_cast<size_t>(A.size(0)), static_cast<size_t>(A.size(1))}, A_type,
nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse),
nvte_scaling_modeA);
auto te_B = makeTransformerEngineTensor(
B.data_ptr(), {static_cast<size_t>(B.size(0)), static_cast<size_t>(B.size(1))}, B_type,
nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse),
nvte_scaling_modeB);
const size_t A_shape_data[2] = {static_cast<size_t>(A.size(0)), static_cast<size_t>(A.size(1))};
const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2);
auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr,
A_scale_inverse.data_ptr(),
getTensorShape(A_scale_inverse), nvte_scaling_modeA);
const size_t B_shape_data[2] = {static_cast<size_t>(B.size(0)), static_cast<size_t>(B.size(1))};
const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2);
auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr,
B_scale_inverse.data_ptr(),
getTensorShape(B_scale_inverse), nvte_scaling_modeB);
// TODO: D_scale_inv cannot be nullptr when D_type is FP8.
auto te_D = makeTransformerEngineTensor(
D.data_ptr(),
Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/pytorch/csrc/extensions/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor
init_extension();

// Tensor dimensions
const auto shape = getTensorShape(input);
const auto shape_nvte = getTensorShape(input);
const auto shape = convertShape(shape_nvte);
std::vector<int64_t> transpose_shape_int64;
if (shape.size() > 0) {
transpose_shape_int64.push_back(shape.back());
Expand Down Expand Up @@ -60,7 +61,8 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out) {

// Allocate output tensor if needed
if (!out) {
auto in_shape = getTensorShape(input);
const auto in_shape_nvte = getTensorShape(input);
const auto in_shape = convertShape(in_shape_nvte);
NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")");
std::vector<int64_t> out_shape_int64(in_shape.begin(), in_shape.end());
out_shape_int64[0] = static_cast<int64_t>(in_shape[1]);
Expand Down
35 changes: 20 additions & 15 deletions transformer_engine/pytorch/csrc/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,20 +209,22 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::convert_and_update_tensor(
// Tensor dimensions
std::vector<size_t> shape;
if (has_transpose) {
const auto transpose_shape = getTensorShape(*transpose_tensor);
const auto transpose_shape_nvte = getTensorShape(*transpose_tensor);
const auto transpose_shape = convertShape(transpose_shape_nvte);
if (transpose_shape.size() > 0) {
for (size_t i = 1; i < transpose_shape.size(); ++i) {
shape.push_back(transpose_shape[i]);
}
shape.push_back(transpose_shape.front());
}
if (has_data) {
auto expected_shape = getTensorShape(*data_tensor);
const auto expected_shape_nvte = getTensorShape(*data_tensor);
const auto expected_shape = convertShape(expected_shape_nvte);
NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape,
") and transpose (shape=", transpose_shape, ") do not match");
}
} else { // Already checked has_data == true
shape = getTensorShape(*data_tensor);
shape = convertShape(getTensorShape(*data_tensor));
}

// Coerce data tensor
Expand Down Expand Up @@ -430,20 +432,22 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::convert_and_
// Tensor dimensions
std::vector<size_t> shape;
if (has_transpose) {
const auto transpose_shape = getTensorShape(*transpose_tensor);
const auto transpose_shape_nvte = getTensorShape(*transpose_tensor);
const auto transpose_shape = convertShape(transpose_shape_nvte);
if (transpose_shape.size() > 0) {
for (size_t i = 1; i < transpose_shape.size(); ++i) {
shape.push_back(transpose_shape[i]);
}
shape.push_back(transpose_shape.front());
}
if (has_data) {
auto expected_shape = getTensorShape(*data_tensor);
const auto expected_shape_nvte = getTensorShape(*data_tensor);
const auto expected_shape = convertShape(expected_shape_nvte);
NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape,
") and transpose (shape=", transpose_shape, ") do not match");
}
} else { // Already checked has_data == true
shape = getTensorShape(*data_tensor);
shape = convertShape(getTensorShape(*data_tensor));
}

// Coerce data tensor in Python tensor
Expand Down Expand Up @@ -680,9 +684,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
return std::vector<size_t>();
}
if (all_gather_usage) {
return getTensorShape(*columnwise_data);
return convertShape(getTensorShape(*columnwise_data));
}
std::vector<size_t> shape = getTensorShape(*columnwise_data);
std::vector<size_t> shape = convertShape(getTensorShape(*columnwise_data));
std::vector<size_t> shape_transposed(shape.size());
for (size_t i = 0; i + 1 < shape.size(); ++i) {
shape_transposed[i] = shape[i + 1];
Expand All @@ -694,7 +698,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
};
std::vector<size_t> shape;
if (rowwise_data) {
shape = getTensorShape(*rowwise_data);
shape = convertShape(getTensorShape(*rowwise_data));
if (columnwise_data) {
auto expected_shape = get_columnwise_shape(all_gather_usage);
NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape,
Expand Down Expand Up @@ -1004,14 +1008,14 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
// Tensor dimensions
std::vector<size_t> shape;
if (columnwise_data) {
shape = getTensorShape(*columnwise_data);
shape = convertShape(getTensorShape(*columnwise_data));
if (rowwise_data) {
auto expected_shape = getTensorShape(*rowwise_data);
const auto expected_shape = convertShape(getTensorShape(*rowwise_data));
NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape,
") and column-wise data (shape=", shape, ") do not match");
}
} else { // Already checked columnwise_data_tensor == true
shape = getTensorShape(*rowwise_data);
shape = convertShape(getTensorShape(*rowwise_data));
}

// Coerce row-wise data
Expand Down Expand Up @@ -1320,14 +1324,15 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
// Tensor dimensions, shape means original shape
std::vector<size_t> shape;
if (columnwise_data) {
shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true);
shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*columnwise_data)), true);
if (rowwise_data) {
auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false);
auto expected_shape =
convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false);
NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape,
") and column-wise data (shape=", shape, ") do not match");
}
} else { // Already checked columnwise_data_tensor == true
shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false);
shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false);
}

size_t flat_first_dim = 1;
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/csrc/type_converters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer)
const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast<at::Tensor>();
ret.set_rowwise_data(data.data_ptr(), dtype,
convert_shape_back_from_fp4(getTensorShape(data), false));
convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false));
ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv));
ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise));
}
Expand All @@ -143,7 +143,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer)
const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast<at::Tensor>();
ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1,
convert_shape_back_from_fp4(getTensorShape(data), false));
convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false));
ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3,
getTensorShape(scale_inv));
ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32,
Expand Down
Loading
Loading