Skip to content

Commit 80d8dc7

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1aecbf0 commit 80d8dc7

File tree

6 files changed

+25
-23
lines changed

6 files changed

+25
-23
lines changed

transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -893,22 +893,22 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
893893

894894
// GEMM
895895
TensorWrapper input_a_chunk, input_b_chunk;
896-
if (ag_on_B) { // AllGather is performed on input B tensor (default case).
897-
// Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
898-
input_a_chunk = get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id / 2 : 0,
896+
if (ag_on_B) { // AllGather is performed on input B tensor (default case).
897+
// Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
898+
input_a_chunk = get_tensor_chunk(
899+
A, transb ? input_a_chunk_size * send_chunk_id / 2 : 0,
899900
transb ? std::vector<size_t>{k_chunk * 2, m} : shape_to_vector(A.shape()));
900901
input_b_chunk =
901902
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape);
902-
} else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
903+
} else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
903904
assert(transa == false && transb == true);
904-
input_a_chunk = get_buffer_chunk_like(
905-
A, input_a_chunk_size * send_chunk_id / 2, std::vector<size_t>{k_chunk * 2, m}
906-
);
905+
input_a_chunk = get_buffer_chunk_like(A, input_a_chunk_size * send_chunk_id / 2,
906+
std::vector<size_t>{k_chunk * 2, m});
907907
input_b_chunk =
908908
get_tensor_chunk(B, input_b_chunk_size * send_chunk_id / 2, input_b_chunk_shape);
909909
}
910-
auto output_chunk =
911-
get_tensor_chunk(D, transb ? 0 : output_chunk_size * send_chunk_id / 2, output_chunk_shape);
910+
auto output_chunk = get_tensor_chunk(D, transb ? 0 : output_chunk_size * send_chunk_id / 2,
911+
output_chunk_shape);
912912
auto aux_chunk = (do_gelu)
913913
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id / 2,
914914
{2 * n_chunk, k})
@@ -963,15 +963,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
963963

964964
// GEMM
965965
TensorWrapper input_a_chunk, input_b_chunk;
966-
if (ag_on_B) { // AllGather is performed on input B tensor (default case).
967-
// Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
968-
input_a_chunk = get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id : 0,
969-
transb ? std::vector<size_t>{k_chunk, m} : shape_to_vector(A.shape()));
966+
if (ag_on_B) { // AllGather is performed on input B tensor (default case).
967+
// Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
968+
input_a_chunk =
969+
get_tensor_chunk(A, transb ? input_a_chunk_size * send_chunk_id : 0,
970+
transb ? std::vector<size_t>{k_chunk, m} : shape_to_vector(A.shape()));
970971
input_b_chunk =
971972
get_buffer_chunk_like(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
972-
} else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
973+
} else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
973974
assert(transa == false && transb == true);
974-
input_a_chunk = get_buffer_chunk_like(A, input_a_chunk_size * send_chunk_id,
975+
input_a_chunk = get_buffer_chunk_like(
976+
A, input_a_chunk_size * send_chunk_id,
975977
transb ? std::vector<size_t>{k_chunk, m} : std::vector<size_t>{m, k});
976978
input_b_chunk =
977979
get_tensor_chunk(B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);

transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,8 @@ class CommOverlapCore {
130130
virtual void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B,
131131
bool transb, TensorWrapper &D, TensorWrapper &bias,
132132
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
133-
bool accumulate, bool use_split_accumulator,
134-
bool ag_on_B, TensorWrapper &B_copy,
135-
cudaStream_t stream_main) {
133+
bool accumulate, bool use_split_accumulator, bool ag_on_B,
134+
TensorWrapper &B_copy, cudaStream_t stream_main) {
136135
NVTE_ERROR("Operation is not implemented.");
137136
}
138137

transformer_engine/pytorch/cpp_extensions/gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def general_gemm(
106106
workspace.shape[0],
107107
accumulate,
108108
use_split_accumulator,
109-
ag_on_B, # ag_on_B
109+
ag_on_B, # ag_on_B
110110
)
111111
kwargs = {
112112
"comm_overlap": ub,

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
120120
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
121121
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
122122
at::Tensor workspace, size_t workspaceSize, bool accumulate,
123-
bool use_split_accumulator, bool ag_on_B, CommOverlapCore *comm_overlap = nullptr,
123+
bool use_split_accumulator, bool ag_on_B,
124+
CommOverlapCore *comm_overlap = nullptr,
124125
std::optional<CommOverlapType> comm_type = std::nullopt,
125126
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
126127

transformer_engine/pytorch/csrc/extensions/gemm.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
9090
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
9191
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
9292
at::Tensor workspace, size_t workspaceSize, bool accumulate,
93-
bool use_split_accumulator, bool ag_on_B, CommOverlapCore* comm_overlap,
93+
bool use_split_accumulator, bool ag_on_B,
94+
CommOverlapCore* comm_overlap,
9495
std::optional<CommOverlapType> comm_type, MaybeTensor extra_output,
9596
bool bulk_overlap) {
9697
// Input tensors

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
110110
py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"),
111111
py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"),
112112
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
113-
py::arg("ag_on_B"),
114-
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
113+
py::arg("ag_on_B"), py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
115114
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
116115
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
117116
py::arg("quantizer"));

0 commit comments

Comments
 (0)