diff --git a/tests/fairseq_layers.py b/tests/fairseq_layers.py index 2a85fd27..9b878e43 100644 --- a/tests/fairseq_layers.py +++ b/tests/fairseq_layers.py @@ -4,9 +4,8 @@ We use layers from Facebook Fairseq as our baseline for unit test """ -from typing import Dict, List, Optional, Callable +from typing import Dict, List, Optional import math -from copy import deepcopy import torch import torch.nn as nn @@ -18,7 +17,7 @@ from torch import Tensor -class TransformerEncoderLayer(nn.Module): +class FSTransformerEncoderLayer(nn.Module): """Encoder layer implemented by fairseq. This version only removes the "args" parameter, no other changes @@ -165,128 +164,7 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): return x -class TransformerSentenceEncoderLayer(nn.Module): - """ - Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained - models. - """ - - def __init__( - self, - embedding_dim: int = 768, - ffn_embedding_dim: int = 3072, - num_attention_heads: int = 8, - dropout: float = 0.1, - attention_dropout: float = 0.1, - activation_dropout: float = 0.1, - activation_fn: str = "relu", - export: bool = False, - q_noise: float = 0.0, - qn_block_size: int = 8, - init_fn: Callable = None, - ) -> None: - super().__init__() - - if init_fn is not None: - init_fn() - - # Initialize parameters - self.embedding_dim = embedding_dim - self.dropout_module = FairseqDropout( - dropout, module_name=self.__class__.__name__ - ) - self.activation_dropout_module = FairseqDropout( - activation_dropout, module_name=self.__class__.__name__ - ) - - # Initialize blocks - self.activation_fn = utils.get_activation_fn(activation_fn) - self.self_attn = self.build_self_attention( - self.embedding_dim, - num_attention_heads, - dropout=attention_dropout, - self_attention=True, - q_noise=q_noise, - qn_block_size=qn_block_size, - ) - - # layer norm associated with the self attention layer - self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export) - - self.fc1 = self.build_fc1( - self.embedding_dim, - ffn_embedding_dim, - q_noise=q_noise, - qn_block_size=qn_block_size, - ) - self.fc2 = self.build_fc2( - ffn_embedding_dim, - self.embedding_dim, - q_noise=q_noise, - qn_block_size=qn_block_size, - ) - - # layer norm associated with the position wise feed-forward NN - self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) - - def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): - return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) - - def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): - return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) - - def build_self_attention( - self, - embed_dim, - num_attention_heads, - dropout, - self_attention, - q_noise, - qn_block_size, - ): - return MultiheadAttention( - embed_dim, - num_attention_heads, - dropout=dropout, - self_attention=True, - q_noise=q_noise, - qn_block_size=qn_block_size, - ) - - def forward( - self, - x: torch.Tensor, - self_attn_mask: Optional[torch.Tensor] = None, - self_attn_padding_mask: Optional[torch.Tensor] = None, - ): - """ - LayerNorm is applied either before or after the self-attention/ffn - modules similar to the original Transformer implementation. - """ - residual = x - x, attn = self.self_attn( - query=x, - key=x, - value=x, - key_padding_mask=self_attn_padding_mask, - need_weights=False, - attn_mask=self_attn_mask, - ) - x = self.dropout_module(x) - x = residual + x - x = self.self_attn_layer_norm(x) - - residual = x - x = self.activation_fn(self.fc1(x)) - x = self.activation_dropout_module(x) - x = self.fc2(x) - x = self.dropout_module(x) - x = residual + x - x = self.final_layer_norm(x) - return x, attn - - -class TransformerDecoderLayer(nn.Module): +class FSTransformerDecoderLayer(nn.Module): """Decoder layer implemented by fairseq. This version only removes the "args" parameter, no other changes """ @@ -544,72 +422,6 @@ def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn -def generate_enc_layer(): - hidden_size = 1024 - intermediate_size = 1024 * 4 - heads = 16 - hidden_dropout_ratio = 0.0 - attn_dropout_ratio = 0.0 - activation_dropout_ratio = 0.0 - pre_layer_norm = True - layer = TransformerEncoderLayer( - hidden_size, - intermediate_size, - heads, - hidden_dropout_ratio, - attn_dropout_ratio, - activation_dropout_ratio, - pre_layer_norm, - activation_fn="relu", - ) - layer.to(torch.device("cuda:0"), dtype=torch.half) - return layer - - -def generate_dec_layer(): - hidden_size = 1024 - intermediate_size = 1024 * 4 - heads = 16 - hidden_dropout_ratio = 0.0 - attn_dropout_ratio = 0.0 - activation_dropout_ratio = 0.0 - pre_layer_norm = True - layer = TransformerDecoderLayer( - embed_dim=hidden_size, - ffn_embed_dim=intermediate_size, - nhead=heads, - encoder_embed_dim=hidden_size, - dropout=hidden_dropout_ratio, - attn_dropout=attn_dropout_ratio, - activation_dropout=activation_dropout_ratio, - normalize_before=pre_layer_norm, - activation_fn="relu", - ) - - layer.to(torch.device("cuda:0"), dtype=torch.half) - return layer - - -def generate_bert_enc_layer(): - hidden_size = 1024 - intermediate_size = 1024 * 4 - heads = 16 - hidden_dropout_ratio = 0.0 - attn_dropout_ratio = 0.0 - activation_dropout_ratio = 0.0 - layer = TransformerSentenceEncoderLayer( - hidden_size, - intermediate_size, - heads, - hidden_dropout_ratio, - attn_dropout_ratio, - activation_dropout_ratio, - activation_fn="gelu", - ) - layer.to(torch.device("cuda:0")) - return layer - - class SinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length. @@ -674,7 +486,7 @@ def forward( ).detach() -class TransformerEmbeddingLayer(nn.Module): +class FSTransformerEmbeddingLayer(nn.Module): def __init__( self, vocab_size, embedding_dim, max_seq_len, padding_idx, dropout, fp16 ): @@ -703,21 +515,97 @@ def forward(self, input): return x -def generate_emb_layer(ls_emb_config): - layer = TransformerEmbeddingLayer( - ls_emb_config.vocab_size, - ls_emb_config.embedding_dim, - ls_emb_config.max_seq_len, - ls_emb_config.padding_idx, - ls_emb_config.dropout, - ls_emb_config.fp16, - ) - dtype = torch.float16 if ls_emb_config.fp16 else torch.float32 - layer.to(torch.device("cuda:0"), dtype=dtype) - - return layer - +class FSCrossEntropyLayer(nn.Module): + def __init__(self, epsilon, ignore_index): + super().__init__() -if __name__ == "__main__": - generate_enc_layer() - generate_dec_layer() + self.epsilon = epsilon + self.ignore_index = ignore_index + + def label_smoothed_nll_loss(self, lprobs, target, reduce=True): + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + if self.ignore_index is not None: + pad_mask = target.eq(self.ignore_index) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) + else: + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + if reduce: + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + eps_i = self.epsilon / (lprobs.size(-1) - 1) + loss = (1.0 - self.epsilon - eps_i) * nll_loss + eps_i * smooth_loss + return loss, nll_loss + + def forward(self, inputs, targets): + x = torch.nn.functional.log_softmax(inputs, dim=-1, dtype=torch.float32) + loss, nll_loss = self.label_smoothed_nll_loss(x, targets) + loss = loss.to(inputs) + nll_loss = nll_loss.to(inputs) + + return loss, nll_loss + + +def get_fairseq_enc_params(fairseq_layer): + initial_weights = [] + initial_biases = [] + + initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.self_attn.out_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.out_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.self_attn_layer_norm.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn_layer_norm.bias.detach().clone()) + + initial_weights.append(fairseq_layer.fc1.weight.detach().clone()) + initial_biases.append(fairseq_layer.fc1.bias.detach().clone()) + initial_weights.append(fairseq_layer.fc2.weight.detach().clone()) + initial_biases.append(fairseq_layer.fc2.bias.detach().clone()) + initial_weights.append(fairseq_layer.final_layer_norm.weight.detach().clone()) + initial_biases.append(fairseq_layer.final_layer_norm.bias.detach().clone()) + return initial_weights, initial_biases + + +def get_fairseq_dec_params(fairseq_layer): + initial_weights = [] + initial_biases = [] + + initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.self_attn.out_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.out_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.self_attn_layer_norm.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn_layer_norm.bias.detach().clone()) + + initial_weights.append(fairseq_layer.encodec_attn.q_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.encodec_attn.q_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.encodec_attn.k_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.encodec_attn.k_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.encodec_attn.v_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.encodec_attn.v_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.encodec_attn.out_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.encodec_attn.out_proj.bias.detach().clone()) + initial_weights.append( + fairseq_layer.encodec_attn_layer_norm.weight.detach().clone() + ) + initial_biases.append(fairseq_layer.encodec_attn_layer_norm.bias.detach().clone()) + + initial_weights.append(fairseq_layer.fc1.weight.detach().clone()) + initial_biases.append(fairseq_layer.fc1.bias.detach().clone()) + initial_weights.append(fairseq_layer.fc2.weight.detach().clone()) + initial_biases.append(fairseq_layer.fc2.bias.detach().clone()) + initial_weights.append(fairseq_layer.final_layer_norm.weight.detach().clone()) + initial_biases.append(fairseq_layer.final_layer_norm.bias.detach().clone()) + return initial_weights, initial_biases diff --git a/tests/gen_test_layers.py b/tests/gen_test_layers.py new file mode 100644 index 00000000..1ed51590 --- /dev/null +++ b/tests/gen_test_layers.py @@ -0,0 +1,251 @@ +import torch + +from tests.fairseq_layers import ( + FSTransformerEncoderLayer, + FSTransformerDecoderLayer, + FSTransformerEmbeddingLayer, + FSCrossEntropyLayer, + get_fairseq_enc_params, + get_fairseq_dec_params, +) +from lightseq.training import ( + LSTransformerEncoderLayer, + LSTransformerEmbeddingLayer, + LSCrossEntropyLayer, +) +from examples.training.fairseq.fs_modules.ls_fs_transformer_decoder_layer import ( + LSFSTransformerDecoderLayer, +) + + +###################### encoder layer ###################### +def gen_enc_layer(config): + def gen_ls_enc_layer(initial_weights=None, initial_biases=None): + enc_config = LSTransformerEncoderLayer.get_config( + max_batch_tokens=config.max_batch_tokens, + max_seq_len=config.max_seq_len, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + nhead=config.nhead, + attn_prob_dropout_ratio=config.attn_prob_dropout_ratio, + activation_dropout_ratio=config.activation_dropout_ratio, + hidden_dropout_ratio=config.hidden_dropout_ratio, + pre_layer_norm=config.pre_layer_norm, + fp16=config.fp16, + local_rank=config.local_rank, + activation_fn=config.activation_fn, + ) + layer = LSTransformerEncoderLayer(enc_config, initial_weights, initial_biases) + layer.to( + torch.device("cuda:{}".format(config.local_rank)), + dtype=(torch.half if config.fp16 else torch.float), + ) + layer.train() + return layer + + def gen_fs_enc_layer(): + layer = FSTransformerEncoderLayer( + embed_dim=config.hidden_size, + ffn_embed_dim=config.intermediate_size, + nhead=config.nhead, + dropout=config.hidden_dropout_ratio, + attn_dropout=config.attn_prob_dropout_ratio, + activation_dropout=config.activation_dropout_ratio, + normalize_before=config.pre_layer_norm, + activation_fn=config.activation_fn, + ) + layer.to( + torch.device("cuda:{}".format(config.local_rank)), + dtype=(torch.half if config.fp16 else torch.float), + ) + layer.train() + return layer + + custom_enc_layer_list = [] + fairseq_enc_layer_list = [] + + for _ in range(config.num_layers): + fairseq_enc_layer = gen_fs_enc_layer() + initial_enc_weights, initial_enc_biases = get_fairseq_enc_params( + fairseq_enc_layer + ) + custom_enc_layer = gen_ls_enc_layer(initial_enc_weights, initial_enc_biases) + custom_enc_layer_list.append(custom_enc_layer) + fairseq_enc_layer_list.append(fairseq_enc_layer) + + return torch.nn.ModuleList(custom_enc_layer_list), torch.nn.ModuleList( + fairseq_enc_layer_list + ) + + +###################### decoder layer ###################### +def gen_dec_layer(config): + def gen_ls_dec_layer(initial_weights=None, initial_biases=None): + dec_config = LSFSTransformerDecoderLayer.get_config( + max_batch_tokens=config.max_batch_tokens, + max_seq_len=config.max_seq_len, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + nhead=config.nhead, + attn_prob_dropout_ratio=config.attn_prob_dropout_ratio, + activation_dropout_ratio=config.activation_dropout_ratio, + hidden_dropout_ratio=config.hidden_dropout_ratio, + pre_layer_norm=config.pre_layer_norm, + fp16=config.fp16, + local_rank=config.local_rank, + nlayer=config.num_layers, + activation_fn=config.activation_fn, + ) + layer = LSFSTransformerDecoderLayer( + dec_config, + initial_weights, + initial_biases, + ) + layer.to( + torch.device("cuda:{}".format(config.local_rank)), + dtype=(torch.half if config.fp16 else torch.float), + ) + layer.train() + return layer + + def gen_fs_dec_layer(): + layer = FSTransformerDecoderLayer( + embed_dim=config.hidden_size, + ffn_embed_dim=config.intermediate_size, + nhead=config.nhead, + encoder_embed_dim=config.hidden_size, + dropout=config.hidden_dropout_ratio, + attn_dropout=config.attn_prob_dropout_ratio, + activation_dropout=config.activation_dropout_ratio, + normalize_before=config.pre_layer_norm, + activation_fn=config.activation_fn, + ) + layer.to( + torch.device("cuda:{}".format(config.local_rank)), + dtype=(torch.half if config.fp16 else torch.float), + ) + layer.train() + return layer + + custom_dec_layer_list = [] + fairseq_dec_layer_list = [] + _initial_dec_weights_list = [] + _initial_dec_biases_list = [] + _initial_encdec_attn_kvw_list = [] + _initial_encdec_attn_kvb_list = [] + + for _ in range(config.num_layers): + fairseq_dec_layer = gen_fs_dec_layer() + initial_dec_weights, initial_dec_biases = get_fairseq_dec_params( + fairseq_dec_layer + ) + fairseq_dec_layer_list.append(fairseq_dec_layer) + _initial_dec_weights_list.append(initial_dec_weights) + _initial_dec_biases_list.append(initial_dec_biases) + _initial_encdec_attn_kvw_list.append(initial_dec_weights[6]) + _initial_encdec_attn_kvw_list.append(initial_dec_weights[7]) + _initial_encdec_attn_kvb_list.append(initial_dec_biases[6]) + _initial_encdec_attn_kvb_list.append(initial_dec_biases[7]) + + _initial_encdec_attn_kvw = torch.cat(_initial_encdec_attn_kvw_list, dim=0) + _initial_encdec_attn_kvb = torch.cat(_initial_encdec_attn_kvb_list, dim=0) + for i in range(config.num_layers): + _initial_dec_weights_list[i].pop(7) + _initial_dec_weights_list[i].pop(6) + if i == 0: + _initial_dec_weights_list[i].append(_initial_encdec_attn_kvw) + _initial_dec_biases_list[i].pop(7) + _initial_dec_biases_list[i].pop(6) + if i == 0: + _initial_dec_biases_list[i].append(_initial_encdec_attn_kvb) + custom_dec_layer = gen_ls_dec_layer( + _initial_dec_weights_list[i], _initial_dec_biases_list[i] + ) + custom_dec_layer_list.append(custom_dec_layer) + + return torch.nn.ModuleList(custom_dec_layer_list), torch.nn.ModuleList( + fairseq_dec_layer_list + ) + + +###################### embedding layer ###################### +def gen_emb_layer(config): + def gen_ls_emb_layer(initial_embedding=None): + emb_config = LSTransformerEmbeddingLayer.get_config( + vocab_size=config.vocab_size, + embedding_dim=config.hidden_size, + max_batch_tokens=config.max_batch_tokens, + max_seq_len=config.max_seq_len, + padding_idx=config.padding_idx, + dropout=config.hidden_dropout_ratio, + fp16=config.fp16, + local_rank=config.local_rank, + ) + layer = LSTransformerEmbeddingLayer( + emb_config, + initial_embedding, + ) + layer.to( + torch.device("cuda:{}".format(config.local_rank)), + dtype=(torch.half if config.fp16 else torch.float), + ) + layer.train() + return layer + + def gen_fs_emb_layer(): + layer = FSTransformerEmbeddingLayer( + vocab_size=config.vocab_size, + embedding_dim=config.hidden_size, + max_seq_len=config.max_seq_len, + padding_idx=config.padding_idx, + dropout=config.hidden_dropout_ratio, + fp16=config.fp16, + ) + layer.to( + torch.device("cuda:{}".format(config.local_rank)), + dtype=(torch.half if config.fp16 else torch.float), + ) + layer.train() + return layer + + fairseq_emb_layer = gen_fs_emb_layer() + initial_embedding = fairseq_emb_layer.embeddings.weight.detach().clone() + custom_emb_layer = gen_ls_emb_layer(initial_embedding) + + return custom_emb_layer, fairseq_emb_layer + + +###################### cross entropy layer ###################### +def gen_ce_layer(config): + def gen_ls_ce_layer(): + ce_config = LSCrossEntropyLayer.get_config( + max_batch_tokens=config.max_batch_tokens, + padding_idx=config.padding_idx, + epsilon=config.label_smooth, + fp16=config.fp16, + local_rank=config.local_rank, + ) + layer = LSCrossEntropyLayer(ce_config) + layer.to( + torch.device("cuda:{}".format(config.local_rank)), + dtype=(torch.half if config.fp16 else torch.float), + ) + layer.train() + return layer + + def gen_fs_ce_layer(): + layer = FSCrossEntropyLayer( + epsilon=config.label_smooth, + ignore_index=config.padding_idx, + ) + layer.to( + torch.device("cuda:{}".format(config.local_rank)), + dtype=(torch.half if config.fp16 else torch.float), + ) + layer.train() + return layer + + fairseq_ce_layer = gen_fs_ce_layer() + custom_ce_layer = gen_ls_ce_layer() + + return custom_ce_layer, fairseq_ce_layer diff --git a/tests/test_ls_kernels.py b/tests/test_ls_kernels.py index 6600e9df..293e24f8 100644 --- a/tests/test_ls_kernels.py +++ b/tests/test_ls_kernels.py @@ -14,8 +14,7 @@ @kt.case() def test_launch_bias_add_transform_20314(): batch_size, seq_len = kt.bs_sl() - hidden_dim = kt.hidden_dim - nhead = kt.nhead + hidden_dim, nhead = kt.h_nh head_dim = int(hidden_dim / nhead) count = random.randint(1, 20) print( @@ -56,8 +55,7 @@ def baseline(): @kt.case() def test_launch_transform_0213(): batch_size, seq_len = kt.bs_sl() - hidden_dim = kt.hidden_dim - nhead = kt.nhead + hidden_dim, nhead = kt.h_nh head_dim = int(hidden_dim / nhead) print( "(batch_size, seq_len, hidden_dim, nhead): " @@ -93,8 +91,7 @@ def baseline(): @kt.case() def test_launch_transform4d_0213(): batch_size, seq_len = kt.bs_sl() - hidden_dim = kt.hidden_dim - nhead = kt.nhead + hidden_dim, nhead = kt.h_nh head_dim = int(hidden_dim / nhead) trans_count = random.choice([1, 3]) print( @@ -128,7 +125,7 @@ def baseline(): return custom, baseline -@kt.case(atol=1e-3, rtol=1e-3, ntest=20) +@kt.case(rtol=1e-3, atol=1e-3) def test_launch_attn_softmax(): batch_size, from_len = kt.bs_sl() is_dec_self_attn = random.choice([True, False]) @@ -145,7 +142,7 @@ def test_launch_attn_softmax(): beam_size = random.choice([3, 4, 5]) batch_size *= beam_size - nhead = kt.nhead + _, nhead = kt.h_nh print( "(batch_size, nhead, from_len, to_len, is_dec_self_attn): " f"({batch_size}, {nhead}, {from_len}, {to_len}, {is_dec_self_attn})" @@ -199,9 +196,9 @@ def baseline(): return custom, baseline -@kt.case(atol=1e-2, rtol=1e-3) +@kt.case(rtol=1e-3, atol=1e-2) def test_launch_attn_softmax_bw(): - nhead = kt.nhead + _, nhead = kt.h_nh batch_size, from_len = kt.bs_sl() _, to_len = kt.bs_sl(batch_size) print( @@ -243,7 +240,7 @@ def baseline(): @kt.case() def test_launch_fused_add2(): batch_size, seq_len = kt.bs_sl() - hidden_dim = kt.hidden_dim + hidden_dim, _ = kt.h_nh print( "(batch_size, seq_len, hidden_dim): " f"({batch_size}, {seq_len}, {hidden_dim})" ) @@ -275,11 +272,11 @@ def baseline(): return custom, baseline -@kt.case(atol=1e-2, rtol=1e-3) +@kt.case(rtol=1e-3, atol=1e-2) def test_launch_layer_norm(): batch_size, seq_len = kt.bs_sl() bsz_seq = batch_size * seq_len - hidden_dim = kt.hidden_dim + hidden_dim, _ = kt.h_nh with_mean = random.choice([True, False]) print( "(batch_token_num, hidden_dim, with_mean): " @@ -316,11 +313,11 @@ def baseline(): return custom, baseline -@kt.case(atol=1e-3, rtol=1e-2) +@kt.case(rtol=1e-2, atol=1e-3) def test_launch_ln_bw(): batch_size, seq_len = kt.bs_sl() bsz_seq = batch_size * seq_len - hidden_dim = kt.hidden_dim + hidden_dim, _ = kt.h_nh with_mean = random.choice([True, False]) fuse_add = random.choice([True, False]) print( @@ -396,10 +393,10 @@ def baseline(): return custom, baseline -@kt.case() +@kt.case(rtol=1e-3, atol=1e-4) def test_launch_ffn_bias_bwd(): batch_size, seq_len = kt.bs_sl() - hidden_dim = kt.hidden_dim + hidden_dim, _ = kt.h_nh coef = random.randint(1, 4) print("(rows, cols): " f"({batch_size*seq_len}, {coef*hidden_dim})") @@ -434,8 +431,7 @@ def baseline(): @kt.case() def test_launch_concat3_dim1(): batch_size, seq_len = kt.bs_sl() - hidden_dim = kt.hidden_dim - nhead = kt.nhead + hidden_dim, nhead = kt.h_nh head_dim = int(hidden_dim / nhead) assert seq_len > 1 sl1 = random.randint(1, seq_len - 1) @@ -465,10 +461,10 @@ def baseline(): return custom, baseline -@kt.case(dtypes=[torch.float32]) +@kt.case(dtypes=[torch.float]) def test_adam(): batch_size, seq_len = kt.bs_sl() - hidden_dim = kt.hidden_dim + hidden_dim, _ = kt.h_nh cus_p = kt.rand((batch_size, seq_len, hidden_dim * 32)) cus_out_p = kt.rand((batch_size, seq_len, hidden_dim * 32)) cus_exp_avg = kt.rand((batch_size, seq_len, hidden_dim * 32)) @@ -524,20 +520,15 @@ def baseline(): return custom, baseline -@kt.case(dtypes=[torch.float, torch.half], ntest=5, atol=1e-2, rtol=1e-2) +@kt.case(rtol=1e-2, atol=1e-2) def test_launch_dropout_relu_bias(): batch_size, seq_len = kt.bs_sl() - hidden_dim = kt.hidden_dim + hidden_dim, _ = kt.h_nh print("test shape:", (batch_size, seq_len, hidden_dim)) test_input = kt.rand((batch_size, seq_len, hidden_dim)) test_bias = kt.rand((hidden_dim,)) - test_out_base = kt.rand((batch_size, seq_len, hidden_dim)) test_out_cus = kt.rand((batch_size, seq_len, hidden_dim)) - test_mask_base = torch.rand((batch_size, seq_len, hidden_dim)).to( - dtype=torch.uint8, - device="cuda:0", - ) test_mask_cus = torch.rand((batch_size, seq_len, hidden_dim)).to( dtype=torch.uint8, device="cuda:0" ) @@ -569,21 +560,15 @@ def baseline(): return custom, baseline -@kt.case(dtypes=[torch.float, torch.half], ntest=5, atol=1e-2, rtol=1e-2) +@kt.case(rtol=1e-2, atol=1e-2) def test_launch_dropout_gelu_bias(): batch_size, seq_len = kt.bs_sl() - hidden_dim = kt.hidden_dim + hidden_dim, _ = kt.h_nh print("test shape:", (batch_size, seq_len, hidden_dim)) test_input = kt.rand((batch_size, seq_len, hidden_dim)) test_bias = kt.rand((hidden_dim,)) - test_out_base = kt.rand((batch_size, seq_len, hidden_dim)) test_out_cus = kt.rand((batch_size, seq_len, hidden_dim)) - temp = kt.rand((batch_size, seq_len, hidden_dim)) - test_mask_base = torch.rand((batch_size, seq_len, hidden_dim)).to( - dtype=torch.uint8, - device="cuda:0", - ) test_mask_cus = torch.rand((batch_size, seq_len, hidden_dim)).to( dtype=torch.uint8, device="cuda:0" ) @@ -615,10 +600,11 @@ def baseline(): return custom, baseline -@kt.case(dtypes=[torch.float, torch.half], ntest=5, atol=1e-2, rtol=1e-2) +@kt.case(rtol=1e-2, atol=1e-2) def test_launch_dropout_relu_bias_bwd(): batch_size, seq_len = kt.bs_sl() - hidden_dim = kt.hidden_dim * 4 + hidden_dim, _ = kt.h_nh + hidden_dim *= 4 print("test shape:", (batch_size, seq_len, hidden_dim)) test_input = kt.rand((batch_size, seq_len, hidden_dim)) @@ -661,10 +647,11 @@ def baseline(): return custom, baseline -@kt.case(dtypes=[torch.float, torch.half], ntest=5, atol=1e-2, rtol=1e-2) +@kt.case(rtol=1e-2, atol=1e-2) def test_launch_dropout_gelu_bias_bwd(): batch_size, seq_len = kt.bs_sl() - hidden_dim = kt.hidden_dim * 4 + hidden_dim, _ = kt.h_nh + hidden_dim *= 4 print("test shape:", (batch_size, seq_len, hidden_dim)) test_input = kt.rand((batch_size, seq_len, hidden_dim)) @@ -719,22 +706,22 @@ def baseline(): if __name__ == "__main__": - kt.init(device="cuda:0", nhead=16) - kernel_list = [ - "test_launch_transform_0213", - "test_launch_bias_add_transform_20314", - "test_launch_transform4d_0213", - "test_launch_fused_add2", - "test_launch_ffn_bias_bwd", - "test_launch_attn_softmax", - "test_launch_attn_softmax_bw", - "test_launch_layer_norm", - "test_launch_ln_bw", - "test_launch_concat3_dim1", - "test_adam", - "test_launch_dropout_gelu_bias", - "test_launch_dropout_relu_bias", - "test_launch_dropout_relu_bias_bwd", - "test_launch_dropout_gelu_bias_bwd", - ] - kt.run(kernel_list) + kt.run( + [ + "test_launch_transform_0213", + "test_launch_bias_add_transform_20314", + "test_launch_transform4d_0213", + "test_launch_fused_add2", + "test_launch_ffn_bias_bwd", + "test_launch_attn_softmax", + "test_launch_attn_softmax_bw", + "test_launch_layer_norm", + "test_launch_ln_bw", + "test_launch_concat3_dim1", + "test_adam", + "test_launch_dropout_gelu_bias", + "test_launch_dropout_relu_bias", + "test_launch_dropout_relu_bias_bwd", + "test_launch_dropout_gelu_bias_bwd", + ] + ) diff --git a/tests/test_ls_ops.py b/tests/test_ls_ops.py index 7660bf43..20013b53 100644 --- a/tests/test_ls_ops.py +++ b/tests/test_ls_ops.py @@ -1,274 +1,55 @@ +import multiprocessing as mp import random -import math -from copy import deepcopy -from dataclasses import dataclass import torch -import torch.nn as nn - from tests.util import ( TestDecorator, - get_fairseq_enc_params, - get_fairseq_dec_params, - max_batch_tokens, - max_seq_len, split_custom_layer_grad, copy_grad_from_paras, ) - -from tests import fairseq_layers -from lightseq.training.ops.pytorch.transformer_encoder_layer import ( - LSTransformerEncoderLayer, -) -from lightseq.training.ops.pytorch.transformer_embedding_layer import ( - LSTransformerEmbeddingLayer, -) -from lightseq.training.ops.pytorch.cross_entropy_layer import LSCrossEntropyLayer -from examples.training.fairseq.fs_modules.ls_fs_transformer_decoder_layer import ( - LSFSTransformerDecoderLayer, -) - -kt = TestDecorator() - -num_layers = 1 - -###################### encoding layer ###################### - - -def generate_enc_layer(initial_weights=None, initial_biases=None): - config = LSTransformerEncoderLayer.get_config( - max_batch_tokens=max_batch_tokens, - max_seq_len=max_seq_len, - hidden_size=1024, - intermediate_size=4096, - nhead=16, - attn_prob_dropout_ratio=0.0, - activation_dropout_ratio=0.0, - hidden_dropout_ratio=0.0, - pre_layer_norm=True, - fp16=True, - local_rank=0, - activation_fn="relu", - ) - layer = LSTransformerEncoderLayer(config, initial_weights, initial_biases) - layer.to(torch.device("cuda:0"), dtype=torch.half) - return layer - - -custom_enc_layer_list = [] -fairseq_enc_layer_list = [] - - -def gen_enc_layer_pair(): - fairseq_enc_layer = fairseq_layers.generate_enc_layer() - fairseq_enc_layer.train() - initial_enc_weights, initial_enc_biases = get_fairseq_enc_params(fairseq_enc_layer) - custom_enc_layer = generate_enc_layer(initial_enc_weights, initial_enc_biases) - custom_enc_layer.train() - return fairseq_enc_layer, custom_enc_layer - - -for _ in range(num_layers): - fairseq_enc_layer, custom_enc_layer = gen_enc_layer_pair() - custom_enc_layer_list.append(custom_enc_layer) - fairseq_enc_layer_list.append(fairseq_enc_layer) - - -###################### bert encoder layer ###################### - - -def get_test_bert_encoder(num_layers): - def ls_generate_bert_enc_layer(initial_weights=None, initial_biases=None): - config = LSTransformerEncoderLayer.get_config( - max_batch_tokens=max_batch_tokens, - max_seq_len=max_seq_len, - hidden_size=1024, - intermediate_size=4096, - nhead=16, - attn_prob_dropout_ratio=0.0, - activation_dropout_ratio=0.0, - hidden_dropout_ratio=0.0, - pre_layer_norm=False, - fp16=True, - local_rank=0, - activation_fn="gelu", - ) - layer = LSTransformerEncoderLayer(config, initial_weights, initial_biases) - layer.to(torch.device("cuda:0")) - return layer - - def gen_bert_enc_layer_pair(): - fairseq_enc_layer = fairseq_layers.generate_bert_enc_layer() - fairseq_enc_layer.train() - initial_enc_weights, initial_enc_biases = get_fairseq_enc_params( - fairseq_enc_layer - ) - custom_enc_layer = ls_generate_bert_enc_layer( - initial_enc_weights, initial_enc_biases - ) - custom_enc_layer.train() - return fairseq_enc_layer, custom_enc_layer - - custom_bert_enc_layer_list = [] - fairseq_bert_enc_layer_list = [] - for _ in range(num_layers): - fairseq_enc_layer, custom_enc_layer = gen_bert_enc_layer_pair() - custom_bert_enc_layer_list.append(custom_enc_layer) - fairseq_bert_enc_layer_list.append(fairseq_enc_layer) - - return torch.nn.ModuleList(custom_bert_enc_layer_list), torch.nn.ModuleList( - fairseq_bert_enc_layer_list - ) - - -###################### decoding layer ###################### - - -def generate_dec_layer(initial_weights=None, initial_biases=None): - config = LSFSTransformerDecoderLayer.get_config( - max_batch_tokens=max_batch_tokens, - max_seq_len=max_seq_len, - hidden_size=1024, - intermediate_size=4096, - nhead=16, - attn_prob_dropout_ratio=0.0, - activation_dropout_ratio=0.0, - hidden_dropout_ratio=0.0, - pre_layer_norm=True, - fp16=True, - local_rank=0, - nlayer=num_layers, - activation_fn="relu", - ) - layer = LSFSTransformerDecoderLayer( - config, - initial_weights, - initial_biases, - ) - layer.to(torch.device("cuda:0"), dtype=torch.half) - return layer - - -custom_dec_layer_list = [] -fairseq_dec_layer_list = [] -_initial_dec_weights_list = [] -_initial_dec_biases_list = [] -_initial_encdec_attn_kvw_list = [] -_initial_encdec_attn_kvb_list = [] - -for _ in range(num_layers): - fairseq_dec_layer = fairseq_layers.generate_dec_layer() - fairseq_dec_layer.train() - initial_dec_weights, initial_dec_biases = get_fairseq_dec_params(fairseq_dec_layer) - fairseq_dec_layer_list.append(fairseq_dec_layer) - _initial_dec_weights_list.append(initial_dec_weights) - _initial_dec_biases_list.append(initial_dec_biases) - _initial_encdec_attn_kvw_list.append(initial_dec_weights[6]) - _initial_encdec_attn_kvw_list.append(initial_dec_weights[7]) - _initial_encdec_attn_kvb_list.append(initial_dec_biases[6]) - _initial_encdec_attn_kvb_list.append(initial_dec_biases[7]) - -_initial_encdec_attn_kvw = torch.cat(_initial_encdec_attn_kvw_list, dim=0) -_initial_encdec_attn_kvb = torch.cat(_initial_encdec_attn_kvb_list, dim=0) -for i in range(num_layers): - _initial_dec_weights_list[i].pop(7) - _initial_dec_weights_list[i].pop(6) - if i == 0: - _initial_dec_weights_list[i].append(_initial_encdec_attn_kvw) - _initial_dec_biases_list[i].pop(7) - _initial_dec_biases_list[i].pop(6) - if i == 0: - _initial_dec_biases_list[i].append(_initial_encdec_attn_kvb) - custom_dec_layer = generate_dec_layer( - _initial_dec_weights_list[i], _initial_dec_biases_list[i] - ) - custom_dec_layer.train() - custom_dec_layer_list.append(custom_dec_layer) - -# ###################### embedding layer ###################### - -ls_emb_config_fp16 = LSTransformerEmbeddingLayer.get_config( - vocab_size=40480, - embedding_dim=1024, - max_batch_tokens=9216, - max_seq_len=256, - padding_idx=2, - dropout=0.0, - fp16=True, - local_rank=0, -) -ls_emb_config_fp32 = deepcopy(ls_emb_config_fp16) -ls_emb_config_fp32.fp16 = False - -fs_emb_layer_fp32 = fairseq_layers.generate_emb_layer(ls_emb_config_fp32) -fs_emb_layer_fp16 = fairseq_layers.generate_emb_layer(ls_emb_config_fp16) -fs_emb_layer_fp32.train() -fs_emb_layer_fp16.train() - - -def generate_emb_layer(config, initial_weights=None): - custom_layer = LSTransformerEmbeddingLayer(config, initial_weights) - dtype = torch.float16 if config.fp16 else torch.float32 - custom_layer.to(torch.device("cuda:0"), dtype=dtype) - return custom_layer - - -custom_emb_layer_fp32 = generate_emb_layer( - ls_emb_config_fp32, fs_emb_layer_fp32.embeddings.weight.detach().clone() -) -custom_emb_layer_fp16 = generate_emb_layer( - ls_emb_config_fp16, fs_emb_layer_fp16.embeddings.weight.detach().clone() -) -custom_emb_layer_fp32.train() -custom_emb_layer_fp16.train() - -###################### cross entropy layer ###################### - -ce_config_fp16 = LSCrossEntropyLayer.get_config( - max_batch_tokens=9216, - padding_idx=2, - epsilon=0.1, - fp16=True, - local_rank=0, +from tests.gen_test_layers import ( + gen_enc_layer, + gen_dec_layer, + gen_emb_layer, + gen_ce_layer, ) -ce_config_fp32 = deepcopy(ce_config_fp16) -ce_config_fp32.fp16 = False -def generate_cross_entropy_layer(config): - dtype = torch.float16 if config.fp16 else torch.float32 - custom_layer = LSCrossEntropyLayer(config) - custom_layer.to(torch.device("cuda:0"), dtype=dtype) - return custom_layer +kt = TestDecorator() +config = kt.generate_config() +kt.dtypes = [torch.half if config.fp16 else torch.float] -custom_cross_entropy_layer_fp32 = generate_cross_entropy_layer(ce_config_fp32) -custom_cross_entropy_layer_fp16 = generate_cross_entropy_layer(ce_config_fp16) -custom_cross_entropy_layer_fp32.train() -custom_cross_entropy_layer_fp16.train() +custom_enc_layers, fairseq_enc_layers = gen_enc_layer(config) +custom_dec_layers, fairseq_dec_layers = gen_dec_layer(config) +custom_emb_layer, fairseq_emb_layer = gen_emb_layer(config) +custom_ce_layer, fairseq_ce_layer = gen_ce_layer(config) -@kt.case(dtypes=[torch.half], rtol=1e-3, atol=1e-2, ntest=10) +@kt.case(rtol=1e-3, atol=1e-2) def test_encoder_layer_forward(): batch_size, seq_len = kt.bs_sl() - print(f"(batch_size, seq_len): ({batch_size}, {seq_len})") + hidden_size = config.hidden_size - hidden_states = kt.rand((batch_size, seq_len, 1024)) + print( + f"(batch_size, seq_len, hidden_size): ({batch_size}, {seq_len}, {hidden_size})" + ) + hidden_states = kt.rand((batch_size, seq_len, hidden_size)) self_attn_padding_mask = kt.attn_mask(batch_size, seq_len, dtype=torch.bool) def custom(): res = hidden_states.clone() - for i in range(num_layers): - res = custom_enc_layer_list[i](res, self_attn_padding_mask) + for layer in custom_enc_layers: + res = layer(res, self_attn_padding_mask) return [ res.contiguous().detach(), ] def baseline(): res = hidden_states.transpose(0, 1).contiguous().clone() - for i in range(num_layers): - res = fairseq_enc_layer_list[i](res, self_attn_padding_mask) + for layer in fairseq_enc_layers: + res = layer(res, self_attn_padding_mask) return [ res.transpose(0, 1).contiguous().detach(), ] @@ -276,170 +57,43 @@ def baseline(): return custom, baseline -@kt.case(dtypes=[torch.half], rtol=1e-2, atol=1e-2, ntest=10) +@kt.case(rtol=1e-2, atol=1e-2) def test_encoder_layer_backward(): batch_size, seq_len = kt.bs_sl() - print(f"(batch_size, seq_len): ({batch_size}, {seq_len})") - hidden_size = 1024 - shs = hidden_size * hidden_size - - hidden_states = kt.rand((batch_size, seq_len, hidden_size)) - self_attn_padding_mask = kt.attn_mask(batch_size, seq_len, dtype=torch.bool) - loss_data = torch.randn(1, dtype=hidden_states.dtype).sum() - - def custom(): - for i in range(num_layers): - custom_enc_layer_list[i].zero_grad() - res = hidden_states.clone() - for i in range(num_layers): - res = custom_enc_layer_list[i](res, self_attn_padding_mask) - custom_loss = (res / 1000).sum() - custom_loss.data.copy_(loss_data) - custom_loss.backward() - grad_list = [] - for i in range(num_layers - 1, -1, -1): - """ - attn_qkvw, attn_qkvb, attn_ow, attn_ob, attn_nw, attn_nb, - inter_w, inter_b, output_w, output_b, ffn_nw, ffn_nb - """ - grads = split_custom_layer_grad(custom_enc_layer_list[i]) - grad_list.extend( - [ - grads[8], - grads[9], - grads[6], - grads[7], - grads[10], - grads[11], - grads[2], - grads[3], - grads[0][:shs], - grads[1][:hidden_size], - grads[0][shs : shs * 2], - grads[1][hidden_size : hidden_size * 2], - grads[0][shs * 2 : shs * 3], - grads[1][hidden_size * 2 : hidden_size * 3], - grads[4], - grads[5], - ] - ) - return grad_list - - def baseline(): - for i in range(num_layers): - fairseq_enc_layer_list[i].zero_grad() - res = hidden_states.transpose(0, 1).clone() - for i in range(num_layers): - res = fairseq_enc_layer_list[i](res, self_attn_padding_mask) - fairseq_loss = (res / 1000).sum() - fairseq_loss.data.copy_(loss_data) - fairseq_loss.backward() - grad_list = [] - for i in range(num_layers - 1, -1, -1): - curl = fairseq_enc_layer_list[i] - cur_grads = copy_grad_from_paras( - [ - curl.fc2.weight, - curl.fc2.bias, - curl.fc1.weight, - curl.fc1.bias, - curl.final_layer_norm.weight, - curl.final_layer_norm.bias, - curl.self_attn.out_proj.weight, - curl.self_attn.out_proj.bias, - curl.self_attn.q_proj.weight, - curl.self_attn.q_proj.bias, - curl.self_attn.k_proj.weight, - curl.self_attn.k_proj.bias, - curl.self_attn.v_proj.weight, - curl.self_attn.v_proj.bias, - curl.self_attn_layer_norm.weight, - curl.self_attn_layer_norm.bias, - ] - ) - grad_list.extend(cur_grads) - return grad_list - - return custom, baseline - - -@kt.case(dtypes=[torch.float, torch.half], rtol=1e-3, atol=1e-2, ntest=10) -def test_bert_encoder_layer_forward(): - batch_size, seq_len = kt.bs_sl() - print(f"(batch_size, seq_len): ({batch_size}, {seq_len})") - - hidden_states = kt.rand((batch_size, seq_len, 1024)) - self_attn_padding_mask = kt.attn_mask(batch_size, seq_len, dtype=torch.bool) - num_layers = 1 - - custom_bert_enc_layer_list, fairseq_bert_enc_layer_list = get_test_bert_encoder( - num_layers + hidden_size = config.hidden_size + print( + f"(batch_size, seq_len, hidden_size): ({batch_size}, {seq_len}, {hidden_size})" ) - custom_bert_enc_layer_list = custom_bert_enc_layer_list.to(kt.dtype) - fairseq_bert_enc_layer_list = fairseq_bert_enc_layer_list.to(kt.dtype) - - def custom(): - res = hidden_states.clone() - for i in range(num_layers): - res = custom_bert_enc_layer_list[i](res, self_attn_padding_mask) - return [ - res.contiguous().detach(), - ] - - def baseline(): - res = hidden_states.transpose(0, 1).contiguous().clone() - for i in range(num_layers): - res = fairseq_bert_enc_layer_list[i]( - res, self_attn_padding_mask=self_attn_padding_mask - )[0] - return [ - res.transpose(0, 1).contiguous().detach(), - ] - - del custom_bert_enc_layer_list, fairseq_bert_enc_layer_list - return custom, baseline - - -@kt.case(dtypes=[torch.float, torch.half], rtol=1e-2, atol=1e-2, ntest=10) -def test_bert_encoder_layer_backward(): - batch_size, seq_len = kt.bs_sl() - print(f"(batch_size, seq_len): ({batch_size}, {seq_len})") - hidden_size = 1024 shs = hidden_size * hidden_size - hidden_states = kt.rand((batch_size, seq_len, hidden_size)) self_attn_padding_mask = kt.attn_mask(batch_size, seq_len, dtype=torch.bool) - num_layers = 1 - custom_bert_enc_layer_list, fairseq_bert_enc_layer_list = get_test_bert_encoder( - num_layers - ) - custom_bert_enc_layer_list = custom_bert_enc_layer_list.to(kt.dtype).train() - fairseq_bert_enc_layer_list = fairseq_bert_enc_layer_list.to(kt.dtype).train() - - cus_x = hidden_states.clone() - for i in range(num_layers): - cus_x = custom_bert_enc_layer_list[i](cus_x, self_attn_padding_mask) - custom_loss = (cus_x / 1000).sum() + # custom fw + custom_enc_layers.zero_grad() + res = hidden_states.clone() + for layer in custom_enc_layers: + res = layer(res, self_attn_padding_mask) + custom_loss = (res / 1000).sum() - base_x = hidden_states.transpose(0, 1).clone() - for i in range(num_layers): - base_x = fairseq_bert_enc_layer_list[i]( - base_x, self_attn_padding_mask=self_attn_padding_mask - )[0] - fairseq_loss = (base_x.transpose(0, 1) / 1000).sum() + # fairseq fw + fairseq_enc_layers.zero_grad() + res = hidden_states.transpose(0, 1).clone() + for layer in fairseq_enc_layers: + res = layer(res, self_attn_padding_mask) + fairseq_loss = (res / 1000).sum() def custom(): - custom_bert_enc_layer_list.zero_grad() + custom_enc_layers.zero_grad() custom_loss.backward(retain_graph=True) + grad_list = [] - for i in range(num_layers - 1, -1, -1): + for i in range(config.num_layers - 1, -1, -1): """ attn_qkvw, attn_qkvb, attn_ow, attn_ob, attn_nw, attn_nb, inter_w, inter_b, output_w, output_b, ffn_nw, ffn_nb """ - grads = split_custom_layer_grad(custom_bert_enc_layer_list[i]) + grads = split_custom_layer_grad(custom_enc_layers[i]) grad_list.extend( [ grads[8], @@ -463,11 +117,12 @@ def custom(): return grad_list def baseline(): - fairseq_bert_enc_layer_list.zero_grad() + fairseq_enc_layers.zero_grad() fairseq_loss.backward(retain_graph=True) + grad_list = [] - for i in range(num_layers - 1, -1, -1): - curl = fairseq_bert_enc_layer_list[i] + for i in range(config.num_layers - 1, -1, -1): + curl = fairseq_enc_layers[i] cur_grads = copy_grad_from_paras( [ curl.fc2.weight, @@ -491,28 +146,30 @@ def baseline(): grad_list.extend(cur_grads) return grad_list - del custom_bert_enc_layer_list, fairseq_bert_enc_layer_list return custom, baseline -@kt.case(dtypes=[torch.half], rtol=1e-3, atol=1e-2, ntest=10) +@kt.case(rtol=1e-3, atol=1e-2) def test_decoder_layer_forward(): batch_size, enc_seq_len = kt.bs_sl() _, dec_seq_len = kt.bs_sl(batch_size) + hidden_size = config.hidden_size + print( - f"(batch_size, enc_seq_len, dec_seq_len): ({batch_size}, {enc_seq_len}, {dec_seq_len})" + f"(batch_size, enc_seq_len, dec_seq_len, hidden_size): " + f"({batch_size}, {enc_seq_len}, {dec_seq_len}, {hidden_size})" ) - hidden_states = kt.rand((batch_size, dec_seq_len, 1024)) - encoder_out = kt.rand((enc_seq_len, batch_size, 1024)) + hidden_states = kt.rand((batch_size, dec_seq_len, hidden_size)) + encoder_out = kt.rand((enc_seq_len, batch_size, hidden_size)) incremental_state = None encoder_padding_mask = kt.attn_mask(batch_size, enc_seq_len, dtype=torch.bool) self_attn_mask = kt.dec_self_attn_mask(dec_seq_len) * -1e8 def custom(): res = hidden_states.clone() - for i in range(num_layers): - res, _, _ = custom_dec_layer_list[i]( + for layer in custom_dec_layers: + res, _, _ = layer( res, encoder_out=encoder_out, encoder_padding_mask=encoder_padding_mask, @@ -524,8 +181,8 @@ def custom(): def baseline(): res = hidden_states.transpose(0, 1).clone() - for i in range(num_layers): - res, _, _ = fairseq_dec_layer_list[i]( + for layer in fairseq_dec_layers: + res, _, _ = layer( res, encoder_out=encoder_out, encoder_padding_mask=encoder_padding_mask, @@ -539,45 +196,45 @@ def baseline(): return custom, baseline -@kt.case(dtypes=[torch.half], rtol=1e-2, atol=1e-2, ntest=10) +@kt.case(rtol=1e-2, atol=1e-2) def test_decoder_layer_backward(): batch_size, enc_seq_len = kt.bs_sl() _, dec_seq_len = kt.bs_sl(batch_size) + hidden_size = config.hidden_size print( - f"(batch_size, enc_seq_len, dec_seq_len): ({batch_size}, {enc_seq_len}, {dec_seq_len})" + f"(batch_size, enc_seq_len, dec_seq_len, hidden_size): " + f"({batch_size}, {enc_seq_len}, {dec_seq_len}, {hidden_size})" ) - hidden_size = 1024 + shs = hidden_size * hidden_size hidden_states = kt.rand((batch_size, dec_seq_len, hidden_size)) encoder_out = kt.rand((enc_seq_len, batch_size, hidden_size)) incremental_state = None encoder_padding_mask = kt.attn_mask(batch_size, enc_seq_len, dtype=torch.bool) self_attn_mask = kt.dec_self_attn_mask(dec_seq_len) * -1e8 - loss_data = torch.randn(1, dtype=hidden_states.dtype).sum() def custom(): - for i in range(num_layers): - custom_dec_layer_list[i].zero_grad() - res = hidden_states.clone() - for i in range(num_layers): - res, _, _ = custom_dec_layer_list[i]( - res, - encoder_out=encoder_out, + custom_dec_layers.zero_grad() + cus_res = hidden_states.clone() + cus_encoder_out = encoder_out.clone() + for layer in custom_dec_layers: + cus_res, _, _ = layer( + cus_res, + encoder_out=cus_encoder_out, encoder_padding_mask=encoder_padding_mask, incremental_state=incremental_state, ) - custom_loss = (res / 1000).sum() - custom_loss.data.copy_(loss_data) + custom_loss = (cus_res / 1000).sum() custom_loss.backward() grad_list = [] - for i in range(num_layers - 1, -1, -1): + for i in range(config.num_layers - 1, -1, -1): """ 0 attn_qkvw, attn_qkvb, attn_ow, attn_ob, attn_nw, attn_nb, 6 encdec_attn_qw, encdec_attn_qb, encdec_attn_ow, encdec_attn_ob, encdec_attn_nw, encdec_attn_nb, 12 inter_w, inter_b, output_w, output_b, ffn_nw, ffn_nb 18 encdec_attn_kvw, encdec_attn_kvb, """ - grads = split_custom_layer_grad(custom_dec_layer_list[i]) + grads = split_custom_layer_grad(custom_dec_layers[i]) grad_list.extend( [ grads[14], @@ -618,121 +275,80 @@ def custom(): return grad_list def baseline(): - for i in range(num_layers): - fairseq_dec_layer_list[i].zero_grad() - res = hidden_states.transpose(0, 1).clone() - for i in range(num_layers): - res, _, _ = fairseq_dec_layer_list[i]( - res, - encoder_out=encoder_out, + fairseq_dec_layers.zero_grad() + base_res = hidden_states.transpose(0, 1).clone() + base_encoder_out = encoder_out.clone() + for layer in fairseq_dec_layers: + base_res, _, _ = layer( + base_res, + encoder_out=base_encoder_out, encoder_padding_mask=encoder_padding_mask, self_attn_mask=self_attn_mask, incremental_state=incremental_state, ) - fairseq_loss = (res / 1000).sum() - fairseq_loss.data.copy_(loss_data) + fairseq_loss = (base_res / 1000).sum() fairseq_loss.backward() + grad_list = [] - for i in range(num_layers - 1, -1, -1): - grad_list.extend( + for i in range(config.num_layers - 1, -1, -1): + curl = fairseq_dec_layers[i] + cur_grads = copy_grad_from_paras( [ - fairseq_dec_layer_list[i].fc2.weight.grad.contiguous().detach(), - fairseq_dec_layer_list[i].fc2.bias.grad.contiguous().detach(), - fairseq_dec_layer_list[i].fc1.weight.grad.contiguous().detach(), - fairseq_dec_layer_list[i].fc1.bias.grad.contiguous().detach(), - fairseq_dec_layer_list[i] - .final_layer_norm.weight.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .final_layer_norm.bias.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn.out_proj.weight.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn.out_proj.bias.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn.q_proj.weight.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn.q_proj.bias.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn.k_proj.weight.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn.k_proj.bias.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn.v_proj.weight.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn.v_proj.bias.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn_layer_norm.weight.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn_layer_norm.bias.grad.contiguous() - .detach(), - # encdec weights grad - fairseq_dec_layer_list[i] - .encodec_attn.q_proj.weight.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .encodec_attn.q_proj.bias.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .encodec_attn.out_proj.weight.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .encodec_attn.out_proj.bias.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .encodec_attn_layer_norm.weight.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .encodec_attn_layer_norm.bias.grad.contiguous() - .detach(), + curl.fc2.weight, + curl.fc2.bias, + curl.fc1.weight, + curl.fc1.bias, + curl.final_layer_norm.weight, + curl.final_layer_norm.bias, + curl.self_attn.out_proj.weight, + curl.self_attn.out_proj.bias, + curl.self_attn.q_proj.weight, + curl.self_attn.q_proj.bias, + curl.self_attn.k_proj.weight, + curl.self_attn.k_proj.bias, + curl.self_attn.v_proj.weight, + curl.self_attn.v_proj.bias, + curl.self_attn_layer_norm.weight, + curl.self_attn_layer_norm.bias, + curl.encodec_attn.q_proj.weight, + curl.encodec_attn.q_proj.bias, + curl.encodec_attn.out_proj.weight, + curl.encodec_attn.out_proj.bias, + curl.encodec_attn_layer_norm.weight, + curl.encodec_attn_layer_norm.bias, ] ) + grad_list.extend(cur_grads) if i == 0: - grad_list.extend( + cur_grads = copy_grad_from_paras( [ - # encdec kv grad - fairseq_dec_layer_list[i] - .encodec_attn.k_proj.weight.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .encodec_attn.k_proj.bias.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .encodec_attn.v_proj.weight.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .encodec_attn.v_proj.bias.grad.contiguous() - .detach(), + curl.encodec_attn.k_proj.weight, + curl.encodec_attn.k_proj.bias, + curl.encodec_attn.v_proj.weight, + curl.encodec_attn.v_proj.bias, ] ) + grad_list.extend(cur_grads) return grad_list return custom, baseline -@kt.case(dtypes=[torch.half], rtol=1e-3, atol=1e-2, ntest=10, nrepeat=1) +@kt.case(rtol=1e-3, atol=1e-2) def test_decoder_layer_forward_inference(): batch_size, enc_seq_len = kt.bs_sl() beam_size = random.randint(2, 5) + hidden_size = config.hidden_size print( - f"(batch_size, enc_seq_len, beam_size): ({batch_size}, {enc_seq_len}, {beam_size})" + f"(batch_size, enc_seq_len, beam_size, hidden_size): " + f"({batch_size}, {enc_seq_len}, {beam_size}, {hidden_size})" ) - ls_encoder_out = kt.rand((enc_seq_len, batch_size, 1024)) + ls_encoder_out = kt.rand((enc_seq_len, batch_size, hidden_size)) fs_encoder_out = ( ls_encoder_out.unsqueeze(2) .repeat(1, 1, beam_size, 1) - .reshape(enc_seq_len, -1, 1024) + .reshape(enc_seq_len, -1, hidden_size) ) ls_enc_mask = kt.attn_mask(batch_size, enc_seq_len, dtype=torch.bool) fs_enc_mask = ( @@ -742,7 +358,7 @@ def test_decoder_layer_forward_inference(): hidden_states_list = [] max_step = 10 for _ in range(max_step): - hidden_states = kt.rand((batch_size * beam_size, 1, 1024)) + hidden_states = kt.rand((batch_size * beam_size, 1, hidden_size)) hidden_states_list.append(hidden_states) def custom(): @@ -750,8 +366,8 @@ def custom(): res_list = [] for i in range(max_step): res = hidden_states_list[i].clone() - for i in range(num_layers): - res, _, _ = custom_dec_layer_list[i]( + for layer in custom_dec_layers: + res, _, _ = layer( res, encoder_out=ls_encoder_out, encoder_padding_mask=ls_enc_mask, @@ -765,8 +381,8 @@ def baseline(): res_list = [] for i in range(max_step): res = hidden_states_list[i].clone().transpose(0, 1) - for i in range(num_layers): - res, _, _ = fairseq_dec_layer_list[i]( + for layer in fairseq_dec_layers: + res, _, _ = layer( res, encoder_out=fs_encoder_out, encoder_padding_mask=fs_enc_mask, @@ -778,14 +394,13 @@ def baseline(): return custom, baseline -@kt.case(ntest=10) +@kt.case(rtol=1e-3, atol=1e-3) def test_embedding_layer_forward(): batch_size, seq_len = kt.bs_sl() print(f"(batch_size, seq_len): ({batch_size}, {seq_len})") padding_mask = kt.attn_mask(batch_size, seq_len, dtype=torch.int) # TODO: can not generate PAD in the middle of the sentences. - config = ls_emb_config_fp16 input = kt.randint(config.padding_idx + 1, config.vocab_size, (batch_size, seq_len)) pad_left = random.choice([True, False]) if pad_left: @@ -793,21 +408,14 @@ def test_embedding_layer_forward(): else: input = input * (1 - padding_mask) + config.padding_idx * padding_mask - if kt.dtype == torch.float: - custom_layer = custom_emb_layer_fp32 - fs_layer = fs_emb_layer_fp32 - else: - custom_layer = custom_emb_layer_fp16 - fs_layer = fs_emb_layer_fp16 - def custom(): - res = custom_layer(input) + res = custom_emb_layer(input) return [ res.contiguous().detach(), ] def baseline(): - x = fs_layer(input) + x = fairseq_emb_layer(input) return [ x.contiguous().detach(), ] @@ -815,13 +423,12 @@ def baseline(): return custom, baseline -@kt.case(ntest=10) +@kt.case(rtol=1e-3, atol=1e-3) def test_embedding_layer_backward(): batch_size, seq_len = kt.bs_sl() print(f"(batch_size, seq_len): ({batch_size}, {seq_len})") padding_mask = kt.attn_mask(batch_size, seq_len, dtype=torch.int) - config = ls_emb_config_fp16 input = kt.randint(config.padding_idx + 1, config.vocab_size, (batch_size, seq_len)) pad_left = random.choice([True, False]) if pad_left: @@ -829,136 +436,93 @@ def test_embedding_layer_backward(): else: input = input * (1 - padding_mask) + config.padding_idx * padding_mask - if kt.dtype == torch.float: - custom_layer = custom_emb_layer_fp32 - fs_layer = fs_emb_layer_fp32 - else: - custom_layer = custom_emb_layer_fp16 - fs_layer = fs_emb_layer_fp16 + custom_emb_layer.zero_grad() + custom_input = input.clone() + res = custom_emb_layer(custom_input) + custom_loss = (res / 1000).sum() - loss_data = torch.randn(1, dtype=kt.dtype).sum() + fairseq_emb_layer.zero_grad() + fs_input = input.clone() + res = fairseq_emb_layer(fs_input) + fs_loss = (res / 1000).sum() def custom(): - custom_layer.zero_grad() - custom_input = input.clone() - res = custom_layer(custom_input) - custom_loss = (res / 1000).sum() - custom_loss.data.copy_(loss_data) - custom_loss.backward() + custom_emb_layer.zero_grad() + custom_loss.backward(retain_graph=True) + return [ - custom_layer.embeddings.grad.contiguous().detach(), + custom_emb_layer.embeddings.grad.contiguous().detach(), ] def baseline(): - fs_layer.zero_grad() - fs_input = input.clone() - res = fs_layer(fs_input) - fs_loss = (res / 1000).sum() - fs_loss.data.copy_(loss_data) - fs_loss.backward() + fairseq_emb_layer.zero_grad() + fs_loss.backward(retain_graph=True) + return [ - fs_layer.embeddings.weight.grad.contiguous().detach(), + fairseq_emb_layer.embeddings.weight.grad.contiguous().detach(), ] return custom, baseline -def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): - if target.dim() == lprobs.dim() - 1: - target = target.unsqueeze(-1) - nll_loss = -lprobs.gather(dim=-1, index=target) - smooth_loss = -lprobs.sum(dim=-1, keepdim=True) - if ignore_index is not None: - pad_mask = target.eq(ignore_index) - nll_loss.masked_fill_(pad_mask, 0.0) - smooth_loss.masked_fill_(pad_mask, 0.0) - else: - nll_loss = nll_loss.squeeze(-1) - smooth_loss = smooth_loss.squeeze(-1) - if reduce: - nll_loss = nll_loss.sum() - smooth_loss = smooth_loss.sum() - eps_i = epsilon / (lprobs.size(-1) - 1) - loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss - return loss, nll_loss - - -@kt.case(ntest=10) +@kt.case() def test_cross_entropy_layer_forward(): batch_size, seq_len = kt.bs_sl() - vocab_size = random.randint(30413, 40519) + vocab_size = random.randint(1000, 42000) print(f"(batch_size, seq_len, vocab_size): ({batch_size}, {seq_len}, {vocab_size})") inputs = kt.rand((batch_size, seq_len, vocab_size)) - targets = kt.randint( - ce_config_fp16.padding_idx - 1, vocab_size, (batch_size, seq_len) - ) + targets = kt.randint(0, vocab_size, (batch_size, seq_len)) targets_32 = targets.to(torch.int32) - if kt.dtype == torch.float: - custom_layer = custom_cross_entropy_layer_fp32 - else: - custom_layer = custom_cross_entropy_layer_fp16 - def custom(): - res, cus_nll_loss = custom_layer(inputs, targets_32) + loss, cus_nll_loss = custom_ce_layer(inputs, targets_32) + loss = loss.to(inputs) + cus_nll_loss = cus_nll_loss.to(inputs) return [ - res.contiguous().detach(), + loss.contiguous().detach(), cus_nll_loss.contiguous().detach(), ] def baseline(): - - x = torch.nn.functional.log_softmax(inputs, dim=-1, dtype=torch.float32) - x, base_nll_loss = label_smoothed_nll_loss( - x, targets, ce_config_fp16.epsilon, ignore_index=ce_config_fp16.padding_idx - ) - x = x.to(inputs) - base_nll_loss = base_nll_loss.to(inputs) + loss, base_nll_loss = fairseq_ce_layer(inputs, targets) return [ - x.contiguous().detach(), + loss.contiguous().detach(), base_nll_loss.contiguous().detach(), ] return custom, baseline -@kt.case(ntest=10) +@kt.case() def test_cross_entropy_layer_backward(): batch_size, seq_len = kt.bs_sl() - vocab_size = random.randint(30413, 40519) + vocab_size = random.randint(1000, 42000) print(f"(batch_size, seq_len, vocab_size): ({batch_size}, {seq_len}, {vocab_size})") base_inputs = kt.rand((batch_size, seq_len, vocab_size)).requires_grad_() cus_inputs = base_inputs.clone().detach().requires_grad_() - targets = kt.randint( - ce_config_fp16.padding_idx - 1, vocab_size, (batch_size, seq_len) - ) + targets = kt.randint(0, vocab_size, (batch_size, seq_len)) targets_32 = targets.to(torch.int32) - if kt.dtype == torch.float: - custom_layer = custom_cross_entropy_layer_fp32 - else: - custom_layer = custom_cross_entropy_layer_fp16 - cus_res = custom_layer(cus_inputs, targets_32)[0].to(kt.dtype) - x = torch.nn.functional.log_softmax(base_inputs, dim=-1, dtype=torch.float32) - base_res, _ = label_smoothed_nll_loss( - x, targets, ce_config_fp16.epsilon, ignore_index=ce_config_fp16.padding_idx - ) - base_res = base_res.to(kt.dtype) + custom_ce_layer.zero_grad() + custom_loss, _ = custom_ce_layer(cus_inputs, targets_32) + + fairseq_ce_layer.zero_grad() + base_loss, _ = fairseq_ce_layer(base_inputs, targets) def custom(): - if cus_inputs.grad is not None: - cus_inputs.grad.zero_() - cus_res.backward(retain_graph=True) + custom_ce_layer.zero_grad() + custom_loss.backward(retain_graph=True) + return [ cus_inputs.grad.contiguous().detach(), ] def baseline(): - if base_inputs.grad is not None: - base_inputs.grad.zero_() - base_res.backward(retain_graph=True) + fairseq_ce_layer.zero_grad() + base_loss.backward(retain_graph=True) + return [ base_inputs.grad.contiguous().detach(), ] @@ -966,14 +530,12 @@ def baseline(): return custom, baseline -if __name__ == "__main__": - kt.init(device="cuda:0", nhead=16) +def main(epoch): + print(">>>>>>>>>>>>>>>>>>>>>>Test epoch: {}>>>>>>>>>>>>>>>>>>>>>>".format(epoch)) kt.run( [ "test_encoder_layer_forward", "test_encoder_layer_backward", - "test_bert_encoder_layer_forward", - "test_bert_encoder_layer_backward", "test_decoder_layer_forward", "test_decoder_layer_backward", "test_decoder_layer_forward_inference", @@ -983,3 +545,11 @@ def baseline(): "test_cross_entropy_layer_backward", ] ) + + +if __name__ == "__main__": + ctx = mp.get_context("spawn") + for i in range(50): + p = ctx.Process(target=main, args=(i,)) + p.start() + p.join() diff --git a/tests/util.py b/tests/util.py index 17d1fc6d..8ee3645d 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,25 +1,51 @@ import random import time from collections import OrderedDict +from dataclasses import dataclass +from copy import deepcopy import numpy as np import torch -def cast_fp32_tensor(tlist): - return [ele.to(torch.float32) for ele in tlist] - - -def is_nan(x): - return x.isnan().any().item() - - -def is_inf(x): - return x.isinf().any().item() - - -max_batch_tokens = 9216 -max_seq_len = 256 +@dataclass +class Config: + max_batch_tokens: int + max_seq_len: int + vocab_size: int + padding_idx: int + hidden_size: int + intermediate_size: int + nhead: int + attn_prob_dropout_ratio: float + activation_dropout_ratio: float + hidden_dropout_ratio: float + pre_layer_norm: bool + fp16: bool + local_rank: int + activation_fn: str + num_layers: int + label_smooth: float + + +default_config = Config( + max_batch_tokens=9216, + max_seq_len=256, + vocab_size=32000, + padding_idx=0, + hidden_size=1024, + intermediate_size=1024 * 4, + nhead=16, + attn_prob_dropout_ratio=0.0, + activation_dropout_ratio=0.0, + hidden_dropout_ratio=0.0, + pre_layer_norm=True, + fp16=True, + local_rank=0, + activation_fn="relu", + num_layers=2, + label_smooth=0.1, +) class TestDecorator(object): @@ -27,31 +53,47 @@ def __init__(self): self.all_case = OrderedDict() self.dtypes = [torch.float, torch.half] self.dtype = None - self.max_batch_tokens = max_batch_tokens - self.max_seq_len = max_seq_len - - def init(self, device, nhead): - # device: str. e.g. "cuda:0" - self.device = torch.device(device) - assert nhead % 4 == 0 - self.nhead = nhead + self.device = torch.device("cuda:{}".format(default_config.local_rank)) + + def generate_config(self, use_default=False): + if use_default: + return deepcopy(default_config) + config = deepcopy(default_config) + config.vocab_size = random.randint(1000, 42000) + hidden_size, nhead = self.h_nh + config.hidden_size = hidden_size + config.intermediate_size = hidden_size * 4 + config.nhead = nhead + config.pre_layer_norm = random.choice([True, False]) + config.activation_fn = self.act_fn + config.num_layers = random.randint(1, 2) + return config def bs_sl(self, batch_size=None): if batch_size is None: - seq_len = random.randint(1, self.max_seq_len) - max_batch_size = self.max_batch_tokens // seq_len + seq_len = random.randint(1, default_config.max_seq_len) + max_batch_size = default_config.max_batch_tokens // seq_len batch_size = random.randint(1, max_batch_size) else: - max_seq_len = min(self.max_batch_tokens // batch_size, self.max_seq_len) + max_seq_len = min( + default_config.max_batch_tokens // batch_size, + default_config.max_seq_len, + ) seq_len = random.randint(1, max_seq_len) return batch_size, seq_len @property - def hidden_dim(self): - upbound = 1024 // self.nhead + def h_nh(self): + nhead = random.choice([8, 12, 16]) + upbound = 1024 // nhead head_dim = random.choice(range(1, upbound + 1)) - hs = head_dim * self.nhead * self.io_factor - return hs + hs = head_dim * nhead * self.io_factor + return hs, nhead + + @property + def act_fn(self): + act = random.choice(["relu", "gelu"]) + return act @property def io_factor(self): @@ -103,7 +145,7 @@ def dec_self_attn_mask(self, seq_len, dtype=None): mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) return mask.to(self.device, dtype=dtype) - def case(self, dtypes=list(), ntest=5, nrepeat=5, rtol=1e-5, atol=1e-5): + def case(self, dtypes=list(), ntest=10, nrepeat=5, rtol=1e-5, atol=1e-5): if not dtypes: dtypes = self.dtypes @@ -191,6 +233,18 @@ def run(self, case_names=None): self.test(custom, baseline, nrepeat, rtol, atol) +def cast_fp32_tensor(tlist): + return [ele.to(torch.float32) for ele in tlist] + + +def is_nan(x): + return x.isnan().any().item() + + +def is_inf(x): + return x.isinf().any().item() + + def flat_dim(idxs, dims): assert len(idxs) == len(dims) or len(idxs) == len(dims) + 1 base = 1 @@ -217,67 +271,6 @@ def expand_dim(idx, dims): return res[::-1] -def get_fairseq_enc_params(fairseq_layer): - initial_weights = [] - initial_biases = [] - - initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.self_attn.out_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.out_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.self_attn_layer_norm.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn_layer_norm.bias.detach().clone()) - - initial_weights.append(fairseq_layer.fc1.weight.detach().clone()) - initial_biases.append(fairseq_layer.fc1.bias.detach().clone()) - initial_weights.append(fairseq_layer.fc2.weight.detach().clone()) - initial_biases.append(fairseq_layer.fc2.bias.detach().clone()) - initial_weights.append(fairseq_layer.final_layer_norm.weight.detach().clone()) - initial_biases.append(fairseq_layer.final_layer_norm.bias.detach().clone()) - return initial_weights, initial_biases - - -def get_fairseq_dec_params(fairseq_layer): - initial_weights = [] - initial_biases = [] - - initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.self_attn.out_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.out_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.self_attn_layer_norm.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn_layer_norm.bias.detach().clone()) - - initial_weights.append(fairseq_layer.encodec_attn.q_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.encodec_attn.q_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.encodec_attn.k_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.encodec_attn.k_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.encodec_attn.v_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.encodec_attn.v_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.encodec_attn.out_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.encodec_attn.out_proj.bias.detach().clone()) - initial_weights.append( - fairseq_layer.encodec_attn_layer_norm.weight.detach().clone() - ) - initial_biases.append(fairseq_layer.encodec_attn_layer_norm.bias.detach().clone()) - - initial_weights.append(fairseq_layer.fc1.weight.detach().clone()) - initial_biases.append(fairseq_layer.fc1.bias.detach().clone()) - initial_weights.append(fairseq_layer.fc2.weight.detach().clone()) - initial_biases.append(fairseq_layer.fc2.bias.detach().clone()) - initial_weights.append(fairseq_layer.final_layer_norm.weight.detach().clone()) - initial_biases.append(fairseq_layer.final_layer_norm.bias.detach().clone()) - return initial_weights, initial_biases - - def split_custom_layer_grad(layer): res = [] for i in range(1, len(layer.para_offset)):