diff --git a/tensorpipe/benchmark/herring/benchmark_herring_gdr.cc b/tensorpipe/benchmark/herring/benchmark_herring_gdr.cc new file mode 100644 index 000000000..46acb4e9c --- /dev/null +++ b/tensorpipe/benchmark/herring/benchmark_herring_gdr.cc @@ -0,0 +1,832 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cuda_kernels.cuh" + +namespace { + +int64_t deltaAsUs( + std::chrono::steady_clock::time_point start, + std::chrono::steady_clock::time_point stop) { + return std::chrono::duration_cast(stop - start) + .count(); +} + +template +T ceilOfRatio(T num, T den) { + return (num - 1) / den + 1; +} + +class CallbackBarrier { + public: + CallbackBarrier() = default; + + template + auto wrapCallback(T fn) { + return wrapTask( + [this, fn{std::move(fn)}]( + const tensorpipe::Error& error, auto&&... args) mutable { + if (error) { + LOG(ERROR) << error.what(); + std::unique_lock lock(mutex_); + if (!anyError_) { + anyError_ = error; + } + } else { + fn(std::forward(args)...); + } + }); + } + + template + auto wrapTask(T fn) { + { + std::unique_lock lock(mutex_); + numPendingCallbacks_ += 1; + } + return [this, fn{std::move(fn)}](auto&&... args) mutable { + fn(std::forward(args)...); + std::unique_lock lock(mutex_); + numPendingCallbacks_ -= 1; + cv_.notify_all(); + }; + } + + void notifyExternalEventHappened() { + std::unique_lock lock(mutex_); + numExternalEvents_ += 1; + cv_.notify_all(); + } + + void waitForNextExternalEvent() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&]() { + return numPendingCallbacks_ == 0 || numExternalEvents_ > 0; + }); + if (anyError_) { + throw std::runtime_error(anyError_.what()); + } + if (numExternalEvents_ == 0) { + throw std::runtime_error( + "All callbacks terminated before an external event occurred"); + } + numExternalEvents_ -= 1; + } + + void join() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&]() { + return numPendingCallbacks_ == 0 || numExternalEvents_ > 0; + }); + if (anyError_) { + throw std::runtime_error(anyError_.what()); + } + if (numExternalEvents_ > 0) { + throw std::runtime_error( + "An external event occurred while waiting for callbacks to terminate"); + } + } + + ~CallbackBarrier() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&]() { return numPendingCallbacks_ == 0; }); + } + + private: + std::mutex mutex_; + std::condition_variable cv_; + tensorpipe::Error anyError_ = tensorpipe::Error::kSuccess; + size_t numPendingCallbacks_ = 0; + size_t numExternalEvents_ = 0; +}; + +#define CUDA_CHECK(op) \ + { \ + cudaError_t res = (op); \ + if (res != cudaSuccess) { \ + throw std::runtime_error("CUDA error"); \ + } \ + } + +#define NCCL_CHECK(op) \ + { \ + ncclResult_t res = (op); \ + if (res != ncclSuccess) { \ + throw std::runtime_error("NCCL error"); \ + } \ + } + +struct NcclCommDeleter { + void operator()(ncclComm_t comm) { + NCCL_CHECK(ncclCommDestroy(comm)); + } +}; + +using NcclComm = + std::unique_ptr, NcclCommDeleter>; + +NcclComm createNcclComm(int rank, int worldSize, ncclUniqueId uniqueId) { + ncclComm_t comm; + NCCL_CHECK(ncclCommInitRank(&comm, worldSize, uniqueId, rank)); + return NcclComm(comm, NcclCommDeleter{}); +} + +std::shared_ptr createTensorPipeContext(std::string name) { + auto ctx = std::make_shared( + tensorpipe::ContextOptions().name(std::move(name))); + ctx->registerTransport(0, "ibv", tensorpipe::transport::ibv::create()); + ctx->registerChannel(0, "cuda_gdr", tensorpipe::channel::cuda_gdr::create()); + return ctx; +} + +// We need this extra named namespace inside our unnamed namespace because of +// https://github.com/pybind/pybind11/issues/3289 +namespace benchmark_herring_gdr { + +struct ServerStats { + struct EpochStats { + struct BucketStats { + struct MachineStats { + int64_t additionTime = 0; + int64_t recvToSendTime = 0; + }; + + std::vector machines; + + explicit BucketStats(size_t numMachines) : machines(numMachines) {} + }; + + std::vector buckets; + + explicit EpochStats(size_t numBuckets, size_t numMachines) + : buckets(numBuckets, BucketStats(numMachines)) {} + }; + + std::vector epochs; + + explicit ServerStats(size_t numEpochs, size_t numBuckets, size_t numMachines) + : epochs(numEpochs, EpochStats(numBuckets, numMachines)) {} +}; + +class Server { + public: + Server( + size_t machineIdx, + size_t deviceIdx, + size_t numMachines, + size_t numDevicesPerMachine, + size_t numBuckets, + size_t bucketSize, + size_t numEpochs, + c10::intrusive_ptr store) + : machineIdx_(machineIdx), + deviceIdx_(deviceIdx), + numMachines_(numMachines), + numBuckets_(numBuckets), + sliceLen_( + (bucketSize / numDevicesPerMachine) * (machineIdx_ + 1) / + numMachines_ - + (bucketSize / numDevicesPerMachine) * machineIdx_ / numMachines_), + numEpochs_(numEpochs), + store_(std::move(store)), + context_(createTensorPipeContext("s" + std::to_string(machineIdx_))), + stats_(numBuckets, numMachines), + recvTimes_( + numBuckets, + std::vector( + numMachines, + std::chrono::steady_clock::time_point())) {} + + ServerStats run() { + allocateTensors(); + startListening(); + waitForIncomingPipes(); + ServerStats stats(numEpochs_, numBuckets_, numMachines_); + for (size_t epochIdx = 0; epochIdx < numEpochs_; epochIdx += 1) { + setTensorsToZero(); + runOneEpoch(); + stats.epochs[epochIdx] = stats_; + } + + // @nocommit + // Ugly hack to prevent the server's TP context from shutting down before + // the clients have received all the data. + std::this_thread::sleep_for(std::chrono::seconds(5)); + + return stats; + } + + private: + const size_t machineIdx_; + const size_t deviceIdx_; + const size_t numMachines_; + const size_t numBuckets_; + const size_t sliceLen_; + const size_t numEpochs_; + const c10::intrusive_ptr store_; + const std::shared_ptr context_; + std::shared_ptr listener_; + std::vector> pipes_; + std::vector buckets_; + std::vector> stagingTensors_; + ServerStats::EpochStats stats_; + std::vector> recvTimes_; + + void allocateTensors() { + buckets_.reserve(numBuckets_); + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + buckets_.push_back(torch::empty( + sliceLen_, + c10::TensorOptions() + .dtype(c10::kFloat) + .device(c10::Device(c10::kCUDA, 0)))); + } + + stagingTensors_.resize(numMachines_); + for (size_t machineIdx = 0; machineIdx < numMachines_; machineIdx += 1) { + stagingTensors_[machineIdx].reserve(numBuckets_); + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + stagingTensors_[machineIdx].push_back(torch::empty( + sliceLen_, + c10::TensorOptions() + .dtype(c10::kFloat) + .device(c10::Device(c10::kCUDA, 0)))); + } + } + } + + void startListening() { + tensorpipe::Error error; + std::string address; + const char* iface = std::getenv("TP_SOCKET_IFNAME"); + std::tie(error, address) = iface != nullptr + ? tensorpipe::transport::ibv::lookupAddrForIface(std::string(iface)) + : tensorpipe::transport::ibv::lookupAddrForHostname(); + if (error) { + throw std::runtime_error(error.what()); + } + listener_ = context_->listen({ + "ibv://" + std::move(address), + }); + + std::string key = "machines/" + std::to_string(machineIdx_) + "/servers/" + + std::to_string(deviceIdx_) + "/address"; + std::string concreteAddress = listener_->url("ibv"); + store_->set( + key, + std::vector(concreteAddress.begin(), concreteAddress.end())); + } + + void waitForIncomingPipes() { + CallbackBarrier barrier; + + pipes_.resize(numMachines_); + for (size_t clientMachineIdx = 0; clientMachineIdx < numMachines_; + clientMachineIdx += 1) { + listener_->accept(barrier.wrapCallback( + [&, this](std::shared_ptr pipe) { + int otherClientMachineIdx = std::strtol( + pipe->getRemoteName().c_str() + 1, nullptr, /*base=*/10); + pipes_[otherClientMachineIdx] = std::move(pipe); + })); + } + + barrier.join(); + } + + void setTensorsToZero() { + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + buckets_[bucketIdx].fill_(0); + } + } + + void runOneEpoch() { + c10::cuda::CUDAStream recvStream = + c10::cuda::getStreamFromPool(/*isHighPriority=*/true, /*device=*/0); + std::vector computeStreams; + computeStreams.reserve(numMachines_); + for (size_t otherMachineIdx = 0; otherMachineIdx < numMachines_; + otherMachineIdx += 1) { + computeStreams.push_back( + c10::cuda::getStreamFromPool(/*isHighPriority=*/true, /*device=*/0)); + } + c10::cuda::CUDAStream sendStream = + c10::cuda::getStreamFromPool(/*isHighPriority=*/true, /*device=*/0); + + std::vector> events; + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + events.emplace_back(numMachines_); + } + + CallbackBarrier barrier; + + std::mutex mutex; + std::vector numClientsDoneForBucket(numBuckets_, 0); + std::vector hasBucketBeenSent(numBuckets_, false); + + for (size_t machineIdx = 0; machineIdx < numMachines_; machineIdx += 1) { + tensorpipe::Pipe& pipe = *pipes_[machineIdx]; + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + torch::Tensor& bucket = buckets_[bucketIdx]; + torch::Tensor& stagingTensor = stagingTensors_[machineIdx][bucketIdx]; + + pipe.readDescriptor(barrier.wrapCallback( + [&, machineIdx, bucketIdx](tensorpipe::Descriptor /* unused */) { + tensorpipe::Allocation allocation; + allocation.tensors.resize(1); + allocation.tensors[0].buffer = tensorpipe::CudaBuffer{ + .ptr = stagingTensor.data_ptr(), + .stream = recvStream.stream(), + }; + pipe.read( + std::move(allocation), + barrier.wrapCallback([&, machineIdx, bucketIdx]() { + recvTimes_[bucketIdx][machineIdx] = + std::chrono::steady_clock::now(); + { + at::cuda::CUDAEvent event; + event.record(recvStream); + event.block(computeStreams[machineIdx]); + } + { + std::chrono::steady_clock::time_point additionStartTime = + std::chrono::steady_clock::now(); + atomicAddInto( + buckets_[bucketIdx].data_ptr(), + stagingTensor.data_ptr(), + bucket.numel(), + computeStreams[machineIdx].stream()); + stats_.buckets[bucketIdx] + .machines[machineIdx] + .additionTime = deltaAsUs( + additionStartTime, std::chrono::steady_clock::now()); + } + events[bucketIdx][machineIdx].record( + computeStreams[machineIdx]); + std::unique_lock lock(mutex); + numClientsDoneForBucket[bucketIdx] += 1; + for (size_t otherBucketIdx = 0; + otherBucketIdx < numBuckets_; + otherBucketIdx += 1) { + if (hasBucketBeenSent[otherBucketIdx]) { + continue; + } + if (numClientsDoneForBucket[otherBucketIdx] < + numMachines_) { + break; + } + for (size_t otherMachineIdx = 0; + otherMachineIdx < numMachines_; + otherMachineIdx += 1) { + events[bucketIdx][otherMachineIdx].block(sendStream); + } + for (size_t otherMachineIdx = 0; + otherMachineIdx < numMachines_; + otherMachineIdx += 1) { + tensorpipe::Pipe& pipe = *pipes_[otherMachineIdx]; + tensorpipe::Message message; + message.tensors.resize(1); + message.tensors[0] = { + .buffer = + tensorpipe::CudaBuffer{ + .ptr = bucket.data_ptr(), + .stream = sendStream.stream(), + }, + .length = bucket.nbytes(), + .targetDevice = tensorpipe::Device( + tensorpipe::kCudaDeviceType, 0), + }; + stats_.buckets[bucketIdx] + .machines[otherMachineIdx] + .recvToSendTime = deltaAsUs( + recvTimes_[bucketIdx][otherMachineIdx], + std::chrono::steady_clock::now()); + pipe.write( + std::move(message), barrier.wrapCallback([]() {})); + } + hasBucketBeenSent[otherBucketIdx] = true; + } + })); + })); + } + } + + barrier.join(); + } +}; + +struct ClientStats { + struct EpochStats { + struct BucketStats { + struct ServerStats { + int64_t transferTime = 0; + }; + + int64_t ncclReduceScatterTime = 0; + int64_t ncclAllGatherTime = 0; + std::vector servers; + + explicit BucketStats(size_t numServers) : servers(numServers) {} + }; + + int64_t endToEndTime = 0; + std::vector buckets; + + explicit EpochStats(size_t numBuckets, size_t numServers) + : buckets(numBuckets, BucketStats(numServers)) {} + }; + + std::vector epochs; + + explicit ClientStats(size_t numEpochs, size_t numBuckets, size_t numServers) + : epochs(numEpochs, EpochStats(numBuckets, numServers)) {} +}; + +class Client { + public: + Client( + size_t machineIdx, + size_t deviceIdx, + size_t numMachines, + size_t numDevicesPerMachine, + size_t numBuckets, + size_t bucketSize, + size_t numEpochs, + c10::intrusive_ptr store) + : machineIdx_(machineIdx), + deviceIdx_(deviceIdx), + numMachines_(numMachines), + numDevicesPerMachine_(numDevicesPerMachine), + numBuckets_(numBuckets), + bucketSize_(bucketSize), + numEpochs_(numEpochs), + store_(std::move(store)), + context_(createTensorPipeContext("c" + std::to_string(machineIdx))), + stats_(numBuckets, numMachines), + ncclAllGatherStartTimes_( + numBuckets, + std::chrono::steady_clock::time_point()) {} + + ClientStats run() { + allocateTensors(); + setUpNccl(); + connectToServers(); + ClientStats stats(numEpochs_, numBuckets_, numMachines_); + for (size_t epochIdx = 0; epochIdx < numEpochs_; epochIdx += 1) { + setTensorsToOne(); + runOneEpoch(); + checkTensors(); + stats.epochs[epochIdx] = stats_; + } + return stats; + } + + private: + const size_t machineIdx_; + const size_t deviceIdx_; + const size_t numMachines_; + const size_t numDevicesPerMachine_; + const size_t numBuckets_; + const size_t bucketSize_; + const size_t numEpochs_; + const c10::intrusive_ptr store_; + const std::shared_ptr context_; + std::vector> pipes_; + std::vector buckets_; + std::vector stagingTensors_; + NcclComm ncclComm_; + ClientStats::EpochStats stats_; + std::vector ncclAllGatherStartTimes_; + + void allocateTensors() { + buckets_.reserve(numBuckets_); + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + buckets_.push_back(torch::empty( + bucketSize_, + c10::TensorOptions() + .dtype(c10::kFloat) + .device(c10::Device(c10::kCUDA, 0)))); + } + + assert(bucketSize_ % numDevicesPerMachine_ == 0); + stagingTensors_.reserve(numBuckets_); + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + stagingTensors_.push_back(torch::empty( + bucketSize_ / numDevicesPerMachine_, + c10::TensorOptions() + .dtype(c10::kFloat) + .device(c10::Device(c10::kCUDA, 0)))); + } + } + + void setUpNccl() { + ncclUniqueId uniqueId; + if (deviceIdx_ == 0) { + NCCL_CHECK(ncclGetUniqueId(&uniqueId)); + store_->set( + "machines/" + std::to_string(machineIdx_) + "/nccl_id", + std::vector( + reinterpret_cast(&uniqueId), + reinterpret_cast(&uniqueId) + sizeof(ncclUniqueId))); + } else { + std::vector uniqueIdData = + store_->get("machines/" + std::to_string(machineIdx_) + "/nccl_id"); + std::memcpy(&uniqueId, uniqueIdData.data(), sizeof(ncclUniqueId)); + } + ncclComm_ = createNcclComm( + /*rank=*/deviceIdx_, + /*worldSize=*/numDevicesPerMachine_, + uniqueId); + } + + void connectToServers() { + pipes_.resize(numMachines_); + for (size_t otherMachineIdx = 0; otherMachineIdx < numMachines_; + otherMachineIdx += 1) { + std::vector addressData = store_->get( + "machines/" + std::to_string(otherMachineIdx) + "/servers/" + + std::to_string(deviceIdx_) + "/address"); + std::string address((char*)addressData.data(), addressData.size()); + pipes_[otherMachineIdx] = context_->connect(std::move(address)); + } + } + + void setTensorsToOne() { + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + buckets_[bucketIdx].fill_(1); + } + } + + void runOneEpoch() { + c10::cuda::CUDAStream stream = + c10::cuda::getStreamFromPool(/*isHighPriority=*/true, /*device=*/0); + + std::chrono::steady_clock::time_point start = + std::chrono::steady_clock::now(); + + CallbackBarrier barrier; + + std::vector reduceScatterEvents(numBuckets_); + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + NCCL_CHECK(ncclReduceScatter( + buckets_[bucketIdx].data_ptr(), + stagingTensors_[bucketIdx].data_ptr(), + bucketSize_ / numDevicesPerMachine_, + ncclFloat, + ncclSum, + ncclComm_.get(), + stream)); + reduceScatterEvents[bucketIdx].record(stream); + } + + std::mutex mutex; + std::vector numServersDoneForBucket(numBuckets_, 0); + std::vector allGatherEvents(numBuckets_); + + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + reduceScatterEvents[bucketIdx].synchronize(); + stats_.buckets[bucketIdx].ncclReduceScatterTime = + deltaAsUs(start, std::chrono::steady_clock::now()); + + torch::Tensor& stagingTensor = stagingTensors_[bucketIdx]; + for (size_t serverMachineIdx = 0; serverMachineIdx < numMachines_; + serverMachineIdx += 1) { + size_t startPos = (bucketSize_ / numDevicesPerMachine_) * + serverMachineIdx / numMachines_; + size_t endPos = (bucketSize_ / numDevicesPerMachine_) * + (serverMachineIdx + 1) / numMachines_; + torch::Tensor slice = + stagingTensor.slice(/*dim=*/0, /*start=*/startPos, /*end=*/endPos); + tensorpipe::Message message; + message.tensors.resize(1); + message.tensors[0] = { + .buffer = + tensorpipe::CudaBuffer{ + .ptr = slice.data_ptr(), + .stream = stream.stream(), + }, + .length = slice.nbytes(), + .targetDevice = tensorpipe::Device(tensorpipe::kCudaDeviceType, 0), + }; + tensorpipe::Pipe& pipe = *pipes_[serverMachineIdx]; + std::chrono::steady_clock::time_point transferStartTime = + std::chrono::steady_clock::now(); + pipe.write( + std::move(message), + barrier.wrapCallback( + [&, bucketIdx, serverMachineIdx, slice, transferStartTime]() { + pipe.readDescriptor(barrier.wrapCallback( + [&, + bucketIdx, + serverMachineIdx, + slice, + transferStartTime](tensorpipe::Descriptor /* unused */) { + tensorpipe::Allocation allocation; + allocation.tensors.resize(1); + allocation.tensors[0].buffer = tensorpipe::CudaBuffer{ + .ptr = slice.data_ptr(), + .stream = stream.stream(), + }; + pipe.read( + std::move(allocation), + barrier.wrapCallback([&, + bucketIdx, + serverMachineIdx, + transferStartTime]() { + stats_.buckets[bucketIdx] + .servers[serverMachineIdx] + .transferTime = deltaAsUs( + transferStartTime, + std::chrono::steady_clock::now()); + std::unique_lock lock(mutex); + numServersDoneForBucket[bucketIdx] += 1; + if (numServersDoneForBucket[bucketIdx] < + numMachines_) { + return; + } + ncclAllGatherStartTimes_[bucketIdx] = + std::chrono::steady_clock::now(); + NCCL_CHECK(ncclAllGather( + stagingTensors_[bucketIdx].data_ptr(), + buckets_[bucketIdx].data_ptr(), + bucketSize_ / numDevicesPerMachine_, + ncclFloat, + ncclComm_.get(), + stream)); + allGatherEvents[bucketIdx].record(stream); + barrier.notifyExternalEventHappened(); + })); + })); + })); + } + } + + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + barrier.waitForNextExternalEvent(); + allGatherEvents[bucketIdx].synchronize(); + stats_.buckets[bucketIdx].ncclAllGatherTime = deltaAsUs( + ncclAllGatherStartTimes_[bucketIdx], + std::chrono::steady_clock::now()); + } + + barrier.join(); + stream.synchronize(); + + stats_.endToEndTime = deltaAsUs(start, std::chrono::steady_clock::now()); + } + + void checkTensors() { + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + if (!buckets_[bucketIdx].allclose(torch::full( + {}, + static_cast(numMachines_ * numDevicesPerMachine_), + c10::TensorOptions() + .dtype(c10::kFloat) + .device(c10::Device(c10::kCUDA, 0))))) { + throw std::runtime_error("Bad result"); + } + } + } +}; + +} // namespace benchmark_herring_gdr +} // namespace + +namespace py = pybind11; + +template +using shared_ptr_class_ = py::class_>; + +PYBIND11_MODULE(benchmark_herring_gdr, module) { + shared_ptr_class_ server(module, "Server"); + shared_ptr_class_ client(module, "Client"); + + py::class_< + benchmark_herring_gdr::ServerStats::EpochStats::BucketStats::MachineStats> + serverStatsEpochBucketMachine(module, "ServerStatsEpochBucketMachine"); + serverStatsEpochBucketMachine.def_readonly( + "addition_time", + &benchmark_herring_gdr::ServerStats::EpochStats::BucketStats:: + MachineStats::additionTime); + serverStatsEpochBucketMachine.def_readonly( + "recv_to_send_time", + &benchmark_herring_gdr::ServerStats::EpochStats::BucketStats:: + MachineStats::recvToSendTime); + + py::class_ + serverStatsEpochBucket(module, "ServerStatsEpochBucket"); + serverStatsEpochBucket.def_readonly( + "machines", + &benchmark_herring_gdr::ServerStats::EpochStats::BucketStats::machines); + + py::class_ serverStatsEpoch( + module, "ServerStatsEpoch"); + serverStatsEpoch.def_readonly( + "buckets", &benchmark_herring_gdr::ServerStats::EpochStats::buckets); + + py::class_ serverStats( + module, "ServerStats"); + serverStats.def_readonly( + "epochs", &benchmark_herring_gdr::ServerStats::epochs); + + server.def( + py::init< + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + const c10::intrusive_ptr&>(), + py::arg("machine_idx"), + py::arg("device_idx"), + py::arg("num_machines"), + py::arg("num_devices_per_machine"), + py::arg("num_buckets"), + py::arg("bucket_size"), + py::arg("num_epochs"), + py::arg("store")); + server.def( + "run", + &benchmark_herring_gdr::Server::run, + py::call_guard()); + + py::class_< + benchmark_herring_gdr::ClientStats::EpochStats::BucketStats::ServerStats> + clientStatsEpochBucketMachine(module, "ClientStatsEpochBucketMachine"); + clientStatsEpochBucketMachine.def_readonly( + "transfer_time", + &benchmark_herring_gdr::ClientStats::EpochStats::BucketStats:: + ServerStats::transferTime); + + py::class_ + clientStatsEpochBucket(module, "ClientStatsEpochBucket"); + clientStatsEpochBucket.def_readonly( + "servers", + &benchmark_herring_gdr::ClientStats::EpochStats::BucketStats::servers); + clientStatsEpochBucket.def_readonly( + "nccl_all_gather_time", + &benchmark_herring_gdr::ClientStats::EpochStats::BucketStats:: + ncclAllGatherTime); + clientStatsEpochBucket.def_readonly( + "nccl_reduce_scatter_time", + &benchmark_herring_gdr::ClientStats::EpochStats::BucketStats:: + ncclReduceScatterTime); + + py::class_ clientStatsEpoch( + module, "ClientStatsEpoch"); + clientStatsEpoch.def_readonly( + "buckets", &benchmark_herring_gdr::ClientStats::EpochStats::buckets); + clientStatsEpoch.def_readonly( + "end_to_end_time", + &benchmark_herring_gdr::ClientStats::EpochStats::endToEndTime); + + py::class_ clientStats( + module, "ClientStats"); + clientStats.def_readonly( + "epochs", &benchmark_herring_gdr::ClientStats::epochs); + + client.def( + py::init< + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + const c10::intrusive_ptr&>(), + py::arg("machine_idx"), + py::arg("device_idx"), + py::arg("num_machines"), + py::arg("num_devices_per_machine"), + py::arg("num_buckets"), + py::arg("bucket_size"), + py::arg("num_epochs"), + py::arg("store")); + client.def( + "run", + &benchmark_herring_gdr::Client::run, + py::call_guard()); +} diff --git a/tensorpipe/benchmark/herring/benchmark_herring_tcp.cc b/tensorpipe/benchmark/herring/benchmark_herring_tcp.cc new file mode 100644 index 000000000..43004b6c5 --- /dev/null +++ b/tensorpipe/benchmark/herring/benchmark_herring_tcp.cc @@ -0,0 +1,851 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +int64_t deltaAsUs( + std::chrono::steady_clock::time_point start, + std::chrono::steady_clock::time_point stop) { + return std::chrono::duration_cast(stop - start) + .count(); +} + +template +T ceilOfRatio(T num, T den) { + return (num - 1) / den + 1; +} + +class CallbackBarrier { + public: + CallbackBarrier() = default; + + template + auto wrapCallback(T fn) { + return wrapTask( + [this, fn{std::move(fn)}]( + const tensorpipe::Error& error, auto&&... args) mutable { + if (error) { + LOG(ERROR) << error.what(); + std::unique_lock lock(mutex_); + if (!anyError_) { + anyError_ = error; + } + } else { + fn(std::forward(args)...); + } + }); + } + + template + auto wrapTask(T fn) { + { + std::unique_lock lock(mutex_); + numPendingCallbacks_ += 1; + } + return [this, fn{std::move(fn)}](auto&&... args) mutable { + fn(std::forward(args)...); + std::unique_lock lock(mutex_); + numPendingCallbacks_ -= 1; + cv_.notify_all(); + }; + } + + void notifyExternalEventHappened() { + std::unique_lock lock(mutex_); + numExternalEvents_ += 1; + cv_.notify_all(); + } + + void waitForNextExternalEvent() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&]() { + return numPendingCallbacks_ == 0 || numExternalEvents_ > 0; + }); + if (anyError_) { + throw std::runtime_error(anyError_.what()); + } + if (numExternalEvents_ == 0) { + throw std::runtime_error( + "All callbacks terminated before an external event occurred"); + } + numExternalEvents_ -= 1; + } + + void join() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&]() { + return numPendingCallbacks_ == 0 || numExternalEvents_ > 0; + }); + if (anyError_) { + throw std::runtime_error(anyError_.what()); + } + if (numExternalEvents_ > 0) { + throw std::runtime_error( + "An external event occurred while waiting for callbacks to terminate"); + } + } + + ~CallbackBarrier() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&]() { return numPendingCallbacks_ == 0; }); + } + + private: + std::mutex mutex_; + std::condition_variable cv_; + tensorpipe::Error anyError_ = tensorpipe::Error::kSuccess; + size_t numPendingCallbacks_ = 0; + size_t numExternalEvents_ = 0; +}; + +#define NCCL_CHECK(op) \ + { \ + ncclResult_t res = (op); \ + if (res != ncclSuccess) { \ + throw std::runtime_error("NCCL error"); \ + } \ + } + +struct NcclCommDeleter { + void operator()(ncclComm_t comm) { + NCCL_CHECK(ncclCommDestroy(comm)); + } +}; + +using NcclComm = + std::unique_ptr, NcclCommDeleter>; + +NcclComm createNcclComm(int rank, int worldSize, ncclUniqueId uniqueId) { + ncclComm_t comm; + NCCL_CHECK(ncclCommInitRank(&comm, worldSize, uniqueId, rank)); + return NcclComm(comm, NcclCommDeleter{}); +} + +std::shared_ptr createTensorPipeContext(std::string name) { + auto ctx = std::make_shared( + tensorpipe::ContextOptions().name(std::move(name))); + ctx->registerTransport(0, "uv", tensorpipe::transport::uv::create()); + ctx->registerChannel(0, "basic", tensorpipe::channel::basic::create()); + return ctx; +} + +constexpr size_t kParamsPerLock = 1024; + +// We need this extra named namespace inside our unnamed namespace because of +// https://github.com/pybind/pybind11/issues/3289 +namespace benchmark_herring_tcp { + +struct ServerStats { + struct EpochStats { + struct BucketStats { + struct MachineStats { + int64_t additionTime = 0; + int64_t recvToSendTime = 0; + }; + + std::vector machines; + + explicit BucketStats(size_t numMachines) : machines(numMachines) {} + }; + + std::vector buckets; + + explicit EpochStats(size_t numBuckets, size_t numMachines) + : buckets(numBuckets, BucketStats(numMachines)) {} + }; + + std::vector epochs; + + explicit ServerStats(size_t numEpochs, size_t numBuckets, size_t numMachines) + : epochs(numEpochs, EpochStats(numBuckets, numMachines)) {} +}; + +class Server { + public: + Server( + size_t machineIdx, + size_t numMachines, + size_t numDevicesPerMachine, + size_t numBuckets, + size_t bucketSize, + size_t numEpochs, + c10::intrusive_ptr store, + size_t numThreads) + : machineIdx_(machineIdx), + numMachines_(numMachines), + numBuckets_(numBuckets), + bucketSize_(bucketSize), + sliceLen_( + (machineIdx_ + 1) * bucketSize_ / numMachines_ - + machineIdx_ * bucketSize_ / numMachines_), + numEpochs_(numEpochs), + store_(std::move(store)), + contexts_([&]() { + std::vector> res(numMachines); + for (size_t machineIdx = 0; machineIdx < numMachines; + machineIdx += 1) { + res[machineIdx] = + createTensorPipeContext(std::to_string(machineIdx)); + } + return res; + }()), + threadPool_(numThreads), + stats_(numBuckets, numMachines), + recvTimes_( + numBuckets, + std::vector( + numMachines, + std::chrono::steady_clock::time_point())) {} + + ServerStats run() { + allocateTensors(); + startListening(); + waitForIncomingPipes(); + ServerStats stats(numEpochs_, numBuckets_, numMachines_); + for (size_t epochIdx = 0; epochIdx < numEpochs_; epochIdx += 1) { + setTensorsToZero(); + runOneEpoch(); + stats.epochs[epochIdx] = stats_; + } + return stats; + } + + private: + const size_t machineIdx_; + const size_t numMachines_; + const size_t numBuckets_; + const size_t bucketSize_; + const size_t sliceLen_; + const size_t numEpochs_; + const c10::intrusive_ptr store_; + const std::vector> contexts_; + std::vector> listeners_; + std::vector> pipes_; + std::vector buckets_; + std::vector> bucketLocks_; + std::vector> stagingTensors_; + c10::ThreadPool threadPool_; + ServerStats::EpochStats stats_; + std::vector> recvTimes_; + + void allocateTensors() { + buckets_.reserve(numBuckets_); + bucketLocks_.reserve(numBuckets_); + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + buckets_.push_back(torch::empty(sliceLen_, c10::kFloat)); + size_t numChunks = ceilOfRatio(sliceLen_, kParamsPerLock); + bucketLocks_.push_back(std::vector(numChunks)); + for (size_t chunkIdx = 0; chunkIdx < numChunks; chunkIdx += 1) { + bucketLocks_[bucketIdx][chunkIdx].clear(); + } + } + + stagingTensors_.resize(numMachines_); + for (size_t machineIdx = 0; machineIdx < numMachines_; machineIdx += 1) { + stagingTensors_[machineIdx].reserve(numBuckets_); + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + stagingTensors_[machineIdx].push_back( + torch::empty(sliceLen_, c10::kFloat)); + } + } + } + + void startListening() { + listeners_.resize(numMachines_); + for (size_t machineIdx = 0; machineIdx < numMachines_; machineIdx += 1) { + tensorpipe::Error error; + std::string address; + const char* iface = std::getenv("TP_SOCKET_IFNAME"); + std::tie(error, address) = iface != nullptr + ? tensorpipe::transport::uv::lookupAddrForIface(std::string(iface)) + : tensorpipe::transport::uv::lookupAddrForHostname(); + if (error) { + throw std::runtime_error(error.what()); + } + listeners_[machineIdx] = contexts_[machineIdx]->listen({ + "uv://" + std::move(address), + }); + + std::string key = "servers/" + std::to_string(machineIdx_) + "/clients/" + + std::to_string(machineIdx) + "/address"; + std::string concreteAddress = listeners_[machineIdx]->url("uv"); + store_->set( + key, + std::vector(concreteAddress.begin(), concreteAddress.end())); + } + } + + void waitForIncomingPipes() { + CallbackBarrier barrier; + + pipes_.resize(numMachines_); + for (size_t clientMachineIdx = 0; clientMachineIdx < numMachines_; + clientMachineIdx += 1) { + listeners_[clientMachineIdx]->accept(barrier.wrapCallback( + [&, this](std::shared_ptr pipe) { + int otherClientMachineIdx = std::strtol( + pipe->getRemoteName().c_str(), nullptr, /*base=*/10); + pipes_[otherClientMachineIdx] = std::move(pipe); + })); + } + + barrier.join(); + } + + void setTensorsToZero() { + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + buckets_[bucketIdx].fill_(0); + } + } + + void addToBucket(size_t bucketIdx, torch::Tensor increment) { + torch::Tensor& bucket = buckets_[bucketIdx]; + std::vector& locks = bucketLocks_[bucketIdx]; + size_t numChunks = ceilOfRatio(sliceLen_, kParamsPerLock); + for (size_t chunkIdx = 0; chunkIdx < numChunks; chunkIdx += 1) { + size_t chunkStart = chunkIdx * kParamsPerLock; + size_t chunkEnd = std::min(sliceLen_, (chunkIdx + 1) * kParamsPerLock); + bool wasLocked; + do { + wasLocked = locks[chunkIdx].test_and_set(std::memory_order_acquire); + } while (wasLocked); + bucket.slice(/*dim=*/0, /*start=*/chunkStart, /*end=*/chunkEnd) += + increment.slice(/*dim=*/0, /*start=*/chunkStart, /*end=*/chunkEnd); + locks[chunkIdx].clear(std::memory_order_release); + } + } + + void runOneEpoch() { + CallbackBarrier barrier; + + std::mutex mutex; + std::vector numClientsDoneForBucket(numBuckets_, 0); + std::vector hasBucketBeenSent(numBuckets_, false); + + for (size_t machineIdx = 0; machineIdx < numMachines_; machineIdx += 1) { + tensorpipe::Pipe& pipe = *pipes_[machineIdx]; + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + torch::Tensor& bucket = buckets_[bucketIdx]; + torch::Tensor& stagingTensor = stagingTensors_[machineIdx][bucketIdx]; + + pipe.readDescriptor(barrier.wrapCallback([&, machineIdx, bucketIdx]( + tensorpipe::Descriptor + /* unused */) { + tensorpipe::Allocation allocation; + allocation.tensors.resize(1); + allocation.tensors[0].buffer = tensorpipe::CpuBuffer{ + .ptr = stagingTensor.data_ptr(), + }; + pipe.read( + std::move(allocation), + barrier.wrapCallback([&, machineIdx, bucketIdx]() { + recvTimes_[bucketIdx][machineIdx] = + std::chrono::steady_clock::now(); + threadPool_.run(barrier.wrapTask([&, machineIdx, bucketIdx]() { + { + std::chrono::steady_clock::time_point additionStartTime = + std::chrono::steady_clock::now(); + addToBucket(bucketIdx, stagingTensor); + stats_.buckets[bucketIdx] + .machines[machineIdx] + .additionTime = deltaAsUs( + additionStartTime, std::chrono::steady_clock::now()); + } + std::unique_lock lock(mutex); + numClientsDoneForBucket[bucketIdx] += 1; + for (size_t otherBucketIdx = 0; otherBucketIdx < numBuckets_; + otherBucketIdx += 1) { + if (hasBucketBeenSent[otherBucketIdx]) { + continue; + } + if (numClientsDoneForBucket[otherBucketIdx] < + numMachines_) { + break; + } + for (size_t otherMachineIdx = 0; + otherMachineIdx < numMachines_; + otherMachineIdx += 1) { + tensorpipe::Pipe& pipe = *pipes_[otherMachineIdx]; + tensorpipe::Message message; + message.tensors.resize(1); + message.tensors[0] = { + .buffer = + tensorpipe::CpuBuffer{ + .ptr = bucket.data_ptr(), + }, + .length = bucket.nbytes(), + .targetDevice = + tensorpipe::Device(tensorpipe::kCpuDeviceType, 0), + }; + stats_.buckets[bucketIdx] + .machines[otherMachineIdx] + .recvToSendTime = deltaAsUs( + recvTimes_[bucketIdx][otherMachineIdx], + std::chrono::steady_clock::now()); + pipe.write( + std::move(message), barrier.wrapCallback([]() {})); + } + hasBucketBeenSent[otherBucketIdx] = true; + } + })); + })); + })); + } + } + + barrier.join(); + } +}; + +struct ClientStats { + struct EpochStats { + struct BucketStats { + struct ServerStats { + int64_t transferTime = 0; + }; + + int64_t ncclReduceTime = 0; + int64_t ncclBroadcastTime = 0; + std::vector servers; + + explicit BucketStats(size_t numMachines) : servers(numMachines) {} + }; + + int64_t endToEndTime = 0; + std::vector buckets; + + explicit EpochStats(size_t numBuckets, size_t numMachines) + : buckets(numBuckets, BucketStats(numMachines)) {} + }; + + std::vector epochs; + + explicit ClientStats(size_t numEpochs, size_t numBuckets, size_t numMachines) + : epochs(numEpochs, EpochStats(numBuckets, numMachines)) {} +}; + +class Client { + public: + Client( + size_t machineIdx, + size_t deviceIdx, + size_t numMachines, + size_t numDevicesPerMachine, + size_t numBuckets, + size_t bucketSize, + size_t numEpochs, + c10::intrusive_ptr store) + : machineIdx_(machineIdx), + deviceIdx_(deviceIdx), + numMachines_(numMachines), + numDevicesPerMachine_(numDevicesPerMachine), + numBuckets_(numBuckets), + bucketSize_(bucketSize), + numEpochs_(numEpochs), + store_(std::move(store)), + contexts_([&]() { + std::vector> res(numMachines); + if (deviceIdx_ == 0) { + for (size_t serverMachineIdx = 0; serverMachineIdx < numMachines; + serverMachineIdx += 1) { + res[serverMachineIdx] = + createTensorPipeContext(std::to_string(machineIdx)); + } + } + return res; + }()), + stats_(numBuckets, numMachines), + ncclBroadcastStartTimes_( + numBuckets, + std::chrono::steady_clock::time_point()) {} + + ClientStats run() { + allocateTensors(); + setUpNccl(); + connectToServers(); + ClientStats stats(numEpochs_, numBuckets_, numMachines_); + for (size_t epochIdx = 0; epochIdx < numEpochs_; epochIdx += 1) { + setTensorsToOne(); + runOneEpoch(); + checkTensors(); + stats.epochs[epochIdx] = stats_; + } + return stats; + } + + private: + const size_t machineIdx_; + const size_t deviceIdx_; + const size_t numMachines_; + const size_t numDevicesPerMachine_; + const size_t numBuckets_; + const size_t bucketSize_; + const size_t numEpochs_; + const c10::intrusive_ptr store_; + const std::vector> contexts_; + std::vector> pipes_; + std::vector buckets_; + tensorpipe::optional> stagingTensors_; + NcclComm ncclComm_; + ClientStats::EpochStats stats_; + std::vector ncclBroadcastStartTimes_; + + void allocateTensors() { + buckets_.reserve(numBuckets_); + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + buckets_.push_back(torch::empty( + bucketSize_, + c10::TensorOptions() + .dtype(c10::kFloat) + .device(c10::Device(c10::kCUDA, 0)))); + } + + if (deviceIdx_ == 0) { + stagingTensors_.emplace(); + stagingTensors_->reserve(numBuckets_); + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + stagingTensors_->push_back(torch::empty( + bucketSize_, + c10::TensorOptions().dtype(c10::kFloat).pinned_memory(true))); + } + } + } + + void setUpNccl() { + ncclUniqueId uniqueId; + if (deviceIdx_ == 0) { + NCCL_CHECK(ncclGetUniqueId(&uniqueId)); + store_->set( + "machines/" + std::to_string(machineIdx_) + "/nccl_id", + std::vector( + reinterpret_cast(&uniqueId), + reinterpret_cast(&uniqueId) + sizeof(ncclUniqueId))); + } else { + std::vector uniqueIdData = + store_->get("machines/" + std::to_string(machineIdx_) + "/nccl_id"); + std::memcpy(&uniqueId, uniqueIdData.data(), sizeof(ncclUniqueId)); + } + ncclComm_ = createNcclComm( + /*rank=*/deviceIdx_, + /*worldSize=*/numDevicesPerMachine_, + uniqueId); + } + + void connectToServers() { + if (deviceIdx_ == 0) { + pipes_.resize(numMachines_); + for (size_t serverMachineIdx = 0; serverMachineIdx < numMachines_; + serverMachineIdx += 1) { + std::vector addressData = store_->get( + "servers/" + std::to_string(serverMachineIdx) + "/clients/" + + std::to_string(machineIdx_) + "/address"); + std::string address((char*)addressData.data(), addressData.size()); + pipes_[serverMachineIdx] = + contexts_[serverMachineIdx]->connect(std::move(address)); + } + } + } + + void setTensorsToOne() { + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + buckets_[bucketIdx].fill_(1); + } + } + + void runOneEpoch() { + c10::cuda::CUDAStream stream = + c10::cuda::getStreamFromPool(/*isHighPriority=*/true, /*device=*/0); + + std::chrono::steady_clock::time_point start = + std::chrono::steady_clock::now(); + + if (deviceIdx_ == 0) { + CallbackBarrier barrier; + + std::vector reduceEvents(numBuckets_); + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + NCCL_CHECK(ncclReduce( + buckets_[bucketIdx].data_ptr(), + stagingTensors_.value()[bucketIdx].data_ptr(), + bucketSize_, + ncclFloat, + ncclSum, + 0, + ncclComm_.get(), + stream)); + reduceEvents[bucketIdx].record(stream); + } + + std::mutex mutex; + std::condition_variable cv; + std::vector numMachinesDoneForBucket(numBuckets_, 0); + std::vector broadcastEvents(numBuckets_); + + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + reduceEvents[bucketIdx].synchronize(); + stats_.buckets[bucketIdx].ncclReduceTime = + deltaAsUs(start, std::chrono::steady_clock::now()); + + torch::Tensor& stagingTensor = stagingTensors_.value()[bucketIdx]; + for (size_t serverMachineIdx = 0; serverMachineIdx < numMachines_; + serverMachineIdx += 1) { + size_t startPos = bucketSize_ * serverMachineIdx / numMachines_; + size_t endPos = bucketSize_ * (serverMachineIdx + 1) / numMachines_; + torch::Tensor slice = stagingTensor.slice( + /*dim=*/0, /*start=*/startPos, /*end=*/endPos); + tensorpipe::Message message; + message.tensors.resize(1); + message.tensors[0] = { + .buffer = + tensorpipe::CpuBuffer{ + .ptr = slice.data_ptr(), + }, + .length = slice.nbytes(), + .targetDevice = tensorpipe::Device(tensorpipe::kCpuDeviceType, 0), + }; + tensorpipe::Pipe& pipe = *pipes_[serverMachineIdx]; + std::chrono::steady_clock::time_point transferStartTime = + std::chrono::steady_clock::now(); + pipe.write( + std::move(message), + barrier.wrapCallback([&, + bucketIdx, + serverMachineIdx, + slice, + transferStartTime]() { + pipe.readDescriptor(barrier.wrapCallback( + [&, bucketIdx, serverMachineIdx, slice, transferStartTime]( + tensorpipe::Descriptor /* unused */) { + tensorpipe::Allocation allocation; + allocation.tensors.resize(1); + allocation.tensors[0].buffer = tensorpipe::CpuBuffer{ + .ptr = slice.data_ptr(), + }; + pipe.read( + std::move(allocation), + barrier.wrapCallback([&, + bucketIdx, + serverMachineIdx, + transferStartTime]() { + stats_.buckets[bucketIdx] + .servers[serverMachineIdx] + .transferTime = deltaAsUs( + transferStartTime, + std::chrono::steady_clock::now()); + std::unique_lock lock(mutex); + numMachinesDoneForBucket[bucketIdx] += 1; + if (numMachinesDoneForBucket[bucketIdx] < + numMachines_) { + return; + } + ncclBroadcastStartTimes_[bucketIdx] = + std::chrono::steady_clock::now(); + NCCL_CHECK(ncclBroadcast( + stagingTensors_.value()[bucketIdx].data_ptr(), + buckets_[bucketIdx].data_ptr(), + bucketSize_, + ncclFloat, + 0, + ncclComm_.get(), + stream)); + broadcastEvents[bucketIdx].record(stream); + barrier.notifyExternalEventHappened(); + })); + })); + })); + } + } + + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + barrier.waitForNextExternalEvent(); + broadcastEvents[bucketIdx].synchronize(); + stats_.buckets[bucketIdx].ncclBroadcastTime = deltaAsUs( + ncclBroadcastStartTimes_[bucketIdx], + std::chrono::steady_clock::now()); + } + + barrier.join(); + stream.synchronize(); + } else { + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + NCCL_CHECK(ncclReduce( + buckets_[bucketIdx].data_ptr(), + nullptr, + bucketSize_, + ncclFloat, + ncclSum, + 0, + ncclComm_.get(), + stream)); + } + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + NCCL_CHECK(ncclBroadcast( + nullptr, + buckets_[bucketIdx].data_ptr(), + bucketSize_, + ncclFloat, + 0, + ncclComm_.get(), + stream)); + } + + stream.synchronize(); + } + + stats_.endToEndTime = deltaAsUs(start, std::chrono::steady_clock::now()); + } + + void checkTensors() { + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + if (!buckets_[bucketIdx].allclose(torch::full( + {}, + static_cast(numMachines_ * numDevicesPerMachine_), + c10::TensorOptions() + .dtype(c10::kFloat) + .device(c10::Device(c10::kCUDA, 0))))) { + throw std::runtime_error("Bad result"); + } + } + } +}; + +} // namespace benchmark_herring_tcp +} // namespace + +namespace py = pybind11; + +template +using shared_ptr_class_ = py::class_>; + +PYBIND11_MODULE(benchmark_herring_tcp, module) { + shared_ptr_class_ server(module, "Server"); + shared_ptr_class_ client(module, "Client"); + + py::class_< + benchmark_herring_tcp::ServerStats::EpochStats::BucketStats::MachineStats> + serverStatsEpochBucketMachine(module, "ServerStatsEpochBucketMachine"); + serverStatsEpochBucketMachine.def_readonly( + "addition_time", + &benchmark_herring_tcp::ServerStats::EpochStats::BucketStats:: + MachineStats::additionTime); + serverStatsEpochBucketMachine.def_readonly( + "recv_to_send_time", + &benchmark_herring_tcp::ServerStats::EpochStats::BucketStats:: + MachineStats::recvToSendTime); + + py::class_ + serverStatsEpochBucket(module, "ServerStatsEpochBucket"); + serverStatsEpochBucket.def_readonly( + "machines", + &benchmark_herring_tcp::ServerStats::EpochStats::BucketStats::machines); + + py::class_ serverStatsEpoch( + module, "ServerStatsEpoch"); + serverStatsEpoch.def_readonly( + "buckets", &benchmark_herring_tcp::ServerStats::EpochStats::buckets); + + py::class_ serverStats( + module, "ServerStats"); + serverStats.def_readonly( + "epochs", &benchmark_herring_tcp::ServerStats::epochs); + + server.def( + py::init< + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + const c10::intrusive_ptr&, + size_t>(), + py::arg("machine_idx"), + py::arg("num_machines"), + py::arg("num_devices_per_machine"), + py::arg("num_buckets"), + py::arg("bucket_size"), + py::arg("num_epochs"), + py::arg("store"), + py::arg("num_threads")); + server.def( + "run", + &benchmark_herring_tcp::Server::run, + py::call_guard()); + + py::class_< + benchmark_herring_tcp::ClientStats::EpochStats::BucketStats::ServerStats> + clientStatsEpochBucketMachine(module, "ClientStatsEpochBucketMachine"); + clientStatsEpochBucketMachine.def_readonly( + "transfer_time", + &benchmark_herring_tcp::ClientStats::EpochStats::BucketStats:: + ServerStats::transferTime); + + py::class_ + clientStatsEpochBucket(module, "ClientStatsEpochBucket"); + clientStatsEpochBucket.def_readonly( + "servers", + &benchmark_herring_tcp::ClientStats::EpochStats::BucketStats::servers); + clientStatsEpochBucket.def_readonly( + "nccl_broadcast_time", + &benchmark_herring_tcp::ClientStats::EpochStats::BucketStats:: + ncclBroadcastTime); + clientStatsEpochBucket.def_readonly( + "nccl_reduce_time", + &benchmark_herring_tcp::ClientStats::EpochStats::BucketStats:: + ncclReduceTime); + + py::class_ clientStatsEpoch( + module, "ClientStatsEpoch"); + clientStatsEpoch.def_readonly( + "buckets", &benchmark_herring_tcp::ClientStats::EpochStats::buckets); + clientStatsEpoch.def_readonly( + "end_to_end_time", + &benchmark_herring_tcp::ClientStats::EpochStats::endToEndTime); + + py::class_ clientStats( + module, "ClientStats"); + clientStats.def_readonly( + "epochs", &benchmark_herring_tcp::ClientStats::epochs); + + client.def( + py::init< + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + const c10::intrusive_ptr&>(), + py::arg("machine_idx"), + py::arg("device_idx"), + py::arg("num_machines"), + py::arg("num_devices_per_machine"), + py::arg("num_buckets"), + py::arg("bucket_size"), + py::arg("num_epochs"), + py::arg("store")); + client.def( + "run", + &benchmark_herring_tcp::Client::run, + py::call_guard()); +} diff --git a/tensorpipe/benchmark/herring/benchmark_nccl.cc b/tensorpipe/benchmark/herring/benchmark_nccl.cc new file mode 100644 index 000000000..c2960acd0 --- /dev/null +++ b/tensorpipe/benchmark/herring/benchmark_nccl.cc @@ -0,0 +1,208 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +int64_t deltaAsUs( + std::chrono::steady_clock::time_point start, + std::chrono::steady_clock::time_point stop) { + return std::chrono::duration_cast(stop - start) + .count(); +} + +#define NCCL_CHECK(op) \ + { \ + ncclResult_t res = (op); \ + if (res != ncclSuccess) { \ + throw std::runtime_error("NCCL error"); \ + } \ + } + +struct NcclCommDeleter { + void operator()(ncclComm_t comm) { + NCCL_CHECK(ncclCommDestroy(comm)); + } +}; + +using NcclComm = + std::unique_ptr, NcclCommDeleter>; + +NcclComm createNcclComm(int rank, int worldSize, ncclUniqueId uniqueId) { + ncclComm_t comm; + NCCL_CHECK(ncclCommInitRank(&comm, worldSize, uniqueId, rank)); + return NcclComm(comm, NcclCommDeleter{}); +} + +// We need this extra named namespace inside our unnamed namespace because of +// https://github.com/pybind/pybind11/issues/3289 +namespace benchmark_nccl { + +class Client { + public: + Client( + size_t machineIdx, + size_t deviceIdx, + size_t numMachines, + size_t numDevicesPerMachine, + size_t numBuckets, + size_t bucketSize, + size_t numEpochs, + c10::intrusive_ptr store) + : machineIdx_(machineIdx), + deviceIdx_(deviceIdx), + numMachines_(numMachines), + numDevicesPerMachine_(numDevicesPerMachine), + numBuckets_(numBuckets), + bucketSize_(bucketSize), + numEpochs_(numEpochs), + store_(std::move(store)) {} + + std::vector run() { + allocateTensors(); + setUpNccl(); + std::vector stats; + for (size_t epochIdx = 0; epochIdx < numEpochs_; epochIdx += 1) { + setTensorsToOne(); + { + auto start = std::chrono::steady_clock::now(); + runOneEpoch(); + auto stop = std::chrono::steady_clock::now(); + stats.push_back(deltaAsUs(start, stop)); + } + checkTensors(); + } + return stats; + } + + private: + const size_t machineIdx_; + const size_t deviceIdx_; + const size_t numMachines_; + const size_t numDevicesPerMachine_; + const size_t numBuckets_; + const size_t bucketSize_; + const size_t numEpochs_; + const c10::intrusive_ptr store_; + std::vector buckets_; + NcclComm ncclComm_; + + void allocateTensors() { + buckets_.reserve(numBuckets_); + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + buckets_.push_back(torch::empty( + bucketSize_, + c10::TensorOptions() + .dtype(c10::kFloat) + .device(c10::Device(c10::kCUDA, 0)))); + } + } + + void setUpNccl() { + ncclUniqueId uniqueId; + if (machineIdx_ == 0 && deviceIdx_ == 0) { + NCCL_CHECK(ncclGetUniqueId(&uniqueId)); + store_->set( + "nccl_id", + std::vector( + reinterpret_cast(&uniqueId), + reinterpret_cast(&uniqueId) + sizeof(ncclUniqueId))); + } else { + std::vector uniqueIdData = store_->get("nccl_id"); + std::memcpy(&uniqueId, uniqueIdData.data(), sizeof(ncclUniqueId)); + } + ncclComm_ = createNcclComm( + /*rank=*/machineIdx_ * numDevicesPerMachine_ + deviceIdx_, + /*worldSize=*/numMachines_ * numDevicesPerMachine_, + uniqueId); + } + + void setTensorsToOne() { + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + buckets_[bucketIdx].fill_(1); + } + } + + void runOneEpoch() { + c10::cuda::CUDAStream stream = + c10::cuda::getStreamFromPool(/*isHighPriority=*/true, /*device=*/0); + + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + NCCL_CHECK(ncclAllReduce( + buckets_[bucketIdx].data_ptr(), + buckets_[bucketIdx].data_ptr(), + bucketSize_, + ncclFloat, + ncclSum, + ncclComm_.get(), + stream)); + } + + stream.synchronize(); + } + + void checkTensors() { + for (size_t bucketIdx = 0; bucketIdx < numBuckets_; bucketIdx += 1) { + if (!buckets_[bucketIdx].allclose(torch::full( + {}, + static_cast(numMachines_ * numDevicesPerMachine_), + c10::TensorOptions() + .dtype(c10::kFloat) + .device(c10::Device(c10::kCUDA, 0))))) { + throw std::runtime_error("Bad result"); + } + } + } +}; + +} // namespace benchmark_nccl +} // namespace + +namespace py = pybind11; + +template +using shared_ptr_class_ = py::class_>; + +PYBIND11_MODULE(benchmark_nccl, module) { + shared_ptr_class_ client(module, "Client"); + + client.def( + py::init< + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + size_t, + const c10::intrusive_ptr&>(), + py::arg("machine_idx"), + py::arg("device_idx"), + py::arg("num_machines"), + py::arg("num_devices_per_machine"), + py::arg("num_buckets"), + py::arg("bucket_size"), + py::arg("num_epochs"), + py::arg("store")); + client.def( + "run", + &benchmark_nccl::Client::run, + py::call_guard()); +} diff --git a/tensorpipe/benchmark/herring/cuda_kernels.cu b/tensorpipe/benchmark/herring/cuda_kernels.cu new file mode 100644 index 000000000..92fb359a9 --- /dev/null +++ b/tensorpipe/benchmark/herring/cuda_kernels.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +// Copied from PyTorch's aten/src/ATen/native/cuda/Loops.cuh + +constexpr size_t warp_size = 32; +constexpr size_t num_threads = warp_size * 2; +constexpr size_t thread_work_size = 4; +constexpr size_t block_work_size = thread_work_size * num_threads; + +#define CUDA_CHECK(op) \ + { \ + cudaError_t res = (op); \ + if (res != cudaSuccess) { \ + throw std::runtime_error("CUDA error"); \ + } \ + } + +namespace { + +template +T ceilOfRatio(T num, T den) { + return (num - 1) / den + 1; +} + +__global__ void atomicAddIntoKernel(float* dst, float* src, size_t len) { + for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < len; + idx += (gridDim.x * blockDim.x)) { + atomicAdd(dst + idx, *(src + idx)); + } +} + +} // namespace + +void atomicAddInto(float* dst, float* src, size_t len, cudaStream_t stream) { + int64_t grid = ceilOfRatio(len, block_work_size); + atomicAddIntoKernel<<>>(dst, src, len); + CUDA_CHECK(cudaGetLastError()); +} diff --git a/tensorpipe/benchmark/herring/cuda_kernels.cuh b/tensorpipe/benchmark/herring/cuda_kernels.cuh new file mode 100644 index 000000000..7c080ffdd --- /dev/null +++ b/tensorpipe/benchmark/herring/cuda_kernels.cuh @@ -0,0 +1,13 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +void atomicAddInto(float* dst, float* src, size_t len, cudaStream_t stream); diff --git a/tensorpipe/benchmark/herring/launch_herring_gdr.py b/tensorpipe/benchmark/herring/launch_herring_gdr.py new file mode 100644 index 000000000..75b6c0f75 --- /dev/null +++ b/tensorpipe/benchmark/herring/launch_herring_gdr.py @@ -0,0 +1,314 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import dataclasses +import multiprocessing +import os +import sys +from typing import List + +import torch +import torch.distributed +from utils import recv_from_connections_and_join_processes + +# Must come after torch or else it will fail because it won't find libc10.so +import benchmark_herring_gdr # isort: skip + + +@dataclasses.dataclass +class ServerStats: + addition_time: List[List[List[int]]] # epoch x bucket x machine + recv_to_send_time: List[List[List[int]]] # epoch x bucket x machine + + +@dataclasses.dataclass +class ClientStats: + transfer_time: List[List[List[int]]] # epoch x bucket x server + nccl_reduce_scatter_time: List[List[int]] # epoch x bucket + nccl_all_gather_time: List[List[int]] # epoch x bucket + end_to_end_time: List[int] # epoch + + +@dataclasses.dataclass +class OneMachineHerringStats: + addition_time: torch.Tensor # server x epoch x bucket x machine + recv_to_send_time: torch.Tensor # server x epoch x bucket x machine + transfer_time: torch.Tensor # client x epoch x bucket x server + nccl_reduce_scatter_time: torch.Tensor # client x epoch x bucket + nccl_all_gather_time: torch.Tensor # client x epoch x bucket + end_to_end_time: torch.Tensor # client x epoch + + +def run_herring_server( + init_method: str, + machine_idx: int, + device_idx: int, + num_machines: int, + num_devices_per_machine: int, + num_buckets: int, + bucket_size: int, + num_epochs: int, + conn: multiprocessing.connection.Connection, +) -> None: + torch._C._set_print_stack_traces_on_fatal_signal(True) + + rdv_iterator = torch.distributed.rendezvous( + init_method, + (num_machines + machine_idx) * num_devices_per_machine + device_idx, + 2 * num_machines * num_devices_per_machine, + ) + store, _, _ = next(rdv_iterator) + + assert 0 <= machine_idx < num_machines + assert 0 <= device_idx < num_devices_per_machine + + os.environ["CUDA_VISIBLE_DEVICES"] = f"{device_idx}" + + server = benchmark_herring_gdr.Server( + machine_idx=machine_idx, + device_idx=device_idx, + num_machines=num_machines, + num_devices_per_machine=num_devices_per_machine, + num_buckets=num_buckets, + bucket_size=bucket_size, + num_epochs=num_epochs, + store=store, + ) + stats = server.run() + conn.send( + ServerStats( + addition_time=[ + [[ms.addition_time for ms in bs.machines] for bs in es.buckets] + for es in stats.epochs + ], + recv_to_send_time=[ + [[ms.recv_to_send_time for ms in bs.machines] for bs in es.buckets] + for es in stats.epochs + ], + ) + ) + + +def run_herring_client( + init_method: str, + machine_idx: int, + device_idx: int, + num_machines: int, + num_devices_per_machine: int, + num_buckets: int, + bucket_size: int, + num_epochs: int, + conn: multiprocessing.connection.Connection, +) -> None: + torch._C._set_print_stack_traces_on_fatal_signal(True) + + rdv_iterator = torch.distributed.rendezvous( + init_method, + machine_idx * num_devices_per_machine + device_idx, + 2 * num_machines * num_devices_per_machine, + ) + store, _, _ = next(rdv_iterator) + + assert 0 <= machine_idx < num_machines + assert 0 <= device_idx < num_devices_per_machine + + os.environ["CUDA_VISIBLE_DEVICES"] = f"{device_idx}" + + client = benchmark_herring_gdr.Client( + machine_idx=machine_idx, + device_idx=device_idx, + num_machines=num_machines, + num_devices_per_machine=num_devices_per_machine, + num_buckets=num_buckets, + bucket_size=bucket_size, + num_epochs=num_epochs, + store=store, + ) + stats = client.run() + conn.send( + ClientStats( + transfer_time=[ + [[ss.transfer_time for ss in bs.servers] for bs in es.buckets] + for es in stats.epochs + ], + nccl_reduce_scatter_time=[ + [bs.nccl_reduce_scatter_time for bs in es.buckets] + for es in stats.epochs + ], + nccl_all_gather_time=[ + [bs.nccl_all_gather_time for bs in es.buckets] for es in stats.epochs + ], + end_to_end_time=[es.end_to_end_time for es in stats.epochs], + ) + ) + + +def run_one_machine_herring( + init_method: str, + machine_idx: int, + num_machines: int, + num_devices_per_machine: int, + num_buckets: int, + bucket_size: int, + num_epochs: int, +) -> OneMachineHerringStats: + server_receiving_conns = [] + server_sending_conns = [] + for _ in range(num_devices_per_machine): + recv_end, send_end = multiprocessing.Pipe() + server_receiving_conns.append(recv_end) + server_sending_conns.append(send_end) + servers = [ + multiprocessing.Process( + target=run_herring_server, + name=f"server_{machine_idx}_{device_idx}", + args=( + init_method, + machine_idx, + device_idx, + num_machines, + num_devices_per_machine, + num_buckets, + bucket_size, + num_epochs, + server_sending_conns[device_idx], + ), + ) + for device_idx in range(num_devices_per_machine) + ] + + client_receiving_conns = [] + client_sending_conns = [] + for _ in range(num_devices_per_machine): + recv_end, send_end = multiprocessing.Pipe() + client_receiving_conns.append(recv_end) + client_sending_conns.append(send_end) + clients = [ + multiprocessing.Process( + target=run_herring_client, + name=f"client_{machine_idx}_{device_idx}", + args=( + init_method, + machine_idx, + device_idx, + num_machines, + num_devices_per_machine, + num_buckets, + bucket_size, + num_epochs, + client_sending_conns[device_idx], + ), + ) + for device_idx in range(num_devices_per_machine) + ] + for t in servers + clients: + t.start() + for c in server_sending_conns + client_sending_conns: + c.close() + + stats = recv_from_connections_and_join_processes( + list(zip(servers, server_receiving_conns)) + + list(zip(clients, client_receiving_conns)) + ) + server_stats = stats[:num_devices_per_machine] + client_stats = stats[num_devices_per_machine:] + + return OneMachineHerringStats( + addition_time=torch.tensor( + [s.addition_time for s in server_stats], + dtype=torch.long, + ), + recv_to_send_time=torch.tensor( + [s.recv_to_send_time for s in server_stats], + dtype=torch.long, + ), + transfer_time=torch.tensor( + [s.transfer_time for s in client_stats], + dtype=torch.long, + ), + nccl_reduce_scatter_time=torch.tensor( + [s.nccl_reduce_scatter_time for s in client_stats], + dtype=torch.long, + ), + nccl_all_gather_time=torch.tensor( + [s.nccl_all_gather_time for s in client_stats], + dtype=torch.long, + ), + end_to_end_time=torch.tensor( + [s.end_to_end_time for s in client_stats], + dtype=torch.long, + ), + ) + + +def main(): + parser = argparse.ArgumentParser(description="NCCL allreduce benchmark") + parser.add_argument( + "--init-method", + type=str, + default="env://", + help="How to do rendezvous between machines (uses PyTorch, hence see its doc)", + ) + parser.add_argument( + "--machine-idx", + type=int, + required=True, + help="The rank of the machine on which this script was invoked (0-based)", + ) + parser.add_argument( + "--num-machines", + type=int, + required=True, + help="On how many machines this script is being invoked (each with its own rank)", + ) + parser.add_argument( + "--num-devices-per-machine", + type=int, + required=True, + help="How many clients this script should launch (each will use one GPU)", + ) + parser.add_argument( + "--num-buckets", + type=int, + required=True, + help="How many buffers to do an allreduce over in each epoch", + ) + parser.add_argument( + "--bucket-size", + type=int, + required=True, + help="How big each buffer should be (expressed in number of float32 elements)", + ) + parser.add_argument( + "--num-epochs", + type=int, + required=True, + help="How many times to run the benchmark", + ) + parser.add_argument( + "--output", + type=argparse.FileType("wb"), + default=sys.stdout.buffer, + ) + + args = parser.parse_args() + + res = run_one_machine_herring( + init_method=args.init_method, + machine_idx=args.machine_idx, + num_machines=args.num_machines, + num_devices_per_machine=args.num_devices_per_machine, + num_buckets=args.num_buckets, + bucket_size=args.bucket_size, + num_epochs=args.num_epochs, + ) + + torch.save(res, args.output) + + +if __name__ == "__main__": + main() diff --git a/tensorpipe/benchmark/herring/launch_herring_tcp.py b/tensorpipe/benchmark/herring/launch_herring_tcp.py new file mode 100644 index 000000000..1fde5a315 --- /dev/null +++ b/tensorpipe/benchmark/herring/launch_herring_tcp.py @@ -0,0 +1,317 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import dataclasses +import multiprocessing +import os +import sys +from typing import List + +import torch +import torch.distributed +from utils import recv_from_connections_and_join_processes + +# Must come after torch or else it will fail because it won't find libc10.so +import benchmark_herring_tcp # isort: skip + + +@dataclasses.dataclass +class ServerStats: + addition_time: List[List[List[int]]] # epoch x bucket x machine + recv_to_send_time: List[List[List[int]]] # epoch x bucket x machine + + +@dataclasses.dataclass +class ClientStats: + transfer_time: List[List[List[int]]] # epoch x bucket x server + nccl_reduce_time: List[List[int]] # epoch x bucket + nccl_broadcast_time: List[List[int]] # epoch x bucket + end_to_end_time: List[int] # epoch + + +@dataclasses.dataclass +class OneMachineHerringStats: + addition_time: torch.Tensor # epoch x bucket x machine + recv_to_send_time: torch.Tensor # epoch x bucket x machine + transfer_time: torch.Tensor # client x epoch x bucket x server + nccl_reduce_time: torch.Tensor # client x epoch x bucket + nccl_broadcast_time: torch.Tensor # client x epoch x bucket + end_to_end_time: torch.Tensor # client x epoch + + +def run_herring_server( + init_method: str, + machine_idx: int, + num_machines: int, + num_devices_per_machine: int, + num_buckets: int, + bucket_size: int, + num_epochs: int, + num_compute_threads: int, + conn: multiprocessing.connection.Connection, +) -> None: + torch._C._set_print_stack_traces_on_fatal_signal(True) + + rdv_iterator = torch.distributed.rendezvous( + init_method, + num_machines * num_devices_per_machine + machine_idx, + num_machines * num_devices_per_machine + num_machines, + ) + store, _, _ = next(rdv_iterator) + + assert 0 <= machine_idx < num_machines + + server = benchmark_herring_tcp.Server( + machine_idx=machine_idx, + num_machines=num_machines, + num_devices_per_machine=num_devices_per_machine, + num_buckets=num_buckets, + bucket_size=bucket_size, + num_epochs=num_epochs, + store=store, + num_threads=num_compute_threads, + ) + stats = server.run() + conn.send( + ServerStats( + addition_time=[ + [[ms.addition_time for ms in bs.machines] for bs in es.buckets] + for es in stats.epochs + ], + recv_to_send_time=[ + [[ms.recv_to_send_time for ms in bs.machines] for bs in es.buckets] + for es in stats.epochs + ], + ) + ) + + +def run_herring_client( + init_method: str, + machine_idx: int, + device_idx: int, + num_machines: int, + num_devices_per_machine: int, + num_buckets: int, + bucket_size: int, + num_epochs: int, + conn: multiprocessing.connection.Connection, +) -> None: + torch._C._set_print_stack_traces_on_fatal_signal(True) + + rdv_iterator = torch.distributed.rendezvous( + init_method, + machine_idx * num_devices_per_machine + device_idx, + num_machines * num_devices_per_machine + num_machines, + ) + store, _, _ = next(rdv_iterator) + + assert 0 <= machine_idx < num_machines + assert 0 <= device_idx < num_devices_per_machine + + os.environ["CUDA_VISIBLE_DEVICES"] = f"{device_idx}" + + client = benchmark_herring_tcp.Client( + machine_idx=machine_idx, + device_idx=device_idx, + num_machines=num_machines, + num_devices_per_machine=num_devices_per_machine, + num_buckets=num_buckets, + bucket_size=bucket_size, + num_epochs=num_epochs, + store=store, + ) + stats = client.run() + conn.send( + ClientStats( + transfer_time=[ + [[ss.transfer_time for ss in bs.servers] for bs in es.buckets] + for es in stats.epochs + ], + nccl_reduce_time=[ + [bs.nccl_reduce_time for bs in es.buckets] for es in stats.epochs + ], + nccl_broadcast_time=[ + [bs.nccl_broadcast_time for bs in es.buckets] for es in stats.epochs + ], + end_to_end_time=[es.end_to_end_time for es in stats.epochs], + ) + ) + + +def run_one_machine_herring( + init_method: str, + machine_idx: int, + num_machines: int, + num_devices_per_machine: int, + num_buckets: int, + bucket_size: int, + num_epochs: int, + num_compute_threads: int, + num_network_threads: int, +) -> OneMachineHerringStats: + server_receiving_conn, server_sending_conn = multiprocessing.Pipe() + server = multiprocessing.Process( + target=run_herring_server, + name=f"server_{machine_idx}", + args=( + init_method, + machine_idx, + num_machines, + num_devices_per_machine, + num_buckets, + bucket_size, + num_epochs, + num_compute_threads, + server_sending_conn, + ), + ) + + client_receiving_conns = [] + client_sending_conns = [] + for _ in range(num_devices_per_machine): + recv_end, send_end = multiprocessing.Pipe() + client_receiving_conns.append(recv_end) + client_sending_conns.append(send_end) + clients = [ + multiprocessing.Process( + target=run_herring_client, + name=f"client_{machine_idx}_{device_idx}", + args=( + init_method, + machine_idx, + device_idx, + num_machines, + num_devices_per_machine, + num_buckets, + bucket_size, + num_epochs, + client_sending_conns[device_idx], + ), + ) + for device_idx in range(num_devices_per_machine) + ] + for t in [server] + clients: + t.start() + for c in [server_sending_conn] + client_sending_conns: + c.close() + + stats = recv_from_connections_and_join_processes( + [(server, server_receiving_conn)] + list(zip(clients, client_receiving_conns)) + ) + server_stats = stats[0] + client_stats = stats[1:] + + return OneMachineHerringStats( + addition_time=torch.tensor( + server_stats.addition_time, + dtype=torch.long, + ), + recv_to_send_time=torch.tensor( + server_stats.recv_to_send_time, + dtype=torch.long, + ), + transfer_time=torch.tensor( + [s.transfer_time for s in client_stats], + dtype=torch.long, + ), + nccl_reduce_time=torch.tensor( + [s.nccl_reduce_time for s in client_stats], + dtype=torch.long, + ), + nccl_broadcast_time=torch.tensor( + [s.nccl_broadcast_time for s in client_stats], + dtype=torch.long, + ), + end_to_end_time=torch.tensor( + [s.end_to_end_time for s in client_stats], + dtype=torch.long, + ), + ) + + +def main(): + parser = argparse.ArgumentParser(description="NCCL allreduce benchmark") + parser.add_argument( + "--init-method", + type=str, + default="env://", + help="How to do rendezvous between machines (uses PyTorch, hence see its doc)", + ) + parser.add_argument( + "--machine-idx", + type=int, + required=True, + help="The rank of the machine on which this script was invoked (0-based)", + ) + parser.add_argument( + "--num-machines", + type=int, + required=True, + help="On how many machines this script is being invoked (each with its own rank)", + ) + parser.add_argument( + "--num-devices-per-machine", + type=int, + required=True, + help="How many clients this script should launch (each will use one GPU)", + ) + parser.add_argument( + "--num-buckets", + type=int, + required=True, + help="How many buffers to do an allreduce over in each epoch", + ) + parser.add_argument( + "--bucket-size", + type=int, + required=True, + help="How big each buffer should be (expressed in number of float32 elements)", + ) + parser.add_argument( + "--num-epochs", + type=int, + required=True, + help="How many times to run the benchmark", + ) + parser.add_argument( + "--num-compute-threads", + type=int, + required=True, + help="How many threads to use to calculate reductions on the servers", + ) + parser.add_argument( + "--num-network-threads", + type=int, + required=True, + help="How many TCP event loop threads to use (to multiplex and saturate bandwidth)", + ) + parser.add_argument( + "--output", + type=argparse.FileType("wb"), + default=sys.stdout.buffer, + ) + + args = parser.parse_args() + + res = run_one_machine_herring( + init_method=args.init_method, + machine_idx=args.machine_idx, + num_machines=args.num_machines, + num_devices_per_machine=args.num_devices_per_machine, + num_buckets=args.num_buckets, + bucket_size=args.bucket_size, + num_epochs=args.num_epochs, + num_compute_threads=args.num_compute_threads, + num_network_threads=args.num_network_threads, + ) + + torch.save(res, args.output) + + +if __name__ == "__main__": + main() diff --git a/tensorpipe/benchmark/herring/launch_nccl.py b/tensorpipe/benchmark/herring/launch_nccl.py new file mode 100644 index 000000000..4d1a14f0f --- /dev/null +++ b/tensorpipe/benchmark/herring/launch_nccl.py @@ -0,0 +1,192 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import multiprocessing +import os +import sys +from typing import List, Optional + +import torch +import torch.distributed +from utils import recv_from_connections_and_join_processes + +# Must come after torch or else it will fail because it won't find libc10.so +import benchmark_nccl # isort: skip + + +def run_nccl_client( + init_method: str, + machine_idx: int, + device_idx: int, + num_machines: int, + num_devices_per_machine: int, + num_buckets: int, + bucket_size: int, + num_epochs: int, + num_network_threads: Optional[int], + num_sockets_per_network_thread: Optional[int], + conn: multiprocessing.connection.Connection, +) -> List[int]: + torch._C._set_print_stack_traces_on_fatal_signal(True) + + rdv_iterator = torch.distributed.rendezvous( + init_method, + machine_idx * num_devices_per_machine + device_idx, + num_machines * num_devices_per_machine, + ) + store, _, _ = next(rdv_iterator) + + assert 0 <= machine_idx < num_machines + assert 0 <= device_idx < num_devices_per_machine + + os.environ["CUDA_VISIBLE_DEVICES"] = f"{device_idx}" + if num_network_threads is not None: + os.environ["NCCL_SOCKET_NTHREADS"] = f"{num_network_threads}" + if num_sockets_per_network_thread is not None: + os.environ["NCCL_NSOCKS_PERTHREAD"] = f"{num_sockets_per_network_thread}" + + client = benchmark_nccl.Client( + machine_idx=machine_idx, + device_idx=device_idx, + num_machines=num_machines, + num_devices_per_machine=num_devices_per_machine, + num_buckets=num_buckets, + bucket_size=bucket_size, + num_epochs=num_epochs, + store=store, + ) + conn.send(client.run()) + + +def run_one_machine_nccl( + init_method: str, + machine_idx: int, + num_machines: int, + num_devices_per_machine: int, + num_buckets: int, + bucket_size: int, + num_epochs: int, + num_network_threads: Optional[int], + num_sockets_per_network_thread: Optional[int], +) -> torch.Tensor: + receiving_conns = [] + sending_conns = [] + for _ in range(num_devices_per_machine): + recv_end, send_end = multiprocessing.Pipe() + receiving_conns.append(recv_end) + sending_conns.append(send_end) + clients = [ + multiprocessing.Process( + target=run_nccl_client, + name=f"client_{machine_idx}_{device_idx}", + args=( + init_method, + machine_idx, + device_idx, + num_machines, + num_devices_per_machine, + num_buckets, + bucket_size, + num_epochs, + num_network_threads, + num_sockets_per_network_thread, + sending_conns[device_idx], + ), + ) + for device_idx in range(num_devices_per_machine) + ] + for t in clients: + t.start() + for c in sending_conns: + c.close() + + stats = recv_from_connections_and_join_processes( + list(zip(clients, receiving_conns)) + ) + + return torch.tensor(stats, dtype=torch.long) + + +def main(): + parser = argparse.ArgumentParser(description="NCCL allreduce benchmark") + parser.add_argument( + "--init-method", + type=str, + default="env://", + help="How to do rendezvous between machines (uses PyTorch, hence see its doc)", + ) + parser.add_argument( + "--machine-idx", + type=int, + required=True, + help="The rank of the machine on which this script was invoked (0-based)", + ) + parser.add_argument( + "--num-machines", + type=int, + required=True, + help="On how many machines this script is being invoked (each with its own rank)", + ) + parser.add_argument( + "--num-devices-per-machine", + type=int, + required=True, + help="How many clients this script should launch (each will use one GPU)", + ) + parser.add_argument( + "--num-buckets", + type=int, + required=True, + help="How many buffers to do an allreduce over in each epoch", + ) + parser.add_argument( + "--bucket-size", + type=int, + required=True, + help="How big each buffer should be (expressed in number of float32 elements)", + ) + parser.add_argument( + "--num-epochs", + type=int, + required=True, + help="How many times to run the benchmark", + ) + parser.add_argument( + "--num-network-threads", + type=int, + help="The value of the NCCL_SOCKET_NTHREADS env var (see NCCL's doc)", + ) + parser.add_argument( + "--num-sockets-per-network-thread", + type=int, + help="The value of the NCCL_NSOCKS_PERTHREAD env var (see NCCL's doc)", + ) + parser.add_argument( + "--output", + type=argparse.FileType("wb"), + default=sys.stdout.buffer, + ) + + args = parser.parse_args() + + res = run_one_machine_nccl( + init_method=args.init_method, + machine_idx=args.machine_idx, + num_machines=args.num_machines, + num_devices_per_machine=args.num_devices_per_machine, + num_buckets=args.num_buckets, + bucket_size=args.bucket_size, + num_epochs=args.num_epochs, + num_network_threads=args.num_network_threads, + num_sockets_per_network_thread=args.num_sockets_per_network_thread, + ) + + torch.save(res, args.output) + + +if __name__ == "__main__": + main() diff --git a/tensorpipe/benchmark/herring/setup.py b/tensorpipe/benchmark/herring/setup.py new file mode 100644 index 000000000..deab4475e --- /dev/null +++ b/tensorpipe/benchmark/herring/setup.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from setuptools import setup +from torch.utils import cpp_extension + +setup( + name="herring", + ext_modules=[ + cpp_extension.CUDAExtension( + "benchmark_herring_gdr", + [ + "benchmark_herring_gdr.cc", + "cuda_kernels.cu", + ], + libraries=["nccl", "tensorpipe", "tensorpipe_cuda"], + ), + cpp_extension.CUDAExtension( + "benchmark_herring_tcp", + [ + "benchmark_herring_tcp.cc", + ], + libraries=["nccl", "tensorpipe"], + ), + cpp_extension.CUDAExtension( + "benchmark_nccl", + [ + "benchmark_nccl.cc", + ], + libraries=["nccl"], + ), + ], + py_modules=[ + "launch_herring_gdr", + "launch_herring_tcp", + "launch_nccl", + "utils", + ], + cmdclass={"build_ext": cpp_extension.BuildExtension}, + entry_points={ + "console_scripts": [ + "launch_herring_gdr=launch_herring_gdr:main", + "launch_herring_tcp=launch_herring_tcp:main", + "launch_nccl=launch_nccl:main", + ], + }, + setup_requires=["setuptools", "torch"], + install_requires=["torch"], +) diff --git a/tensorpipe/benchmark/herring/utils.py b/tensorpipe/benchmark/herring/utils.py new file mode 100644 index 000000000..cf6e69bfc --- /dev/null +++ b/tensorpipe/benchmark/herring/utils.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import multiprocessing +from typing import Any, List, Tuple + + +def recv_from_connections_and_join_processes( + processes_and_connections: List[ + Tuple[multiprocessing.Process, multiprocessing.connection.Connection] + ], +) -> List[Any]: + """ + Wait for processes to return a value via a connection and then to terminate + + Given a list of processes and, for each of them, (the reading end of) a + connection on which the process will send its result, gather the results of + all processes and then join them, with extra care taken to handle any error + (e.g., process crashing without returning) and kill all processes in case. + """ + results = [None] * len(processes_and_connections) + + try: + connections = [c for _, c in processes_and_connections] + sentinels = [p.sentinel for p, _ in processes_and_connections] + not_ready = connections + sentinels + while len(not_ready) > 0: + ready = multiprocessing.connection.wait(not_ready) + for obj in ready: + if obj in connections: + idx = connections.index(obj) + try: + val = obj.recv() + except EOFError: + # We won't get any more values out of this connection. + not_ready.remove(obj) + else: + if results[idx] is not None: + raise RuntimeError( + f"Process {idx} returned more than one value" + ) + # Wrap in a tuple so we can distinguish a process that + # returned None from one that didn't return yet. + results[idx] = (val,) + elif obj in sentinels: + idx = sentinels.index(obj) + proc, _ = processes_and_connections[idx] + proc.join() + if proc.exitcode != 0: + raise RuntimeError( + f"Process {idx} exited with status {proc.exitcode}" + ) + not_ready.remove(obj) + else: + raise RuntimeError(f"Unexpected object: {obj}") + except Exception: + for p, _ in processes_and_connections: + p.kill() + for p, _ in processes_and_connections: + p.join() + raise + + for idx, result in enumerate(results): + if result is None: + raise RuntimeError(f"Process {idx} exited without producing a result") + + # Unwrap from the tuples. + return [r for r, in results]