Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] support for DeepseekV2 #129

Open
tmm1 opened this issue Aug 27, 2024 · 4 comments
Open

[feat] support for DeepseekV2 #129

tmm1 opened this issue Aug 27, 2024 · 4 comments
Labels
feature help wanted Extra attention is needed huggingface

Comments

@tmm1
Copy link
Contributor

tmm1 commented Aug 27, 2024

🚀 The feature, motivation and pitch

It would be nice to support DeepseekV2 models. Unfortunately the modeling code is not yet accepted into transformers, and requires trust_remote_code=True

I'm monkey-patching myself for now, and wanted to leave some notes that may be helpful when support is added officially down the road.

from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM

with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-Coder-V2-Lite-Base", trust_remote_code=True)
    modeling_mod = sys.modules[model.__class__.__module__]

modeling_mod.apply_rotary_pos_emb = liger_rotary_pos_emb
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
modeling_mod.DeepseekV2MLP = LigerSwiGLUMLP
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward

One initial issue when swapping in swiglu:

  File "/mnt/ML/huggingface/modules/transformers_modules/deepseek-ai/DeepSeek-Coder-V2-Lite-Base/ea9b066cee82f82906fdd58898cb3788b1c5d770/modeling_deepseek.py", line 555, in <listcomp>
    DeepseekV2MLP(
TypeError: LigerSwiGLUMLP.__init__() got an unexpected keyword argument 'intermediate_size'
@tmm1
Copy link
Contributor Author

tmm1 commented Aug 27, 2024

modeling_mod.apply_rotary_pos_emb = liger_rotary_pos_emb

this is causing loss calculations to be wildly different for some reason

i will investigate further


TypeError: LigerSwiGLUMLP.init() got an unexpected keyword argument 'intermediate_size'

i was able to fix this issue as follows:

modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward

@tmm1
Copy link
Contributor Author

tmm1 commented Aug 27, 2024

this is causing loss calculations to be wildly different for some reason

the rope method seems to be modified in deepseek v2?

llama:

    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

deepseekv2:

    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)

    b, h, s, d = q.shape
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    b, h, s, d = k.shape
    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

@ByronHsu ByronHsu added help wanted Extra attention is needed huggingface feature labels Aug 28, 2024
@xinyubai1209
Copy link

deepseek v2 use MLA(Multi-head Latent Attention) to reduce the kv cache.

@qingquansong
Copy link
Collaborator

qingquansong commented Aug 28, 2024

Yeah, deepseekv2 one is quite interesting as it used decoupled RoPE.

For the MLA part, since it mainly target on inference case speed up with absorbed low-rank projection matrices into the original linear matrices. Feel free to first try implementing the layers apart from that and can gradually improve with separate prs. Thanks for the interesting feature request and rapid kick off~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature help wanted Extra attention is needed huggingface
Projects
None yet
Development

No branches or pull requests

4 participants