@@ -893,22 +893,22 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
893
893
894
894
// GEMM
895
895
TensorWrapper input_a_chunk, input_b_chunk;
896
- if (ag_on_B) { // AllGather is performed on input B tensor (default case).
897
- // Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
898
- input_a_chunk = get_tensor_chunk (A, transb ? input_a_chunk_size * send_chunk_id / 2 : 0 ,
896
+ if (ag_on_B) { // AllGather is performed on input B tensor (default case).
897
+ // Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
898
+ input_a_chunk = get_tensor_chunk (
899
+ A, transb ? input_a_chunk_size * send_chunk_id / 2 : 0 ,
899
900
transb ? std::vector<size_t >{k_chunk * 2 , m} : shape_to_vector (A.shape ()));
900
901
input_b_chunk =
901
902
get_buffer_chunk_like (B, input_b_chunk_size * send_chunk_id / 2 , input_b_chunk_shape);
902
- } else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
903
+ } else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
903
904
assert (transa == false && transb == true );
904
- input_a_chunk = get_buffer_chunk_like (
905
- A, input_a_chunk_size * send_chunk_id / 2 , std::vector<size_t >{k_chunk * 2 , m}
906
- );
905
+ input_a_chunk = get_buffer_chunk_like (A, input_a_chunk_size * send_chunk_id / 2 ,
906
+ std::vector<size_t >{k_chunk * 2 , m});
907
907
input_b_chunk =
908
908
get_tensor_chunk (B, input_b_chunk_size * send_chunk_id / 2 , input_b_chunk_shape);
909
909
}
910
- auto output_chunk =
911
- get_tensor_chunk (D, transb ? 0 : output_chunk_size * send_chunk_id / 2 , output_chunk_shape);
910
+ auto output_chunk = get_tensor_chunk (D, transb ? 0 : output_chunk_size * send_chunk_id / 2 ,
911
+ output_chunk_shape);
912
912
auto aux_chunk = (do_gelu)
913
913
? get_tensor_chunk (pre_gelu_out, output_chunk_size * send_chunk_id / 2 ,
914
914
{2 * n_chunk, k})
@@ -963,15 +963,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
963
963
964
964
// GEMM
965
965
TensorWrapper input_a_chunk, input_b_chunk;
966
- if (ag_on_B) { // AllGather is performed on input B tensor (default case).
967
- // Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
968
- input_a_chunk = get_tensor_chunk (A, transb ? input_a_chunk_size * send_chunk_id : 0 ,
969
- transb ? std::vector<size_t >{k_chunk, m} : shape_to_vector (A.shape ()));
966
+ if (ag_on_B) { // AllGather is performed on input B tensor (default case).
967
+ // Use case: AG->{FC2, PROJ}_Wgrad, AG->{FC1, QKV}_FPROP.
968
+ input_a_chunk =
969
+ get_tensor_chunk (A, transb ? input_a_chunk_size * send_chunk_id : 0 ,
970
+ transb ? std::vector<size_t >{k_chunk, m} : shape_to_vector (A.shape ()));
970
971
input_b_chunk =
971
972
get_buffer_chunk_like (B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
972
- } else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
973
+ } else { // AllGather is performed on input A tensor. Use case: AG->{FC1, QKV}_Wgrad.
973
974
assert (transa == false && transb == true );
974
- input_a_chunk = get_buffer_chunk_like (A, input_a_chunk_size * send_chunk_id,
975
+ input_a_chunk = get_buffer_chunk_like (
976
+ A, input_a_chunk_size * send_chunk_id,
975
977
transb ? std::vector<size_t >{k_chunk, m} : std::vector<size_t >{m, k});
976
978
input_b_chunk =
977
979
get_tensor_chunk (B, input_b_chunk_size * send_chunk_id, input_b_chunk_shape);
0 commit comments