diff --git a/jaxformer/hf/codegen/modeling_codegen.py b/jaxformer/hf/codegen/modeling_codegen.py index 8304f92..cf7d0bf 100644 --- a/jaxformer/hf/codegen/modeling_codegen.py +++ b/jaxformer/hf/codegen/modeling_codegen.py @@ -18,19 +18,20 @@ from typing import Tuple import numpy as np - import torch import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss - from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.utils.model_parallel_utils import assert_device_map, get_device_map -from .configuration_codegen import CodeGenConfig +from .configuration_codegen import CodeGenConfig logger = logging.get_logger(__name__) @@ -43,7 +44,11 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None): # original # sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float() # QHD fix onnx error by https://github.com/microsoft/onnxruntime/discussions/10121#discussioncomment-1987845 - sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len).float(), inv_freq).to(x.device).float() + sinusoid_inp = ( + torch.einsum("i , j -> i j", torch.arange(seq_len).float(), inv_freq) + .to(x.device) + .float() + ) return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) @@ -55,7 +60,12 @@ def rotate_every_two(x): def apply_rotary_pos_emb(x, sincos, offset=0): - sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos) + sin, cos = map( + lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave( + 2, 3 + ), + sincos, + ) # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) return (x * cos) + (rotate_every_two(x) * sin) @@ -67,9 +77,9 @@ def __init__(self, config): max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( - 1, 1, max_positions, max_positions - ), + torch.tril( + torch.ones((max_positions, max_positions), dtype=torch.bool) + ).view(1, 1, max_positions, max_positions), ) self.register_buffer("masked_bias", torch.tensor(-1e9)) @@ -83,7 +93,9 @@ def __init__(self, config): raise ValueError( f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})." ) - self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) + self.scale_attn = torch.sqrt( + torch.tensor(self.head_dim, dtype=torch.float32) + ).to(torch.get_default_dtype()) self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) @@ -92,8 +104,8 @@ def __init__(self, config): self.rotary_dim = config.rotary_dim def _split_heads(self, x, n_head, dim_head, mp_num): - reshaped = x.reshape(x.shape[:-1] + (n_head//mp_num, dim_head)) - reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + reshaped.shape[-1:]) + reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head)) + reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:]) return reshaped def _merge_heads(self, tensor, num_attention_heads, attn_head_size): @@ -105,7 +117,9 @@ def _merge_heads(self, tensor, num_attention_heads, attn_head_size): elif len(tensor.shape) == 4: tensor = tensor.permute(0, 2, 1, 3).contiguous() else: - raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + raise ValueError( + f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}" + ) new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) return tensor.view(new_shape) @@ -120,7 +134,9 @@ def _attn( # compute causal mask from causal mask buffer query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) + causal_mask = self.bias[ + :, :, key_length - query_length : key_length, :key_length + ].to(torch.bool) # Keep the attention weights computation in fp32 to avoid overflow issues query = query.to(torch.float32) @@ -129,7 +145,9 @@ def _attn( attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = attn_weights / self.scale_attn - attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + attn_weights = torch.where( + causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype) + ) if attention_mask is not None: # Apply the attention mask @@ -164,10 +182,16 @@ def forward( local_dim = self.head_dim * self.num_attention_heads // mp_num query, value, key = torch.split(qkv_split, local_dim, dim=-1) - query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num) - key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num) + query = self._split_heads( + query, self.num_attention_heads, self.head_dim, mp_num=mp_num + ) + key = self._split_heads( + key, self.num_attention_heads, self.head_dim, mp_num=mp_num + ) - value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num) + value = self._split_heads( + value, self.num_attention_heads, self.head_dim, mp_num=mp_num + ) value = value.permute(0, 2, 1, 3) seq_len = key.shape[1] @@ -210,9 +234,13 @@ def forward( present = None # compute self-attention: V x Softmax(QK^T) - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output, attn_weights = self._attn( + query, key, value, attention_mask, head_mask + ) - attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + attn_output = self._merge_heads( + attn_output, self.num_attention_heads, self.head_dim + ) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -225,7 +253,9 @@ def forward( class CodeGenMLP(nn.Module): - def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim + def __init__( + self, intermediate_size, config + ): # in MLP: intermediate_size= 4 * embed_dim super().__init__() embed_dim = config.n_embd @@ -324,22 +354,29 @@ def __init__(self, config): self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([CodeGenBlock(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads) + self.rotary_dim = min( + config.rotary_dim, config.n_ctx // config.num_attention_heads + ) self.init_weights() # Model parallel self.model_parallel = False self.device_map = None - def parallelize(self, device_map=None): # Check validity of device_map self.device_map = ( - get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + get_device_map(len(self.h), range(torch.cuda.device_count())) + if device_map is None + else device_map ) assert_device_map(self.device_map, len(self.h)) self.model_parallel = True - self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.first_device = ( + "cpu" + if "cpu" in self.device_map.keys() + else "cuda:" + str(min(self.device_map.keys())) + ) self.last_device = "cuda:" + str(max(self.device_map.keys())) self.wte = self.wte.to(self.first_device) # Load onto devices @@ -350,7 +387,6 @@ def parallelize(self, device_map=None): # ln_f to last self.ln_f = self.ln_f.to(self.last_device) - def deparallelize(self): self.model_parallel = False self.device_map = None @@ -382,15 +418,25 @@ def forward( output_hidden_states=None, return_dict=None, ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) @@ -416,7 +462,12 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) # Attention mask. @@ -467,7 +518,9 @@ def forward( torch.cuda.set_device(hidden_states.device) # Ensure layer_past is on same device as hidden_states (might not be correct) if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + layer_past = tuple( + past_state.to(hidden_states.device) for past_state in layer_past + ) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) @@ -514,7 +567,9 @@ def custom_forward(*inputs): presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + ( + outputs[2 if use_cache else 1], + ) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: @@ -530,7 +585,16 @@ def custom_forward(*inputs): all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -541,7 +605,11 @@ def custom_forward(*inputs): class CodeGenForCausalLM(CodeGenPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] + _keys_to_ignore_on_load_missing = [ + r"h\.\d+\.attn\.masked_bias", + r"h\.\d+\.attn\.bias", + r"lm_head\.weight", + ] def __init__(self, config): super().__init__(config) @@ -626,7 +694,9 @@ def forward( ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) transformer_outputs = self.transformer( input_ids, @@ -660,7 +730,9 @@ def forward( shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) loss = loss.to(hidden_states.dtype) @@ -677,13 +749,18 @@ def forward( ) @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the :obj:`past_key_values` cache if :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. """ return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ) for layer_past in past ) diff --git a/jaxformer/hf/sample.py b/jaxformer/hf/sample.py index 4477e10..5e640e8 100644 --- a/jaxformer/hf/sample.py +++ b/jaxformer/hf/sample.py @@ -3,18 +3,16 @@ # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +import argparse import os +import random import re import time -import random -import argparse import torch - from transformers import GPT2TokenizerFast -from jaxformer.hf.codegen.modeling_codegen import CodeGenForCausalLM - +from jaxformer.hf.codegen.modeling_codegen import CodeGenForCausalLM ######################################################################## # util @@ -29,16 +27,16 @@ def __enter__(self): self.t = time.time() def __exit__(self, type, value, traceback): - print(f'{self.desc} took {time.time()-self.t:.02f}s') + print(f"{self.desc} took {time.time()-self.t:.02f}s") def set_env(): - os.environ['TOKENIZERS_PARALLELISM'] = 'false' + os.environ["TOKENIZERS_PARALLELISM"] = "false" def set_seed(seed, deterministic=True): random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) + os.environ["PYTHONHASHSEED"] = str(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) @@ -53,31 +51,38 @@ def cast(model, fp16=True): return model - ######################################################################## # model def create_model(ckpt, fp16=True): if fp16: - return CodeGenForCausalLM.from_pretrained(ckpt, revision='float16', torch_dtype=torch.float16, low_cpu_mem_usage=True) + return CodeGenForCausalLM.from_pretrained( + ckpt, revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True + ) else: return CodeGenForCausalLM.from_pretrained(ckpt) def create_tokenizer(): - t = GPT2TokenizerFast.from_pretrained('gpt2') - t.max_model_input_sizes['gpt2'] = 1e20 + t = GPT2TokenizerFast.from_pretrained("gpt2") + t.max_model_input_sizes["gpt2"] = 1e20 return t def include_whitespace(t, n_min=2, n_max=20, as_special_tokens=False): - t.add_tokens([' ' * n for n in reversed(range(n_min, n_max))], special_tokens=as_special_tokens) + t.add_tokens( + [" " * n for n in reversed(range(n_min, n_max))], + special_tokens=as_special_tokens, + ) return t def include_tabs(t, n_min=2, n_max=20, as_special_tokens=False): - t.add_tokens(['\t' * n for n in reversed(range(n_min, n_max))], special_tokens=as_special_tokens) + t.add_tokens( + ["\t" * n for n in reversed(range(n_min, n_max))], + special_tokens=as_special_tokens, + ) return t @@ -91,6 +96,7 @@ def create_custom_gpt2_tokenizer(): ######################################################################## # sample + def sample( device, model, @@ -101,7 +107,7 @@ def sample( temp=0.2, top_p=0.95, max_length_sample=128, - max_length=2048 + max_length=2048, ): input_ids = tokenizer( @@ -109,7 +115,7 @@ def sample( truncation=True, padding=True, max_length=max_length, - return_tensors='pt', + return_tensors="pt", ).input_ids input_ids_len = input_ids.shape[1] @@ -133,44 +139,42 @@ def sample( def truncate(completion): - def find_re(string, pattern, start_pos): m = pattern.search(string, start_pos) return m.start() if m else -1 terminals = [ re.compile(r, re.MULTILINE) - for r in - [ - '^#', - re.escape('<|endoftext|>'), - "^'''", - '^"""', - '\n\n\n' - ] + for r in ["^#", re.escape("<|endoftext|>"), "^'''", '^"""', "\n\n\n"] ] - prints = list(re.finditer('^print', completion, re.MULTILINE)) + prints = list(re.finditer("^print", completion, re.MULTILINE)) if len(prints) > 1: - completion = completion[:prints[1].start()] + completion = completion[: prints[1].start()] - defs = list(re.finditer('^def', completion, re.MULTILINE)) + defs = list(re.finditer("^def", completion, re.MULTILINE)) if len(defs) > 1: - completion = completion[:defs[1].start()] + completion = completion[: defs[1].start()] start_pos = 0 - terminals_pos = [pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1] + terminals_pos = [ + pos + for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] + if pos != -1 + ] if len(terminals_pos) > 0: - return completion[:min(terminals_pos)] + return completion[: min(terminals_pos)] else: return completion def test_truncate(): - assert truncate('\nif len_a > len_b:\n result = a\nelse:\n result = b\n\n\n\n#') == '\nif len_a > len_b:\n result = a\nelse:\n result = b' - + assert ( + truncate("\nif len_a > len_b:\n result = a\nelse:\n result = b\n\n\n\n#") + == "\nif len_a > len_b:\n result = a\nelse:\n result = b" + ) ######################################################################## @@ -181,74 +185,89 @@ def main(): # (0) constants - models_nl = ['codegen-350M-nl', 'codegen-2B-nl', 'codegen-6B-nl', 'codegen-16B-nl'] - models_pl = ['codegen-350M-multi', 'codegen-2B-multi', 'codegen-6B-multi', 'codegen-16B-multi', 'codegen-350M-mono', 'codegen-2B-mono', 'codegen-6B-mono', 'codegen-16B-mono'] + models_nl = ["codegen-350M-nl", "codegen-2B-nl", "codegen-6B-nl", "codegen-16B-nl"] + models_pl = [ + "codegen-350M-multi", + "codegen-2B-multi", + "codegen-6B-multi", + "codegen-16B-multi", + "codegen-350M-mono", + "codegen-2B-mono", + "codegen-6B-mono", + "codegen-16B-mono", + ] models = models_nl + models_pl - # (1) params parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str, choices=models, default='codegen-350M-mono') - parser.add_argument('--device', type=str, default='cuda:0') - parser.add_argument('--rng-seed', type=int, default=42) - parser.add_argument('--rng-deterministic', type=bool, default=True) - parser.add_argument('--p', type=float, default=0.95) - parser.add_argument('--t', type=float, default=0.2) - parser.add_argument('--max-length', type=int, default=128) - parser.add_argument('--batch-size', type=int, default=1) - parser.add_argument('--no-fp16', action="store_true") - parser.add_argument('--pad', type=int, default=50256) - parser.add_argument('--context', type=str, default='def helloworld():') + parser.add_argument( + "--model", type=str, choices=models, default="codegen-350M-mono" + ) + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--rng-seed", type=int, default=42) + parser.add_argument("--rng-deterministic", type=bool, default=True) + parser.add_argument("--p", type=float, default=0.95) + parser.add_argument("--t", type=float, default=0.2) + parser.add_argument("--max-length", type=int, default=128) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--no-fp16", action="store_true") + parser.add_argument("--pad", type=int, default=50256) + parser.add_argument("--context", type=str, default="def helloworld():") args = parser.parse_args() - # (2) preamble set_env() set_seed(args.rng_seed, deterministic=args.rng_deterministic) device = torch.device(args.device) - + use_fp16 = True - if (args.no_fp16 or device.type == "cpu"): + if args.no_fp16 or device.type == "cpu": use_fp16 = False if args.model.startswith("codegen-16B"): use_fp16 = True - ckpt = f'./checkpoints/{args.model}' - + ckpt = f"./checkpoints/{args.model}" # (3) load - with print_time('loading parameters'): + with print_time("loading parameters"): model = create_model(ckpt=ckpt, fp16=use_fp16).to(device) - - with print_time('loading tokenizer'): + with print_time("loading tokenizer"): if args.model in models_pl: tokenizer = create_custom_gpt2_tokenizer() else: tokenizer = create_tokenizer() - tokenizer.padding_side = 'left' + tokenizer.padding_side = "left" tokenizer.pad_token = args.pad - # (4) sample - with print_time('sampling'): - completion = sample(device=device, model=model, tokenizer=tokenizer, context=args.context, pad_token_id=args.pad, num_return_sequences=args.batch_size, temp=args.t, top_p=args.p, max_length_sample=args.max_length)[0] + with print_time("sampling"): + completion = sample( + device=device, + model=model, + tokenizer=tokenizer, + context=args.context, + pad_token_id=args.pad, + num_return_sequences=args.batch_size, + temp=args.t, + top_p=args.p, + max_length_sample=args.max_length, + )[0] truncation = truncate(completion) - print('=' * 100) + print("=" * 100) print(completion) - print('=' * 100) - print(args.context+truncation) - print('=' * 100) - + print("=" * 100) + print(args.context + truncation) + print("=" * 100) -if __name__ == '__main__': +if __name__ == "__main__": test_truncate() main() - print('done.') + print("done.") diff --git a/jaxformer/hf/train_deepspeed.py b/jaxformer/hf/train_deepspeed.py index 2673de3..45ec023 100644 --- a/jaxformer/hf/train_deepspeed.py +++ b/jaxformer/hf/train_deepspeed.py @@ -1,6 +1,6 @@ # Minimal example of training the 16B checkpoint on GPU with CPU offloading using deepspeed. -''' +""" apt install python3.8 python3.8-venv python3.8-dev python3.8 -m venv .venv @@ -10,56 +10,75 @@ pip install transformers==4.21.1 datasets==1.16.1 deepspeed==0.7.0 deepspeed --num_gpus=1 train_deepspeed.py -''' +""" ######################################################################################################## ## imports -import os import argparse -import random import math - +import os +import random from time import time +import deepspeed import numpy as np - import torch - from transformers import AutoConfig, AutoModelForCausalLM -import deepspeed - - ######################################################################################################## ## args -DEEPSPEED_CONFIG = \ -{ - 'fp16': {'enabled': True, 'loss_scale': 0, 'loss_scale_window': 1000, 'initial_scale_power': 12, 'hysteresis': 2, 'min_loss_scale': 1}, - 'optimizer': {'type': 'AdamW', 'params': {'lr': 1e-05, 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0.0}}, - 'scheduler': {'type': 'WarmupLR', 'params': {'warmup_min_lr': 0, 'warmup_max_lr': 1e-05, 'warmup_num_steps': 100}}, - 'zero_optimization': { - 'stage': 3, - 'offload_optimizer': {'device': 'cpu', 'pin_memory': False}, - 'offload_param': {'device': 'cpu', 'pin_memory': False}, - 'overlap_comm': True, - 'contiguous_gradients': True, - 'sub_group_size': 1e9, - 'reduce_bucket_size': 16777216, - 'stage3_prefetch_bucket_size': 15099494.4, - 'stage3_param_persistence_threshold': 40960, - 'stage3_max_live_parameters': 1e9, - 'stage3_max_reuse_distance': 1e9, - 'stage3_gather_fp16_weights_on_model_save': True +DEEPSPEED_CONFIG = { + "fp16": { + "enabled": True, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 12, + "hysteresis": 2, + "min_loss_scale": 1, + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-05, + "betas": [0.9, 0.999], + "eps": 1e-08, + "weight_decay": 0.0, + }, + }, + "scheduler": { + "type": "WarmupLR", + "params": {"warmup_min_lr": 0, "warmup_max_lr": 1e-05, "warmup_num_steps": 100}, + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": {"device": "cpu", "pin_memory": False}, + "offload_param": {"device": "cpu", "pin_memory": False}, + "overlap_comm": True, + "contiguous_gradients": True, + "sub_group_size": 1e9, + "reduce_bucket_size": 16777216, + "stage3_prefetch_bucket_size": 15099494.4, + "stage3_param_persistence_threshold": 40960, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_fp16_weights_on_model_save": True, + }, + "train_batch_size": 32, + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 16, + "gradient_clipping": 1.0, + "steps_per_print": 8, + "wall_clock_breakdown": False, + "compression_training": { + "weight_quantization": {"shared_parameters": {}, "different_groups": {}}, + "activation_quantization": {"shared_parameters": {}, "different_groups": {}}, + "sparse_pruning": {"shared_parameters": {}, "different_groups": {}}, + "row_pruning": {"shared_parameters": {}, "different_groups": {}}, + "head_pruning": {"shared_parameters": {}, "different_groups": {}}, + "channel_pruning": {"shared_parameters": {}, "different_groups": {}}, }, - 'train_batch_size': 32, - 'train_micro_batch_size_per_gpu': 2, - 'gradient_accumulation_steps': 16, - 'gradient_clipping': 1.0, - 'steps_per_print': 8, - 'wall_clock_breakdown': False, - 'compression_training': {'weight_quantization': {'shared_parameters': {}, 'different_groups': {}}, 'activation_quantization': {'shared_parameters': {}, 'different_groups': {}}, 'sparse_pruning': {'shared_parameters': {}, 'different_groups': {}}, 'row_pruning': {'shared_parameters': {}, 'different_groups': {}}, 'head_pruning': {'shared_parameters': {}, 'different_groups': {}}, 'channel_pruning': {'shared_parameters': {}, 'different_groups': {}}} } @@ -67,7 +86,7 @@ def create_args(args=argparse.Namespace()): args.seed = 42 - args.model = 'Salesforce/codegen-16B-mono' + args.model = "Salesforce/codegen-16B-mono" args.deepspeed_config = DEEPSPEED_CONFIG @@ -76,10 +95,10 @@ def create_args(args=argparse.Namespace()): return args - ######################################################################################################## ## train + def train(args): ####################### @@ -87,11 +106,10 @@ def train(args): set_seed(args.seed) - ####################### ## model - print('initializing model') + print("initializing model") config = AutoConfig.from_pretrained(args.model) config.gradient_checkpointing = True @@ -103,45 +121,50 @@ def train(args): # TODO(enijkamp): we need to set this flag twice? model.gradient_checkpointing_enable() - ####################### ## deepspeed - print('initializing deepspeed') + print("initializing deepspeed") model_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) - model_engine, optimizer, _, _ = deepspeed.initialize(config=args.deepspeed_config, model=model, model_parameters=model_parameters) + model_engine, optimizer, _, _ = deepspeed.initialize( + config=args.deepspeed_config, model=model, model_parameters=model_parameters + ) torch.cuda.empty_cache() - ####################### ## train - print('starting training') + print("starting training") - input_ids = torch.randint(low=0, high=10, size=[args.deepspeed_config['train_micro_batch_size_per_gpu'], 1024], dtype=torch.int64).cuda() + input_ids = torch.randint( + low=0, + high=10, + size=[args.deepspeed_config["train_micro_batch_size_per_gpu"], 1024], + dtype=torch.int64, + ).cuda() - for step in range(args.opt_steps_train+1): + for step in range(args.opt_steps_train + 1): loss = model_engine(input_ids=input_ids, labels=input_ids).loss model_engine.backward(loss) model_engine.step() - print(f'{step} {loss:8.3f}') - + print(f"{step} {loss:8.3f}") ######################################################################################################## ## preamble + def set_gpus(gpu): torch.cuda.set_device(gpu) def set_seed(seed): - os.environ['PYTHONHASHSEED'] = str(seed) + os.environ["PYTHONHASHSEED"] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) @@ -162,21 +185,22 @@ def get_exp_id(file): def get_output_dir(exp_id): import datetime - t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') - output_dir = os.path.join('output/' + exp_id, t) + + t = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + output_dir = os.path.join("output/" + exp_id, t) return output_dir def copy_source(file, output_dir): import shutil - shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file))) - + shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file))) ######################################################################################################## ## main + def main(): # preamble @@ -196,6 +220,5 @@ def main(): train(args=args) - -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main()