@@ -404,12 +404,11 @@ int DoNeighborAllreduce(::torch::Tensor tensor, ::torch::Tensor output,
404
404
output_reduced.mul_ (weight);
405
405
} else {
406
406
output_reduced.add_ (
407
- output_buffer.slice (0 , i * first_dim, (i + 1 ) * first_dim)
408
- .mul_ (weight));
407
+ output_buffer.slice (0 , i * first_dim, (i + 1 ) * first_dim), weight);
409
408
}
410
409
}
411
410
output_buffer.resize_ (shape_vector);
412
- output_buffer.add_ (tensor_buffer. mul ( self_weight) );
411
+ output_buffer.add_ (tensor_buffer, self_weight);
413
412
if (is_hierarchical){
414
413
// Because there is ncclAllreduce just take sum.
415
414
output_buffer.div_ (bluefog_local_size ());
@@ -495,12 +494,11 @@ int DoNeighborAllreduce(::torch::Tensor tensor, ::torch::Tensor output,
495
494
output_reduced.mul_ (weight);
496
495
} else {
497
496
output_reduced.add_ (
498
- output_buffer.slice (0 , i * first_dim, (i + 1 ) * first_dim)
499
- .mul_ (weight));
497
+ output_buffer.slice (0 , i * first_dim, (i + 1 ) * first_dim), weight);
500
498
}
501
499
}
502
500
output_buffer.resize_ (shape_vector);
503
- output_buffer.add_ (tensor_buffer. mul ( self_weight) );
501
+ output_buffer.add_ (tensor_buffer, self_weight);
504
502
if (is_hierarchical){
505
503
// Because there is ncclAllreduce just take sum.
506
504
output_buffer.div_ (bluefog_local_size ());
@@ -576,7 +574,7 @@ int DoPairGossip(::torch::Tensor tensor, ::torch::Tensor output,
576
574
output_buffer.add_ (tensor_buffer).div_ (2 );
577
575
} else {
578
576
output_buffer.mul_ (pair_weight)
579
- .add_ (tensor_buffer. mul ( self_weight) );
577
+ .add_ (tensor_buffer, self_weight);
580
578
}
581
579
MaybeCopyBufferBack (output, output_buffer);
582
580
}));
@@ -597,7 +595,7 @@ int DoPairGossip(::torch::Tensor tensor, ::torch::Tensor output,
597
595
output_buffer.add_ (tensor_buffer).div_ (2 );
598
596
} else {
599
597
output_buffer.mul_ (pair_weight)
600
- .add_ (tensor_buffer. mul ( self_weight) );
598
+ .add_ (tensor_buffer, self_weight);
601
599
}
602
600
MaybeCopyBufferBack (output, output_buffer);
603
601
}));
0 commit comments