Skip to content

Commit

Permalink
Use NCCL to allreduce batch sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
Masahiro Tanaka committed Jan 25, 2022
1 parent 749f226 commit 3cd7e70
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 3 deletions.
5 changes: 5 additions & 0 deletions src/comm/NCCLWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,11 @@ void NCCLWrapper::allreduceMin(
doAllreduce(tag, tensors, ncclMin);
}

void NCCLWrapper::allreduceMax(
int tag, const std::vector<at::Tensor>& tensors) {
doAllreduce(tag, tensors, ncclMax);
}

void NCCLWrapper::reduce(
int tag, const std::vector<at::Tensor>& tensors,
const std::vector<int>& roots) {
Expand Down
1 change: 1 addition & 0 deletions src/comm/NCCLWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class NCCLWrapper {

void allreduce(int tag, const std::vector<at::Tensor>& tensors);
void allreduceMin(int tag, const std::vector<at::Tensor>& tensors);
void allreduceMax(int tag, const std::vector<at::Tensor>& tensors);
void reduce(
int tag, const std::vector<at::Tensor>& tensors,
const std::vector<int>& roots);
Expand Down
28 changes: 28 additions & 0 deletions src/comm/SComm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,34 @@ MPI_Comm SComm::getCommunicator(int tag, const std::unordered_set<int>& ranks) {
return *comm_map_.at(tag);
}

int64_t allReduceBatchSize(
int64_t batch_size, const std::function<void(int, at::Tensor)>& f) {
at::TensorOptions options =
torch::TensorOptions().dtype(c10::ScalarType::Long);
at::Tensor ten = torch::from_blob(&batch_size, {}, options).cuda();

NCCLWrapper& nccl = NCCLWrapper::get();
TagMap& tag_map = TagMap::get();
int tag = tag_map.getRankSetTag(mpi::getAllRanks());
nccl.createCommunicator(tag, mpi::getAllRanks());
f(tag, {ten});
return ten.cpu().item<int64_t>();
}

int64_t SComm::allReduceSumBatchSize(int64_t batch_size) {
return allReduceBatchSize(batch_size, [](int tag, at::Tensor ten) {
NCCLWrapper& nccl = NCCLWrapper::get();
nccl.allreduce(tag, {ten});
});
}

int64_t SComm::allReduceMaxBatchSize(int64_t batch_size) {
return allReduceBatchSize(batch_size, [](int tag, at::Tensor ten) {
NCCLWrapper& nccl = NCCLWrapper::get();
nccl.allreduceMax(tag, {ten});
});
}

void SComm::destroy() {
for (auto& c : comm_map_) {
c.second.reset();
Expand Down
3 changes: 3 additions & 0 deletions src/comm/SComm.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class SComm {
const torch::jit::IValue& tensor, const RouteDP& route, bool is_bwd,
const IRType& global_type, int split_delay = 0);

int64_t allReduceSumBatchSize(int64_t batch_size);
int64_t allReduceMaxBatchSize(int64_t batch_size);

MPI_Comm getCommunicator(int tag, const std::unordered_set<int>& ranks);

void destroy();
Expand Down
8 changes: 6 additions & 2 deletions src/comp/GraphLauncher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,9 @@ torch::jit::IValue GraphLauncher::forward(
IValueMap pad_inputs;
int64_t global_batch_size;
if (gather_inputs_) {
int64_t max_local_batch_size = mpi::allReduceMaxBatchSize(input_batch_size);
SComm& scomm = SComm::get();
int64_t max_local_batch_size =
scomm.allReduceMaxBatchSize(input_batch_size);
last_batch_size_ = max_local_batch_size;
global_batch_size = max_local_batch_size * mpi::getSize();
pad_inputs =
Expand Down Expand Up @@ -472,7 +474,9 @@ IValueMap GraphLauncher::backward(
IValueMap scaled_inputs;
int64_t global_batch_size;
if (gather_inputs_) {
int64_t max_local_batch_size = mpi::allReduceMaxBatchSize(input_batch_size);
SComm& scomm = SComm::get();
int64_t max_local_batch_size =
scomm.allReduceMaxBatchSize(input_batch_size);
const auto pad_inputs =
alignBatch(inputs, max_local_batch_size, deployment_.graph, true);
global_batch_size = max_local_batch_size * mpi::getSize();
Expand Down
3 changes: 2 additions & 1 deletion src/comp/RaNNCModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ std::vector<long> RaNNCModule::init(
std::vector<torch::jit::IValue> input_ivals =
torch::jit::_toTypeInferredIValue(args).toTuple()->elements();
int64_t local_batch_size = guessBatchSize(input_ivals);
SComm& scomm = SComm::get();
int64_t batch_size = gather_inputs
? mpi::allReduceSumBatchSize(local_batch_size)
? scomm.allReduceSumBatchSize(local_batch_size)
: local_batch_size;

config::Config& conf = config::Config::get();
Expand Down

0 comments on commit 3cd7e70

Please sign in to comment.