diff --git a/journey/understanding_tensor_parallel/0_naive_tensor_parallel.py b/journey/understanding_tensor_parallel/0_naive_tensor_parallel.py new file mode 100644 index 0000000..7f7f0ba --- /dev/null +++ b/journey/understanding_tensor_parallel/0_naive_tensor_parallel.py @@ -0,0 +1,134 @@ +import os +import math +import random +import numpy as np +from copy import deepcopy +from typing import List, Dict +import argparse + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + + +def set_seed(seed=1234): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def init_dist(): + rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size) + print(f"rank: {rank}, world size: {world_size}") + return rank, world_size + +def print_message_with_master_process(rank, message): + if rank==0: + print(message) + +class DummyModel(torch.nn.Module): + def __init__(self, hidden, bias=False): + super(DummyModel, self).__init__() + assert bias == False, "currently bias is not supported" + self.fc1 = torch.nn.Linear(hidden, hidden, bias=bias) # for Colwise, 128, 128 + self.fc2 = torch.nn.Linear(hidden, hidden, bias=bias) # for Rowwise, 128, 128 + + def forward(self, x): + return self.fc2(torch.relu(self.fc1(x))) + +def colwise_backward(self, grad_output): + grad_input = grad_output.mm(self.weight.t()) + dist.all_reduce(grad_input, op=dist.ReduceOp.SUM) # addmm + return grad_input + +def rowwise_forward(self, x): + bias = self.bias if self.bias else None + x = F.linear(x, self.weight, bias) + dist.all_reduce(x, op=dist.ReduceOp.SUM) + return x + +def parallelize_module( + model: torch.nn.Module, + world_size: int, + rank: int, + layer_tp_plan: Dict +): + assert world_size > 1, "need at least two devices for TP" + + for name, module in model.named_children(): + if name in layer_tp_plan: + assert layer_tp_plan[name] in ['colwise', 'rowwise'], "plan should be colwise or rowwise" + + ''' + for example, weight of column wise parallel linear layer should be splitted into row-wise + because pytorch implementation of linear layer is X = XW^T (F.linear(x, self.weight, bias)) + ''' + if layer_tp_plan[name] == 'rowwise': + assert module.weight.size(1) % world_size == 0 + chunk_size = module.weight.size(1)//world_size # e.g. world_size = 2, rank = 0, 1 + module.weight.data = module.weight.data[:, chunk_size*rank: chunk_size*(rank+1)].contiguous() # weight 128, 16 // input 10, 128 + module.forward = rowwise_forward.__get__(module) + + elif layer_tp_plan[name] == 'colwise': + assert module.weight.size(0) % world_size == 0 + chunk_size = module.weight.size(0)//world_size + module.weight.data = module.weight.data[chunk_size*rank: chunk_size*(rank+1), :].contiguous() # weight 16, 128 // input 10, 16 + module.backward = colwise_backward.__get__(module) + + +def main(args): + rank, world_size = init_dist() + device = f"cuda:{rank}" + bsz, hidden = 8, 128 + num_iter, lr = 2, 0.01 + + ## create model and parallelize if TP + set_seed() + model = DummyModel(hidden).to(device).train() + if args.TP: + layer_tp_plan = { + "fc1": 'colwise', + "fc2": 'rowwise', + } + parallelize_module(model, world_size, rank, layer_tp_plan) + optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.1) + print_message_with_master_process(rank, f'model: {model}') + + ## create dummy input + set_seed() + x = torch.randn(bsz, hidden).to(device) + + ## for loop + for iter in range(num_iter): + output = model(x) + loss = output.sum() + loss.backward() + + ## get gathered gradient results + if args.TP: + fc1_grad = [torch.zeros_like(model.fc1.weight, dtype=torch.float32) for _ in range(world_size)] + dist.all_gather(fc1_grad, model.fc1.weight.grad) + fc1_grad = torch.cat(fc1_grad, dim=0) + else: + fc1_grad = model.fc1.weight.grad + + optimizer.step() + optimizer.zero_grad() + + ## print outputs + message = f''' + iter: {iter+1} + output: {output} + loss: {loss} + fc1_grad = {fc1_grad} + ''' + print_message_with_master_process(rank, message) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--TP', action='store_true') + args, _ = parser.parse_known_args() + main(args) \ No newline at end of file diff --git a/journey/understanding_tensor_parallel/1_transformer_tensor_parallel.py b/journey/understanding_tensor_parallel/1_transformer_tensor_parallel.py new file mode 100644 index 0000000..6c9974c --- /dev/null +++ b/journey/understanding_tensor_parallel/1_transformer_tensor_parallel.py @@ -0,0 +1,371 @@ +import os +import math +import random +import numpy as np +from copy import deepcopy +from typing import List, Dict +import argparse + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from torch_profiler_utils import ContextManagers, get_torch_profiler + +from pdb import set_trace as Tra + + +def set_seed(seed=1234): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def init_dist(): + rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size) + print(f"rank: {rank}, world size: {world_size}") + return rank, world_size + +def print_message_with_master_process(rank, message): + if rank==0: + print(message) + +''' +adapted from karpathy +https://github.com/karpathy/nanoGPT/blob/master/model.py +''' +class Attention(nn.Module): + def __init__(self, hidden, nhead, bias=False): + super(Attention, self).__init__() + assert hidden % nhead == 0, "hidden size should be divisible by nhead" + self.dhead = hidden // nhead + self.q_proj = nn.Linear(hidden, hidden, bias=bias) + self.k_proj = nn.Linear(hidden, hidden, bias=bias) + self.v_proj = nn.Linear(hidden, hidden, bias=bias) + self.o_proj = nn.Linear(hidden, hidden, bias=bias) + + def forward(self, x): + B, T, C = x.size() + q = self.q_proj(x).view(B, T, -1, self.dhead).transpose(1, 2).contiguous() # B, nhead, T, dhead + k = self.k_proj(x).view(B, T, -1, self.dhead).transpose(1, 2).contiguous() # B, nhead, T, dhead + v = self.v_proj(x).view(B, T, -1, self.dhead).transpose(1, 2).contiguous() # B, nhead, T, dhead + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True) + x = x.transpose(1, 2).contiguous().view(B, T, -1) + return self.o_proj(x) + +class MLP(nn.Module): + def __init__(self, hidden, bias=False): + super(MLP, self).__init__() + self.ffn1 = nn.Linear(hidden, 4*hidden, bias) + self.act = nn.GELU() + self.ffn2 = nn.Linear(4*hidden, hidden, bias) + + def forward(self, x): + return self.ffn2(self.act(self.ffn1(x))) + +class LayerNorm(nn.Module): + def __init__(self, hidden, bias=False): + super(LayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden)) + self.bias = nn.Parameter(torch.zeros(hidden)) if bias else None + + def forward(self, x): + return F.layer_norm(x.float(), self.weight.shape, self.weight, self.bias, 1e-5).type_as(x) + +class ResidualBlock(nn.Module): + def __init__(self, hidden, nhead, bias=False): + super(ResidualBlock, self).__init__() + self.ln1 = LayerNorm(hidden, bias) + self.attn = Attention(hidden, nhead, bias) + self.ln2 = LayerNorm(hidden, bias) + self.mlp = MLP(hidden, bias) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + return x + self.mlp(self.ln2(x)) + +class Transformer(nn.Module): + def __init__(self, vocab_size, block_size, hidden, nhead, nlayer, bias=False): + super(Transformer, self).__init__() + assert bias == False, "currently bias is not supported" + self.vocab_size = vocab_size + self.nhead = nhead + self.model = nn.ModuleDict( + dict( + wte = nn.Embedding(vocab_size, hidden), # long tensor -> 3d tensor -> channel dim 쪼개 + wpe = nn.Embedding(block_size, hidden), + h = nn.ModuleList([ResidualBlock(hidden, nhead, bias) for _ in range(nlayer)]), + ln = LayerNorm(hidden, bias=bias), + ) + ) + self.lm_head = nn.Linear(hidden, vocab_size, bias=bias) + self.model.wte.weight = self.lm_head.weight # for pure megatron implementation, we automatically tie embedding + + def compute_loss(self, z, y, ignore_index=-100, reduction='mean'): + return F.cross_entropy(z, y, ignore_index=ignore_index, reduction=reduction) + + def forward(self, x, y): + B, T = x.size() + pos = torch.arange(0, T, dtype=torch.long, device=x.device) + x = self.model.wte(x) + self.model.wpe(pos) + for block in self.model.h: + x = block(x) + x = self.model.ln(x) + z = self.lm_head(x).float() # projection to logit space and upcast + z = z[..., :-1, :].contiguous().view(B*(T-1), -1) # B*T, C + y = y.view(-1) # B*T, 1 + return self.compute_loss(z, y), z + +class g(torch.autograd.Function): + def forward(ctx, x): + dist.all_reduce(x, op=dist.ReduceOp.SUM) + return x + def backward(ctx, dx): + return dx + +class f(torch.autograd.Function): + def forward(ctx, x): + return x + def backward(ctx, dx): + dist.all_reduce(dx, op=dist.ReduceOp.SUM) + return dx + +def rowwise_forward(self, x): + bias = self.bias if self.bias else None + x = F.linear(x, self.weight, bias) + return g.apply(x) + +def colwise_forward(self, x): + bias = self.bias if self.bias else None + x = f.apply(x) + return F.linear(x, self.weight, bias) + +''' +Refrences for vocab parallel (but it's not exactly same) +https://github.com/NVIDIA/Megatron-LM/blob/2d487b1871ba64ef1625781ea05715af1bc0d8ee/megatron/core/tensor_parallel/cross_entropy.py#L121-L126 +https://github.com/NVIDIA/Megatron-LM/blob/e8f8e63f13a074f7e35d72c8bfb3e1168cd84e8e/megatron/core/tensor_parallel/layers.py#L151 +https://github.com/pytorch/pytorch/blob/5ed3b70d09a4ab2a5be4becfda9dd0d3e3227c39/torch/distributed/tensor/parallel/loss.py#L126 +https://github.com/pytorch/pytorch/blob/41e653456e4a96b43ea96c9cd3cddc63ea74711d/torch/ao/nn/qat/modules/embedding_ops.py#L11 +https://github.com/mgmalek/efficient_cross_entropy/blob/main/modules.py +''' + +def get_mask_and_masked_input(x, vocab_start_index, vocab_end_index): + x_mask = (x < vocab_start_index) | (x >= vocab_end_index) + x = x.clone() - vocab_start_index + x[x_mask] = 0 + return x, x_mask + +class LossParallel_: + def get_logit_max(z): + return torch.max(z.float(), dim=-1)[0] + + def get_exp(z, z_max): + z -= z_max.unsqueeze(dim=-1) + exp = torch.exp(z) # B*T, C + sum_exp = torch.sum(exp, dim=-1, keepdim=True) # B*T, 1 + return z, exp, sum_exp + + def get_one_hot(y, z, vocab_start_index, vocab_end_index): + y, y_mask = get_mask_and_masked_input(y, vocab_start_index, vocab_end_index) + y = F.one_hot(y, num_classes=z.size(1)) + y.masked_fill_(y_mask.unsqueeze(-1), 0.0) + return y, y_mask + + def get_nll_loss(z, y, exp, sum_exp, y_one_hot, y_mask, ignore_index, reduction): + # compute loss using log sum exponential trick # https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/ + log_sum_exp = torch.log(sum_exp) # normalizer + log_sum_exp.masked_fill_(y_mask.unsqueeze(-1), 0.0) + gt_z = torch.sum(z * y_one_hot, dim=1) + + # Compute the loss + divisor = 1 if reduction == 'sum' else (y!=ignore_index).sum() + loss = (log_sum_exp.squeeze(1) - gt_z) / divisor + loss = torch.where(y == ignore_index, torch.tensor(0.0, device=z.device), loss) # token-level loss + loss = loss.sum() + return loss, divisor + +class LossParallel(torch.autograd.Function): + def forward(ctx, z, y, vocab_start_index, vocab_end_index, ignore_index=-100, reduction='mean'): + + # communicate max logit value for numerical stability + z_max = LossParallel_.get_logit_max(z) # B*T, C + dist.all_reduce(z_max, op=dist.ReduceOp.MAX) # max + + # get numerical stable exponentiated vectors + z, exp_z, sum_exp_z = LossParallel_.get_exp(z, z_max) + dist.all_reduce(sum_exp_z, op=dist.ReduceOp.SUM) + + # compute loss and reduce all + y_one_hot, y_mask = LossParallel_.get_one_hot(y, z, vocab_start_index, vocab_end_index) + loss, divisor = LossParallel_.get_nll_loss(z, y, exp_z, sum_exp_z, y_one_hot, y_mask, ignore_index, reduction) + dist.all_reduce(loss, op=dist.ReduceOp.SUM) # mean and sum loss + + # store results for backward + ctx.save_for_backward(exp_z.div_(sum_exp_z), y_one_hot, divisor) + return loss + + def backward(ctx, grad_output): + y_hat, y_one_hot, divisor = ctx.saved_tensors + dz = y_hat - y_one_hot # logit gradient + dz /= divisor # dL/dLogit + dz *= grad_output # 1.0 because it's end + return dz, None, None, None, None, None # No gradients needed for y, ignore_index, or reduction parameters + +def embedding_parallel(self, x): + x, x_mask = get_mask_and_masked_input(x, self.vocab_start_index, self.vocab_end_index) + x = F.embedding(x, self.weight) + x.masked_fill_(x_mask.unsqueeze(-1), 0.0) + return g.apply(x) # because readout layer is col-wise, embedding layer is row-wise + +def parallelize_module( + args, + model: nn.Module, + world_size: int, + rank: int, +): + assert world_size > 1, "need at least two devices for TP" + colwise_list = ['q_proj', 'k_proj', 'v_proj', 'ffn1'] + rowwise_list = ['o_proj', 'ffn2'] + + for name, module in model.named_children(): + if isinstance(module, nn.Module): + parallelize_module(args, module, world_size, rank) + + ''' + pytorch impl matmul with transposed weight matrix, + so you should slice weight matrix counter-intuitively. + ''' + for _ in rowwise_list: + if _ in name.lower(): + assert module.weight.size(1) % world_size == 0 + chunk_size = module.weight.size(1)//world_size + module.weight.data = module.weight.data[:, chunk_size*rank: chunk_size*(rank+1)].contiguous() + module.forward = rowwise_forward.__get__(module) + for _ in colwise_list: + if _ in name.lower(): + assert module.weight.size(0) % world_size == 0 + chunk_size = module.weight.size(0)//world_size + module.weight.data = module.weight.data[chunk_size*rank: chunk_size*(rank+1), :].contiguous() + module.forward = colwise_forward.__get__(module) + + ''' + you should slice embedding weight matrix col-wise (vocab dimension), + so you need to perform softmax operation across sliced vocab dim. + and because original megatron paper tie embedding and unembedding matrices, you should care this too. + ''' + if args.loss_parallel: + if 'lm_head' in name.lower() or 'wte' in name.lower(): + ## TODO: need vocab padding + chunk_size = module.weight.size(0)//world_size + vocab_start_index = chunk_size*rank + vocab_end_index = chunk_size*(rank+1) + + if 'lm_head' in name.lower(): + module.weight.data = module.weight.data[vocab_start_index:vocab_end_index, :].contiguous() + module.forward = colwise_forward.__get__(module) + def loss_parallel(x, y, ignore_index=-100, reduction='mean'): + return LossParallel.apply(x, y, vocab_start_index, vocab_end_index, ignore_index, reduction) + model.compute_loss = loss_parallel + + elif 'wte' in name.lower(): + module.vocab_start_index = vocab_start_index + module.vocab_end_index = vocab_end_index + module.forward = embedding_parallel.__get__(module) + +def get_dummy_input( + vocab_size, + device, + batch_size=256, + seq_len=1024, +): + num_pad_tokens = seq_len//10 + input_ids = torch.randint(vocab_size, (batch_size, seq_len)) + labels = torch.cat((input_ids[:, 1:seq_len-num_pad_tokens], torch.full((batch_size, num_pad_tokens), -100)),1) + return { + 'input_ids': input_ids.to(device), + 'labels': labels.to(device), + } + +def main(args): + rank, world_size = init_dist() + device = f"cuda:{rank}" + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained('gpt2') + vocab_size = len(tokenizer) + block_size = tokenizer.model_max_length + hidden, nhead, nlayer = args.hidden, 8, 2 + + set_seed() + model = Transformer(vocab_size, block_size, hidden, nhead, nlayer).to(device).train() + if args.TP: + assert model.nhead % world_size == 0, "nhead should be divisible by TP degree" + parallelize_module(args, model, world_size, rank) + else: + if args.loss_parallel: + def loss_parallel(x, y, ignore_index=-100, reduction='mean'): + return LossParallel.apply(x, y, 0, vocab_size, ignore_index, reduction) + model.compute_loss = loss_parallel + lr = 0.01 + optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.1) + + if args.batch_size and args.seq_len: + input_ids = get_dummy_input(vocab_size-1, device, args.batch_size, args.seq_len) + else: + sent = "i love tensor parallelism." + input_ids = tokenizer(sent, return_tensors='pt').to(device) + input_ids['labels'] = input_ids['input_ids'][:, 1:] + + if args.use_torch_profiler: + num_wait_steps, num_warmup_steps, num_active_steps, num_repeat = 1, 2, 3, 1 + num_iter = int((num_wait_steps + num_warmup_steps + num_active_steps)*num_repeat) + context = [ + get_torch_profiler( + num_wait_steps=num_wait_steps, + num_warmup_steps=num_warmup_steps, + num_active_steps=num_active_steps, + num_repeat=num_repeat, + save_dir_name=f'TP_{args.TP}_world_size_{world_size}_hidden_{hidden}' + ) + ] + else: + num_iter = 5 + context = [] + + with ContextManagers(context) as p: + for iter in range(num_iter): + loss, z = model(input_ids['input_ids'], input_ids['labels']) + z.retain_grad() + loss.backward() + + message = f''' + iter: {iter+1} + input size: {input_ids['input_ids'].size()} + num padding toekns: {(input_ids['labels'] == -100).sum()} + loss: {loss} + ''' + # message += f''' + # z.grad: {z.grad} + # ''' + + optimizer.step() + optimizer.zero_grad() + + print_message_with_master_process(rank, message) + if args.use_torch_profiler: + p.step() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', default=None, type=int) + parser.add_argument('--seq_len', default=None, type=int) + + parser.add_argument('--hidden', default=256, type=int) + parser.add_argument('--TP', action='store_true') + parser.add_argument('--loss_parallel', action='store_true') + parser.add_argument('--use_torch_profiler', action='store_true') + args, _ = parser.parse_known_args() + main(args) diff --git a/journey/understanding_tensor_parallel/README.md b/journey/understanding_tensor_parallel/README.md new file mode 100644 index 0000000..3e53bef --- /dev/null +++ b/journey/understanding_tensor_parallel/README.md @@ -0,0 +1,397 @@ +# References + +- Papers + - [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) + - [Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/abs/2205.05198) +- Others + - [pytorch.org/tutorials/intermediate/TP_tutorial.html](https://pytorch.org/tutorials/intermediate/TP_tutorial.html) + - [lightning.ai/docs/pytorch/stable/advanced/model_parallel/tp.html](https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/tp.html) + - [pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) + + +# Goal + +- [x] naive impl on MLP +- [x] transformer (using Autograd) +- [x] vocab parallel (loss parallel) +- [ ] sequence parallel (TBC) + +# Examples (run scripts) + +## naive impl on MLP + +![simple_mlp_TP](./assets/images/simple_mlp_TP.png) + + +- w/o TP + +```bash +export MASTER_ADDR=node0 &&\ +export MASTER_PORT=23458 &&\ +torchrun --nproc_per_node=1 --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ +0_naive_tensor_parallel.py +``` + +```python +rank: 0, world size: 1 +model: DummyModel( + (fc1): Linear(in_features=128, out_features=128, bias=False) + (fc2): Linear(in_features=128, out_features=128, bias=False) +) + + iter: 1 + output: tensor([[-0.0446, 0.0869, 0.2034, ..., 0.0353, -0.2906, 0.0388], + [-0.0149, 0.3999, 0.0187, ..., 0.1280, -0.1074, 0.2212], + [ 0.0592, 0.2287, 0.2629, ..., -0.3098, 0.3747, 0.1021], + ..., + [-0.1120, 0.1608, 0.1155, ..., 0.0570, -0.0458, 0.3998], + [-0.0837, 0.1127, 0.1840, ..., -0.0339, 0.3072, 0.6933], + [ 0.1525, 0.2822, -0.0211, ..., 0.1974, 0.0768, 0.2375]], + device='cuda:0', grad_fn=) + loss: 0.969451904296875 + fc1_grad = tensor([[-0.7231, 0.7115, -0.2774, ..., -0.6077, -0.0960, 0.1508], + [-0.0553, -0.4548, -0.0235, ..., 0.1630, -0.1945, -0.1485], + [ 1.4298, -1.3797, 1.5428, ..., 2.0844, -0.6803, 0.3992], + ..., + [-1.3434, 1.1863, -0.8411, ..., -0.6940, 0.9600, 0.8013], + [-0.1506, 0.7074, -0.3786, ..., -1.2123, 1.7474, 1.8508], + [-0.5859, 0.4911, -0.4167, ..., -0.0043, 0.1661, 0.3382]], + device='cuda:0') + + + iter: 2 + output: tensor([[-0.5817, -0.0260, -0.5679, ..., -0.5887, -0.6975, -0.1548], + [-0.2621, 0.1407, -0.4802, ..., -0.1570, -0.2467, 0.1012], + [-0.2493, 0.1170, -0.3523, ..., -0.7328, 0.1866, -0.3034], + ..., + [-0.3621, -0.0533, -0.3692, ..., -0.4276, -0.2218, 0.1831], + [-0.4475, 0.1047, -0.7256, ..., -0.5500, -0.0167, 0.1446], + [-0.1938, -0.2023, -0.7151, ..., -0.1744, -0.3086, -0.0498]], + device='cuda:0', grad_fn=) + loss: -445.4638671875 + fc1_grad = tensor([[ 2.4085, 1.6419, 0.8216, ..., 2.0955, 0.7012, -1.0162], + [-0.7059, -5.8104, -0.3002, ..., 2.0821, -2.4854, -1.8969], + [ 4.1235, -3.9789, 4.4493, ..., 6.0115, -1.9621, 1.1513], + ..., + [ 0.2301, -1.8097, -0.5846, ..., 1.1556, -0.6764, -0.2249], + [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], + [ 1.4045, -0.0199, 0.4096, ..., 0.3518, -0.3399, -1.3144]], + device='cuda:0') +``` + +- w/ TP + +```bash +export LOCAL_RANK=1 &&\ +export WORLD_SIZE=2 &&\ +export MASTER_ADDR=node0 &&\ +export MASTER_PORT=23458 &&\ +torchrun --nproc_per_node=$WORLD_SIZE --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ +0_naive_tensor_parallel.py --TP +``` + +```python +rank: 0, world size: 2 +rank: 1, world size: 2 +model: DummyModel( + (fc1): Linear(in_features=128, out_features=128, bias=False) + (fc2): Linear(in_features=128, out_features=128, bias=False) +) + + iter: 1 + output: tensor([[-0.0446, 0.0869, 0.2034, ..., 0.0353, -0.2906, 0.0388], + [-0.0149, 0.3999, 0.0187, ..., 0.1280, -0.1074, 0.2212], + [ 0.0592, 0.2287, 0.2629, ..., -0.3098, 0.3747, 0.1021], + ..., + [-0.1120, 0.1608, 0.1155, ..., 0.0570, -0.0458, 0.3998], + [-0.0837, 0.1127, 0.1840, ..., -0.0339, 0.3072, 0.6933], + [ 0.1525, 0.2822, -0.0211, ..., 0.1974, 0.0768, 0.2375]], + device='cuda:0', grad_fn=) + loss: 0.9694492816925049 + fc1_grad = tensor([[-0.7231, 0.7115, -0.2774, ..., -0.6077, -0.0960, 0.1508], + [-0.0553, -0.4548, -0.0235, ..., 0.1630, -0.1945, -0.1485], + [ 1.4298, -1.3797, 1.5428, ..., 2.0844, -0.6803, 0.3992], + ..., + [-1.3434, 1.1863, -0.8411, ..., -0.6940, 0.9600, 0.8013], + [-0.1506, 0.7074, -0.3786, ..., -1.2123, 1.7474, 1.8508], + [-0.5859, 0.4911, -0.4167, ..., -0.0043, 0.1661, 0.3382]], + device='cuda:0') + + + iter: 2 + output: tensor([[-0.5817, -0.0260, -0.5679, ..., -0.5887, -0.6975, -0.1548], + [-0.2621, 0.1407, -0.4802, ..., -0.1570, -0.2467, 0.1012], + [-0.2493, 0.1170, -0.3523, ..., -0.7328, 0.1866, -0.3034], + ..., + [-0.3621, -0.0533, -0.3692, ..., -0.4276, -0.2218, 0.1831], + [-0.4475, 0.1047, -0.7256, ..., -0.5500, -0.0167, 0.1446], + [-0.1938, -0.2023, -0.7151, ..., -0.1744, -0.3086, -0.0498]], + device='cuda:0', grad_fn=) + loss: -445.4638671875 + fc1_grad = tensor([[ 2.4085, 1.6419, 0.8216, ..., 2.0955, 0.7012, -1.0162], + [-0.7059, -5.8104, -0.3002, ..., 2.0821, -2.4854, -1.8969], + [ 4.1235, -3.9789, 4.4493, ..., 6.0115, -1.9621, 1.1513], + ..., + [ 0.2301, -1.8097, -0.5846, ..., 1.1556, -0.6764, -0.2249], + [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], + [ 1.4045, -0.0199, 0.4096, ..., 0.3518, -0.3399, -1.3144]], + device='cuda:0') +``` + + +## Transformer with TP + +![transformer_TP](./assets/images/transformer_TP.png) + +![vocab_parallel](./assets/images/vocab_parallel.png) + +- w/o TP + +```bash +export MASTER_ADDR=node0 &&\ +export MASTER_PORT=23458 &&\ +torchrun --nproc_per_node=1 --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ +1_transformer_tensor_parallel.py +``` + +```python +rank: 0, world size: 1 + + iter: 1 + loss: 10.939807891845703 + + + iter: 2 + loss: 3.437135934829712 + + + iter: 3 + loss: 1.5810130834579468 + + + iter: 4 + loss: 0.453738808631897 + + + iter: 5 + loss: 0.1264963299036026 +``` + +- w/ TP + +```bash +export LOCAL_RANK=1 &&\ +export WORLD_SIZE=2 &&\ +export MASTER_ADDR=node0 &&\ +export MASTER_PORT=23458 &&\ +torchrun --nproc_per_node=$WORLD_SIZE --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ +1_transformer_tensor_parallel.py --TP +``` + + +```python +rank: 1, world size: 2 +rank: 0, world size: 2 + + iter: 1 + loss: 10.939807891845703 + + + iter: 2 + loss: 3.4371347427368164 + + + iter: 3 + loss: 1.58101224899292 + + + iter: 4 + loss: 0.45373836159706116 + + + iter: 5 + loss: 0.12649638950824738 +``` + + +### Memroy Profiling + +``` +--use_torch_profiler +``` + +```bash +export MASTER_ADDR=node0 &&\ +export MASTER_PORT=23458 &&\ +torchrun --nproc_per_node=1 --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ +1_transformer_tensor_parallel.py --use_torch_profiler --hidden=2048 +``` + +```bash +export LOCAL_RANK=1 &&\ +export WORLD_SIZE=2 &&\ +export MASTER_ADDR=node0 &&\ +export MASTER_PORT=23458 &&\ +torchrun --nproc_per_node=$WORLD_SIZE --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ +1_transformer_tensor_parallel.py --TP --use_torch_profiler --hidden=2048 +``` + + +![baseline_torch_profiler_fig1](./assets/images/baseline_torch_profiler_fig1.png) + +![baseline_torch_profiler_fig2](./assets/images/baseline_torch_profiler_fig2.png) + +![TP_torch_profiler_fig1](./assets/images/TP_torch_profiler_fig1.png) + +![TP_torch_profiler_fig2](./assets/images/TP_torch_profiler_fig2.png) + + +### Applying Vocab (Loss) Parallel + +```bash +export MASTER_ADDR=node0 &&\ +export MASTER_PORT=23458 &&\ +torchrun --nproc_per_node=1 --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ +1_transformer_tensor_parallel.py --batch_size 2 --seq_len 64 +``` + +```python + iter: 1 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 11.14531421661377 + + + iter: 2 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 7.8605475425720215 + + + iter: 3 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 6.055154800415039 + + + iter: 4 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 4.597280502319336 + + + iter: 5 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 3.266993761062622 +``` + + +```bash +export MASTER_ADDR=node0 &&\ +export MASTER_PORT=23458 &&\ +torchrun --nproc_per_node=1 --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ +1_transformer_tensor_parallel.py --batch_size 2 --seq_len 64 --loss_parallel +``` + +```python + iter: 1 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 11.145313262939453 + + + iter: 2 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 7.860340595245361 + + + iter: 3 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 6.054848670959473 + + + iter: 4 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 4.597006320953369 + + + iter: 5 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 3.2667441368103027 +``` + +```bash +export LOCAL_RANK=1 &&\ +export WORLD_SIZE=2 &&\ +export MASTER_ADDR=node0 &&\ +export MASTER_PORT=23458 &&\ +torchrun --nproc_per_node=$WORLD_SIZE --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ +1_transformer_tensor_parallel.py --batch_size 2 --seq_len 64 --loss_parallel --TP +``` + +- there is a bug lol + +```python + iter: 1 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 11.145294189453125 + + + iter: 2 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 7.860313415527344 + + + iter: 3 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 6.0548553466796875 + + + iter: 4 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 4.596996307373047 + + + iter: 5 + input size: torch.Size([2, 64]) + num padding toekns: 12 + loss: 3.2667508125305176 +``` + + +### Profling Final Results + +- batch_size: 256 +- sqe_len: 256 +- 1gpu baseline 2gpu TP + +```bash +export LOCAL_RANK=1 &&\ +export WORLD_SIZE=2 &&\ +export MASTER_ADDR=node0 &&\ +export MASTER_PORT=23458 &&\ +torchrun --nproc_per_node=$WORLD_SIZE --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ +1_transformer_tensor_parallel.py --batch_size 256 --seq_len 256 --use_torch_profiler +``` + +![hidden_256_batch_256_seq_len_256_baseline](./assets/images/hidden_256_batch_256_seq_len_256_baseline.png) + +```bash +export LOCAL_RANK=1 &&\ +export WORLD_SIZE=2 &&\ +export MASTER_ADDR=node0 &&\ +export MASTER_PORT=23458 &&\ +torchrun --nproc_per_node=$WORLD_SIZE --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ +1_transformer_tensor_parallel.py --batch_size 256 --seq_len 256 --loss_parallel --TP --use_torch_profiler +``` + +![hidden_256_batch_256_seq_len_256_TP_2gpu](./assets/images/hidden_256_batch_256_seq_len_256_TP_2gpu.png) \ No newline at end of file diff --git a/journey/understanding_tensor_parallel/assets/images/TP_torch_profiler_fig1.png b/journey/understanding_tensor_parallel/assets/images/TP_torch_profiler_fig1.png new file mode 100644 index 0000000..4e06ec9 Binary files /dev/null and b/journey/understanding_tensor_parallel/assets/images/TP_torch_profiler_fig1.png differ diff --git a/journey/understanding_tensor_parallel/assets/images/TP_torch_profiler_fig2.png b/journey/understanding_tensor_parallel/assets/images/TP_torch_profiler_fig2.png new file mode 100644 index 0000000..45c0530 Binary files /dev/null and b/journey/understanding_tensor_parallel/assets/images/TP_torch_profiler_fig2.png differ diff --git a/journey/understanding_tensor_parallel/assets/images/baseline_torch_profiler_fig1.png b/journey/understanding_tensor_parallel/assets/images/baseline_torch_profiler_fig1.png new file mode 100644 index 0000000..fa5fac4 Binary files /dev/null and b/journey/understanding_tensor_parallel/assets/images/baseline_torch_profiler_fig1.png differ diff --git a/journey/understanding_tensor_parallel/assets/images/baseline_torch_profiler_fig2.png b/journey/understanding_tensor_parallel/assets/images/baseline_torch_profiler_fig2.png new file mode 100644 index 0000000..c18a589 Binary files /dev/null and b/journey/understanding_tensor_parallel/assets/images/baseline_torch_profiler_fig2.png differ diff --git a/journey/understanding_tensor_parallel/assets/images/hidden_256_batch_256_seq_len_256_TP_2gpu.png b/journey/understanding_tensor_parallel/assets/images/hidden_256_batch_256_seq_len_256_TP_2gpu.png new file mode 100644 index 0000000..12cb459 Binary files /dev/null and b/journey/understanding_tensor_parallel/assets/images/hidden_256_batch_256_seq_len_256_TP_2gpu.png differ diff --git a/journey/understanding_tensor_parallel/assets/images/hidden_256_batch_256_seq_len_256_baseline.png b/journey/understanding_tensor_parallel/assets/images/hidden_256_batch_256_seq_len_256_baseline.png new file mode 100644 index 0000000..51ea4fb Binary files /dev/null and b/journey/understanding_tensor_parallel/assets/images/hidden_256_batch_256_seq_len_256_baseline.png differ diff --git a/journey/understanding_tensor_parallel/assets/images/simple_mlp_TP.png b/journey/understanding_tensor_parallel/assets/images/simple_mlp_TP.png new file mode 100644 index 0000000..f92dcd4 Binary files /dev/null and b/journey/understanding_tensor_parallel/assets/images/simple_mlp_TP.png differ diff --git a/journey/understanding_tensor_parallel/assets/images/transformer_TP.png b/journey/understanding_tensor_parallel/assets/images/transformer_TP.png new file mode 100644 index 0000000..7b48a10 Binary files /dev/null and b/journey/understanding_tensor_parallel/assets/images/transformer_TP.png differ diff --git a/journey/understanding_tensor_parallel/assets/images/vocab_parallel.png b/journey/understanding_tensor_parallel/assets/images/vocab_parallel.png new file mode 100644 index 0000000..7f2508c Binary files /dev/null and b/journey/understanding_tensor_parallel/assets/images/vocab_parallel.png differ diff --git a/journey/understanding_tensor_parallel/torch_profiler_utils.py b/journey/understanding_tensor_parallel/torch_profiler_utils.py new file mode 100644 index 0000000..5003752 --- /dev/null +++ b/journey/understanding_tensor_parallel/torch_profiler_utils.py @@ -0,0 +1,98 @@ + +import os +import torch +import socket +from datetime import datetime, timedelta +from contextlib import contextmanager, ExitStack +from typing import Any, ContextManager, Iterable, List, Tuple + + +class ContextManagers: + """ + Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` + in the `fastcore` library. + """ + + def __init__(self, context_managers: List[ContextManager]): + self.context_managers = context_managers + self.stack = ExitStack() + + def __enter__(self): + entered_contexts = [ + self.stack.enter_context(cm) for cm in self.context_managers + ] + # Assuming you want to return the first context manager, adjust as needed + return entered_contexts[0] if entered_contexts else None + + def __exit__(self, *args, **kwargs): + self.stack.__exit__(*args, **kwargs) + + +def get_torch_profiler( + use_tensorboard=True, + root_dir="./assets/torch_profiler_log", + save_dir_name="tmp", + + num_wait_steps=1, # During this phase profiler is not active. + num_warmup_steps=2, # During this phase profiler starts tracing, but the results are discarded. + num_active_steps=2, # During this phase profiler traces and records data. + num_repeat=1, # Specifies an upper bound on the number of cycles. + + record_shapes=True, + profile_memory=True, + + with_flops=True, + with_stack = False, # Enable stack tracing, adds extra profiling overhead. stack tracing adds an extra profiling overhead. + with_modules=True, +): + save_path=os.path.join(root_dir, save_dir_name) + os.makedirs(save_path, exist_ok=True) + + ''' + https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-profiler-to-analyze-long-running-jobs + https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html + https://github.com/pytorch/kineto/blob/main/tb_plugin/README.md + https://oss.navercorp.com/seunghyun-seo1/seosh_fairseq/blob/main/toward_iclr/cuda_profile_speech_encoder.py + + https://pytorch.org/blog/accelerating-generative-ai-2/ + https://www.deepspeed.ai/tutorials/pytorch-profiler/ + https://ui.perfetto.dev + chrome://tracing/ + + https://pytorch.org/blog/introducing-pytorch-profiler-the-new-and-improved-performance-tool/ + https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html + https://pytorch.org/blog/pytorch-profiler-1.9-released/ + + 231214 added + https://pytorch.org/blog/understanding-gpu-memory-1/ + https://github.com/pytorch/pytorch.github.io/tree/site/assets/images/understanding-gpu-memory-1 + ''' + + def trace_handler(prof: torch.profiler.profile): + TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S" + + host_name = socket.gethostname() + timestamp = datetime.now().strftime(TIME_FORMAT_STR) + file_prefix = f"{host_name}_{timestamp}" + prof.export_chrome_trace(f"{save_path}/{file_prefix}.json.gz") + + return torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=record_shapes, + profile_memory=profile_memory, + + with_flops=with_flops, + with_stack = with_stack, + with_modules = with_modules, + + schedule=torch.profiler.schedule( + wait=num_wait_steps, + warmup=num_warmup_steps, + active=num_active_steps, + repeat=num_repeat, + ), + on_trace_ready = trace_handler if not use_tensorboard else torch.profiler.tensorboard_trace_handler(save_path), + ) \ No newline at end of file