Skip to content

small tidyups #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 116 additions & 39 deletions jaxformer/hf/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,20 @@
from typing import Tuple

import numpy as np

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_codegen import CodeGenConfig

from .configuration_codegen import CodeGenConfig

logger = logging.get_logger(__name__)

Expand All @@ -43,7 +44,11 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
# original
# sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float()
# QHD fix onnx error by https://github.com/microsoft/onnxruntime/discussions/10121#discussioncomment-1987845
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len).float(), inv_freq).to(x.device).float()
sinusoid_inp = (
torch.einsum("i , j -> i j", torch.arange(seq_len).float(), inv_freq)
.to(x.device)
.float()
)
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)


Expand All @@ -55,7 +60,12 @@ def rotate_every_two(x):


def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos)
sin, cos = map(
lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(
2, 3
),
sincos,
)
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
return (x * cos) + (rotate_every_two(x) * sin)

Expand All @@ -67,9 +77,9 @@ def __init__(self, config):
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
torch.tril(
torch.ones((max_positions, max_positions), dtype=torch.bool)
).view(1, 1, max_positions, max_positions),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))

Expand All @@ -83,7 +93,9 @@ def __init__(self, config):
raise ValueError(
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.scale_attn = torch.sqrt(
torch.tensor(self.head_dim, dtype=torch.float32)
).to(torch.get_default_dtype())
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)

self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
Expand All @@ -92,8 +104,8 @@ def __init__(self, config):
self.rotary_dim = config.rotary_dim

def _split_heads(self, x, n_head, dim_head, mp_num):
reshaped = x.reshape(x.shape[:-1] + (n_head//mp_num, dim_head))
reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + reshaped.shape[-1:])
reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
return reshaped

def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
Expand All @@ -105,7 +117,9 @@ def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
raise ValueError(
f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}"
)
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)

Expand All @@ -120,7 +134,9 @@ def _attn(

# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
causal_mask = self.bias[
:, :, key_length - query_length : key_length, :key_length
].to(torch.bool)

# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
Expand All @@ -129,7 +145,9 @@ def _attn(
attn_weights = torch.matmul(query, key.transpose(-1, -2))

attn_weights = attn_weights / self.scale_attn
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
attn_weights = torch.where(
causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)
)

if attention_mask is not None:
# Apply the attention mask
Expand Down Expand Up @@ -164,10 +182,16 @@ def forward(

local_dim = self.head_dim * self.num_attention_heads // mp_num
query, value, key = torch.split(qkv_split, local_dim, dim=-1)
query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
query = self._split_heads(
query, self.num_attention_heads, self.head_dim, mp_num=mp_num
)
key = self._split_heads(
key, self.num_attention_heads, self.head_dim, mp_num=mp_num
)

value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
value = self._split_heads(
value, self.num_attention_heads, self.head_dim, mp_num=mp_num
)
value = value.permute(0, 2, 1, 3)

seq_len = key.shape[1]
Expand Down Expand Up @@ -210,9 +234,13 @@ def forward(
present = None

# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output, attn_weights = self._attn(
query, key, value, attention_mask, head_mask
)

attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self._merge_heads(
attn_output, self.num_attention_heads, self.head_dim
)

attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
Expand All @@ -225,7 +253,9 @@ def forward(


class CodeGenMLP(nn.Module):
def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
def __init__(
self, intermediate_size, config
): # in MLP: intermediate_size= 4 * embed_dim
super().__init__()
embed_dim = config.n_embd

Expand Down Expand Up @@ -324,22 +354,29 @@ def __init__(self, config):
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([CodeGenBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
self.rotary_dim = min(
config.rotary_dim, config.n_ctx // config.num_attention_heads
)
self.init_weights()

# Model parallel
self.model_parallel = False
self.device_map = None


def parallelize(self, device_map=None):
# Check validity of device_map
self.device_map = (
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
get_device_map(len(self.h), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.h))
self.model_parallel = True
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
self.first_device = (
"cpu"
if "cpu" in self.device_map.keys()
else "cuda:" + str(min(self.device_map.keys()))
)
self.last_device = "cuda:" + str(max(self.device_map.keys()))
self.wte = self.wte.to(self.first_device)
# Load onto devices
Expand All @@ -350,7 +387,6 @@ def parallelize(self, device_map=None):
# ln_f to last
self.ln_f = self.ln_f.to(self.last_device)


def deparallelize(self):
self.model_parallel = False
self.device_map = None
Expand Down Expand Up @@ -382,15 +418,25 @@ def forward(
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
Expand All @@ -416,7 +462,12 @@ def forward(
past_length = past_key_values[0][0].size(-2)

if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])

# Attention mask.
Expand Down Expand Up @@ -467,7 +518,9 @@ def forward(
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
layer_past = tuple(
past_state.to(hidden_states.device) for past_state in layer_past
)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
Expand Down Expand Up @@ -514,7 +567,9 @@ def custom_forward(*inputs):
presents = presents + (outputs[1],)

if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
all_self_attentions = all_self_attentions + (
outputs[2 if use_cache else 1],
)

# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
Expand All @@ -530,7 +585,16 @@ def custom_forward(*inputs):
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
Expand All @@ -541,7 +605,11 @@ def custom_forward(*inputs):


class CodeGenForCausalLM(CodeGenPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
_keys_to_ignore_on_load_missing = [
r"h\.\d+\.attn\.masked_bias",
r"h\.\d+\.attn\.bias",
r"lm_head\.weight",
]

def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -626,7 +694,9 @@ def forward(
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

transformer_outputs = self.transformer(
input_ids,
Expand Down Expand Up @@ -660,7 +730,9 @@ def forward(
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)

loss = loss.to(hidden_states.dtype)

Expand All @@ -677,13 +749,18 @@ def forward(
)

@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
def _reorder_cache(
past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
)
for layer_past in past
)
Loading