diff --git a/megatron/arguments.py b/megatron/arguments.py index c18235a78..2e7a4c109 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -242,6 +242,9 @@ def parse_args(extra_args_provider=None, defaults={}, # Checks. if args.ffn_hidden_size is None: args.ffn_hidden_size = 4 * args.hidden_size + + if args.student_ffn_hidden_size is None: + args.student_ffn_hidden_size = 4 * args.student_hidden_size if args.kv_channels is None: assert args.hidden_size % args.num_attention_heads == 0 @@ -353,9 +356,18 @@ def _add_network_size_args(parser): help='Number of transformer layers.') group.add_argument('--hidden-size', type=int, default=None, help='Tansformer hidden size.') + group.add_argument('--student-num-layers', type=int, default=None, + help='Number of student transformer layers.') + group.add_argument('--student-hidden-size', type=int, default=None, + help='Student Tansformer hidden size.') + group.add_argument('--student-num-attention-heads', type=int, default=None, + help='Number of student transformer attention heads.') group.add_argument('--ffn-hidden-size', type=int, default=None, help='Transformer Feed-Forward Network hidden size. ' 'This is set to 4*hidden-size if not provided') + group.add_argument('--student-ffn-hidden-size', type=int, default=None, + help='Transformer Feed-Forward Network hidden size. ' + 'This is set to 4*hidden-size if not provided') group.add_argument('--num-attention-heads', type=int, default=None, help='Number of transformer attention heads.') group.add_argument('--kv-channels', type=int, default=None, @@ -660,6 +672,10 @@ def _add_checkpointing_args(parser): help='Do not save current rng state.') group.add_argument('--load', type=str, default=None, help='Directory containing a model checkpoint.') + group.add_argument('--teacher-load', type=str, default=None, + help='Directory containing a model checkpoint.') + group.add_argument('--student-load', type=str, default=None, + help='Directory containing a model checkpoint.') group.add_argument('--no-load-optim', action='store_true', default=None, help='Do not load optimizer when loading checkpoint.') group.add_argument('--no-load-rng', action='store_true', default=None, @@ -715,8 +731,12 @@ def _add_distributed_args(parser): group.add_argument('--tensor-model-parallel-size', type=int, default=1, help='Degree of tensor model parallelism.') + group.add_argument('--student-tensor-model-parallel-size', type=int, default=1, + help='Degree of tensor model parallelism.') group.add_argument('--pipeline-model-parallel-size', type=int, default=1, help='Degree of pipeline model parallelism.') + group.add_argument('--student-pipeline-model-parallel-size', type=int, default=1, + help='Degree of pipeline model parallelism.') group.add_argument('--model-parallel-size', type=int, default=None, help='Old model parallel argument, do not use. Use ' '--tensor-model-parallel-size instead.') diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 55e9c9dd8..13a108096 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -41,7 +41,10 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): ctx.normalized_shape = normalized_shape ctx.eps = eps - input_ = input.contiguous() + if isinstance(input, tuple): + input_ = input[0].contiguous() + else: + input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( @@ -91,7 +94,6 @@ def __init__(self, normalized_shape, eps=1e-5): or version.parse(torch.__version__) >= version.parse("1.11.0") # https://github.com/pytorch/pytorch/pull/66920 ) - def reset_parameters(self): init.ones_(self.weight) @@ -99,13 +101,31 @@ def reset_parameters(self): def forward(self, input): + if isinstance(input, tuple): + input = input[0] if self.layernorm_tp_auto_sync: torch.distributed.all_reduce(self.weight, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) torch.distributed.all_reduce(self.bias, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) if self.use_meg_ds_fused_layer_norm: + #if False: return FusedLayerNormAffineFunction.apply( input, self.weight, self.bias, self.normalized_shape, self.eps) else: return F.layer_norm(input, self.normalized_shape, self.weight, self.bias) + +class MixedFusedLayerNormTeacher(MixedFusedLayerNorm): + # @torch.no_grad() + def forward(self, input): + # if isinstance(input, tuple): + input, *original_input = input + # return (super().forward(input), *original_input) + # else: + # return super().forward(input) + + return super().forward(input), *original_input + +class MixedFusedLayerNormStudent(MixedFusedLayerNorm): + def forward(self, input): + return (super().forward(input), input[1]) \ No newline at end of file diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index 07192e2bf..9813085e1 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -188,7 +188,7 @@ def forward_fused_softmax(self, input, mask): if self.attn_mask_type == AttnMaskType.causal: assert sq == sk, "causal mask is only for self attention" - assert mask is None, "Mask is silently ignored due to the use of a custom kernel" + # assert mask is None, "Mask is silently ignored due to the use of a custom kernel" # input is 3D tensor (attn_batches, sq, sk) input = input.view(-1, sq, sk) @@ -236,3 +236,4 @@ def get_batch_per_block(sq, sk, b, np): import scaled_masked_softmax_cuda return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) + diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index a9e3e2604..730a7faed 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -30,9 +30,11 @@ from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm +from megatron.model.fused_layer_norm import MixedFusedLayerNormTeacher as LayerNormTeacher +from megatron.model.fused_layer_norm import MixedFusedLayerNormStudent as LayerNormStudent from megatron.model.module import float16_to_fp32 -from .language_model import EmbeddingPipe -from .transformer import ParallelTransformerLayerPipe +from .language_model import EmbeddingPipe, EmbeddingPipeTeacher, EmbeddingPipeStudent +from .transformer import ParallelTransformerLayerPipe, ParallelTransformerLayerPipeTeacher, ParallelTransformerLayerPipeStudent def post_language_model_processing(lm_output, labels, logit_weights, @@ -195,6 +197,56 @@ def CrossEntropy(output, labels): return CrossEntropy +def get_ts_loss(is_prefix: bool): + def TeacherStudentLoss(output, labels): + student_logits, teacher_logits = output + if isinstance(teacher_logits, tuple): + teacher_logits = teacher_logits[0] + labels, loss_mask = labels[0], labels[1] + + args = get_args() + + losses = mpu.vocab_parallel_cross_entropy(student_logits.contiguous().float(), labels) + + if is_prefix: + micro_batch_size, sequence_length = loss_mask.shape + average_tokens_per_sample: torch.Tensor + if args.loss_on_targets_only: + # HACK: This is useful when we obtain loss masks that are microbatch dependent. Consequently, if we want to + # preserve the notion that all tokens have the same impact on the loss, we can only normalise using a + # microbatch independent value. It should be expected weight over a microbatch. + # Here we still use `sequence_length`, that's batch size dependent, in order to be backwards compatible with + # current experiment on vanilla gpt. + if args.reweight_loss_based_on_position_frequency: + reweight = torch.arange( + sequence_length, 0, -1, dtype=torch.float, device=loss_mask.device + ) / (sequence_length + 1) * 2 + average_tokens_per_sample = reweight.flip(-1).cumsum(-1).mean() + else: + average_tokens_per_sample = (sequence_length + 1) / 2 + else: + average_tokens_per_sample = sequence_length + expected_number_of_tokens = average_tokens_per_sample * micro_batch_size + else: + expected_number_of_tokens = loss_mask.sum() + + loss_mask = loss_mask.view(-1) + loss = torch.sum(losses.view(-1) * loss_mask) / expected_number_of_tokens + + # TODO: check if the formula is correct + # teacher_logits = teacher_logits.detach() + # First pass it on CPU - otherwise we get OOM errors + # teacher_logits = teacher_logits.detach() + softmax_labels = torch.nn.Softmax(dim=-1)(teacher_logits.contiguous().float()) + student_log_softax = -torch.nn.LogSoftmax(dim=-1)(student_logits.contiguous().float()) + + softmax_logits = student_log_softax * softmax_labels + logits_loss = softmax_logits.mean() + + return loss + logits_loss + return TeacherStudentLoss + + class GPTModelPipe(PipelineModule,MegatronModule): """GPT-2 Language model.""" @@ -222,15 +274,15 @@ def _to_float16(inputs): self.specs.append(_to_float16) # Embedding layer - self.specs.append(TiedLayerSpec('embed', - EmbeddingPipe, + self.specs.append(TiedLayerSpec('embed_teacher', + EmbeddingPipeTeacher, args.hidden_size, args.padded_vocab_size, args.hidden_dropout, init_method=init_method, num_tokentypes=num_tokentypes, tied_weight_attr='word_embeddings_weight')) - + if args.fp32_residual_connection: if getattr(args, 'pretrain_causal_attention', False): self.specs.append(lambda x: x.transpose(0, 1).contiguous().float()) @@ -239,14 +291,15 @@ def _to_float16(inputs): self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous().float(), *x[1:])) else: if getattr(args, 'pretrain_causal_attention', False): - self.specs.append(lambda x: x.transpose(0, 1).contiguous()) + self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:])) + # self.specs.append(lambda x: (x.transpose(0, 1).contiguous(), *x[1:])) else: # EmbeddingPipe returns attention mask as well self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:])) for layer_idx in range(args.num_layers): self.specs.append( - LayerSpec(ParallelTransformerLayerPipe, + LayerSpec(ParallelTransformerLayerPipeTeacher, init_method=init_method, output_layer_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers), @@ -256,14 +309,15 @@ def _to_float16(inputs): # Undo data format change def undo(x): - if not getattr(args, 'pretrain_causal_attention', False): - x = x[0] + if isinstance(x, tuple): + return (x[0].transpose(0, 1).contiguous(), (x[1:])) + # return (x[0].transpose(0, 1).contiguous(), *x[1:]) return x.transpose(0, 1).contiguous() self.specs.append(undo) # Final layernorm after transformer layers self.specs.append( - LayerSpec(LayerNorm, + LayerSpec(LayerNormTeacher, args.hidden_size, eps=args.layernorm_epsilon)) @@ -275,8 +329,8 @@ def _logits_helper(embedding, lm_output): self.parallel_output) self.specs.append( - TiedLayerSpec('embed', - EmbeddingPipe, + TiedLayerSpec('embed_teacher', + EmbeddingPipeTeacher, args.hidden_size, args.padded_vocab_size, args.hidden_dropout, @@ -286,19 +340,17 @@ def _logits_helper(embedding, lm_output): tied_weight_attr='word_embeddings_weight') ) + # self.specs.append(lambda x: print(x[0])) # Convert to fp32 if needed - if args.fp16 or args.bf16: - self.specs.append(float16_to_fp32) + # if args.fp16 or args.bf16: + # self.specs.append(float16_to_fp32) if args.checkpoint_activations: interval = args.checkpoint_num_layers else: interval = 0 - from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology - topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(), - num_mp=mpu.get_tensor_model_parallel_world_size(), - num_dp=mpu.get_data_parallel_world_size()) + # here one can extend the regex to include more layers to be counted towards partitioning, # e.g. 'type:transformer|embedding' will add up all the transformer blocks and also the first @@ -306,14 +358,109 @@ def _logits_helper(embedding, lm_output): # balance you may want to use less transformer layers # # caveat emptor: the current implementation of PP fails unless each stage has at least one + + # Beginning student model + + init_method = init_method_normal(args.init_method_std) + + + def _to_float16(inputs): + if args.fp16: + return fp32_to_float16(inputs, lambda v: v.half()) + elif args.bf16: + return fp32_to_float16(inputs, lambda v: v.bfloat16()) + else: + return inputs + + + # Embedding layer + self.specs.append(TiedLayerSpec('embed_student', + EmbeddingPipeStudent, + args.student_hidden_size, + args.padded_vocab_size, + args.hidden_dropout, + init_method=init_method, + num_tokentypes=num_tokentypes, + tied_weight_attr='word_embeddings_weight')) + + if args.fp32_residual_connection: + if getattr(args, 'pretrain_causal_attention', False): + self.specs.append(lambda x: x.transpose(0, 1).contiguous().float()) + else: + # EmbeddingPipe returns attention mask as well + self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous().float(), *x[1:])) + else: + if getattr(args, 'pretrain_causal_attention', False): + self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), x[1])) + else: + # EmbeddingPipe returns attention mask as well + self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:])) + + for layer_idx in range(args.student_num_layers): + self.specs.append( + LayerSpec(ParallelTransformerLayerPipeStudent, + init_method=init_method, + output_layer_init_method=scaled_init_method_normal(args.init_method_std, + args.student_num_layers), + layer_number=layer_idx, + # TODO: Change naming of class from GPT to something that encapsulate prefix lm. + self_attn_mask_type=attn_mask_type)) + + # Undo data format change + def undo(x): + if isinstance(x, tuple): + return (x[0].transpose(0, 1).contiguous(), x[1:]) + return x.transpose(0, 1).contiguous() + self.specs.append(undo) + + # Final layernorm after transformer layers + self.specs.append( + LayerSpec(LayerNormStudent, + args.student_hidden_size, + eps=args.layernorm_epsilon)) + + def _logits_helper_student(embedding, lm_output): + """A wrapper to massage inputs/outputs from pipeline. """ + return parallel_lm_logits( + lm_output, + embedding.word_embeddings_weight, + self.parallel_output, permute_output=True) + + + self.specs.append( + TiedLayerSpec('embed_student', + EmbeddingPipeStudent, + args.student_hidden_size, + args.padded_vocab_size, + args.hidden_dropout, + init_method=init_method, + num_tokentypes=num_tokentypes, + forward_fn=_logits_helper_student, + tied_weight_attr='word_embeddings_weight') + ) + + # Convert to fp32 if needed + if args.fp16 or args.bf16: + self.specs.append(float16_to_fp32) + + if args.checkpoint_activations: + interval = args.checkpoint_num_layers + else: + interval = 0 + # transformer layer if args.pp_partition_method is not None: partition_method = args.pp_partition_method else: partition_method = 'type:transformer' - + + from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology + topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(), + num_mp=mpu.get_tensor_model_parallel_world_size(), + num_dp=mpu.get_data_parallel_world_size()) + super().__init__(layers=self.specs, - loss_fn=get_cross_entropy(is_prefix=attn_mask_type is AttnMaskType.prefix), + loss_fn=get_ts_loss(is_prefix=attn_mask_type is AttnMaskType.prefix), topology=topo, activation_checkpoint_interval=interval, partition_method=partition_method) diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index fc284431a..001d19709 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -15,6 +15,7 @@ """Transformer based language model.""" +from importlib import invalidate_caches import torch import torch.nn.functional as F @@ -27,8 +28,14 @@ from megatron.model.utils import init_method_normal, scaled_init_method_normal def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, - bias=None): + bias=None, permute_output=False): """LM logits using word embedding weights.""" + if isinstance(input_, tuple): + # retrieve the input tensor from the tuple + original_inputs = input_[1] + input_ = input_[0] + else: + original_inputs = None # Parallel logits. input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) # Matrix multiply. @@ -36,11 +43,12 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, logits_parallel = F.linear(input_parallel, word_embeddings_weight) else: logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) + # Gather if needed. if parallel_output: - return logits_parallel - - return mpu.gather_from_tensor_model_parallel_region(logits_parallel) + return (logits_parallel, original_inputs) if permute_output else (original_inputs, logits_parallel) + + return (mpu.gather_from_tensor_model_parallel_region(logits_parallel), original_inputs) if permute_output else (original_inputs, mpu.gather_from_tensor_model_parallel_region(logits_parallel)) def get_language_model(num_tokentypes, add_pooler, @@ -271,9 +279,11 @@ class EmbeddingPipe(Embedding): def forward(self, inputs, **kwargs): if not hasattr(self, '_args'): self._args = get_args() - input_ids = inputs[0] position_ids = inputs[1] + if isinstance(input_ids, tuple): + # print(input_ids) + input_ids = input_ids[0] if getattr(self._args, 'pretrain_causal_attention', False): attention_mask = None else: @@ -298,6 +308,16 @@ def word_embeddings_weight(self): """Easy accessory for the DeepSpeed pipeline engine to tie embeddings across stages.""" return self.word_embeddings.weight +class EmbeddingPipeTeacher(EmbeddingPipe): + # @torch.no_grad() + def forward(self, inputs, **kwargs): + return (super().forward(inputs, **kwargs), *inputs) + +class EmbeddingPipeStudent(EmbeddingPipe): + def forward(self, inputs, **kwargs): + inputs, logits_teacher = inputs + return (super().forward(inputs, **kwargs), logits_teacher) + class TransformerLanguageModel(MegatronModule): """Transformer language model. diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 03e6faaec..78400ad22 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -23,6 +23,7 @@ from megatron import mpu from .module import MegatronModule from megatron.enums import AttnMaskType, LayerType, AttnType, PositionEmbeddingType +from megatron.model.fused_layer_norm import MixedFusedLayerNormStudent as LayerNormStudent from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl @@ -65,18 +66,27 @@ class ParallelMLP(MegatronModule): applied. """ - def __init__(self, init_method, output_layer_init_method): + def __init__(self, init_method, output_layer_init_method, student_=False): super(ParallelMLP, self).__init__() args = get_args() # Project to ffn_hidden_size - self.dense_h_to_4h = mpu.ColumnParallelLinear( - args.hidden_size, - # GLU is a special activation that divides the dimension by a factor 2. - 2 * args.ffn_hidden_size if args.glu_activation else args.ffn_hidden_size, - gather_output=False, - init_method=init_method, - skip_bias_add=True) + if not student_: + self.dense_h_to_4h = mpu.ColumnParallelLinear( + args.hidden_size, + # GLU is a special activation that divides the dimension by a factor 2. + 2 * args.ffn_hidden_size if args.glu_activation else args.ffn_hidden_size, + gather_output=False, + init_method=init_method, + skip_bias_add=True) + else: + self.dense_h_to_4h = mpu.ColumnParallelLinear( + args.student_hidden_size, + # GLU is a special activation that divides the dimension by a factor 2. + 2 * args.student_ffn_hidden_size if args.glu_activation else args.student_ffn_hidden_size, + gather_output=False, + init_method=init_method, + skip_bias_add=True) self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu @@ -88,12 +98,20 @@ def __init__(self, init_method, output_layer_init_method): self.activation_func = erf_gelu # Project back to h. - self.dense_4h_to_h = mpu.RowParallelLinear( - args.ffn_hidden_size, - args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - skip_bias_add=True) + if not student_: + self.dense_4h_to_h = mpu.RowParallelLinear( + args.ffn_hidden_size, + args.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True) + else: + self.dense_4h_to_h = mpu.RowParallelLinear( + args.student_ffn_hidden_size, + args.student_hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True) def forward(self, hidden_states): @@ -123,7 +141,8 @@ class ParallelAttention(MegatronModule): def __init__(self, init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, - attn_mask_type=AttnMaskType.padding): + attn_mask_type=AttnMaskType.padding, + student_=False): super(ParallelAttention, self).__init__() args = get_args() self.fp16 = args.fp16 @@ -145,27 +164,27 @@ def __init__(self, init_method, self.hidden_size_per_partition = mpu.divide(projection_size, world_size) self.hidden_size_per_attention_head = mpu.divide( - projection_size, args.num_attention_heads) + projection_size, args.num_attention_heads if not student_ else args.student_num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( - args.num_attention_heads, world_size) + args.num_attention_heads if not student_ else args.student_num_attention_heads, world_size) # Strided linear layer. if attention_type == AttnType.self_attn: self.query_key_value = mpu.ColumnParallelLinear( - args.hidden_size, + args.hidden_size if not student_ else args.student_hidden_size, 3 * projection_size, gather_output=False, init_method=init_method) else: assert attention_type == AttnType.cross_attn self.query = mpu.ColumnParallelLinear( - args.hidden_size, + args.hidden_size if not student_ else args.student_hidden_size, projection_size, gather_output=False, init_method=init_method) self.key_value = mpu.ColumnParallelLinear( - args.hidden_size, + args.hidden_size if not student_ else args.student_hidden_size, 2 * projection_size, gather_output=False, init_method=init_method) @@ -192,7 +211,7 @@ def __init__(self, init_method, # Output. self.dense = mpu.RowParallelLinear( projection_size, - args.hidden_size, + args.hidden_size if not student_ else args.student_hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) @@ -439,7 +458,8 @@ class ParallelTransformerLayer(MegatronModule): def __init__(self, init_method, output_layer_init_method, layer_number, layer_type=LayerType.encoder, - self_attn_mask_type=AttnMaskType.padding): + self_attn_mask_type=AttnMaskType.padding, + student_=False): args = get_args() super(ParallelTransformerLayer, self).__init__() @@ -463,7 +483,8 @@ def __init__(self, init_method, output_layer_init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, - attn_mask_type=self_attn_mask_type) + attn_mask_type=self_attn_mask_type, + student_=student_) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion @@ -477,10 +498,11 @@ def __init__(self, init_method, output_layer_init_method, init_method, output_layer_init_method, layer_number, - attention_type=AttnType.cross_attn) + attention_type=AttnType.cross_attn, + student_=student_) # Layernorm on the attention output. self.post_inter_attention_layernorm = LayerNorm( - args.hidden_size, + args.hidden_size if not student_ else args.student_hidden_size, eps=args.layernorm_epsilon) # MLP @@ -504,6 +526,8 @@ def forward(self, hidden_states, attention_mask, # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) + if isinstance(layernorm_output, tuple): + layernorm_output, _ = layernorm_output # Self attention. attention_output, attention_bias = \ self.self_attention(layernorm_output, @@ -521,6 +545,12 @@ def forward(self, hidden_states, attention_mask, else: residual = hidden_states + if isinstance(residual, tuple): + if len(residual) > 1: + residual, _ = residual + else: + residual = residual[0] + # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying @@ -567,6 +597,9 @@ def forward(self, hidden_states, attention_mask, layernorm_output = self.post_inter_attention_layernorm(layernorm_input) # MLP. + # print("========", layernorm_output) + if isinstance(layernorm_output, tuple): + layernorm_output, _ = layernorm_output mlp_output, mlp_bias = self.mlp(layernorm_output) # Second residual connection. @@ -647,6 +680,108 @@ def forward(self, inputs, **kwargs): else: raise RuntimeError('Received more inputs than understood.') +class ParallelTransformerLayerPipeTeacher(ParallelTransformerLayerPipe): + """Extends ParallelTransformerLayer to forward attention_mask through the pipeline. + + Forward has two usages that affect attention mask communication: + + 1) forward((input, attn_mask) , **kwargs) -> (output, mask) + When the attention mask is provided as the second positional + argument, typical pipeline behavior is used and both the output + *and* mask are returned in a tuple. This tuple is then forwarded + to the next stage in the pipeline. + + This version is useful if masks are dynamic. + + 2) forward(input, **kwargs) -> output + When the mask is static over all samples, it is advantageous to + cache the mask and avoid communicating it. + """ + # @torch.no_grad() + def forward(self, inputs, **kwargs): + return (super().forward(inputs[0], **kwargs), *inputs[1:]) + +class ParallelTransformerLayerPipeStudent(ParallelTransformerLayerPipe): + """Extends ParallelTransformerLayer to forward attention_mask through the pipeline. + + Forward has two usages that affect attention mask communication: + + 1) forward((input, attn_mask) , **kwargs) -> (output, mask) + When the attention mask is provided as the second positional + argument, typical pipeline behavior is used and both the output + *and* mask are returned in a tuple. This tuple is then forwarded + to the next stage in the pipeline. + + This version is useful if masks are dynamic. + + 2) forward(input, **kwargs) -> output + When the mask is static over all samples, it is advantageous to + cache the mask and avoid communicating it. + """ + def __init__(self, init_method, output_layer_init_method, + layer_number, layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding): + args = get_args() + + super(ParallelTransformerLayer, self).__init__() + self.layer_number = layer_number + self.layer_type = layer_type + + self.apply_residual_connection_post_layernorm \ + = args.apply_residual_connection_post_layernorm + + self.bf16 = args.bf16 + self.fp32_residual_connection = args.fp32_residual_connection + + # Layernorm on the input data. + self.input_layernorm = LayerNormStudent( + args.student_hidden_size, + eps=args.layernorm_epsilon) + + # Self attention. + self.self_attention = ParallelAttention( + init_method, + output_layer_init_method, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=self_attn_mask_type, + student_=True) + self.hidden_dropout = args.hidden_dropout + self.bias_dropout_fusion = args.bias_dropout_fusion + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormStudent( + args.student_hidden_size, + eps=args.layernorm_epsilon) + + if self.layer_type == LayerType.decoder: + self.inter_attention = ParallelAttention( + init_method, + output_layer_init_method, + layer_number, + attention_type=AttnType.cross_attn, + student_=True) + # Layernorm on the attention output. + self.post_inter_attention_layernorm = LayerNormStudent( + args.student_hidden_size, + eps=args.layernorm_epsilon) + + # MLP + self.mlp = ParallelMLP(init_method, + output_layer_init_method, student_=True) + + # Alibi + if args.position_embedding_type == PositionEmbeddingType.alibi: + self.alibi = self._build_alibi_tensor(args.seq_length, args.student_num_attention_heads, args.micro_batch_size).to(torch.cuda.current_device()) + if args.params_dtype == torch.float16: + self.alibi = self.alibi.to(torch.float16) + elif args.params_dtype == torch.bfloat16: + self.alibi = self.alibi.to(torch.bfloat16) + else: + self.alibi = None + + def forward(self, inputs, **kwargs): + return (super().forward(inputs[0], **kwargs), inputs[1]) class ParallelTransformer(MegatronModule): """Transformer class.""" diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 738717d55..31f1e3a3f 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -18,10 +18,26 @@ from megatron import get_args from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm +from megatron.model.fused_layer_norm import MixedFusedLayerNormTeacher, MixedFusedLayerNormStudent +from megatron.model.transformer import ParallelTransformerLayerPipeStudent, ParallelTransformerLayerPipeTeacher +from megatron.model.language_model import EmbeddingPipeStudent, EmbeddingPipeTeacher from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer +def _filter_for_teacher_student(modules): + trainable_modules = [] + + for module in modules: + # for module_ in module.modules(): + for module_ in module.children(): + # TODO: this is empty ??? + if isinstance(module_, (ParallelTransformerLayerPipeStudent, EmbeddingPipeStudent, MixedFusedLayerNormStudent)): + trainable_modules.append(module_) + # return modules + return trainable_modules + + def _get_params_for_weight_decay_optimization(modules): """Divide params into with-weight-decay and without-weight-decay groups. @@ -30,9 +46,12 @@ def _get_params_for_weight_decay_optimization(modules): weight_decay_params = {'params': []} no_weight_decay_params = {'params': [], 'weight_decay': 0.0} + + # modules = _filter_for_teacher_student(modules) + for module in modules: for module_ in module.modules(): - if isinstance(module_, LayerNorm): + if isinstance(module_, MixedFusedLayerNormStudent): no_weight_decay_params['params'].extend( [p for p in list(module_._parameters.values()) if p is not None]) diff --git a/megatron/training.py b/megatron/training.py index bd00bc77e..d45eabc3c 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -953,8 +953,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler, params_norm = None if args.log_params_norm: params_norm = calc_params_l2_norm(model) + + # raise NotImplementedError(optimizer.param_groups) + report_memory_flag = training_log(loss_dict, total_loss_dict, - optimizer.param_groups[0]['lr'], + optimizer.param_groups[0]['lr'] if len(optimizer.param_groups) > 0 else 0.0, iteration, loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad,