diff --git a/configs/debug/diloco.toml b/configs/debug/diloco.toml index c98e4603..654993b2 100644 --- a/configs/debug/diloco.toml +++ b/configs/debug/diloco.toml @@ -9,7 +9,7 @@ micro_bs = 8 [optim] batch_size = 16 warmup_steps = 10 -total_steps = 4 +total_steps = 128 [data] fake = true diff --git a/configs/debug/normal.toml b/configs/debug/normal.toml index cd64084c..3a63372a 100644 --- a/configs/debug/normal.toml +++ b/configs/debug/normal.toml @@ -9,7 +9,7 @@ micro_bs = 8 [optim] batch_size = 16 warmup_steps = 10 -total_steps = 4 +total_steps = 128 [data] fake = true diff --git a/pyproject.toml b/pyproject.toml index b8004493..83c0af8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "psutil", "torch-shampoo @ git+https://github.com/facebookresearch/optimizers.git@main", "liger-kernel-nightly>=0.5.2.dev20250122195349", + "pccl @ git+https://github.com/PrimeIntellect-ai/pccl.git@main#subdirectory=python/framework" ] [project.optional-dependencies] diff --git a/scripts/all_reduce.py b/scripts/all_reduce.py deleted file mode 100644 index 2d99b418..00000000 --- a/scripts/all_reduce.py +++ /dev/null @@ -1,69 +0,0 @@ -from pydantic_config import BaseConfig, parse_argv -import torch -from torch.distributed import destroy_process_group, init_process_group, ReduceOp -import torch.utils.benchmark as benchmark - -from zeroband.collectives import Compression, all_reduce -from zeroband.utils.world_info import get_world_info -from zeroband.utils.logger import get_logger - -from enum import Enum - - -class TorchDtype(str, Enum): - FLOAT32 = "float32" - FLOAT16 = "float16" - BFLOAT16 = "bfloat16" - UINT8 = "uint8" - - -TORCH_DTYPE_MAP = { - None: None, - TorchDtype.FLOAT32: torch.float32, - TorchDtype.FLOAT16: torch.float16, - TorchDtype.BFLOAT16: torch.bfloat16, - TorchDtype.UINT8: torch.uint8, -} - - -class Config(BaseConfig): - size_model: int = int(1e7) - n_iters: int = 4 - compression: Compression = Compression.NO - - -def main(config: Config): - world_info = get_world_info() - - mat = torch.rand(1, config.size_model) - - logger.info( - f"\n ======== Benchmark all reduce between {world_info.world_size} gpus over {world_info.nnodes} nodes =========\n" - ) - - t0 = benchmark.Timer( - stmt="compressed_all_reduce(compression, mat, op=op)", - globals={ - "compressed_all_reduce": all_reduce, - "mat": mat, - "compression": config.compression, - "op": ReduceOp.SUM, - }, - ) - - measured_time = t0.timeit(config.n_iters).mean - - bandwidth = config.size_model * 4 / 1e6 / measured_time - - logger.info(f"Average time per iteration: {measured_time:.2f} seconds, Average bandwidth: {bandwidth:.4f} MB/s") - - -if __name__ == "__main__": - config = Config(**parse_argv()) - - torch.set_float32_matmul_precision("high") - init_process_group(backend="gloo") - - logger = get_logger() - main(config) - destroy_process_group() diff --git a/scripts/simulate_multi_node_diloco.sh b/scripts/simulate_multi_node_diloco.sh index 38212900..af34cf1f 100755 --- a/scripts/simulate_multi_node_diloco.sh +++ b/scripts/simulate_multi_node_diloco.sh @@ -71,7 +71,7 @@ export GLOO_SOCKET_IFNAME=lo for i in $(seq 0 $(($N - 1 ))) do > logs/log$i.log - WANDB_MODE=$([ $i -eq 0 ] && echo "online" || echo "offline") GLOBAL_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((BASE_PORT + $i)) --nnodes=1 $@ --data.data_rank $i --data.data_world_size $N > logs/log$i.log 2>&1 & + WANDB_MODE=$([ $i -eq 0 ] && echo "online" || echo "offline") GLOBAL_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) OMP_NUM_THREADS=1 PCCL_LOG_LEVEL=DEBUG uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((BASE_PORT + $i)) --nnodes=1 $@ --data.data_rank $i --data.data_world_size $N > logs/log$i.log 2>&1 & child_pids+=($!) done diff --git a/scripts/skip_data.py b/scripts/skip_data.py index 2f2bc48a..28bd9be9 100644 --- a/scripts/skip_data.py +++ b/scripts/skip_data.py @@ -24,7 +24,7 @@ from zeroband.data import get_dataloader -from zeroband.utils.world_info import get_world_info +from zeroband.utils.world_info import get_local_world_info from zeroband.utils.logger import get_logger @@ -79,7 +79,7 @@ def skip_data(config: Config): if __name__ == "__main__": torch.manual_seed(42) - world_info = get_world_info() + world_info = get_local_world_info() logger = get_logger() config = Config(**parse_argv()) diff --git a/src/zeroband/C/__init__.py b/src/zeroband/C/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/zeroband/C/collectives.py b/src/zeroband/C/collectives.py deleted file mode 100644 index 8372d121..00000000 --- a/src/zeroband/C/collectives.py +++ /dev/null @@ -1,35 +0,0 @@ -import os -from typing import Optional -import torch -import torch.distributed as dist -from torch.utils import cpp_extension -from pathlib import Path -from torch.testing._internal.distributed.fake_pg import FakeProcessGroup - - -parent = Path(__file__).parent -INCLUDES = [str(parent / "csrc"), str(parent.parent.parent.parent / "third_party/gloo")] -COLLECTIVES_CSRC_PATH = parent / "csrc" / "collectives.cpp" - -collectives_ops = cpp_extension.load( - name="collectives", - sources=[COLLECTIVES_CSRC_PATH], - extra_cflags=["-O3", "-DUSE_C10D_GLOO"], - verbose=False if os.environ.get("ZERO_BAND_LOG_LEVEL") == "DEBUG" else True, - extra_include_paths=INCLUDES, -) - - -def ring_allreduce( - tensor: torch.Tensor, - op: dist.ReduceOp = dist.ReduceOp.SUM, - group: Optional[dist.ProcessGroup] = None, -) -> None: - if group is None: - group = dist.distributed_c10d._get_default_group() - if isinstance(group, dist.distributed_c10d.ProcessGroupGloo): - collectives_ops.ring_allreduce_gloo(tensor, op, group) - elif isinstance(group, FakeProcessGroup): - return - else: - collectives_ops.ring_allreduce(tensor, op, group) diff --git a/src/zeroband/C/compression.py b/src/zeroband/C/compression.py deleted file mode 100644 index f2e3cc21..00000000 --- a/src/zeroband/C/compression.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Tuple -import torch -from torch.utils.cpp_extension import load -from pathlib import Path - -COMPRESS_CSRC_PATH = Path(__file__).parent / "csrc" / "compression.cpp" - -compress_ops = load(name="compression", sources=[COMPRESS_CSRC_PATH], extra_cflags=["-O3"], verbose=False) - - -def uniform_8bit_quantize(tensor: torch.Tensor, inplace: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize a tensor to 8-bit integers - Args: - tensor (torch.Tensor): The tensor to quantize - inplace (bool): Whether the operation is allowed to modify the input tensor - Returns: - Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the lookup table - """ - return compress_ops.uniform_8bit_quantize(tensor, inplace) - - -def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int) -> torch.Tensor: - """Return the average value in each bin - Args: - tensor (torch.Tensor): The tensor to average - quant_weight (torch.Tensor): The tensor of indices - n_bins (int): The number of bins - Returns: - torch.Tensor: The average value in each bin - """ - return compress_ops.average_buckets(tensor, quant_weight, n_bins) - - -def quantize_per_tensor_uint8(tensor: torch.Tensor, scale: float, zero_point: int) -> torch.Tensor: - """Quantize a tensor to 8-bit integers - - quantized_value = clamp((round(input / scale) + zero_point), 0, 255) - - Args: - tensor (torch.Tensor): The tensor to quantize - scale (float): The scale of the quantization - zero_point (int): The zero point of the quantization - Returns: - torch.Tensor: The quantized tensor - """ - return compress_ops.quantize_per_tensor_uint8(tensor, scale, zero_point) diff --git a/src/zeroband/C/csrc/collectives.cpp b/src/zeroband/C/csrc/collectives.cpp deleted file mode 100644 index ab7777fc..00000000 --- a/src/zeroband/C/csrc/collectives.cpp +++ /dev/null @@ -1,249 +0,0 @@ -#include -#include -#include -#include - -constexpr int BUFFER_COUNT = 2; - -template -void fast_index_add_omp(T* output, const T* lookup_table, const uint8_t* indices, int64_t n) { - #pragma omp parallel for - for (int64_t i = 0; i < n; ++i) { - output[i] += lookup_table[indices[i]]; - } -} - -template -void fast_index_set_omp(T* output, const T* lookup_table, const uint8_t* indices, int64_t n) { - #pragma omp parallel for - for (int64_t i = 0; i < n; ++i) { - output[i] = lookup_table[indices[i]]; - } -} - -inline size_t get_num_threads() { - return std::max(1u, std::thread::hardware_concurrency()); -} - -template -void fast_index_add_worker(T* output, const T* lookup_table, const uint8_t* indices, int64_t start, int64_t end) { - for (int64_t i = start; i < end; ++i) { - output[i] += lookup_table[indices[i]]; - } -} - -template -void fast_index_add(T* output, const T* lookup_table, const uint8_t* indices, int64_t n) { - size_t num_threads = get_num_threads(); - std::vector threads; - int64_t chunk_size = n / num_threads; - - for (size_t i = 0; i < num_threads; ++i) { - int64_t start = i * chunk_size; - int64_t end = (i == num_threads - 1) ? n : (i + 1) * chunk_size; - threads.emplace_back(fast_index_add_worker, output, lookup_table, indices, start, end); - } - - for (auto& thread : threads) { - thread.join(); - } -} - -template -void fast_index_set_worker(T* output, const T* lookup_table, const uint8_t* indices, int64_t start, int64_t end) { - for (int64_t i = start; i < end; ++i) { - output[i] = lookup_table[indices[i]]; - } -} - -template -void fast_index_set(T* output, const T* lookup_table, const uint8_t* indices, int64_t n) { - size_t num_threads = get_num_threads(); - std::vector threads; - int64_t chunk_size = n / num_threads; - - for (size_t i = 0; i < num_threads; ++i) { - int64_t start = i * chunk_size; - int64_t end = (i == num_threads - 1) ? n : (i + 1) * chunk_size; - threads.emplace_back(fast_index_set_worker, output, lookup_table, indices, start, end); - } - - for (auto& thread : threads) { - thread.join(); - } -} - -template -void ring_allreduce( - torch::Tensor& tensor, - c10d::ReduceOp op, - T* group -) { - TORCH_CHECK(group != nullptr, "Group must be provided"); - TORCH_CHECK(op == c10d::ReduceOp::SUM || op == c10d::ReduceOp::AVG, "Unsupported reduce operation. Only SUM and AVG are supported."); - - int world_size = group->getSize(); - int rank = group->getRank(); - - // Divide the tensor into chunks - auto flat_tensor = tensor.view({tensor.numel()}); - std::vector chunks = flat_tensor.chunk(world_size * BUFFER_COUNT); - - // Temporary buffers for transferring data - int num_buffers = BUFFER_COUNT * world_size; - std::vector recv_buffer; - std::vector send_buffer; - std::vector send_lookup_buffer; - std::vector recv_lookup_buffer; - std::vector> send_lookup_work(BUFFER_COUNT); - std::vector> recv_lookup_work(BUFFER_COUNT); - std::vector> send_work(BUFFER_COUNT); - std::vector> recv_work(BUFFER_COUNT); - - for (int i = 0; i < BUFFER_COUNT; ++i) { - recv_buffer.push_back(torch::empty_like(chunks[0], torch::kUInt8)); - send_buffer.push_back(torch::Tensor()); - send_lookup_buffer.push_back(torch::Tensor()); - recv_lookup_buffer.push_back(torch::empty({256}, chunks[0].options())); - } - - // Send and receive ranks - int send_rank = (rank + 1) % world_size; - int recv_rank = (rank - 1 + world_size) % world_size; - - // Reduce-scatter loop - for (int step = 1; step <= world_size * BUFFER_COUNT; ++step) { - int send_chunk = (rank * BUFFER_COUNT - step + num_buffers) % num_buffers; - - if (send_work[step % BUFFER_COUNT]) { - send_work[step % BUFFER_COUNT]->wait(); - recv_work[step % BUFFER_COUNT]->wait(); - send_lookup_work[step % BUFFER_COUNT]->wait(); - recv_lookup_work[step % BUFFER_COUNT]->wait(); - - auto& chunk = chunks[send_chunk]; - auto& lookup = recv_lookup_buffer[step % BUFFER_COUNT]; - auto& indices = recv_buffer[step % BUFFER_COUNT]; - - fast_index_add_omp( - static_cast(chunk.data_ptr()), - static_cast(lookup.data_ptr()), - static_cast(indices.data_ptr()), - chunk.numel() - ); - } - - if (step <= (world_size - 1) * BUFFER_COUNT) { - // Quantize and send - std::tie(send_buffer[step % BUFFER_COUNT], send_lookup_buffer[step % BUFFER_COUNT]) = uniform_8bit_quantize(chunks[send_chunk], false); - - std::vector send_tensors = {send_lookup_buffer[step % BUFFER_COUNT]}; - send_lookup_work[step % BUFFER_COUNT] = group->send(send_tensors, send_rank, step + 1000); - - std::vector recv_tensors = {recv_lookup_buffer[step % BUFFER_COUNT]}; - recv_lookup_work[step % BUFFER_COUNT] = group->recv(recv_tensors, recv_rank, step + 1000); - - send_tensors = {send_buffer[step % BUFFER_COUNT]}; - send_work[step % BUFFER_COUNT] = group->send(send_tensors, send_rank, step); - - recv_tensors = {recv_buffer[step % BUFFER_COUNT]}; - recv_work[step % BUFFER_COUNT] = group->recv(recv_tensors, recv_rank, step); - } - } - - // TODO: Interleave these with the previous loop? - if (op == c10d::ReduceOp::AVG) { - for (int i = 0; i < BUFFER_COUNT; ++i) { - chunks[i + rank * BUFFER_COUNT].div_(world_size); - } - } - - for (int i = 0; i < BUFFER_COUNT; ++i) { - std::tie(send_buffer[0], send_lookup_buffer[0]) = uniform_8bit_quantize(chunks[i + rank * BUFFER_COUNT], true); - auto& chunk = chunks[i + rank * BUFFER_COUNT]; - auto& lookup = send_lookup_buffer[0]; - auto& indices = send_buffer[0]; - - fast_index_set_omp( - static_cast(chunk.data_ptr()), - static_cast(lookup.data_ptr()), - static_cast(indices.data_ptr()), - chunk.numel() - ); - } - - // Reset buffers for the second phase - recv_buffer.clear(); - send_buffer.clear(); - send_lookup_buffer.clear(); - recv_lookup_buffer.clear(); - for (int i = 0; i < BUFFER_COUNT; ++i) { - recv_buffer.push_back(torch::empty_like(chunks[0], torch::kUInt8)); - send_buffer.push_back(torch::Tensor()); - send_lookup_buffer.push_back(torch::Tensor()); - recv_lookup_buffer.push_back(torch::empty({256}, chunks[0].options())); - } - std::fill(send_work.begin(), send_work.end(), nullptr); - std::fill(recv_work.begin(), recv_work.end(), nullptr); - std::fill(send_lookup_work.begin(), send_lookup_work.end(), nullptr); - std::fill(recv_lookup_work.begin(), recv_lookup_work.end(), nullptr); - - for (int step = 1; step <= world_size * BUFFER_COUNT; ++step) { - int send_chunk = (rank * BUFFER_COUNT + BUFFER_COUNT - step + num_buffers) % num_buffers; - - if (send_work[step % BUFFER_COUNT]) { - send_work[step % BUFFER_COUNT]->wait(); - recv_work[step % BUFFER_COUNT]->wait(); - send_lookup_work[step % BUFFER_COUNT]->wait(); - recv_lookup_work[step % BUFFER_COUNT]->wait(); - - auto& chunk = chunks[send_chunk]; - auto& lookup = recv_lookup_buffer[step % BUFFER_COUNT]; - auto& indices = recv_buffer[step % BUFFER_COUNT]; - - fast_index_set_omp( - static_cast(chunk.data_ptr()), - static_cast(lookup.data_ptr()), - static_cast(indices.data_ptr()), - chunk.numel() - ); - } - - if (step <= (world_size - 1) * BUFFER_COUNT) { - // Quantize and send - // todo(jackmin): this quantization is redundant, we should be able to reuse the quantized values we just received - std::tie(send_buffer[step % BUFFER_COUNT], send_lookup_buffer[step % BUFFER_COUNT]) = uniform_8bit_quantize(chunks[send_chunk], false); - - std::vector send_tensors = {send_lookup_buffer[step % BUFFER_COUNT]}; - send_lookup_work[step % BUFFER_COUNT] = group->send(send_tensors, send_rank, step + 1000); - - std::vector recv_tensors = {recv_lookup_buffer[step % BUFFER_COUNT]}; - recv_lookup_work[step % BUFFER_COUNT] = group->recv(recv_tensors, recv_rank, step + 1000); - - send_tensors = {send_buffer[step % BUFFER_COUNT]}; - send_work[step % BUFFER_COUNT] = group->send(send_tensors, send_rank, step); - - recv_tensors = {recv_buffer[step % BUFFER_COUNT]}; - recv_work[step % BUFFER_COUNT] = group->recv(recv_tensors, recv_rank, step); - } - } -} - -PYBIND11_MODULE(collectives, m) { - m.def( - "ring_allreduce", - &ring_allreduce, - "Ring allreduce implementation", - py::arg("tensor"), - py::arg("op"), - py::arg("pg") - ); - m.def( - "ring_allreduce_gloo", - &ring_allreduce, - "Ring allreduce implementation", - py::arg("tensor"), - py::arg("op"), - py::arg("pg") - ); -} \ No newline at end of file diff --git a/src/zeroband/C/csrc/compression.cpp b/src/zeroband/C/csrc/compression.cpp deleted file mode 100644 index 8bd7dcbd..00000000 --- a/src/zeroband/C/csrc/compression.cpp +++ /dev/null @@ -1,155 +0,0 @@ -#include - -namespace py = pybind11; - -constexpr int n_bins = 256; // 8-bit quantization -constexpr double RANGE_IN_SIGMAS = 6.0; -const int max_num_threads = std::thread::hardware_concurrency(); - -torch::Tensor quantize_per_tensor_multithreaded(const torch::Tensor& tensor, float scale, int32_t zero_point, int num_threads) { - torch::TensorOptions options = tensor.options().dtype(torch::kByte); - torch::Tensor quantized_tensor = torch::empty_like(tensor, options); - - float* tensor_data = tensor.data_ptr(); - uint8_t* quant_data = quantized_tensor.data_ptr(); - int64_t numel = tensor.numel(); - float inv_scale = 1.0f / scale; - - std::vector threads; - int64_t chunk_size = numel / num_threads; - - auto quantize_chunk = [&](int64_t start, int64_t end) { - for (int64_t i = start; i < end; ++i) { - int32_t quant_val = static_cast(std::round(tensor_data[i] * inv_scale)) + zero_point; - quant_data[i] = static_cast(std::clamp(quant_val, 0, 255)); - } - }; - - for (int i = 0; i < num_threads - 1; ++i) { - int64_t start = i * chunk_size; - int64_t end = (i + 1) * chunk_size; - threads.emplace_back(quantize_chunk, start, end); - } - - // Handle the last chunk (which may be slightly larger due to rounding) - threads.emplace_back(quantize_chunk, (num_threads - 1) * chunk_size, numel); - - // Wait for all threads to complete - for (auto& thread : threads) { - thread.join(); - } - - return quantized_tensor; -} - -torch::Tensor average_buckets_multithread(const torch::Tensor& tensor, const torch::Tensor& quant_weight, int64_t n_bins, int num_threads) { - torch::NoGradGuard no_grad; - auto flat_tensor = tensor.flatten().contiguous(); - auto flat_quant_weight = quant_weight.flatten().contiguous(); - auto options = flat_tensor.options(); - auto bin_sums = torch::zeros({n_bins}, options); - auto bin_counts = torch::zeros({n_bins}, options.dtype(torch::kLong)); - - // Get raw pointers - float* tensor_data = flat_tensor.data_ptr(); - uint8_t* quant_data = flat_quant_weight.data_ptr(); - float* sums_data = bin_sums.data_ptr(); - int64_t* counts_data = bin_counts.data_ptr(); - int64_t numel = flat_tensor.numel(); - - // Create a vector to hold our threads - std::vector threads; - - // Lambda function for the work each thread will do - auto worker = [&](int64_t start, int64_t end) { - std::vector local_sums(n_bins, 0.0f); - std::vector local_counts(n_bins, 0); - - for (int64_t i = start; i < end; ++i) { - uint8_t bin = quant_data[i]; - if (bin < n_bins) { // No need to check for >= 0 as uint8_t is always non-negative - local_sums[bin] += tensor_data[i]; - local_counts[bin]++; - } - } - - // Use a mutex to safely update the shared data - static std::mutex mutex; - std::lock_guard lock(mutex); - for (int64_t i = 0; i < n_bins; ++i) { - sums_data[i] += local_sums[i]; - counts_data[i] += local_counts[i]; - } - }; - - // Divide the work among threads - int64_t chunk_size = numel / num_threads; - for (unsigned int i = 0; i < num_threads; ++i) { - int64_t start = i * chunk_size; - int64_t end = (i == num_threads - 1) ? numel : (i + 1) * chunk_size; - threads.emplace_back(worker, start, end); - } - - // Wait for all threads to complete - for (auto& thread : threads) { - thread.join(); - } - - // Compute averages - for (int64_t i = 0; i < n_bins; ++i) { - sums_data[i] = counts_data[i] > 0 ? sums_data[i] / counts_data[i] : 0.0f; - } - - return bin_sums; -} - -std::tuple uniform_8bit_quantize(torch::Tensor tensor, bool inplace) { - int offset = n_bins / 2; - - // Centered tensor handling (currently commented out, so no centering) - torch::Tensor centered_tensor = tensor; - - // Calculate unbiased standard deviation - double std_unbiased = centered_tensor.norm().item() / std::sqrt(centered_tensor.numel() - 1); - - // Calculate scale for quantization - double scale = RANGE_IN_SIGMAS * std_unbiased / n_bins; - - // Perform quantization - torch::Tensor quantized_tensor = quantize_per_tensor_multithreaded(centered_tensor, scale, offset, max_num_threads); - - // Call average_buckets to create the lookup table - torch::Tensor lookup = average_buckets_multithread(tensor, quantized_tensor, n_bins, max_num_threads); - - return std::make_tuple(quantized_tensor, lookup); -} - - -// PyBind11 module -PYBIND11_MODULE(compression, m) { - m.def( - "average_buckets", - &average_buckets_multithread, - "Average buckets for quantized values", - py::arg("tensor"), - py::arg("quant_weight"), - py::arg("n_bins"), - py::arg("num_threads") = max_num_threads - ) - .def( - "uniform_8bit_quantize", - &uniform_8bit_quantize, - "Uniform 8-bit quantization function", - py::arg("tensor"), - py::arg("inplace") = true - ) - .def( - "quantize_per_tensor_uint8", - &quantize_per_tensor_multithreaded, - "Faster torch::quantize_per_tensor", - py::arg("tensor"), - py::arg("scale"), - py::arg("zero_point"), - py::arg("num_threads") = max_num_threads - ); -} diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index bdeb4d48..c7b28e0c 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -29,33 +29,31 @@ import warnings import logging from torch.distributed._tensor.api import DTensor -from zeroband.utils.state_dict_send_recv import ( - _get_sendable_state_dict, - recv_state_dict, - send_state_dict, - send_tensor_and_state_dict, -) from distributed_shampoo import DistributedShampoo from zeroband.utils.logger import get_logger from zeroband.config import CkptConfig -from zeroband.utils.world_info import get_world_info +from zeroband.utils.world_info import get_local_world_info ## code inspired by torchtitan https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py +# Note: this is grandfathered in code from before PCCL. +# The veracity of this state is not enforced by PCCL. +# This step count may diverge and should not be relied upon for critical operations, +# especially if it entails branching logic, which WILL deadlock your code @dataclass class TrainingProgress(Stateful): - total_tokens: int - outer_step: int - step: int + num_trained_tokens: int + num_performed_outer_steps: int + num_performed_inner_steps: int def state_dict(self) -> dict[str, Any]: - return {"total_tokens": self.total_tokens, "outer_step": self.outer_step, "step": self.step} + return {"num_trained_tokens": self.num_trained_tokens, "num_performed_outer_steps": self.num_performed_outer_steps, "num_performed_inner_steps": self.num_performed_inner_steps} def load_state_dict(self, state_dict: dict[str, Any]) -> None: - self.total_tokens = state_dict["total_tokens"] - self.outer_step = state_dict["outer_step"] - self.step = state_dict["step"] + self.num_trained_tokens = state_dict["total_tokens"] + self.num_performed_outer_steps = state_dict["num_outer_steps"] + self.num_performed_inner_steps = state_dict["num_performed_inner_steps"] class ModelWrapper(Stateful): @@ -202,13 +200,13 @@ def __init__( self._init_state() self._logger = get_logger(config) - self.world_info = get_world_info() + self.local_world_info = get_local_world_info() self.non_blocking_process: list[multiprocessing.Process] = [] self.blocking_process: list[multiprocessing.Process] = [] self._live_reco_thread: threading.Thread | None = None - if self.world_info.local_rank == 0: + if self.local_world_info.local_rank == 0: if self.config.path is not None: self.check_path_access(self.config.path) @@ -262,10 +260,10 @@ def save(self, remote: bool = False) -> None: """ - step_ckpt_path = os.path.join(self.config.path, f"step_{self.training_progress.step}") + step_ckpt_path = os.path.join(self.config.path, f"step_{self.training_progress.num_performed_inner_steps}") if remote and self.config.remote is not None: - remote_ckpt_path = os.path.join(self.config.remote.path, f"step_{self.training_progress.step}") + remote_ckpt_path = os.path.join(self.config.remote.path, f"step_{self.training_progress.num_performed_inner_steps}") # if we are not in self recovery mode we save to disk time_start = time.perf_counter() @@ -274,7 +272,7 @@ def save(self, remote: bool = False) -> None: # push to remote non_error_barrier() - if self.world_info.local_rank == 0: + if self.local_world_info.local_rank == 0: if remote and self.config.remote is not None: self._async_save_remote(step_ckpt_path, remote_ckpt_path) @@ -293,20 +291,20 @@ def _save(self, ckpt_path: str): dcp.save(self.states, checkpoint_id=ckpt_path) if self.diloco_offloaded_optimizer: - with open(os.path.join(ckpt_path, f"__{self.world_info.local_rank}_0.pt"), "wb") as f: + with open(os.path.join(ckpt_path, f"__{self.local_world_info.local_rank}_0.pt"), "wb") as f: state = {} state["optimizer"] = OuterOptimizerWrapper(self.diloco_offloaded_optimizer).state_dict() torch.save(state, f) data_path = os.path.join(ckpt_path, "data") - self.save_data(data_path, self.dataloader, self.world_info.local_rank) + self.save_data(data_path, self.dataloader, self.local_world_info.local_rank) non_error_barrier() if self.config.remote_data_path is not None: remote_data_path = os.path.join( - self.config.remote_data_path, f"data_{self.data_rank}", f"step_{self.training_progress.step}" + self.config.remote_data_path, f"data_{self.data_rank}", f"step_{self.training_progress.num_performed_inner_steps}" ) latest_remote_data_path = os.path.join(self.config.remote_data_path, f"data_{self.data_rank}", "latest") @@ -352,7 +350,7 @@ def wait_for_blocking_job(self): self.blocking_process = [] - if self.world_info.local_rank == 0: + if self.local_world_info.local_rank == 0: if self.config.topk is not None: delete_topk(self.logger, self.config.path, self.config.topk) @@ -365,7 +363,7 @@ def _del__(self): @torch.no_grad() def _load_data(self, resume_ckpt_path: str): self._logger.debug(f"loading data from {resume_ckpt_path}") - world_info = get_world_info() + world_info = get_local_world_info() data_path = os.path.join(resume_ckpt_path, "data") @@ -392,7 +390,7 @@ def load( """ time_start = time.perf_counter() - world_info = get_world_info() + world_info = get_local_world_info() files = os.listdir(resume_ckpt_path) @@ -405,7 +403,7 @@ def load( dcp.load(self.states, checkpoint_id=resume_ckpt_path) if self.config.token_count is not None: - self.training_progress.total_tokens = self.config.token_count + self.training_progress.num_trained_tokens = self.config.token_count self._logger.debug("sync inner model") # todo(refactor): here we should rather let the diloco class handle this logic @@ -439,115 +437,6 @@ def remote_data_load(self): data_path = dest self._load_data(data_path) - @torch.no_grad() - def recv_ckpt_from_peer(self, global_pg: dist.ProcessGroup): - assert self.diloco_offloaded_param_list is not None, "recv_ckpt_from_peers is only supported with diloco" - - time_start = time.perf_counter() - self._logger.debug(f"Start receiving ckpt from rank {self.config.live_recovery_rank_src}") - - jobs = [] - buffers = [] - for i, param in enumerate(self.diloco_offloaded_param_list): - data = param.data - if isinstance(param.data, DTensor): - data = param.data.to_local() - - buffer = torch.empty_like(data) - buffers.append(buffer) - jobs.append(global_pg.recv([buffer], self.config.live_recovery_rank_src, i)) - - for job in jobs: - job.wait() - - for buffer, param in zip(buffers, self.model.parameters()): - data = param.data - if isinstance(data, DTensor): - data = data.to_local() - data.copy_(buffer) - - self._logger.debug("live recovery progress: offloaded model received 1/5") - - outer_opt_state_dict = recv_state_dict( - global_pg, self.config.live_recovery_rank_src, self.diloco_offloaded_optimizer.state_dict() - ) - self.diloco_offloaded_optimizer.load_state_dict(outer_opt_state_dict) - - self._logger.debug("live recovery progress: outer optimizer state dict received 2/5") - - training_process_state_dict = recv_state_dict( - global_pg, self.config.live_recovery_rank_src, self.training_progress.state_dict() - ) - self.training_progress.load_state_dict(training_process_state_dict) - self._logger.debug("live recovery progress: training progress state dict received 3/5") - - for group in self.optimizer.param_groups: - for p in group["params"]: - p.grad = torch.randn_like(p) - - self.optimizer.step() - self.optimizer.zero_grad() - - inner_opt_state_dict = recv_state_dict( - global_pg, self.config.live_recovery_rank_src, self.optimizer.state_dict() - ) - self.optimizer.load_state_dict(inner_opt_state_dict) - - self._logger.debug("live recovery progress: inner optimizer state dict received 4/5") - - sheduler_state_dict = recv_state_dict( - global_pg, self.config.live_recovery_rank_src, self.scheduler.state_dict() - ) - self.scheduler.load_state_dict(sheduler_state_dict) - - self._logger.debug("live recovery progress: scheduler state dict received 5/5") - - self._logger.debug( - f"Received ckpt from rank {self.config.live_recovery_rank_src} in {time.perf_counter() - time_start} seconds" - ) - - @torch.no_grad() - def send_ckpt_to_peer(self, global_pg: dist.ProcessGroup, dest_rank: int, blocking: bool = False): - def async_send(): - assert self.diloco_offloaded_param_list is not None, "send_ckpt_to_peers is only supported with diloco" - time_start = time.perf_counter() - self._logger.debug(f"Start sending ckpt to rank {dest_rank}") - - try: - jobs = [] - for i, param in enumerate(self.diloco_offloaded_param_list): - data = param.data - if isinstance(data, DTensor): - data = data.to_local() - jobs.append(global_pg.send([data], dest_rank, i)) - - for job in jobs: - job.wait() - - send_state_dict(global_pg, self.diloco_offloaded_optimizer.state_dict(), dest_rank) - send_state_dict(global_pg, self.training_progress.state_dict(), dest_rank) - - inner_optimizer_non_tensor_state_dict, inner_optimizer_tensors = _get_sendable_state_dict( - self.optimizer.state_dict() - ) - send_tensor_and_state_dict( - global_pg, dest_rank, inner_optimizer_non_tensor_state_dict, inner_optimizer_tensors - ) - - send_state_dict(global_pg, self.scheduler.state_dict(), dest_rank) - except RuntimeError as e: - self._logger.error(f"Error sending ckpt to rank {dest_rank}: {e}") - else: - self._logger.debug(f"Sent ckpt to rank {dest_rank} in {time.perf_counter() - time_start} seconds") - - thread = threading.Thread(target=async_send) - thread.start() - self._logger.debug("Live recovery thread started") - if blocking: - thread.join() - else: - self._live_reco_thread = thread - def delete_topk(logger: logging.Logger, ckpt_path: str, topk: int): checkpoints_to_delete = get_checkpoints_to_delete(ckpt_path, topk) diff --git a/src/zeroband/collectives.py b/src/zeroband/collectives.py deleted file mode 100644 index f9f6d47c..00000000 --- a/src/zeroband/collectives.py +++ /dev/null @@ -1,192 +0,0 @@ -from typing import Callable, Optional, TypeAlias -import torch -import torch.distributed as dist - -from zeroband.config import Compression - -AllReduceFunc: TypeAlias = Callable[ - [torch.Tensor, dist.ReduceOp, Optional[dist.ProcessGroup], Optional[torch.dtype]], None -] - - -def gloo_all_reduce( - tensor: torch.Tensor, - op: dist.ReduceOp = dist.ReduceOp.SUM, # type: ignore (defined weird) - group: Optional[dist.ProcessGroup] = None, -) -> None: - """Wrap gloo all reduce""" - if group is None: - group = dist.distributed_c10d._get_default_group() - if op not in [dist.ReduceOp.SUM, dist.ReduceOp.AVG]: - raise ValueError(f"Unsupported reduce operation {op}. Only SUM and AVG are supported.") - - # group = cast(dist.ProcessGroup, group) # just type hint stuff for IDE - if op == dist.ReduceOp.AVG: - # todo check numerical stability of doing post or pre div - tensor.div_(group.size()) - - dist.all_reduce(tensor, op, group=group) - - -def all_reduce( - compression: Compression, - tensor: torch.Tensor, - op: dist.ReduceOp = dist.ReduceOp.SUM, # type: ignore - group: Optional[dist.ProcessGroup] = None, -) -> None: - if compression == Compression.UINT8: - from zeroband.C.collectives import ring_allreduce as ring_allreduce_c - - return ring_allreduce_c(tensor, op, group) - else: - return gloo_all_reduce(tensor, op, group) - - -# =============== -# Code purgatory -# --------------- -# This code is still here because it is used by tests -# ring_allreduce is used by tests/test_c/test_collectives.py to make sure the new c impl doesnt deviate too much numerically -BUFFER_COUNT = 2 - - -def ring_allreduce_py( - tensor: torch.Tensor, - op: dist.ReduceOp = dist.ReduceOp.SUM, # type: ignore - group: Optional[dist.ProcessGroup] = None, - transfer_dtype: Optional[torch.dtype] = None, - quantization_func: Optional[Callable] = None, -) -> None: - """ - Perform all-reduce on a tensor using ring algorithm. - The accumulation will be done in-place on the input tensor. - The transfers will be done using the specified transfer_dtype. - """ - if quantization_func is not None: - if transfer_dtype is not None: - raise ValueError("Quantization and transfer_dtype cannot be used together") - transfer_dtype = tensor.dtype - if transfer_dtype is None: - transfer_dtype = tensor.dtype - if group is None: - group = dist.distributed_c10d._get_default_group() - if op not in [dist.ReduceOp.SUM, dist.ReduceOp.AVG]: - raise ValueError(f"Unsupported reduce operation {op}. Only SUM and AVG are supported.") - - world_size = group.size() - rank = group.rank() - - # Divide the tensor into chunks - flat_tensor = tensor.as_strided((tensor.numel(),), (1,)) - chunks = flat_tensor.chunk(world_size * BUFFER_COUNT) - - assert flat_tensor.size(0) % (world_size * BUFFER_COUNT) == 0, "Tensor size must be divisible by world size" - - # Temporary buffers for transferring data - num_buffers = BUFFER_COUNT * world_size - if quantization_func is not None: - recv_buffer = [torch.empty_like(chunks[0], dtype=torch.uint8) for _ in range(BUFFER_COUNT)] - send_buffer = [None for _ in range(BUFFER_COUNT)] - send_lookup_buffer = [None for _ in range(BUFFER_COUNT)] - recv_lookup_buffer = [torch.empty(256, dtype=chunks[0].dtype) for _ in range(BUFFER_COUNT)] - send_lookup_work = [None for _ in range(BUFFER_COUNT)] - recv_lookup_work = [None for _ in range(BUFFER_COUNT)] - else: - recv_buffer = [torch.empty_like(chunks[0], dtype=transfer_dtype) for _ in range(BUFFER_COUNT)] - send_buffer = [torch.empty_like(chunks[0], dtype=transfer_dtype) for _ in range(BUFFER_COUNT)] - send_work = [None] * BUFFER_COUNT - recv_work = [None] * BUFFER_COUNT - - send_rank = (rank + 1) % world_size - recv_rank = (rank - 1) % world_size - for step in range(1, world_size * BUFFER_COUNT + 1): - send_chunk = (rank * BUFFER_COUNT - step) % num_buffers - - if send_work[step % BUFFER_COUNT] is not None: - send_work[step % BUFFER_COUNT].wait() - recv_work[step % BUFFER_COUNT].wait() - if quantization_func is not None: - send_lookup_work[step % BUFFER_COUNT].wait() - recv_lookup_work[step % BUFFER_COUNT].wait() - # print(recv_lookup_buffer[step % BUFFER_COUNT][recv_buffer[step % BUFFER_COUNT].long()]) - chunks[send_chunk].add_( - recv_lookup_buffer[step % BUFFER_COUNT][recv_buffer[step % BUFFER_COUNT].long()] - ) - else: - chunks[send_chunk].add_(recv_buffer[step % BUFFER_COUNT]) - - if step <= (world_size - 1) * BUFFER_COUNT: - # Send and receive - if quantization_func is not None: - send_buffer[step % BUFFER_COUNT], send_lookup_buffer[step % BUFFER_COUNT] = quantization_func( - chunks[send_chunk] - ) - send_lookup_work[step % BUFFER_COUNT] = dist.isend( - send_lookup_buffer[step % BUFFER_COUNT], dst=send_rank, group=group, tag=step + 1000 - ) - recv_lookup_work[step % BUFFER_COUNT] = dist.irecv( - recv_lookup_buffer[step % BUFFER_COUNT], src=recv_rank, group=group, tag=step + 1000 - ) - else: - send_buffer[step % BUFFER_COUNT].copy_(chunks[send_chunk]) - send_work[step % BUFFER_COUNT] = dist.isend( - send_buffer[step % BUFFER_COUNT], dst=send_rank, group=group, tag=step - ) - recv_work[step % BUFFER_COUNT] = dist.irecv( - recv_buffer[step % BUFFER_COUNT], src=recv_rank, group=group, tag=step - ) - - if op == dist.ReduceOp.AVG: - for i in range(BUFFER_COUNT): - chunks[i + rank * BUFFER_COUNT].divide_(world_size) - if quantization_func is not None: - for i in range(BUFFER_COUNT): - quant_weight, lookup = quantization_func(chunks[i + rank * BUFFER_COUNT]) - chunks[i + rank * BUFFER_COUNT].copy_(lookup[quant_weight.long()]) - - if quantization_func is not None: - recv_buffer = [torch.empty_like(chunks[0], dtype=torch.uint8) for _ in range(BUFFER_COUNT)] - send_buffer = [None for _ in range(BUFFER_COUNT)] - send_lookup_buffer = [None for _ in range(BUFFER_COUNT)] - recv_lookup_buffer = [torch.empty(256, dtype=chunks[0].dtype) for _ in range(BUFFER_COUNT)] - send_lookup_work = [None for _ in range(BUFFER_COUNT)] - recv_lookup_work = [None for _ in range(BUFFER_COUNT)] - send_work = [None] * BUFFER_COUNT - recv_work = [None] * BUFFER_COUNT - - for step in range(1, world_size * BUFFER_COUNT + 1): - send_chunk = (rank * BUFFER_COUNT + BUFFER_COUNT - step) % num_buffers - - if send_work[step % BUFFER_COUNT] is not None: - send_work[step % BUFFER_COUNT].wait() - recv_work[step % BUFFER_COUNT].wait() - if quantization_func is not None: - send_lookup_work[step % BUFFER_COUNT].wait() - recv_lookup_work[step % BUFFER_COUNT].wait() - chunks[send_chunk].copy_( - recv_lookup_buffer[step % BUFFER_COUNT][recv_buffer[step % BUFFER_COUNT].long()] - ) - else: - chunks[send_chunk].copy_(recv_buffer[step % BUFFER_COUNT]) - - if step <= (world_size - 1) * BUFFER_COUNT: - # Send and receive - if quantization_func is not None: - send_buffer[step % BUFFER_COUNT], send_lookup_buffer[step % BUFFER_COUNT] = quantization_func( - chunks[send_chunk] - ) - send_lookup_work[step % BUFFER_COUNT] = dist.isend( - send_lookup_buffer[step % BUFFER_COUNT], dst=send_rank, group=group, tag=step + 1000 - ) - recv_lookup_work[step % BUFFER_COUNT] = dist.irecv( - recv_lookup_buffer[step % BUFFER_COUNT], src=recv_rank, group=group, tag=step + 1000 - ) - else: - send_buffer[step % BUFFER_COUNT].copy_(chunks[send_chunk]) - - send_work[step % BUFFER_COUNT] = dist.isend( - send_buffer[step % BUFFER_COUNT], dst=send_rank, group=group, tag=step - ) - recv_work[step % BUFFER_COUNT] = dist.irecv( - recv_buffer[step % BUFFER_COUNT], src=recv_rank, group=group, tag=step - ) diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py deleted file mode 100644 index ca3d7ce6..00000000 --- a/src/zeroband/comms.py +++ /dev/null @@ -1,609 +0,0 @@ -import sys -import os -import time -import subprocess -from torch.distributed.device_mesh import init_device_mesh -from zeroband.utils.world_info import get_world_info -from zeroband.utils.logger import get_logger -import torch.distributed as dist -from datetime import timedelta -from typing import List, Tuple, Optional -from torch.testing._internal.distributed.fake_pg import FakeProcessGroup -import multiprocessing as mp -from uuid import uuid4 -import toposolve -from zeroband.utils.ip import parse_iperf_output - -TCPSTORE_TIMEOUT = timedelta(seconds=int(os.getenv("ZERO_BAND_GLOBAL_STORE_TIMEOUT_SECONDS", "300"))) -TCPSTORE_POLLING_INTERVAL = float(os.getenv("ZERO_BAND_GLOBAL_STORE_POLLING_INTERVAL_SECONDS", "0.1")) -GLOBAL_PG_TIMEOUT = timedelta(seconds=int(os.getenv("ZERO_BAND_GLOBAL_PG_TIMEOUT_SECONDS", "600"))) -MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit -HEARTBEAT_INTERVAL = int( - os.getenv("ZERO_BAND_EDM_HEARTBEAT_INTERVAL_SECONDS", "2") -) # Interval in seconds between heartbeats -HEARTBEAT_TIMEOUT = int( - os.getenv("ZERO_BAND_EDM_HEARTBEAT_TIMEOUT_SECONDS", "10") -) # Time in seconds after which a node is considered dead if no heartbeat is received -IPERF_PORT = int(os.getenv("ZERO_BAND_IPERF_PORT", "10101")) -IPERF_IFNAME = os.getenv("GLOO_SOCKET_IFNAME", "eth0") -BENCH_TENSOR_SIZE = 1_000_000 - - -class ElasticDeviceMesh: - """A class to manage the process groups for elastic training without restarts. - - The way it works is rank 0 coordinates the joining and leaving of nodes. - Rank 0 manages the status to coordinate the creation and recreation of the process groups. - When a node wants to join, rank 0 will setup the store so that all nodes know the new world size and their respective ranks. - - Store keys used: - - status: "init", "running", "reinit" - - world_size: The current world size - - mesh_count: The version of the mesh - - rank_{uuid}: The rank of the node with the given uuid - - joiner_{i}: The uuid of the ith joiner. Its a KV implmentation of a queue. - """ - - local_pg: dist.ProcessGroup - global_pg: dist.ProcessGroup - - def __init__( - self, backend: str = "cpu:gloo,cuda:nccl", enable: bool = True, live_recovery_rank_src: int | None = None - ): - self._logger = get_logger() - self.world_info = get_world_info() - self.live_recovery_rank_src = live_recovery_rank_src - - # Initialize global process group - self.global_pg = FakeProcessGroup(self.world_info.rank, 1) - - self.enable = enable - if enable: - self._init_global_pg() - - # Initialize local process group - dist.init_process_group(backend=backend) - self.mesh = init_device_mesh( - "cuda", - (self.world_info.nnodes, self.world_info.local_world_size), - mesh_dim_names=("internode", "intranode"), - ) - self.local_pg = self.mesh.get_group("intranode") - - # Start heartbeat - - self.cuda_local_mesh = init_device_mesh("cuda", mesh_shape=(self.local_pg.size(),)) - self.cpu_local_mesh = init_device_mesh("cpu", mesh_shape=(self.local_pg.size(),)) - - # Logging - if self.enable: - self._optimize_ring_ranks() - if self.live_recovery_rank_src is not None: - self.live_recovery.ask_for_live_ckpt(self.live_recovery_rank_src) - self.global_pg.barrier().wait() - - self._logger.info(f"global_pg size : {self.global_pg.size()}, local_pg size: {self.local_pg.size()}") - - def __del__(self): - self._stop_heartbeat() - dist.destroy_process_group() - - def _init_global_store(self): - self._logger.info( - f"[{self.world_info.global_unique_id}](Leader: {self._global_leader}) TCPStore init: Connecting via {self.world_info.global_addr}:{self.world_info.global_port + self.world_info.rank}" - ) - self.global_store = dist.TCPStore( - host_name=self.world_info.global_addr, - port=self.world_info.global_port + self.world_info.rank, - timeout=TCPSTORE_TIMEOUT, - is_master=self._global_leader, - ) - self.god_store = dist.TCPStore( - host_name=self.world_info.global_addr, - port=self.world_info.global_port, - timeout=TCPSTORE_TIMEOUT, - is_master=False, - ) - - def _init_global_store_values(self): - """Initialize the global store with mesh_count, joiner_0, and status. Also sets the global status.""" - self._logger.debug("Initializing global store values") - self.global_store.set(f"gid_{self.world_info.global_rank}", self.world_info.global_unique_id) - self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(self.world_info.global_rank)) - if self._global_leader: - self.global_store.set("mesh_count", "0") - self.global_store.set("world_size", str(self.world_info.global_world_size)) - self.global_store.set("joiner_0", "null") - for i in range(self.world_info.global_world_size): - self.global_store.set(f"barrier_{i}", "null") - self._global_ids = [ - self.global_store.get(f"gid_{i}").decode("utf-8") for i in range(self.world_info.global_world_size) - ] - for i in self._global_ids: - for j in self._global_ids: - self.global_store.set(f"ping_{i}_{j}", "1000_000_000") - self.global_store.set("status", "init") - self.global_status = "init" - else: - self.global_status = self._wait_for_status() - self._global_ids = [ - self.global_store.get(f"gid_{i}").decode("utf-8") for i in range(self.world_info.global_world_size) - ] - - def _create_global_pg(self): - # Delete the old global_pg - if hasattr(self, "global_pg"): - if sys.getrefcount(self.global_pg) > 2: - self._logger.warning( - f"Global PG refcount was {sys.getrefcount(self.global_pg)} when 2 is expected during deletion. This may cause a memory leak." - ) - del self.global_pg # TODO(jackmin): Where do we catch errors in teardown? - self._logger.info("Destroyed process group") - - # Get new global rank and world size - self.world_info.global_rank = int( - self.global_store.get(f"rank_{self.world_info.global_unique_id}").decode("utf-8") - ) - self.world_info.global_world_size = int(self.global_store.get("world_size").decode("utf-8")) - self.mesh_count = int(self.global_store.get("mesh_count").decode("utf-8")) - self._logger.debug( - f"New global rank: {self.world_info.global_rank}, New global world size: {self.world_info.global_world_size} New mesh count: {self.mesh_count}" - ) - - # Create prefix store - prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", self.global_store) - self._logger.debug(f"Created prefix store with mesh_{self.mesh_count}") - - # Create process group - self._logger.debug( - f"Creating global pg with {self.world_info.global_world_size} rank {self.world_info.global_rank}" - ) - self.global_pg = dist.ProcessGroupGloo( - prefix_store, self.world_info.global_rank, self.world_info.global_world_size, GLOBAL_PG_TIMEOUT - ) - self._logger.debug("Global pg created with %d peers. Timeout of %s", self.global_pg.size(), GLOBAL_PG_TIMEOUT) - - def _optimize_ring_ranks(self): - self._global_ids = [ - self.global_store.get(f"gid_{i}").decode("utf-8") for i in range(self.world_info.global_world_size) - ] - if self.world_info.local_rank == 0: - self._logger.debug("Measuring bandwidths") - self._measure_connectivity() - self._logger.debug("Measuring bandwidths done") - - self.local_pg.barrier().wait() - self.global_pg.barrier().wait() - - if self._global_leader: - self._logger.debug("Calculating TSP") - pings = self.get_pings() - min_dist, path = toposolve.TSPSolver().solve_tsp(pings) - self._logger.debug(f"Min distance: {min_dist}") - self._logger.debug(f"Path: {path}") - new_gids = [self._global_ids[i] for i in path[:-1]] - assert set(new_gids) == set(self._global_ids) - - for i, gid in enumerate(new_gids): - self.global_store.set(f"rank_{gid}", str(i)) - self.global_store.set(f"gid_{i}", gid) - self.global_store.set("mesh_count", str(self.mesh_count + 1)) - - self.local_pg.barrier().wait() - self.global_pg.barrier().wait() - - self._global_ids = [ - self.global_store.get(f"gid_{i}").decode("utf-8") for i in range(self.world_info.global_world_size) - ] - self._create_global_pg() - - def _queue_join(self): - """Queue a node to join the mesh.""" - for i in range(MAX_JOINERS): - joiner_id = self.global_store.get(f"joiner_{i}").decode("utf-8") - if joiner_id == "null": - self.global_store.set(f"joiner_{i}", self.world_info.global_unique_id) - self.global_store.set(f"joiner_{i + 1}", "null") - break - else: - raise RuntimeError("Too many joiners") - - def _get_joiners(self) -> Tuple[List[str], List[str]]: - joiners = [] - for i in range(MAX_JOINERS): - joiner_id = self.global_store.get(f"joiner_{i}").decode("utf-8") - if joiner_id == "null": - break - joiners.append(joiner_id) - return joiners - - def _clear_joiners(self): - self.global_store.set("joiner_0", "null") - - def _wait_for_status(self, status: Optional[str] = None) -> str: - """Wait for status to be set in the store. - - Args: - store (dist.Store): The store to check. - status (Optional[str], optional): The status to wait for. If None, wait for any status. Defaults to None. - Returns: - status (str): The status. - """ - while True: - try: - ret = self.global_store.get("status").decode("utf-8") - if status is None or ret == status: - return ret - time.sleep(TCPSTORE_POLLING_INTERVAL) - except dist.DistStoreError as e: - if status is not None: - raise e - time.sleep(0.1) - - def _init_global_pg(self) -> None: - # Each rank gets its own global store with global rank 0 as the master - time_start = time.perf_counter() - - self._global_leader = self.world_info.global_rank == 0 - self._init_global_store() - - # Initialize store values - self._init_global_store_values() - - self.live_recovery = LiveRecovery(store=self.global_store) - - if self.global_status == "running": # Join path - # Ask to join and then wait for the status to be "reinit" - self._logger.info("Waiting to join") - self._queue_join() - self._wait_for_status("reinit") - - # Create global process group - self._create_global_pg() - - # Update global store values - if self._global_leader: - self.global_store.set("status", "running") - self.global_store.set("resolved_time", uuid4().hex) - self.global_status = "running" - self._last_resolved_time = self.global_store.get("resolved_time").decode("utf-8") - - self._start_heartbeat() - - self._logger.info( - f"Elastic Device mesh init done with {self.global_pg.size()} peers in {time.perf_counter() - time_start} seconds" - ) - - if self.world_info.local_rank == 0: - self._start_iperf_server() - self._evicted_nodes = [] - - def _start_heartbeat(self): - """Start sending heartbeats to the global store in a separate process.""" - self._heartbeat_stop_event = mp.Event() - self._heartbeat_process = mp.Process(target=self._heartbeat_loop, args=(self._heartbeat_stop_event,)) - self._heartbeat_process.start() - - def _stop_heartbeat(self): - """Stop the heartbeat process.""" - self._send_deathrattle() - if hasattr(self, "_heartbeat_stop_event"): - self._heartbeat_stop_event.set() - self._heartbeat_process.join() - - def _heartbeat_loop(self, stop_event): - """Continuously send heartbeats until stopped.""" - try: - while not stop_event.is_set(): - self._send_heartbeat() - time.sleep(HEARTBEAT_INTERVAL) - finally: - self._send_deathrattle() - - def _send_heartbeat(self): - """Send a heartbeat to the global store.""" - current_time = time.time() - try: - self.global_store.set(f"heartbeat_{self.world_info.global_unique_id}", str(current_time)) - except Exception: - self._logger.error("Error sending heartbeat", exc_info=True) - pass - - def _send_deathrattle(self): - """Send a deathrattle to the global store.""" - if hasattr(self, "global_store"): - self.global_store.set(f"heartbeat_{self.world_info.global_unique_id}", "-100") - else: - import warnings - - warnings.warn("global_store garbage collected. Skipping deathrattle.") - - def _check_heartbeats(self) -> List[str]: - """Check heartbeats and return a list of nodes that have missed their heartbeats.""" - dead_nodes = [] - current_time = time.time() - for gid in self._global_ids: - try: - last_heartbeat = float(self.global_store.get(f"heartbeat_{gid}").decode("utf-8")) - self._logger.debug(f"Node {gid} last heartbeat: {last_heartbeat}") - if current_time - last_heartbeat > HEARTBEAT_TIMEOUT: - dead_nodes.append(gid) - self.global_store.delete_key(f"heartbeat_{gid}") - except dist.DistStoreError: - self._logger.warning(f"Node {gid} has no heartbeat") - return dead_nodes - - def _resolve_world(self, admit_joiners: bool = False) -> bool: - """Set the new world size and ranks for all nodes if there are joiners or dead nodes. Else, do nothing. - - Args: - admit_joiners (bool, optional): Whether to admit joiners. Defaults to False. - Returns: - bool: True if the world was changed, False otherwise. - """ - # Find joiners - if admit_joiners: - joiners = self._get_joiners() - else: - joiners = [] - - # Check for dead nodes - dead_nodes = self._check_heartbeats() - self._logger.debug( - "Joiners (%sadmitting): %s, Dead nodes: %s, Evicting nodes: %s", - "" if admit_joiners else "not ", - joiners, - dead_nodes, - self._evicted_nodes, - ) - dead_nodes.extend(self._evicted_nodes) - - # If no joiners or dead nodes, no resolution needed - if len(joiners) == 0 and len(dead_nodes) == 0: - return False - - # Remap live ranks to smaller world_size caused by dead nodes - leaving_nodes = set(dead_nodes) - live_ranks = [i for i in self._global_ids if i not in leaving_nodes] - for i, rank in enumerate(live_ranks): - self.global_store.set(f"rank_{rank}", str(i)) - self.global_store.set(f"gid_{i}", rank) - new_world_size = len(live_ranks) - - # Give joiners new ranks - for joiner_id in joiners: - self.global_store.set(f"rank_{joiner_id}", str(new_world_size)) - self.global_store.set(f"gid_{new_world_size}", joiner_id) - live_ranks.append(joiner_id) - new_world_size += 1 - - self._global_ids = live_ranks - for i in self._global_ids: - for j in self._global_ids: - self.global_store.set(f"ping_{i}_{j}", "1000_000_000") - for i in range(1, new_world_size): - self.global_store.set(f"barrier_{i}", "null") - # Update world_size - self.global_store.set("world_size", str(new_world_size)) - self.global_store.set("mesh_count", str(self.mesh_count + 1)) - # Set status to "reinit" - self.global_store.set("status", "reinit") - return True - - def maybe_reinit_global_pg(self, admit_joiners: bool = False) -> bool: - """Reinitialize the global_pg if there are is a state change. - - Args: - admit_joiners (bool, optional): Whether to admit joiners. Defaults to False. - Returns: - bool: True if the global_pg was reinitialized, False otherwise. - """ - if not self.enable: - # no op if disabled - return - - time_start = time.perf_counter() - self._logger.debug("[%s] Resolving world", self.world_info.global_unique_id) - if self._global_leader: - self._resolve_world(admit_joiners=admit_joiners) - self.global_store.set("resolved_time", uuid4().hex) - else: - while (ans := self.global_store.get("resolved_time").decode("utf-8")) == self._last_resolved_time: - # TODO: Have a timeout here in case the leader is dead - time.sleep(TCPSTORE_POLLING_INTERVAL) - self._last_resolved_time = ans - - self._logger.debug("World resolved in %s seconds", time.perf_counter() - time_start) - - status = self.global_store.get("status").decode("utf-8") - if status == "running": # No joiners or dead nodes - return False - - # Reinit Path - try: - self._create_global_pg() - self._optimize_ring_ranks() - self.global_pg.barrier().wait() - except Exception as e: - self._logger.error(f"Error recreating process group: {e}. Retrying...") - return self.maybe_reinit_global_pg(admit_joiners=admit_joiners) - - if self._global_leader: - self._clear_joiners() - self.global_store.set("status", "running") - - self._logger.debug("Reinitialized global_pg done in %s seconds", time.perf_counter() - time_start) - - # TODO: We need to reset the self.world_info.global_rank reference - # Somehow the reference becomes stale and the heartbeats become wrong - # This will be fixed when heartbeats become unique id dependent which never changes - self._logger.debug("Reset Heartbet") - self._stop_heartbeat() - self._start_heartbeat() - self._logger.debug("Reset Heartbeat done") - return True - - def get_global_pg(self, maybe_reinit: bool = False) -> dist.ProcessGroup: - """Get the global process group. If maybe_reinit is True, reinitialize the global process group if needed.""" - if maybe_reinit: - self.maybe_reinit_global_pg() - return self.global_pg - - def monitored_barrier(self, flag: str): - flag = str(flag) - time_start = time.perf_counter() - self._logger.debug("[%s] Monitored Barrier %s", self.world_info.global_unique_id, flag) - if self._global_leader: - self._logger.debug("Others have %d seconds to resolve", GLOBAL_PG_TIMEOUT.total_seconds()) - while not all( - self.global_store.get(f"barrier_{i}").decode("utf-8") == flag - for i in range(1, self.world_info.global_world_size) - ): - if time.perf_counter() - time_start > GLOBAL_PG_TIMEOUT.total_seconds(): - self._logger.error("Monitored barrier failed due to timeout") - self._evicted_nodes = [ - i - for i in range(1, self.world_info.global_world_size) - if self.global_store.get(f"barrier_{i}").decode("utf-8") != flag - ] - self._logger.info("Evicting nodes: %s", self._evicted_nodes) - self.global_store.set(f"barrier_{self.world_info.global_rank}", "error") - # We neeed to evict the dead node - raise RuntimeError("Monitored barrier failed due to timeout") - time.sleep(TCPSTORE_POLLING_INTERVAL) - self.global_store.set(f"barrier_{self.world_info.global_rank}", flag) - else: - self.global_store.set(f"barrier_{self.world_info.global_rank}", flag) - while (ans := self.global_store.get("barrier_0").decode("utf-8")) != flag: - if ans == "error": - raise RuntimeError("Monitored barrier failed due to error") - # TODO: Have a timeout here in case the leader is dead - time.sleep(TCPSTORE_POLLING_INTERVAL) - - self._logger.debug("Monitored barrier resolved in %s seconds", time.perf_counter() - time_start) - - def get_pings(self) -> List[List[int]]: - pings = [[1000_000_000] * self.world_info.global_world_size for _ in range(self.world_info.global_world_size)] - for i, e1 in enumerate(self._global_ids): - for j, e2 in enumerate(self._global_ids): - if i == j: - continue - pings[i][j] = int(self.god_store.get(f"ping_{e1}_{e2}")) - - self._logger.debug("\n %s", format_grid(pings)) - return pings - - def _start_iperf_server(self) -> None: - """Start the iperf server process.""" - try: - from zeroband.utils.ip import get_ip_address - - iperf_addr = get_ip_address(IPERF_IFNAME) - iperf_port = IPERF_PORT + self.world_info.global_rank - cmd: List[str] = ["iperf", "-s", "-p", str(iperf_port)] - self.server_process = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - self.god_store.set(f"iperf_{self.world_info.global_unique_id}", f"{iperf_addr}:{iperf_port}") - self._logger.info(f"Started iperf server on {iperf_addr} with port {iperf_port}") - except Exception as e: - self._logger.error(f"Failed to start iperf server: {str(e)}") - raise - - def _measure_connectivity(self): - for i in self._global_ids: - if i == self.world_info.global_unique_id: - continue - target_host, target_port = self.god_store.get(f"iperf_{i}").decode("utf-8").split(":") - target_port = int(target_port) - time_taken = self.measure_bandwidth(target_host, target_port) - self.god_store.set(f"ping_{self.world_info.global_unique_id}_{i}", str(time_taken)) - - def measure_bandwidth(self, target_host: str, target_port: int) -> int: - """ - Measure bandwidth to a specific target. - - Args: - target_host: The host to measure bandwidth to - target_port: The port to measure bandwidth to - - Returns: - int: The time taken to transfer 10Tb of data in seconds - """ - try: - cmd: List[str] = [ - "iperf", - "-c", - target_host, - "-p", - str(target_port), - "-t", - "1", # 1 second test - ] - result: subprocess.CompletedProcess = subprocess.run(cmd, capture_output=True, text=True, timeout=5) - - if result.returncode != 0: - raise Exception(f"iperf error: {result.stderr}") - - time_taken: int = int(1e13 / parse_iperf_output(result.stdout)) - time_taken = min(time_taken, 1_000_000_000) - - return time_taken - except Exception as e: - self._logger.error(f"Error measuring bandwidth to {target_host}:{target_port} {str(e)}") - return int(1e9) - - -def format_grid(grid): - N = len(grid) - - # Set the main diagonal elements to 0 - for i in range(N): - grid[i][i] = 0 - - # Determine the width needed for formatting based on max possible value (99.99) and indices - cell_width = 6 - - # Create header row with column indices - header_row = " " + " | ".join(f"{j:>{cell_width-1}}" for j in range(N)) - - # Start building the formatted grid string - formatted_grid = header_row + "\n" - - for i, row in enumerate(grid): - # Format each element in the row - formatted_row = [f"{i:>2}"] # Add row index at the beginning of the row - for value in row: - # Divide by 1000 and format to 2 decimal places - formatted_value = f"{value / 1000:.2f}" - formatted_row.append(formatted_value) - - # Join the elements of the row with '|' and add it to the grid string - formatted_grid += " | ".join(formatted_row).center(cell_width * (N + 1)) + "\n" - - return formatted_grid.strip() - - -class LiveRecovery: - def __init__(self, store: dist.Store): - self.logger = get_logger() - self.world_info = get_world_info() - - self.store = dist.PrefixStore("live_recovery", store) - self.reset() - - def reset(self): - self.store.set(f"rank_{self.world_info.global_rank}", "null") - - def should_send_ckpt_to(self) -> int | None: - """use this function to check if someone is awaiting for a live ckpt""" - data = self.store.get(f"rank_{self.world_info.global_rank}").decode("utf-8") - if data == "null": - return None - try: - return int(data) - except ValueError as e: - self.logger.error(f"Error parsing live recovery data: {e}") - return None - - def ask_for_live_ckpt(self, rank: int) -> int | None: - """use this function to send a signal to a node to ask for a live ckpt""" - self.store.set(f"rank_{rank}", str(self.world_info.global_rank)) diff --git a/src/zeroband/compression.py b/src/zeroband/compression.py deleted file mode 100644 index 2fc1da75..00000000 --- a/src/zeroband/compression.py +++ /dev/null @@ -1,70 +0,0 @@ -# Code adapted from https://github.com/PrimeIntellect-ai/hivemind/blob/213bff98a62accb91f254e2afdccbf1d69ebdea9/hivemind/compression/quantization.py -# Original code is licensed under the MIT License. -# See the LICENSE file in the original repository for more information. - -import torch -import numpy as np -from typing import Tuple -import math -from concurrent.futures import ThreadPoolExecutor -import os - -RANGE_IN_SIGMAS: int = 6 -EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128))) -n_bins = 2**8 - - -def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int) -> torch.Tensor: - """Return the average value in each bucket""" - bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten()) - bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1) - lookup = bin_sums / bin_counts - return lookup - - -def get_chunk_size(num_elements: int, min_chunk_size: int) -> int: - """Adjust chunk_size to minimize imbalance between chunk sizes""" - if min_chunk_size >= num_elements: - return min_chunk_size - leftover_elements = num_elements % min_chunk_size - num_chunks = num_elements // min_chunk_size - return min_chunk_size + (leftover_elements - 1) // num_chunks + 1 - - -def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10**5) -> np.ndarray: - """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel.""" - if not array.data.c_contiguous and array.data.f_contiguous: - array = array.T - array = np.ascontiguousarray(array.reshape(-1)) - quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype) - chunk_size = get_chunk_size(len(array), min_chunk_size) - num_chunks = (len(array) - 1) // chunk_size + 1 - partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype) - - jobs = [] - for i in range(num_chunks): - chunk = slice(chunk_size * i, chunk_size * (i + 1)) - jobs.append(EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i])) - - for job in jobs: - job.result() - return np.quantile(partition_quantiles, quantiles) - - -def uniform_8bit_quantize(tensor: torch.Tensor, inplace: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: - offset = n_bins // 2 - # shift = tensor.mean() - # centered_tensor = tensor.sub_(shift) if inplace else tensor - shift - centered_tensor = tensor - std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1) - scale = RANGE_IN_SIGMAS * std_unbiased / n_bins - quantized = torch.quantize_per_tensor(centered_tensor, scale, offset, torch.quint8).int_repr() - lookup = average_buckets(tensor, quantized, n_bins) - return quantized, lookup - - -def quantile_8bit_quantize(tensor: torch.Tensor, inplace: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: - borders = torch.as_tensor(quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1]) - quantized = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1) - lookup = average_buckets(tensor, quantized, n_bins) - return quantized, lookup diff --git a/src/zeroband/config.py b/src/zeroband/config.py index 11c27af5..7b543b1f 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -170,6 +170,10 @@ class Config(BaseConfig): log_level: Literal["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" log_all_rank: bool = False + # CCoIP / PCCL + ccoip_master_addr: str = '127.0.0.1:48148' + ccoip_master_connection_attempts: int = 15 + # sub config diloco: DilocoConfig | None = None data: DataConfig = DataConfig() diff --git a/src/zeroband/data.py b/src/zeroband/data.py index 50ff1f58..fc1d185f 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -6,7 +6,7 @@ from zeroband.models.llama.model import create_block_mask_from_seqlens from zeroband.utils.logger import get_logger -from zeroband.utils.world_info import get_world_info +from zeroband.utils.world_info import get_local_world_info from zeroband.config import DataConfig import torch @@ -339,7 +339,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def _prefetch_next(self): def _task() -> None: # NOTE: Each thread gets its own threadlocal CUDA context and has to reset the device. - local_rank = get_world_info().local_rank + local_rank = get_local_world_info().local_rank torch.cuda.set_device(local_rank) # Grab batch or return sentinel diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 630a8d88..d6bc19bc 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -1,12 +1,17 @@ import re import time +from typing import List + +import pccl import torch +from pccl import Communicator, AsyncReduceHandle, ReduceOp, Attribute, ReduceOperandDescriptor, DistributionHint, \ + DataType, QuantizationOptions, QuantizationAlgorithm from torch import nn -from zeroband.comms import ElasticDeviceMesh -from zeroband.collectives import Compression, all_reduce -from zeroband.utils.world_info import get_world_info +from torch.distributed import init_device_mesh + +from zeroband.utils.world_info import get_local_world_info from zeroband.utils.logger import get_logger -from zeroband.config import DilocoConfig +from zeroband.config import DilocoConfig, Compression import torch.distributed as dist from torch.distributed._tensor.api import DTensor from functools import lru_cache @@ -21,6 +26,124 @@ def _find_first_number(s: str) -> int: return -1 + +def all_reduce_multiple_with_retry(communicator: Communicator, + tensors: list[torch.Tensor], + op: ReduceOp, + compression: Compression, + max_in_flight: int = 8): + """ + Launches concurrent all-reduce operations on a list of tensors, + waits for them all, and retries if a peer fails or the world size changes. + Will attempt to target :param max_in_flight: concurrent all-reduce operations. + The more similar your tensors are in size, the better this in flight system will work. + Future versions of PCCL may provide a "wait for any of multiple async ops" api to improve this pattern. + """ + world_size = communicator.get_attribute(Attribute.CURRENT_WORLD_SIZE) + + total_tx = 0 + total_rx = 0 + + def launch_all_reduce(x: torch.Tensor, tag: int): + op_desc = ReduceOperandDescriptor( + datatype=DataType.FLOAT, + distribution_hint=DistributionHint.NORMAL + ) + if compression == Compression.NO: + quant_desc = QuantizationOptions( + quantized_datatype=DataType.FLOAT, + algorithm=QuantizationAlgorithm.NONE + ) + else: + quant_desc = QuantizationOptions( + quantized_datatype=DataType.UINT8, + algorithm=QuantizationAlgorithm.MIN_MAX + ) + + return communicator.all_reduce_async( + x, x, + operand_descriptor=op_desc, + quantization_options=quant_desc, + op=op, + tag=tag + ) + + handles = [None for _ in range(len(tensors))] + done_handles = set() + + in_flight = 0 + for tensor_index in range(len(tensors)): + dst_tensor = tensors[tensor_index] + + if in_flight >= max_in_flight: + break + + handles[tensor_index] = launch_all_reduce( + dst_tensor, + tensor_index + ) + in_flight += 1 + + while world_size > 1: + all_done = True + for tensor_index in range(len(tensors)): + handle = handles[tensor_index] + dst_tensor = tensors[tensor_index] + + if handle is None: + if tensor_index in done_handles: + continue + + if in_flight >= max_in_flight: + continue + + handle = handles[tensor_index] = launch_all_reduce( + dst_tensor, + tensor_index + ) + in_flight += 1 + + is_success, status, info = handle.wait() + world_size = communicator.get_attribute(Attribute.CURRENT_WORLD_SIZE) + if not is_success: + print(f"Reduce failed: {status}; Starting recovery procedure") + handles[tensor_index] = None + # Wait for all ongoing ops to finish or fail before retry + for j in range(len(tensors)): + if j == tensor_index: + continue + h_j = handles[j] + if h_j is not None: + s_j, _, _ = h_j.wait() + if s_j: + done_handles.add(j) + in_flight -= 1 + handles[j] = None + all_done = False + break + + # success for this handle + handles[tensor_index] = None + done_handles.add(tensor_index) + + total_tx += info.tx_bytes + total_rx += info.rx_bytes + + in_flight -= 1 + + if all_done: + break + + if world_size == 1: + # If we are alone, just finalize all handles and return + for h in handles: + if h is not None: + h.wait() + return False + + return True + + class Diloco: """ This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852. @@ -48,21 +171,17 @@ class Diloco: """ def __init__( - self, - config: DilocoConfig, - model: nn.Module, - elastic_device_mesh: ElasticDeviceMesh, + self, + config: DilocoConfig, + model: nn.Module, ): self.config = config - if config.compression == Compression.UINT8: - from zeroband.C.collectives import ring_allreduce as _ # noqa: F401 - # just force compilation - - self.elastic_device_mesh = elastic_device_mesh - self._logger = get_logger() - self.world_info = get_world_info() + self.local_world_info = get_local_world_info() + + self.cuda_local_mesh = init_device_mesh("cuda", mesh_shape=(self.local_world_info.world_size,)) + self.cpu_local_mesh = init_device_mesh("cpu", mesh_shape=(self.local_world_info.world_size,)) self._init_offloaded_optimizer(model=model) @@ -75,61 +194,39 @@ def _init_offloaded_optimizer(self, model): self._logger.debug("offload model to cpu") @torch.no_grad() - def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str = "outer"): + def sync_pseudo_gradient(self, model: nn.Module, communicator: Communicator, fake: bool = False): """ Sync the pseudo gradient from the local process group to the global process group """ _start_time = time.perf_counter() - self.elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False) - world_size_post_init = self.elastic_device_mesh.global_pg.size() - - world_size = world_size_post_init - - self._logger.debug("sync pseudo gradient %s with world size %d", " fake" if fake else "", world_size) - - global_pg = self.elastic_device_mesh.global_pg - for i in range(self.config.retry_all_reduce): - for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): - assert isinstance(param_offloaded.grad, DTensor) - if fake: - param_offloaded.grad.to_local().zero_() - else: - param_offloaded.grad.to_local().copy_(param_offloaded.data.to_local()) - param_offloaded.grad.to_local().sub_(param.data.to_local().to(param_offloaded.data.device)) - try: - self.offloaded_grad_flat_tensor.div_(world_size) - _collective_start_time = time.perf_counter() - self._logger.debug("Waiting on barrier") - self.elastic_device_mesh.monitored_barrier(flag) - - self._logger.debug("Beginning all reduce") - # all_reduce(self.config.compression, self.offloaded_grad_flat_tensor, dist.ReduceOp.SUM, global_pg) - for j, tensor_group in enumerate(self._offloaded_grad_grouped_tensor): - t0 = time.perf_counter() - all_reduce(self.config.compression, tensor_group, dist.ReduceOp.SUM, global_pg) - self._logger.debug( - f"{j}/{len(self._offloaded_grad_grouped_tensor)} all reduce bucket done in {time.perf_counter() - t0:.6f} seconds, numel: {tensor_group.numel()}" - ) + self._logger.debug("sync pseudo gradient %s with world size %d", "fake" if fake else "", self.local_world_info.world_size) + for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): + assert isinstance(param_offloaded.grad, DTensor) + if fake: + param_offloaded.grad.to_local().zero_() + else: + param_offloaded.grad.to_local().copy_(param_offloaded.data.to_local()) + param_offloaded.grad.to_local().sub_(param.data.to_local().to(param_offloaded.data.device)) + try: + _collective_start_time = time.perf_counter() + + self._logger.debug("Beginning all reduce") + reduce_tensors = [self.offloaded_grad_flat_tensor] + for j, tensor_group in enumerate(self._offloaded_grad_grouped_tensor): + t0 = time.perf_counter() + reduce_tensors.append(tensor_group) self._logger.debug( - f"All reduce takes {time.perf_counter() - _collective_start_time:.6f} seconds numels: {self.offloaded_grad_flat_tensor.numel()}" + f"{j}/{len(self._offloaded_grad_grouped_tensor)} all reduce bucket done in {time.perf_counter() - t0:.6f} seconds, numel: {tensor_group.numel()}" ) - break - except Exception as e: - self._logger.error(f"Error syncing pseudo gradient: {e}, retry {i+1}/{self.config.retry_all_reduce}") - global_pg = self.elastic_device_mesh.get_global_pg(maybe_reinit=True) - else: - self._logger.error( - "Failed to sync pseudo gradient after %d retries. Resorting to calculating pseudo-gradient without reduce", - self.config.retry_all_reduce, + all_reduce_multiple_with_retry(communicator, reduce_tensors, ReduceOp.AVG, self.config.compression, max_in_flight=16) + + self._logger.debug( + f"All reduce takes {time.perf_counter() - _collective_start_time:.6f} seconds numels: {self.offloaded_grad_flat_tensor.numel()}" ) - for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): - if fake: - param_offloaded.grad.to_local().zero_() - else: - param_offloaded.grad.to_local().copy_(param_offloaded.data.to_local()) - param_offloaded.grad.to_local().sub_(param.data.to_local().to(param_offloaded.data.device)) + except Exception as e: + self._logger.error(f"Error syncing pseudo gradient: {e}") self._logger.info(f"Sync psuedo-gradient in {time.perf_counter() - _start_time:.6f} seconds") @@ -174,14 +271,14 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: offloaded_param = nn.Parameter( DTensor.from_local( data_tensor, - device_mesh=self.elastic_device_mesh.cpu_local_mesh, + device_mesh=self.cpu_local_mesh, placements=param.data.placements, ) ) offloaded_param.grad = DTensor.from_local( grad_tensor, - device_mesh=self.elastic_device_mesh.cpu_local_mesh, + device_mesh=self.cpu_local_mesh, placements=param.data.placements, ) # here we pre-allocate the grad DTensor on cpu. @@ -201,12 +298,12 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: return offloaded_params @torch.no_grad() - def step(self, model: nn.Module, fake: bool = False, flag: str = "outer"): + def step(self, model: nn.Module, communicator: Communicator, fake: bool = False): """ Step the optimizer """ time_start = time.perf_counter() - self.sync_pseudo_gradient(model, fake=fake, flag=flag) + self.sync_pseudo_gradient(model, communicator, fake=fake) self._logger.info(f"all reduce pseudo gradient in: {time.perf_counter() - time_start} seconds") if self.outer_optimizer is not None: diff --git a/src/zeroband/master.py b/src/zeroband/master.py new file mode 100644 index 00000000..4c91bd29 --- /dev/null +++ b/src/zeroband/master.py @@ -0,0 +1,25 @@ +from pccl import MasterNode +import argparse +import logging + +logging.basicConfig(level=logging.INFO) + + +def main(): + parser = argparse.ArgumentParser(description='PCCL Master Node') + parser.add_argument( + '--listen-address', + type=str, + default='0.0.0.0:48148', + help='Address for the master node to listen on (format: host:port)' + ) + + args = parser.parse_args() + + logging.info(f"Starting master node on {args.listen_address}") + master: MasterNode = MasterNode(listen_address=args.listen_address) + master.run() + + +if __name__ == '__main__': + main() diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 06585bcc..3d3f6633 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -8,8 +8,7 @@ from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy # type: ignore from torch.autograd.profiler import record_function -from zeroband.checkpoint import CkptManager, TrainingProgress -from zeroband.comms import ElasticDeviceMesh +from zeroband.checkpoint import TrainingProgress, CkptManager from zeroband.config import Config, resolve_env_vars from zeroband.data import TEST_VOCAB_SIZE, get_dataloader from zeroband.diloco import Diloco @@ -30,13 +29,16 @@ from zeroband.utils.metric_logger import MetricLogger, WandbMetricLogger, DummyMetricLogger from zeroband.utils.activation_ckpt import apply_ac_ckpt from zeroband.utils.profiler import MemoryProfiler -from zeroband.utils.world_info import get_world_info from zeroband.utils.logger import get_logger from zeroband.utils.stopwatch import Stopwatch from transformers import AutoTokenizer from pydantic_config import parse_argv +from zeroband.utils.world_info import get_local_world_info + +from pccl import Communicator, PCCLError, Attribute, SharedState, TensorInfo + def log_hash_training_state( config: Config, @@ -71,15 +73,15 @@ def log_hash_training_state( metrics.update( {f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash} ) - if world_info.rank == 0: + if local_world_info.rank == 0: assert metric_logger is not None metric_logger.log(metrics) def train(config: Config): # batch_size is the total batch size for all GPUs - assert config.optim.batch_size % world_info.local_world_size == 0 - batch_size = config.optim.batch_size // world_info.local_world_size + assert config.optim.batch_size % local_world_info.local_world_size == 0 + batch_size = config.optim.batch_size // local_world_info.local_world_size assert batch_size % config.train.micro_bs == 0, ( f"The micro batch size ({config.train.micro_bs}) must divide the number of samples on each GPU ({batch_size})." @@ -108,8 +110,8 @@ def train(config: Config): with sw.record_block("Get Dataloader"): train_dataloader = get_dataloader( tokenizer=tokenizer, - world_size=world_info.world_size, - rank=world_info.rank, + world_size=local_world_info.world_size, + rank=local_world_info.rank, batch_size=config.train.micro_bs, data_config=config.data, ) @@ -137,10 +139,6 @@ def train(config: Config): num = 1 if isinstance(config.train.ac_ckpt, bool) else config.train.ac_ckpt apply_ac_ckpt(model, num) - elastic_device_mesh = ElasticDeviceMesh( - enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src - ) - mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None ) @@ -155,14 +153,12 @@ def train(config: Config): fully_shard( transformer_block, mp_policy=mp_policy, - mesh=elastic_device_mesh.cuda_local_mesh, reshard_after_forward=reshard_after_forward, offload_policy=offload_policy, ) fully_shard( model, mp_policy=mp_policy, - mesh=elastic_device_mesh.cuda_local_mesh, reshard_after_forward=config.train.reshard_after_forward, offload_policy=offload_policy, ) @@ -171,7 +167,7 @@ def train(config: Config): with sw.record_block("Optimizer Setup"): inner_optimizer = get_optimizer(config, model.parameters()) - diloco = Diloco(config.diloco, model, elastic_device_mesh) if config.diloco is not None else None + diloco = Diloco(config.diloco, model) if config.diloco is not None else None scheduler = get_scheduler( sched_type=config.optim.sched_type, @@ -181,7 +177,7 @@ def train(config: Config): num_training_steps=config.optim.total_steps, ) - training_progress = TrainingProgress(total_tokens=0, outer_step=0, step=0) + training_progress = TrainingProgress(0, 0, 0) ckpt_manager = CkptManager( config=config.ckpt, @@ -195,11 +191,11 @@ def train(config: Config): diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, # type: ignore ) - if world_info.rank == 0: + if local_world_info.rank == 0: logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger metric_logger = logger_cls( project=config.project, - logger_config={"config": config.model_dump(), "world_info": world_info.json()}, + logger_config={"config": config.model_dump(), "world_info": local_world_info.json()}, resume=config.wandb_resume, ) else: @@ -219,7 +215,7 @@ def train(config: Config): data_path=config.ckpt.data_path, ) log_hash_training_state( - config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="resume" + config, model, inner_optimizer, diloco, metric_logger, step=training_progress.num_performed_inner_steps, id="resume" ) if config.train.memory_profiler is not None: @@ -228,61 +224,69 @@ def train(config: Config): num_inner_steps = config.diloco.inner_steps if config.diloco is not None else 1 perf_counter = PerfCounter(window_size=10) + logger.debug("Connecting to CCoIP master...") + communicator = Communicator(config.ccoip_master_addr) + try: + communicator.connect(n_attempts=config.ccoip_master_connection_attempts) + except PCCLError: + logger.error("Failed to connect to CCoIP master") + raise + logger.debug("Finished setup in %f seconds", sw.elapsed()) - need_live_recovery = config.ckpt.live_recovery_rank_src is not None + local_iter = 0 + world_size: int = communicator.get_attribute(Attribute.CURRENT_WORLD_SIZE) + + num_syncs = 0 + dummy_tensor = torch.zeros(1, device='cpu') + entries = [ + TensorInfo.from_torch(dummy_tensor, "dummy", allow_content_inequality=False) + ] + shared_state: SharedState = SharedState(entries) + while True: if num_inner_steps > 1: # if we don't use diloco we don't print the outer step logs - logger.info(f"outer_step step: {training_progress.outer_step}") + logger.info(f"outer_step step: {training_progress.num_performed_outer_steps}") + + if local_iter > 0: + # keep retrying if it fails + while True: + try: + communicator.update_topology() + break + except PCCLError: + # could be pccl.UpdateTopologyFailed or other + logger.error("Failed to update topology, retrying...") + time.sleep(0.1) + world_size = communicator.get_attribute(Attribute.CURRENT_WORLD_SIZE) + + if world_size < 2: + logger.info("World size is less than 2, waiting for more peers...") + time.sleep(1) + local_iter += 1 + continue + + current_device = torch.cuda.current_device() # refer to .set_device in the main function + + # Perform cuda device synchronization + # if your shared state partially or fully resides on the GPU we must wait until all currently dispatched kernels have completed + # to avoid validating or potentially transmitting data that is currently being in-place modified. + torch.cuda.synchronize(current_device) + + sync_info = communicator.sync_shared_state(shared_state) + num_syncs += 1 + if num_syncs > 1: + # assert sync_info.rx_bytes == 0, "We should not be receiving any data after the initial sync; Otherwise, the peer has drifted" + pass + + if shared_state.revision * num_inner_steps >= config.optim.total_steps: + logger.info("Reached the total number of steps, exiting ...") + break time_start_outer = time.perf_counter() - if config.diloco is not None: - assert diloco is not None - # this is a patch for now to allow live recovery worker to not affect the all reduce at all - - if not need_live_recovery: - elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=True) - - maybe_dest_rank = elastic_device_mesh.live_recovery.should_send_ckpt_to() - if maybe_dest_rank is not None: - logger.info(f"Start live recovery to rank {maybe_dest_rank}") - ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank, blocking=True) - - elastic_device_mesh.live_recovery.reset() - else: - ## receiving - time_start_live_recovery = time.perf_counter() - logger.info(f"Start live recovery from rank {config.ckpt.live_recovery_rank_src}") - - ## we create grad buffer and opts stats mamnually, the value will be overwritten by the ckpt but we need the DTensor to be correctly init before loading it - - diloco.outer_optimizer.step() # need to step to init the DTensor stats - - ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg) - - log_hash_training_state( - config, - model, - inner_optimizer, - diloco, - metric_logger, - step=training_progress.step, - id="live_reco_recv", - ) - need_live_recovery = False - - if config.ckpt.remote_data_load: - ckpt_manager.remote_data_load() - - logger.info("live recovery done in %f", time.perf_counter() - time_start_live_recovery) - - # at the beginning of the inner steps we allow joiner to arrive. - # We maybe reinit before the all reduce but only to allow leaving, not to join anymore - for inner_step in range(num_inner_steps): - logger.debug("Starting inner step.") sw.start("inner_step") loss_batch = 0 @@ -297,8 +301,6 @@ def train(config: Config): model.set_requires_gradient_sync(not is_accumulating) with sw.record_block("Load batch"): - # TODO/NOTE: We could overlap sending the batch with communication - # although to be honest the perf impact is minimal batch = next(train_dataloader_iterator) input_ids = batch["input_ids"] labels = batch["labels"] @@ -342,27 +344,14 @@ def train(config: Config): else: loss_batch += loss.detach().clone() - elapsed = sw.stop("grad_acc_step") - logger.debug(f"Grad acc step {grad_acc_step} completed in {elapsed:.2f} seconds") - with sw.record_block("Loss allreduce()"): # Launch both allreduces at the same time to hide latency - loss_allreduce = dist.all_reduce( - tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True - ) - if config.optim.z_loss: - z_loss_allreduce = dist.all_reduce( - tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True - ) - - assert isinstance(loss_allreduce, torch.distributed.Work) - loss_allreduce.wait() + dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG) if config.optim.z_loss: - assert isinstance(z_loss_allreduce, torch.distributed.Work) - z_loss_allreduce.wait() + dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG) with sw.record_block("Clip Grad"): - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).full_tensor() # type: ignore (is a dtensor) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).full_tensor() with sw.record_block("Optimizer Step"): inner_optimizer.step() @@ -372,7 +361,7 @@ def train(config: Config): inner_optimizer.zero_grad() # logging - training_progress.step += 1 + training_progress.num_performed_inner_steps += 1 inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0] # syncing loss across all data parallel rank within a nodes @@ -380,19 +369,22 @@ def train(config: Config): perf_counter.count_tokens(new_tokens) if config.diloco is None: - training_progress.total_tokens += new_tokens + training_progress.num_trained_tokens += new_tokens else: # we count the total tokens with respect to all diloco workers # might need to tweak this as some worker might fail to join the all reduce later - training_progress.total_tokens += new_tokens * elastic_device_mesh.global_pg.size() + + # this is technically a faulty approximation, but we don't necessarily care + # for what constitutes a high level monitoring summary statistic + training_progress.num_trained_tokens += new_tokens * world_size assert isinstance(loss_batch, torch.Tensor) metrics = { "Loss": loss_batch.item(), - "step": training_progress.step, + "step": training_progress.num_performed_inner_steps, "inner_lr": inner_lr, "Perplexity": torch.exp(loss_batch).item(), - "total_tokens": training_progress.total_tokens, + "total_tokens": training_progress.num_trained_tokens, "time": time.time(), "grad_norm": grad_norm.item(), } @@ -401,21 +393,21 @@ def train(config: Config): assert isinstance(z_loss_batch, torch.Tensor) metrics["z_loss"] = z_loss_batch.item() - log = f"step: {training_progress.step}, loss: {loss_batch.item():.4f}" + log = f"step: {training_progress.num_performed_inner_steps}, loss: {loss_batch.item():.4f}" tokens_per_second = perf_counter.get_tokens_per_second() if tokens_per_second is not None: metrics["tokens_per_second"] = tokens_per_second metrics["mfu"] = ( - 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops / world_info.local_world_size + 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops / local_world_info.local_world_size ) log += f", tokens_per_second: {tokens_per_second:.2f}, mfu: {metrics['mfu']:.2f}" if config.diloco is not None: - metrics["num_peers"] = elastic_device_mesh.global_pg.size() + metrics["num_peers"] = world_size log += f", diloco_peers: {metrics['num_peers']}" - if world_info.rank == 0: + if local_world_info.rank == 0: assert metric_logger is not None metric_logger.log(metrics) @@ -424,32 +416,29 @@ def train(config: Config): if config.train.memory_profiler is not None: memory_profiler.step() - elapsed = sw.stop("inner_step") - logger.debug(f"Inner step {inner_step} completed in {elapsed:.2f} seconds") - if config.diloco is not None: assert diloco is not None time_start_inner = time.perf_counter() - diloco.step(model=model, flag=str(training_progress.outer_step)) + diloco.step(model, communicator) diloco_time = time.perf_counter() - time_start_inner log_hash_training_state( - config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="outer_step" + config, model, inner_optimizer, diloco, metric_logger, step=training_progress.num_performed_inner_steps, id="outer_step" ) - training_progress.outer_step += 1 + training_progress.num_performed_outer_steps += 1 if ( config.ckpt.interval is not None - and training_progress.step > 0 - and training_progress.step % config.ckpt.interval == 0 + and training_progress.num_performed_inner_steps > 0 + and training_progress.num_performed_inner_steps % config.ckpt.interval == 0 ): # we only allow to checkpoint after a outer step. For non diloco training outer step = 1 anyway - do_remote = config.ckpt.remote is not None and training_progress.step % config.ckpt.remote.interval == 0 + do_remote = config.ckpt.remote is not None and training_progress.num_performed_inner_steps % config.ckpt.remote.interval == 0 ckpt_manager.save(remote=do_remote) log_hash_training_state( - config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="save" + config, model, inner_optimizer, diloco, metric_logger, step=training_progress.num_performed_inner_steps, id="save" ) if config.diloco: @@ -459,35 +448,30 @@ def train(config: Config): * config.data.seq_length / (time.perf_counter() - time_start_outer) ) - mfu = 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops / world_info.local_world_size + mfu = 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops / local_world_info.local_world_size logger.info(f"effective mfu: {mfu}") - if world_info.rank == 0: + if local_world_info.rank == 0: assert metric_logger is not None metric_logger.log( { "outer_mfu": mfu, - "step": training_progress.step, - "outer_step": training_progress.outer_step, + "step": training_progress.num_performed_inner_steps, + "outer_step": training_progress.num_performed_outer_steps, "outer_tokens_per_second": tokens_per_second, "all_reduce_step": diloco_time, } ) - if training_progress.step >= config.optim.total_steps: - # we only allow to break outisde of the inner loop. - # This avoid ending the training in the middle of a the inner loop - # Since ckpt strategy and all reduce is done at the outer loop level. - break + shared_state.revision += 1 + local_iter += 1 - if world_info.rank == 0: + if local_world_info.rank == 0: assert metric_logger is not None metric_logger.finish() ckpt_manager.wait_for_blocking_job() - del elastic_device_mesh # allow to clean up for smoother tests transition - if config.train.memory_profiler is not None: logger.debug(f"Max memory used: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") @@ -503,11 +487,11 @@ def train(config: Config): config = Config(**parse_argv()) # type: ignore resolve_env_vars(config) - world_info = get_world_info() + local_world_info = get_local_world_info() logger = get_logger(config) # torch.set_default_device("cuda") - torch.cuda.set_device(world_info.local_rank) + torch.cuda.set_device(local_world_info.local_rank) def pretty_dict(d, indent=2): for key, value in d.items(): @@ -521,7 +505,7 @@ def pretty_dict(d, indent=2): pretty_dict(config.model_dump()) try: - if config.train.torch_profiler and world_info.rank == 0: + if config.train.torch_profiler and local_world_info.rank == 0: # NOTE(apaz-cli): I cannot seem to get the memory profiler to work. # Running into this issue: https://github.com/pytorch/pytorch/issues/64345 # In the meantime, we can use the memory snapshotter. diff --git a/src/zeroband/utils/ip.py b/src/zeroband/utils/ip.py deleted file mode 100644 index 4ec30aa9..00000000 --- a/src/zeroband/utils/ip.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Optional -import socket -import fcntl -import struct - -MULTIPLIER = {"Kbits/sec": 1e3, "Mbits/sec": 1e6, "Gbits/sec": 1e9, "Tbits/sec": 1e12} - - -def parse_iperf_output(output: str) -> Optional[int]: - try: - value, mult = output.strip().split()[-2:] - return int(float(value) * MULTIPLIER[mult]) - except Exception: - return None - - -# Taken from https://stackoverflow.com/questions/24196932/how-can-i-get-the-ip-address-from-a-nic-network-interface-controller-in-python -def get_ip_address(ifname: str) -> str: - """Get the IP address of the specified network interface. - - Args: - ifname (str): The name of the network interface. - Returns: - str: The IP address of the network interface. - """ - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - ret = socket.inet_ntoa( - fcntl.ioctl( - s.fileno(), - 0x8915, # SIOCGIFADDR - struct.pack("256s", ifname.encode("utf-8")[:15]), - )[20:24] - ) - s.close() - return ret diff --git a/src/zeroband/utils/logger.py b/src/zeroband/utils/logger.py index 91050bf0..6cc17477 100644 --- a/src/zeroband/utils/logger.py +++ b/src/zeroband/utils/logger.py @@ -1,7 +1,7 @@ import logging from zeroband.config import Config -from zeroband.utils.world_info import get_world_info +from zeroband.utils.world_info import get_local_world_info logger = None @@ -18,7 +18,7 @@ def __init__(self, local_rank: int): self.local_rank = local_rank def format(self, record): - log_format = "{asctime} [{levelname}] [Rank {local_rank}] {message}" + log_format = "{asctime} [{levelname}] [LocalRank {local_rank}] {message}" formatter = logging.Formatter(log_format, style="{", datefmt="%H:%M:%S") record.local_rank = self.local_rank # Add this line to set the local rank in the record return formatter.format(record) @@ -30,11 +30,11 @@ def get_logger(config: Config | None = None, name: str | None = None) -> logging return logger try: - world_info = get_world_info() + world_info = get_local_world_info() except KeyError: - from zeroband.utils.world_info import WorldInfo + from zeroband.utils.world_info import LocalWorldInfo - world_info = WorldInfo.__new__(WorldInfo) + world_info = LocalWorldInfo.__new__(LocalWorldInfo) world_info.local_rank = 0 logger = logging.getLogger(name or __name__) diff --git a/src/zeroband/utils/profiler.py b/src/zeroband/utils/profiler.py index e6a87b32..498300ec 100644 --- a/src/zeroband/utils/profiler.py +++ b/src/zeroband/utils/profiler.py @@ -2,7 +2,7 @@ import pickle import torch from zeroband.utils.logger import get_logger -from zeroband.utils.world_info import get_world_info +from zeroband.utils.world_info import get_local_world_info _MAX_ENTRIES = 10000 @@ -16,7 +16,7 @@ def __init__(self, freq: int, snapshot_dir: str): torch.cuda.memory._record_memory_history(max_entries=_MAX_ENTRIES) self.freq = freq - self.world_info = get_world_info() + self.world_info = get_local_world_info() self.logger = get_logger() self.step_num = 0 diff --git a/src/zeroband/utils/state_dict_send_recv.py b/src/zeroband/utils/state_dict_send_recv.py deleted file mode 100644 index 66366dd9..00000000 --- a/src/zeroband/utils/state_dict_send_recv.py +++ /dev/null @@ -1,165 +0,0 @@ -import io -import pickle -import torch -from torch.distributed import ProcessGroup -from torch.distributed._tensor.api import DTensor - - -def _object_to_tensor(obj): - f = io.BytesIO() - pickle.Pickler(f).dump(obj) - byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] - # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. - # Otherwise, it will casue 100X slowdown. - # See: https://github.com/pytorch/pytorch/issues/65696 - byte_tensor = torch.ByteTensor(byte_storage) - local_size = torch.LongTensor([byte_tensor.numel()]) - return byte_tensor, local_size - - -def _tensor_to_object(tensor, tensor_size): - tensor = tensor.cpu() - buf = tensor.numpy().tobytes()[:tensor_size] - return pickle.Unpickler(io.BytesIO(buf)).load() - - -def _tensor_to_placeholder(idx: int, tensor: torch.Tensor) -> str: - return f"zeroband_tensor_{idx}_{tensor.shape}_{tensor.dtype}" - - -def _validate_placeholder_to_tensor(placeholder: str, tensors: list[torch.Tensor]) -> torch.Tensor: - """ - validate that the tensor is compatible with the placeholder. - """ - try: - idx, shape, dtype = placeholder.split("_")[2:] - except ValueError as e: - raise ValueError(f"Invalid tensor placeholder {placeholder}") from e - - tensor = tensors[int(idx)] - if shape != str(tensor.shape): - raise ValueError( - f"tensor {idx} try to load a tensor with shape {shape} but the tensor has shape {tensor.shape}" - ) - if dtype != str(tensor.dtype): - raise ValueError( - f"tensor {idx} try to load a tensor with dtype {dtype} but the tensor has dtype {tensor.dtype}" - ) - - return tensor - - -def _get_sendable_state_dict(state_dict: dict) -> tuple[dict, list[torch.Tensor]]: - """ - This function take a state dict (dict with tensor inside) and return a torch.send/recv-able format. - - It splits the state dict into two part : - * a list of tensor - * a dict emptied from tensor - - The order is deterministic. The function can be used in pair with _load_sendable_state_dict - """ - tensors: list[torch.Tensor] = [] - - def _split(state_dict_, tensors_): - new_dict = {} - for key, value in state_dict_.items(): - if isinstance(value, dict): - new_dict[key] = _split(value, tensors_) - elif isinstance(value, torch.Tensor): - idx = len(tensors_) - tensors_.append(value) - new_dict[key] = _tensor_to_placeholder(idx, value) - else: - new_dict[key] = value - - return new_dict - - state_dict = _split(state_dict, tensors) - return state_dict, tensors - - -def _load_sendable_state_dict(tensors: list[torch.Tensor], state_dict: dict) -> dict: - """ - This function take a list of tensor and a state dict and return state dict. - - The function can be used in pair with _get_sendable_state_dict - """ - - def _load(state_dict_): - for key, value in list(state_dict_.items()): # list needed as we modify the state_dict_ as we traverse it - if isinstance(value, dict): - state_dict_[key] = _load(value) - elif isinstance(value, str) and value.startswith("zeroband_tensor_"): - state_dict_[key] = _validate_placeholder_to_tensor(value, tensors) - - return state_dict_ - - return _load(state_dict) - - -def send_state_dict(pg: ProcessGroup, state_dict: dict, dest_rank: int) -> None: - non_tensored_state_dict, tensors = _get_sendable_state_dict(state_dict) - send_tensor_and_state_dict(pg, dest_rank, non_tensored_state_dict, tensors) - - -def send_tensor_and_state_dict(pg: ProcessGroup, dest_rank: int, state_dict: dict, tensors: list[torch.Tensor]) -> None: - # logger = get_logger() - # logger.debug(f"recv tensors {get_tensor_list_signature(tensors)}") - - state_dict_tensor_buffer, size = _object_to_tensor(state_dict) - pg.send([size], dest_rank, 0).wait() - pg.send([state_dict_tensor_buffer], dest_rank, 0).wait() - - jobs = [] - for i, tensor in enumerate(tensors): - buffer = tensor - if isinstance(tensor, DTensor): - buffer = tensor.to_local() - - buffer = buffer.detach().cpu() - - jobs.append(pg.send([buffer], dest_rank, i)) - - for job in jobs: - job.wait() - - -def recv_state_dict(pg: ProcessGroup, src_rank: int, og_state_dict: dict) -> dict: - size = torch.LongTensor(1) - - # Receive object sizes - pg.recv([size], src_rank, 0).wait() - # Tensor to receive serialized objects into. - object_tensor = torch.empty(size.item(), dtype=torch.uint8) - - pg.recv([object_tensor], src_rank, 0).wait() - state_dict = _tensor_to_object(object_tensor, size) - - _, tensors = _get_sendable_state_dict(og_state_dict) - - jobs = [] - datas = [] - for i, tensor in enumerate(tensors): - buffer = tensor - if isinstance(tensor, DTensor): - buffer = tensor.to_local() - - data = torch.empty_like(buffer, device="cpu") - jobs.append(pg.recv([data], src_rank, i)) - datas.append(data) - - for job in jobs: - job.wait() - - for tensor, data in zip(tensors, datas): - if isinstance(tensor, DTensor): - tensor = tensor.to_local() - tensor.copy_(data) - - state_dict = _load_sendable_state_dict(tensors, state_dict) - - # logger = get_logger() - # logger.debug(f"recv tensors {get_tensor_list_signature(tensors)}") - - return state_dict diff --git a/src/zeroband/utils/stopwatch.py b/src/zeroband/utils/stopwatch.py index 2b49d4fb..004b9251 100644 --- a/src/zeroband/utils/stopwatch.py +++ b/src/zeroband/utils/stopwatch.py @@ -17,7 +17,7 @@ def __enter__(self): if self.sw.disabled: return self - self.sw.start_block(message=f"Starting \"{self.prof_name}\"") + self.sw.start_block(message=f'Starting "{self.prof_name}"') return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -26,15 +26,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): if self.sw.disabled: return - self.sw.end_block(format_str=f"Finished \"{self.prof_name}\"") + self.sw.end_block(format_str=f'Finished "{self.prof_name}"') class Stopwatch: def __init__(self, config: Config | None = None): - self.timers: dict[str, dict[str, float]] = {} # Timer name -> {start_time, last_lap_time} - self.stack: list[str] = [] # List timer names in order of last constructed + self.timers: dict[str, dict[str, float]] = {} # Timer name -> {start_time, last_lap_time} + self.stack: list[str] = [] # List timer names in order of last constructed self.logger = get_logger(config) - self.disabled = (config.log_level != "DEBUG") if config else False + self.disabled = True def _resolve_name(self, name: str | None) -> str: if name is None: @@ -48,10 +48,7 @@ def start(self, name: str) -> None: return current_time = time.perf_counter() - self.timers[name] = { - 'start_time': current_time, - 'last_lap_time': current_time - } + self.timers[name] = {"start_time": current_time, "last_lap_time": current_time} self.stack.append(name) def _lap(self, name: str | None = None) -> float: @@ -67,8 +64,8 @@ def _lap(self, name: str | None = None) -> float: raise ValueError(f"Timer '{name}' does not exist") current_time = time.perf_counter() - elapsed = current_time - timer['last_lap_time'] - timer['last_lap_time'] = current_time + elapsed = current_time - timer["last_lap_time"] + timer["last_lap_time"] = current_time return elapsed def start_block(self, message: str | None = None, name: str | None = None) -> None: @@ -83,13 +80,13 @@ def end_block(self, format_str: str | None = None, name: str | None = None) -> N if self.disabled: return - lap_time = self._lap(name) - if not format_str: - return - elif "{" in format_str: - self.logger.debug(format_str.format(name=name, time=lap_time)) - else: - self.logger.debug(f"{format_str} in {lap_time:.2f} seconds") + # lap_time = self._lap(name) + # if not format_str: + # return + # elif "{" in format_str: + # self.logger.debug(format_str.format(name=name, time=lap_time)) + # else: + # self.logger.debug(f"{format_str} in {lap_time:.2f} seconds") def elapsed(self, name: str | None = None) -> float: if self.disabled: @@ -101,7 +98,7 @@ def elapsed(self, name: str | None = None) -> float: raise ValueError(f"Timer '{name}' does not exist") current_time = time.perf_counter() - return current_time - timer['start_time'] + return current_time - timer["start_time"] def stop(self, name: str | None = None) -> float: if self.disabled: @@ -127,4 +124,3 @@ def record_block(self, prof_name: str) -> _RecordBlockContext: start_message is passed as start_block's message. """ return _RecordBlockContext(self, prof_name) - diff --git a/src/zeroband/utils/wget.py b/src/zeroband/utils/wget.py deleted file mode 100644 index 849e504e..00000000 --- a/src/zeroband/utils/wget.py +++ /dev/null @@ -1,20 +0,0 @@ -import subprocess - -import shutil - -def _get_cut_dirs_from_url(url: str) -> int: - return len(url.rstrip().partition("//")[-1].split("/")) - -def wget(source: str, destination: str) -> None: - # logger = get_logger() - cmd = f"wget -r -np -nH --cut-dirs={_get_cut_dirs_from_url(source)} -P {destination} {source}" - - if shutil.which("wget") is None: - raise RuntimeError("wget is required but not found. Please install wget and try again.") - - try: - subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - # logger.error(f"Error output: {e.stderr}") - print(f"Error output: {e.stderr}") - raise e diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index 8027e848..e7144437 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -3,33 +3,30 @@ world_info = None -class WorldInfo: - """This class parse env var about torch world into class variables.""" +class LocalWorldInfo: + """ + Local World information. + The "local world" shall mean the world within the worker that is contributing as one peer to the training run. + PCCL does not have concept of ranks and this information is strictly separate from PCCL related state. + """ world_size: int rank: int - local_rank: int + local_world_size: int + local_rank: int + + num_nodes: int def __init__(self): self.world_size = int(os.environ["WORLD_SIZE"]) self.rank = int(os.environ["RANK"]) self.local_rank = int(os.environ["LOCAL_RANK"]) self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - self.nnodes = self.world_size // self.local_world_size - - self.global_unique_id = os.environ.get("GLOBAL_UNIQUE_ID", None) - self.global_addr = os.environ.get("GLOBAL_ADDR", None) - self.global_port = int(os.environ.get("GLOBAL_PORT")) if "GLOBAL_PORT" in os.environ else None - self.global_world_size = int(os.environ.get("GLOBAL_WORLD_SIZE", 1)) - self.global_rank = int(os.environ.get("GLOBAL_RANK", 0)) + self.num_nodes = self.world_size // self.local_world_size def __repr__(self): - return f"WorldInfo(world_size={self.world_size}, rank={self.rank}, local_rank={self.local_rank}, local_world_size={self.local_world_size}, nnodes={self.nnodes}, global_unique_id={self.global_unique_id}, global_addr={self.global_addr}, global_port={self.global_port}, global_world_size={self.global_world_size}, global_rank={self.global_rank})" - - @property - def diloco_rank(self): - return self.global_rank + return f"WorldInfo(world_size={self.world_size}, rank={self.rank}, local_rank={self.local_rank}, local_world_size={self.local_world_size}, num_nodes={self.num_nodes})" def json(self) -> dict[str, int | str]: return { @@ -37,20 +34,15 @@ def json(self) -> dict[str, int | str]: "rank": self.rank, "local_rank": self.local_rank, "local_world_size": self.local_world_size, - "nnodes": self.nnodes, - "global_unique_id": self.global_unique_id, - "global_addr": self.global_addr, - "global_port": self.global_port, - "global_world_size": self.global_world_size, - "global_rank": self.global_rank, + "num_nodes": self.num_nodes } -def get_world_info() -> WorldInfo: +def get_local_world_info() -> LocalWorldInfo: """ Return a WorldInfo singleton. """ global world_info if world_info is None: - world_info = WorldInfo() - return world_info + world_info = LocalWorldInfo() + return world_info \ No newline at end of file diff --git a/tests/test_dist/test_comms.py b/tests/test_dist/test_comms.py deleted file mode 100644 index 28732949..00000000 --- a/tests/test_dist/test_comms.py +++ /dev/null @@ -1,237 +0,0 @@ -import time -import torch -import torch.distributed as dist -import pytest -from zeroband.comms import ElasticDeviceMesh -import multiprocessing as mp - -pytest.skip("Skipping test file", allow_module_level=True) -# skipping this test for now as they slow down the ci and we are going to remove them anyway - - -@pytest.mark.parametrize("world_size", [2, 8]) -def test_elastic_device_mesh_no_global(world_size: int, random_available_port: int, mock_env): - def foo(**kwargs): - with mock_env(**kwargs): - edm = ElasticDeviceMesh(enable=False) - - rank = int(kwargs["RANK"]) - a = torch.arange(3) * (rank + 1) - dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.local_pg) - sum_ints = world_size * (world_size + 1) // 2 - assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints])) - - dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg) - assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints])) - - del edm - - processes = [] - for rank in range(world_size): - processes.append( - mp.Process( - target=foo, - kwargs={ - "MASTER_ADDR": "localhost", - "MASTER_PORT": str(random_available_port), - "RANK": str(rank), - "WORLD_SIZE": str(world_size), - "LOCAL_RANK": str(rank), - "LOCAL_WORLD_SIZE": str(world_size), - "ZERO_BAND_LOG_LEVEL": "DEBUG", - }, - ) - ) - for p in processes: - p.start() - for p in processes: - p.join() - if p.exitcode != 0: - pytest.fail(f"Process {p.pid} failed with exit code {p.exitcode}") - - -@pytest.mark.parametrize("world_size", [2, 8]) -@pytest.mark.parametrize("global_world_size", [2, 8]) -def test_elastic_device_mesh(world_size: int, global_world_size: int, mock_env): - def foo(**kwargs): - with mock_env(**kwargs): - edm = ElasticDeviceMesh() - - rank = int(kwargs["RANK"]) - a = torch.arange(3) * (rank + 1) - dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.local_pg) - sum_ints = world_size * (world_size + 1) // 2 - assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints])) - - global_rank = int(kwargs["GLOBAL_RANK"]) - a = torch.arange(3) * (global_rank + 1) + rank - dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg) - sum_ints = global_world_size * (global_world_size + 1) // 2 - assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints]) + rank * global_world_size) - - del edm - - global_ports = [i for i in range(21970, 21970 + world_size)] - master_ports = [i for i in range(31000, 31000 + global_world_size)] - processes = [] - for global_rank in range(global_world_size): - for rank in range(world_size): - processes.append( - mp.Process( - target=foo, - kwargs={ - "MASTER_ADDR": "localhost", - "MASTER_PORT": str(master_ports[global_rank]), - "RANK": str(rank), - "WORLD_SIZE": str(world_size), - "LOCAL_RANK": str(rank), - "LOCAL_WORLD_SIZE": str(world_size), - "GLOBAL_UNIQUE_ID": str(global_rank), - "GLOBAL_ADDR": "localhost", - "GLOBAL_PORT": str(global_ports[0]), - "GLOBAL_RANK": str(global_rank), - "GLOBAL_WORLD_SIZE": str(global_world_size), - "ZERO_BAND_LOG_LEVEL": "DEBUG", - }, - ) - ) - for p in processes: - p.start() - for p in processes: - p.join() - if p.exitcode != 0: - pytest.fail(f"Process {p.pid} failed with exit code {p.exitcode}") - - -@pytest.mark.parametrize("world_size", [1, 2]) -@pytest.mark.parametrize("global_world_size", [2, 4]) -def test_elastic_device_mesh_on_off_ramp(world_size: int, global_world_size: int, mock_env): - ready_event = mp.Event() - - def foo(**kwargs): - with mock_env(**kwargs): - test_value = int(kwargs["TEST_VALUE"]) - - edm = ElasticDeviceMesh() - edm.maybe_reinit_global_pg() - assert edm.mesh_count == 0 - assert edm.global_pg.size() == global_world_size - - ready_event.wait() # Wait for bar to signal readiness - time.sleep(0.5) # Give time for bar to queue - - edm.maybe_reinit_global_pg() - assert edm.mesh_count == 0 - assert edm.global_pg.size() == global_world_size - - time.sleep(1) # TODO: I actually don't know why this is necessary - - edm.maybe_reinit_global_pg(admit_joiners=True) - assert edm.mesh_count == 1 - assert edm.global_pg.size() == global_world_size + 1 - - a = torch.arange(3) * (test_value + 1) - sum_ints = global_world_size * (global_world_size + 1) // 2 + 100 - dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg) - assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints])) - - if test_value == 1: - return - time.sleep(2) - edm.maybe_reinit_global_pg() - assert edm.mesh_count == 2 - assert edm.global_pg.size() == global_world_size - - a = torch.arange(3) * (test_value + 1) - sum_ints = global_world_size * (global_world_size + 1) // 2 + 100 - 2 - dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg) - assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints])) - - dist.barrier(edm.global_pg) - - del edm - - def bar(**kwargs): - with mock_env(**kwargs): - test_value = int(kwargs["TEST_VALUE"]) - time.sleep(1) - - ready_event.set() # Signal that we are about to queue - - edm = ElasticDeviceMesh() - assert edm.mesh_count == 1 - assert edm.global_pg.size() == global_world_size + 1 - - a = torch.arange(3) * test_value - sum_ints = global_world_size * (global_world_size + 1) // 2 + 100 - dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg) - assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints])) - - edm.maybe_reinit_global_pg() - assert edm.mesh_count == 2 - assert edm.global_pg.size() == global_world_size - - a = torch.arange(3) * test_value - sum_ints = global_world_size * (global_world_size + 1) // 2 + 100 - 2 - dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg) - assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints])) - - dist.barrier(edm.global_pg) - - del edm - - global_ports = [i for i in range(21970, 21970 + world_size)] - master_ports = [i for i in range(31000, 31000 + global_world_size + 1)] - processes = [] - for global_rank in range(global_world_size): - for rank in range(world_size): - processes.append( - mp.Process( - target=foo, - kwargs={ - "MASTER_ADDR": "localhost", - "MASTER_PORT": str(master_ports[global_rank]), - "RANK": str(rank), - "WORLD_SIZE": str(world_size), - "LOCAL_RANK": str(rank), - "LOCAL_WORLD_SIZE": str(world_size), - "GLOBAL_UNIQUE_ID": str(global_rank), - "GLOBAL_ADDR": "localhost", - "GLOBAL_PORT": str(global_ports[0]), - "GLOBAL_RANK": str(global_rank), - "GLOBAL_WORLD_SIZE": str(global_world_size), - "ZERO_BAND_LOG_LEVEL": "DEBUG", - "ZERO_BAND_LOG_ALL_RANK": "true", - "TEST_VALUE": str(global_rank), - }, - ) - ) - - for rank in range(world_size): - processes.append( - mp.Process( - target=bar, - kwargs={ - "MASTER_ADDR": "localhost", - "MASTER_PORT": str(master_ports[global_world_size]), - "RANK": str(rank), - "WORLD_SIZE": str(world_size), - "LOCAL_RANK": str(rank), - "LOCAL_WORLD_SIZE": str(world_size), - "GLOBAL_UNIQUE_ID": "A", - "GLOBAL_ADDR": "localhost", - "GLOBAL_PORT": str(global_ports[0]), - "GLOBAL_RANK": "100", - "GLOBAL_WORLD_SIZE": str(global_world_size), - "ZERO_BAND_LOG_LEVEL": "DEBUG", - "TEST_VALUE": "100", - }, - ) - ) - - for p in processes: - p.start() - for p in processes: - p.join() - if p.exitcode != 0: - pytest.fail(f"Process {p.pid} failed with exit code {p.exitcode}") diff --git a/tests/test_dist/test_diloco.py b/tests/test_dist/test_diloco.py deleted file mode 100644 index ba71f107..00000000 --- a/tests/test_dist/test_diloco.py +++ /dev/null @@ -1,64 +0,0 @@ -"""test Diloco.""" - -import multiprocessing -import pytest - -import torch -import torch.distributed as dist -from torch.distributed.fsdp import ShardingStrategy - -from zeroband.diloco import Diloco, DilocoConfig - - -@pytest.mark.skip("test failed since introduce of custom all reduce") -@pytest.mark.parametrize("world_size", [2]) # [1, 2]) -def test_diloco_all_reduce(world_size, random_available_port, dist_environment): - """ - In this test we manually create a inner model and a outer model where we control the weight: - inner has weight: (rank + 1) / 2 - outer has weight: (rank + 1) - - since we know the world_size we can predict the results of the all reduce of the pseudo gradient and therefore test - if it is done correclty. - """ - - class FakeElasticDeviceMesh: - def __init__(self): - self.global_pg = dist.new_group(backend="gloo") - - def maybe_reinit_global_pg(self, *args, **kwargs) -> None: ... - - def all_reduce(rank: int, world_size: int): - with dist_environment(random_available_port, rank=rank, world_size=world_size, global_unique_id=str(rank)): - diloco_config = DilocoConfig(inner_steps=10) - - model = torch.nn.Linear(10, 10) - - # init param to rank + 1 - for param in model.parameters(): - param.data = (rank + 1) * torch.ones_like(param.data).to("cuda") - - diloco = Diloco(diloco_config, model, ShardingStrategy.FULL_SHARD, FakeElasticDeviceMesh()) - - # simulate inner model updates - for param in model.parameters(): - param.data = (rank + 1) / 2 * torch.ones_like(param.data).to("cuda") - - diloco.sync_pseudo_gradient(model) - - for param in diloco.param_list_cpu: - print(f"param.grad.mean() {param.grad.mean()}") - target = ( - torch.ones_like(param.grad) - * sum([(rank + 1) - (rank + 1) / 2 for rank in range(world_size)]) - / world_size - ) - assert param.grad.mean() == target.mean() - - processes = [multiprocessing.Process(target=all_reduce, args=(rank, world_size)) for rank in range(world_size)] - for p in processes: - p.start() - for p in processes: - p.join() - if p.exitcode != 0: - pytest.fail(f"Process {p.pid} failed with exit code {p.exitcode}") diff --git a/tests/test_dist/test_send_state_dict.py b/tests/test_dist/test_send_state_dict.py deleted file mode 100644 index e4e1f22f..00000000 --- a/tests/test_dist/test_send_state_dict.py +++ /dev/null @@ -1,110 +0,0 @@ -import os -import pytest -import torch -from zeroband.comms import ElasticDeviceMesh -from zeroband.utils.state_dict_send_recv import ( - _get_sendable_state_dict, - _load_sendable_state_dict, - recv_state_dict, - send_state_dict, -) -import multiprocessing as mp - - -def test_load_state_dict(): - state_dict_to_send = { - "step": 0, - "world": "karl is having his best life", - "optim_sates": torch.ones(10), - "nested_data": {"foo": "bar", "tensor": torch.ones(10)}, - } - - state_dict_copy = { - "step": 0, - "world": "karl is having his best life", - "optim_sates": torch.ones(10), - "nested_data": {"foo": "bar", "tensor": torch.ones(10)}, - } - - non_tensored_state_send, tensors_send = _get_sendable_state_dict(state_dict_to_send) - - assert isinstance(non_tensored_state_send["optim_sates"], str) - assert non_tensored_state_send["optim_sates"].startswith("zeroband_tensor") - - print(len(tensors_send)) - print(non_tensored_state_send) - _load_sendable_state_dict(tensors_send, non_tensored_state_send) - - assert (state_dict_to_send["optim_sates"] == state_dict_copy["optim_sates"]).all() - assert id(state_dict_to_send["optim_sates"]) != id(state_dict_copy["optim_sates"]) - - assert (state_dict_to_send["nested_data"]["tensor"] == state_dict_copy["nested_data"]["tensor"]).all() - assert id(state_dict_to_send["nested_data"]["tensor"]) != id(state_dict_copy["nested_data"]["tensor"]) - - assert state_dict_to_send["step"] == state_dict_copy["step"] - assert state_dict_to_send["world"] == state_dict_copy["world"] - assert state_dict_to_send["nested_data"]["foo"] == state_dict_copy["nested_data"]["foo"] - - -@pytest.mark.skip(reason="hang") -@pytest.mark.parametrize("world_size", [2]) -def test_send_recv_state_dict(world_size: int, random_available_port: int, mock_env): - def foo(**kwargs): - with mock_env(**kwargs): - edm = ElasticDeviceMesh() - - state_dict_to_send = { - "step": 0, - "world": "karl is having his best life", - "optim_sates": torch.ones(10), - "nested_data": {"foo": "bar", "tensor": torch.ones(10)}, - } - - state_dict_to_recv = { - "step": 10, - "world": "karl is in holiday", - "optim_sates": torch.zeros(10), - "nested_data": {"foo": "barman", "tensor": torch.zeros(10)}, - } - - rank = int(os.environ.get("RANK")) - - if rank == 0: - send_state_dict(state_dict_to_send, 1, world_size) - else: - state_dict = recv_state_dict(pg=edm.global_pg, rank=0, world_size=world_size) - - assert (state_dict["optim_sates"] == state_dict_to_recv["optim_sates"]).all() - assert id(state_dict["optim_sates"]) != id(state_dict_to_recv["optim_sates"]) - - assert (state_dict["nested_data"]["tensor"] == state_dict_to_recv["nested_data"]["tensor"]).all() - assert id(state_dict["nested_data"]["tensor"]) != id(state_dict_to_recv["nested_data"]["tensor"]) - - assert state_dict["step"] == state_dict_to_recv["step"] - assert state_dict["world"] == state_dict_to_recv["world"] - assert state_dict["nested_data"]["foo"] == state_dict_to_recv["nested_data"]["foo"] - - del edm - - processes = [] - for rank in range(world_size): - processes.append( - mp.Process( - target=foo, - kwargs={ - "MASTER_ADDR": "localhost", - "MASTER_PORT": str(random_available_port), - "RANK": str(rank), - "WORLD_SIZE": str(world_size), - "LOCAL_RANK": str(rank), - "LOCAL_WORLD_SIZE": str(world_size), - "ZERO_BAND_LOG_LEVEL": "DEBUG", - }, - ) - ) - for p in processes: - p.start() - for p in processes: - p.join() - if p.exitcode != 0: - pytest.fail(f"Process {p.pid} failed with exit code {p.exitcode}") diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index 58607ad3..ce63ef95 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -6,18 +6,18 @@ import pytest import socket -from zeroband.diloco import Compression - import torch +from zeroband.config import Compression + num_gpu = torch.cuda.device_count() -def get_random_available_port_list(num_port): +def get_random_available_port_list(num_ports): # https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number ports = [] - while len(ports) < num_port: + while len(ports) < num_ports: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) new_port = s.getsockname()[1] @@ -146,42 +146,42 @@ def test_ckpt(tmp_path: Path, soap: bool): num_gpus, "debug/diloco.toml", extra_args=[ - "--project", - str(v1_file), - "--ckpt.path", - str(v1_ckpt), - "--ckpt.interval", - "5", - "--optim.total_steps", - "20", - "--train.log_model_hash", - "--no-data.sequence_packing", - "--train.attn_fn", - "math", - ] - + (["--optim.optim.precondition_frequency", "1"] if soap else []), + "--project", + str(v1_file), + "--ckpt.path", + str(v1_ckpt), + "--ckpt.interval", + "5", + "--optim.total_steps", + "20", + "--train.log_model_hash", + "--no-data.sequence_packing", + "--train.attn_fn", + "math", + ] + + (["--optim.optim.precondition_frequency", "1"] if soap else []), diloco=True, ) _test_multi_gpu( num_gpus, "debug/diloco.toml", extra_args=[ - "--project", - str(v2_file), - "--ckpt.path", - str(v2_ckpt), - "--ckpt.interval", - "5", - "--ckpt.resume", - str(v1_ckpt / "step_5"), - "--optim.total_steps", - "20", - "--train.log_model_hash", - "--no-data.sequence_packing", - "--train.attn_fn", - "math", - ] - + (["--optim.optim.precondition_frequency", "1"] if soap else []), + "--project", + str(v2_file), + "--ckpt.path", + str(v2_ckpt), + "--ckpt.interval", + "5", + "--ckpt.resume", + str(v1_ckpt / "step_5"), + "--optim.total_steps", + "20", + "--train.log_model_hash", + "--no-data.sequence_packing", + "--train.attn_fn", + "math", + ] + + (["--optim.optim.precondition_frequency", "1"] if soap else []), diloco=True, ) # _test_multi_gpu( diff --git a/third_party/gloo b/third_party/gloo deleted file mode 160000 index 5354032e..00000000 --- a/third_party/gloo +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 5354032ea08eadd7fc4456477f7f7c6308818509 diff --git a/uv.lock b/uv.lock index 2ddf34af..fb907d12 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13' and sys_platform == 'linux'", @@ -765,6 +766,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, ] +[[package]] +name = "ipaddress" +version = "1.0.23" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b9/9a/3e9da40ea28b8210dd6504d3fe9fe7e013b62bf45902b458d1cdc3c34ed9/ipaddress-1.0.23.tar.gz", hash = "sha256:b7f8e0369580bb4a24d5ba1d7cc29660a4a6987763faf1d8a8046830e020e7e2", size = 32958 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/f8/49697181b1651d8347d24c095ce46c7346c37335ddc7d255833e7cde674d/ipaddress-1.0.23-py2.py3-none-any.whl", hash = "sha256:6e0f4a39e66cb5bb9a137b00276a2eff74f93b71dcbdad6f10ff7df9d3557fcc", size = 18159 }, +] + [[package]] name = "jinja2" version = "3.1.5" @@ -1465,6 +1475,16 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/50/14/c5a0e1a947909810fc4c043b84cac472b70e438148d34f5393be1bac663f/pathvalidate-3.2.3-py3-none-any.whl", hash = "sha256:5eaf0562e345d4b6d0c0239d0f690c3bd84d2a9a3c4c73b99ea667401b27bee1", size = 24130 }, ] +[[package]] +name = "pccl" +version = "0.1.0" +source = { git = "https://github.com/PrimeIntellect-ai/pccl.git?subdirectory=python%2Fframework&rev=main#17109a50dfbec0c461239dca6bfe408a842ae1cd" } +dependencies = [ + { name = "cffi" }, + { name = "ipaddress" }, + { name = "pycparser" }, +] + [[package]] name = "peft" version = "0.14.0" @@ -2864,6 +2884,7 @@ dependencies = [ { name = "liger-kernel-nightly" }, { name = "ninja" }, { name = "numpy" }, + { name = "pccl" }, { name = "psutil" }, { name = "pyarrow" }, { name = "pydantic-config" }, @@ -2898,6 +2919,7 @@ requires-dist = [ { name = "lm-eval", marker = "extra == 'all'" }, { name = "ninja" }, { name = "numpy" }, + { name = "pccl", git = "https://github.com/PrimeIntellect-ai/pccl.git?subdirectory=python%2Fframework&rev=main" }, { name = "psutil" }, { name = "pyarrow" }, { name = "pydantic-config", git = "https://github.com/samsja/pydantic_config.git?rev=b7becc3" }, @@ -2910,6 +2932,7 @@ requires-dist = [ { name = "wandb", marker = "extra == 'all'" }, { name = "zstandard" }, ] +provides-extras = ["all"] [package.metadata.requires-dev] dev = [