Skip to content

Commit 2f696ed

Browse files
Hanbin HuBichengYing
Hanbin Hu
andauthored
Improve neighbor allreduce (#78)
* Fixed the self_weight under emtpy receiving case * Enable empty send neighbors and fix HalfTensor for recv_size==0 * Fixed the self_weight under emtpy receiving case * Enable empty send neighbors and fix HalfTensor for recv_size==0 * Rename neighbor_weights to src_weights, and send_neighbors to dst_weights for neighbor_allreduce * A script to test existing examples * Accept dst_weights as Dict, and reorganize DoNeighborAllreduce * Reorganize CheckNeighborSendRecvPattern * Fix timeline_ptr for NCCL * Fix timeline_ptr for NCCL * Put dst_weights information into TensorTableEntry * First Version of neighbor_allreduce dst_weight, existing problem: Fusion Not Implemented, CUDA data_weight problem * Add some delay after data_weight as a temporary solution * CPU Fusion for dst_weighted added * Add ReadyEvent for dst_weight for single entry neighbor_allreduce * Remove const identifier for tensor dtype as it is meaningless * Add cuda source for scalebuffer * Scale buffer to modify itself * Add .o file to .gitignore * dst_weight using CUDA for fused entry & compile flow in Python setup.py * make clean *.o files generated by nvcc * Add fix for NCCL single entry * Make setup.py more robust * Add timeout and cuda check * Move test example * Fix NCCL side dst_weight fusion bug * Add agg to make matplotlib more stable * Address comments for setup.py * Simpler logic for dst_weighting_enabled and weighted_average_computation * Better consideration for weight buffer size * Make src_weights as std::map, and simplify logic for PerformNeighborAllreduceCallback * Add TODO #80 and #81, and simplify the logic for dst_weight * Wrap CheckNeighborSendRecvPattern again * Add two more TODOs * Address review comments * Add condition variable to control the loop (#88) * Add condition variable to control the loop * Minor update on topology_setting in global_state * Add missing <condition_variable> header * Change cv.wait to cv.wait_for 10 seconds * Address comment and remove adjusting resetVersionWinMem in ibfrun Co-authored-by: ybc <[email protected]>
1 parent 8bde896 commit 2f696ed

34 files changed

+1233
-508
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ __pycache__/
55

66
# C extensions
77
*.so
8+
*.o
89

910
# Distribution / packaging
1011
.Python

Makefile

+7-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ test_torch: test_torch_basic test_torch_ops test_torch_win_ops test_torch_optimi
1919
test_tensorflow: test_tensorflow_basic test_tensorflow_ops
2020
test_all: test_torch test_tensorflow
2121

22-
clean: clean_build clean_so
22+
clean: clean_build clean_so clean_o
2323

2424
.PHONY: test_torch_basic
2525
test_torch_basic:
@@ -51,8 +51,12 @@ test_tensorflow_ops:
5151

5252
.PHONY: clean_build
5353
clean_build:
54-
rm -R build
54+
rm -fR build
5555

5656
.PHONY: clean_so
5757
clean_so:
58-
rm ./bluefog/torch/mpi_lib.*.so
58+
rm -f ./bluefog/torch/mpi_lib.*.so
59+
60+
.PHONY: clean_o
61+
clean_o:
62+
rm -f ./bluefog/common/cuda/*.o

bluefog/common/common.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,10 @@ class TensorShape {
209209

210210
class Tensor {
211211
public:
212-
virtual const DataType dtype() const = 0;
212+
virtual DataType dtype() const = 0;
213213
virtual const TensorShape shape() const = 0;
214214
virtual const void* data() const = 0;
215-
virtual std::shared_ptr<common::Tensor> data_weight(float weight) = 0;
215+
virtual std::unique_ptr<common::Tensor> data_weight(float weight) = 0;
216216
virtual int64_t size() const = 0;
217217
virtual ~Tensor() = default;
218218
};
@@ -241,6 +241,7 @@ class OpContext {
241241
std::shared_ptr<Tensor>* tensor) = 0;
242242
virtual Status AllocateZeros(int64_t num_elements, DataType dtype,
243243
std::shared_ptr<Tensor>* tensor) = 0;
244+
virtual std::shared_ptr<ReadyEvent> RecordReadyEvent(int device) = 0;
244245
virtual Framework framework() const = 0;
245246
virtual ~OpContext() = default;
246247
};
@@ -279,10 +280,14 @@ struct TensorTableEntry {
279280
// Neighbors for dynamic neighbor_allreduce.
280281
std::shared_ptr<std::vector<int>> send_neighbors;
281282
std::shared_ptr<std::vector<int>> recv_neighbors;
283+
std::shared_ptr<std::vector<double>> send_weights;
282284

283285
// Boolean value if dynamic neighbor is enabled.
284286
bool dynamic_neighbors_enabled = false;
285287

288+
// Boolean value for enabling destination(send) weighting operation or not.
289+
bool dst_weighting_enabled = false;
290+
286291
// Boolean value for enabling topology check.
287292
bool enable_topo_check = false;
288293

bluefog/common/cuda/cuda_kernels.cu

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Copyright (C) 2020 NVIDIA CORPORATION. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
// =============================================================================
15+
16+
#include "cuda_kernels.h"
17+
18+
#include <stdexcept>
19+
#include <cuda_fp16.h>
20+
21+
namespace bluefog {
22+
namespace common {
23+
24+
template<typename T, typename TS>
25+
__global__ void scale_buffer_k(T* buffer, int64_t num_elements, const TS scale_factor) {
26+
27+
const size_t idx = static_cast<size_t>(blockDim.x) * blockIdx.x + threadIdx.x;
28+
29+
for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) {
30+
buffer[i] *= scale_factor;
31+
}
32+
}
33+
34+
// Specialization for half2
35+
__global__ void scale_buffer_half2_k(__half* buffer, int64_t num_elements, const __half scale_factor) {
36+
37+
const size_t idx = static_cast<size_t>(blockDim.x) * blockIdx.x + threadIdx.x;
38+
39+
#if __CUDA_ARCH__ > 530
40+
__half2* buffer_h2 = reinterpret_cast<__half2 *>(buffer);
41+
const __half2 scale_factor_h2 = __halves2half2(scale_factor, scale_factor);
42+
43+
for (size_t i = idx; i < num_elements / 2; i += gridDim.x * blockDim.x) {
44+
buffer_h2[i] = __hmul2(scale_factor_h2, buffer_h2[i]);
45+
}
46+
47+
// Deal with last element if num_elements is odd
48+
if (idx == 0 && num_elements % 2) {
49+
buffer[num_elements - 1] = __hmul(scale_factor, buffer[num_elements - 1]);
50+
}
51+
#else
52+
for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) {
53+
buffer[i] = __float2half(__half2float(scale_factor) * __half2float(buffer[i]));
54+
}
55+
#endif
56+
}
57+
58+
// Specialization for architectures without __half compute
59+
template<>
60+
__global__ void scale_buffer_k(__half* buffer, int64_t num_elements, const __half scale_factor) {
61+
62+
const size_t idx = static_cast<size_t>(blockDim.x) * blockIdx.x + threadIdx.x;
63+
64+
#if __CUDA_ARCH__ > 530
65+
for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) {
66+
buffer[i] *= scale_factor;
67+
}
68+
#else
69+
for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) {
70+
buffer[i] = __float2half(__half2float(scale_factor) * __half2float(buffer[i]));
71+
}
72+
#endif
73+
}
74+
75+
#define NTHREADS_SCALE_BUFFER_KERNEL 512
76+
void ScaleBufferCudaImpl(double scale_factor, void* buffer_data, const int64_t num_elements,
77+
DataType dtype, cudaStream_t stream) {
78+
const int64_t blocks = (num_elements + NTHREADS_SCALE_BUFFER_KERNEL - 1) / NTHREADS_SCALE_BUFFER_KERNEL;
79+
const int threads = NTHREADS_SCALE_BUFFER_KERNEL;
80+
switch (dtype) {
81+
case DataType::BLUEFOG_UINT8:
82+
scale_buffer_k<<<blocks, threads, 0, stream>>>((uint8_t*) buffer_data, num_elements, scale_factor);
83+
break;
84+
case DataType::BLUEFOG_INT8:
85+
scale_buffer_k<<<blocks, threads, 0, stream>>>((int8_t*) buffer_data, num_elements, scale_factor);
86+
break;
87+
case DataType::BLUEFOG_INT32:
88+
scale_buffer_k<<<blocks, threads, 0, stream>>>((int32_t*) buffer_data, num_elements, scale_factor);
89+
break;
90+
case DataType::BLUEFOG_INT64:
91+
scale_buffer_k<<<blocks, threads, 0, stream>>>((int64_t*) buffer_data, num_elements, scale_factor);
92+
break;
93+
case DataType::BLUEFOG_FLOAT16:
94+
{
95+
__half scale_factor_half = __float2half((float) scale_factor);
96+
if ((size_t) buffer_data % 4 == 0) {
97+
// If alignment allows, use half2 specialized kernel
98+
int64_t num_elements_h2 = (num_elements + 1) / 2;
99+
int64_t blocks_h2 = (num_elements_h2 + NTHREADS_SCALE_BUFFER_KERNEL - 1) / NTHREADS_SCALE_BUFFER_KERNEL;
100+
scale_buffer_half2_k<<<blocks_h2, threads, 0, stream>>>((__half*) buffer_data, num_elements, scale_factor_half);
101+
} else {
102+
scale_buffer_k<<<blocks, threads, 0, stream>>>((__half*) buffer_data, num_elements, scale_factor_half);
103+
}
104+
break;
105+
}
106+
case DataType::BLUEFOG_FLOAT32:
107+
scale_buffer_k<<<blocks, threads, 0, stream>>>((float*) buffer_data, num_elements, (float) scale_factor);
108+
break;
109+
case DataType::BLUEFOG_FLOAT64:
110+
scale_buffer_k<<<blocks, threads, 0, stream>>>((double*) buffer_data, num_elements, scale_factor);
111+
break;
112+
default:
113+
throw std::logic_error("Type " + DataType_Name(dtype) +
114+
" not supported by ScaleBufferCudaImpl.");
115+
}
116+
}
117+
118+
} // namespace common
119+
} // namespace bluefog
120+

bluefog/common/cuda/cuda_kernels.h

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (C) 2020 NVIDIA CORPORATION. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
// =============================================================================
15+
16+
#ifndef CUDA_KERNELS_H
17+
#define CUDA_KERNELS_H
18+
19+
#include <cuda_runtime.h>
20+
21+
#include "../common.h"
22+
23+
namespace bluefog {
24+
namespace common {
25+
26+
// Scales buffer by scalar
27+
void ScaleBufferCudaImpl(double scale_factor, void* buffer_data, const int64_t num_elements,
28+
DataType dtype, cudaStream_t stream);
29+
30+
} // namespace common
31+
} // namespace bluefog
32+
33+
#endif // CUDA_KERNELS_H

bluefog/common/global_state.h

+11-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#define BLUEFOG_COMMON_GLOBAL_STATE_H
1919

2020
#include <atomic>
21+
#include <condition_variable>
2122
#include <chrono>
2223
#include <memory>
2324
#include <queue>
@@ -54,6 +55,14 @@ struct BluefogGlobalState {
5455
// Whether collective context has been completed on the background thread.
5556
std::atomic_bool initialization_done{false};
5657

58+
// Condition variable and its mutex for main loop in communication thread.
59+
std::condition_variable loop_cv;
60+
std::mutex loop_mutex;
61+
62+
// Under negotiation, the entries sends to master first and wait until it
63+
// returns ok to run. This variable keeps the records of that.
64+
std::atomic_int unfinished_enqueued_entries{0};
65+
5766
// Timeline writer.
5867
Timeline timeline;
5968

@@ -80,13 +89,12 @@ struct BluefogGlobalState {
8089
// Threshold for Tensor Fusion. All tensors that occupy memory beyond this
8190
// threshold will be fused.
8291
int64_t tensor_fusion_threshold = 8 * 1024 * 1024;
92+
int64_t tensor_fusion_threshold_for_dst_weight = tensor_fusion_threshold;
8393
FusionBufferManager fusion_buffer;
8494

8595
// Because setting topology happens in the main thread instead of communication
86-
// thread. Following three variables are to sync between them.
96+
// thread. Not really used since the condition variable refactor.
8797
std::atomic_bool setting_topology{false};
88-
std::atomic_bool setting_topology_done{false};
89-
std::atomic_bool ready_to_setting_topology{false};
9098

9199
// Only exists on the coordinator node (rank zero). Maintains a vector of
92100
// requests to allreduce every tensor (keyed by tensor name).

bluefog/common/mpi_context.cc

+9-2
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ bool WindowManager::InitializeMutexWin(const MPI_Comm& mpi_comm) {
7575
std::vector<int> WindowManager::GetVersionMemoryCopy() { return version_mem_; }
7676

7777
void WindowManager::resetVersionWinMem(int initialValue /*=0*/) {
78-
for (int i = 0; i < version_mem_.size(); i++) {
78+
for (size_t i = 0; i < version_mem_.size(); i++) {
7979
version_mem_[i] = initialValue;
8080
}
8181
}
@@ -222,7 +222,7 @@ MPI_Op MPIContext::GetMPISumOp(DataType dtype) {
222222
return dtype == DataType::BLUEFOG_FLOAT16 ? mpi_float16_sum : MPI_SUM;
223223
}
224224

225-
MPI_Comm MPIContext::GetMPICommunicator(Communicator comm) {
225+
MPI_Comm MPIContext::GetMPICommunicator(Communicator comm) const {
226226
switch (comm) {
227227
case Communicator::GLOBAL:
228228
return mpi_comm;
@@ -332,6 +332,13 @@ void MPIContext::Initialize(const std::vector<int>& ranks,
332332

333333
// Create custom MPI float16 summation op.
334334
MPI_Op_create(&float16_sum, 1, &mpi_float16_sum);
335+
336+
#if HAVE_CUDA
337+
int greatest_priority;
338+
CUDACHECK(cudaDeviceGetStreamPriorityRange(NULL, &greatest_priority));
339+
CUDACHECK(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking,
340+
greatest_priority));
341+
#endif
335342
}
336343

337344
void MPIContext::Finalize(MPIContextManager& ctx_manager) {

bluefog/common/mpi_context.h

+14-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
#include <unordered_map>
2323
#include <vector>
2424

25+
#if HAVE_CUDA
26+
#include "cuda_runtime.h"
27+
#endif
28+
2529
#include "common.h"
2630
#include "mpi.h"
2731

@@ -144,7 +148,7 @@ class MPIContext {
144148

145149
MPI_Op GetMPISumOp(DataType dtype);
146150

147-
MPI_Comm GetMPICommunicator(Communicator comm);
151+
MPI_Comm GetMPICommunicator(Communicator comm) const;
148152

149153
int GetMPITypeSize(DataType dtype);
150154

@@ -232,8 +236,17 @@ class MPIContext {
232236
// MPI Custom data type for float16.
233237
MPI_Datatype mpi_float16_t;
234238
MPI_Op mpi_float16_sum;
239+
240+
// TODO(hhb): #80 We should use a common context for MPI and NCCL controller for CUDA usage.
241+
#if HAVE_CUDA
242+
// CUDA Stream
243+
cudaStream_t stream;
244+
#endif
235245
};
236246

247+
std::string GenerateNeighborExchangeErrorMessage(const std::vector<MPI_Status>& statuses,
248+
int nsend, int nrecv);
249+
237250
} // namespace common
238251
} // namespace bluefog
239252

0 commit comments

Comments
 (0)