Skip to content

Commit 2b33ba9

Browse files
CISCpwilkin
authored andcommitted
model : add BailingMoeV2 support (ggml-org#16063)
* add BailingMoeV2 support * update llm types * undo * undo * update llm types * add model collection link * update * almost working * correct group selection and rename n_group_exp * avoid large top_k and use argmax instead for now if we had something like argmax2 that would be equivalent, but this works fine until then * poke * skip group selection when there are no tokens * fix 1T conversion * hopefully fixed expert group selection third time's the charm? * make expert group selection generally available The new LLaDA2Moe model uses this method too, make it generally available regardless of architecture. * allow n_expert_groups to be 1 (Kimi K2) * address review suggestions
1 parent e2aad4c commit 2b33ba9

15 files changed

+521
-10
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
138138
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
139139
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
140140
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
141+
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
141142

142143
#### Multimodal
143144

convert_hf_to_gguf.py

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -892,8 +892,8 @@ def get_vocab_base_pre(self, tokenizer) -> str:
892892
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
893893
res = "mellum"
894894
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
895-
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
896-
res = "llada-moe"
895+
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
896+
res = "bailingmoe2"
897897
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
898898
# ref: https://huggingface.co/ibm-granite/granite-docling-258M
899899
res = "granite-docling"
@@ -8085,6 +8085,103 @@ def prepare_tensors(self):
80858085
raise ValueError(f"Unprocessed experts: {experts}")
80868086

80878087

8088+
@ModelBase.register("BailingMoeV2ForCausalLM")
8089+
class BailingMoeV2Model(TextModel):
8090+
model_arch = gguf.MODEL_ARCH.BAILINGMOE2
8091+
8092+
def __init__(self, *args, **kwargs):
8093+
super().__init__(*args, **kwargs)
8094+
if nextn_layers := self.hparams.get("num_nextn_predict_layers", 0):
8095+
self.block_count = self.hparams["num_hidden_layers"] + nextn_layers
8096+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
8097+
8098+
def set_vocab(self):
8099+
self._set_vocab_gpt2()
8100+
8101+
def set_gguf_parameters(self):
8102+
super().set_gguf_parameters()
8103+
hparams = self.hparams
8104+
if (rope_dim := hparams.get("head_dim")) is None:
8105+
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
8106+
8107+
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
8108+
rope_scaling = self.hparams.get("rope_scaling") or {}
8109+
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
8110+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
8111+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
8112+
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
8113+
else:
8114+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
8115+
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
8116+
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
8117+
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
8118+
self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"]))
8119+
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
8120+
self.gguf_writer.add_expert_count(hparams["num_experts"])
8121+
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
8122+
self.gguf_writer.add_expert_group_count(hparams["n_group"])
8123+
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
8124+
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
8125+
8126+
if hparams["score_function"] == "sigmoid":
8127+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
8128+
elif hparams["score_function"] == "softmax":
8129+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
8130+
else:
8131+
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")
8132+
8133+
if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
8134+
self.gguf_writer.add_nextn_predict_layers(nextn_layers)
8135+
8136+
_experts: list[dict[str, Tensor]] | None = None
8137+
8138+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8139+
if "mlp.experts" in name:
8140+
n_experts = self.hparams["num_experts"]
8141+
assert bid is not None
8142+
8143+
tensors: list[tuple[str, Tensor]] = []
8144+
8145+
if self._experts is None:
8146+
self._experts = [{} for _ in range(self.block_count)]
8147+
8148+
self._experts[bid][name] = data_torch
8149+
8150+
if len(self._experts[bid]) >= n_experts * 3:
8151+
# merge the experts into a single 3d tensor
8152+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
8153+
datas: list[Tensor] = []
8154+
8155+
for xid in range(n_experts):
8156+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
8157+
datas.append(self._experts[bid][ename])
8158+
del self._experts[bid][ename]
8159+
8160+
data_torch = torch.stack(datas, dim=0)
8161+
8162+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
8163+
8164+
new_name = self.map_tensor_name(merged_name)
8165+
8166+
tensors.append((new_name, data_torch))
8167+
8168+
return tensors
8169+
8170+
if name.endswith(".expert_bias"):
8171+
name = name.replace(".expert_bias", ".expert_bias.bias")
8172+
8173+
return [(self.map_tensor_name(name), data_torch)]
8174+
8175+
def prepare_tensors(self):
8176+
super().prepare_tensors()
8177+
8178+
if self._experts is not None:
8179+
# flatten `list[dict[str, Tensor]]` into `list[str]`
8180+
experts = [k for d in self._experts for k in d.keys()]
8181+
if len(experts) > 0:
8182+
raise ValueError(f"Unprocessed experts: {experts}")
8183+
8184+
80888185
@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
80898186
class GroveMoeModel(TextModel):
80908187
model_arch = gguf.MODEL_ARCH.GROVEMOE

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class TOKENIZER_TYPE(IntEnum):
139139
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
140140
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
141141
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
142-
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
142+
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
143143
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
144144
]
145145

gguf-py/gguf/constants.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class LLM:
102102
EXPERT_COUNT = "{arch}.expert_count"
103103
EXPERT_USED_COUNT = "{arch}.expert_used_count"
104104
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
105+
EXPERT_GROUP_COUNT = "{arch}.expert_group_count"
106+
EXPERT_GROUP_USED_COUNT = "{arch}.expert_group_used_count"
105107
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
106108
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
107109
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
@@ -401,6 +403,7 @@ class MODEL_ARCH(IntEnum):
401403
WAVTOKENIZER_DEC = auto()
402404
PLM = auto()
403405
BAILINGMOE = auto()
406+
BAILINGMOE2 = auto()
404407
DOTS1 = auto()
405408
ARCEE = auto()
406409
ERNIE4_5 = auto()
@@ -748,6 +751,7 @@ class MODEL_TENSOR(IntEnum):
748751
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
749752
MODEL_ARCH.PLM: "plm",
750753
MODEL_ARCH.BAILINGMOE: "bailingmoe",
754+
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
751755
MODEL_ARCH.DOTS1: "dots1",
752756
MODEL_ARCH.ARCEE: "arcee",
753757
MODEL_ARCH.ERNIE4_5: "ernie4_5",
@@ -2568,6 +2572,35 @@ class MODEL_TENSOR(IntEnum):
25682572
MODEL_TENSOR.FFN_DOWN_SHEXP,
25692573
MODEL_TENSOR.FFN_UP_SHEXP,
25702574
],
2575+
MODEL_ARCH.BAILINGMOE2: [
2576+
MODEL_TENSOR.TOKEN_EMBD,
2577+
MODEL_TENSOR.OUTPUT_NORM,
2578+
MODEL_TENSOR.OUTPUT,
2579+
MODEL_TENSOR.ATTN_NORM,
2580+
MODEL_TENSOR.ATTN_Q_NORM,
2581+
MODEL_TENSOR.ATTN_K_NORM,
2582+
MODEL_TENSOR.ATTN_QKV,
2583+
MODEL_TENSOR.ATTN_OUT,
2584+
MODEL_TENSOR.FFN_GATE_INP,
2585+
MODEL_TENSOR.FFN_EXP_PROBS_B,
2586+
MODEL_TENSOR.FFN_NORM,
2587+
MODEL_TENSOR.FFN_GATE,
2588+
MODEL_TENSOR.FFN_DOWN,
2589+
MODEL_TENSOR.FFN_UP,
2590+
MODEL_TENSOR.FFN_GATE_EXP,
2591+
MODEL_TENSOR.FFN_DOWN_EXP,
2592+
MODEL_TENSOR.FFN_UP_EXP,
2593+
MODEL_TENSOR.FFN_GATE_SHEXP,
2594+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2595+
MODEL_TENSOR.FFN_UP_SHEXP,
2596+
MODEL_TENSOR.NEXTN_EH_PROJ,
2597+
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
2598+
MODEL_TENSOR.NEXTN_ENORM,
2599+
MODEL_TENSOR.NEXTN_HNORM,
2600+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
2601+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
2602+
MODEL_TENSOR.LAYER_OUT_NORM,
2603+
],
25712604
MODEL_ARCH.DOTS1: [
25722605
MODEL_TENSOR.TOKEN_EMBD,
25732606
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/gguf_writer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,12 @@ def add_expert_used_count(self, count: int) -> None:
755755
def add_expert_shared_count(self, count: int) -> None:
756756
self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
757757

758+
def add_expert_group_count(self, count: int) -> None:
759+
self.add_uint32(Keys.LLM.EXPERT_GROUP_COUNT.format(arch=self.arch), count)
760+
761+
def add_expert_group_used_count(self, count: int) -> None:
762+
self.add_uint32(Keys.LLM.EXPERT_GROUP_USED_COUNT.format(arch=self.arch), count)
763+
758764
def add_expert_weights_scale(self, value: float) -> None:
759765
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
760766

gguf-py/gguf/tensor_mapping.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ class TensorNameMap:
174174
"h.{bid}.self_attention.query_key_value", # bloom
175175
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
176176
"model.layers.{bid}.self_attn.query_key_value", # persimmon
177+
"model.layers.{bid}.attention.query_key_value", # bailingmoe2
177178
"h.{bid}.attn.c_attn", # gpt2
178179
"transformer.h.{bid}.mixer.Wqkv", # phi2
179180
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
@@ -260,6 +261,7 @@ class TensorNameMap:
260261
"transformer.h.{bid}.attn.out_proj", # gpt-j
261262
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
262263
"model.layers.{bid}.self_attn.dense", # persimmon
264+
"model.layers.{bid}.attention.dense", # bailingmoe2
263265
"h.{bid}.attn.c_proj", # gpt2
264266
"transformer.h.{bid}.mixer.out_proj", # phi2
265267
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
@@ -373,6 +375,7 @@ class TensorNameMap:
373375
MODEL_TENSOR.FFN_EXP_PROBS_B: (
374376
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
375377
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
378+
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
376379
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
377380
),
378381

@@ -549,6 +552,7 @@ class TensorNameMap:
549552
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
550553
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
551554
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
555+
"model.layers.{bid}.attention.query_layernorm", # bailingmoe2
552556
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
553557
"layers.{bid}.self_attn.q_norm", # embeddinggemma
554558
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
@@ -563,6 +567,7 @@ class TensorNameMap:
563567
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
564568
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
565569
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
570+
"model.layers.{bid}.attention.key_layernorm", # bailingmoe2
566571
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
567572
"layers.{bid}.self_attn.k_norm", # embeddinggemma
568573
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
@@ -584,6 +589,7 @@ class TensorNameMap:
584589
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
585590
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
586591
"encoder.layer.{bid}.layer_norm_2", # jina-v2-code
592+
"model.layers.{bid}.final_layernorm", # bailingmoe2
587593
),
588594

589595
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (

src/llama-arch.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8686
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
8787
{ LLM_ARCH_PLM, "plm" },
8888
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
89+
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
8990
{ LLM_ARCH_DOTS1, "dots1" },
9091
{ LLM_ARCH_ARCEE, "arcee" },
9192
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
@@ -136,6 +137,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
136137
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
137138
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
138139
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
140+
{ LLM_KV_EXPERT_GROUP_COUNT, "%s.expert_group_count" },
141+
{ LLM_KV_EXPERT_GROUP_USED_COUNT, "%s.expert_group_used_count" },
139142
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
140143
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
141144
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
@@ -1979,6 +1982,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
19791982
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
19801983
},
19811984
},
1985+
{
1986+
LLM_ARCH_BAILINGMOE2,
1987+
{
1988+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1989+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1990+
{ LLM_TENSOR_OUTPUT, "output" },
1991+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1992+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1993+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1994+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
1995+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1996+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1997+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1998+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1999+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
2000+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
2001+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2002+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
2003+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
2004+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
2005+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
2006+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
2007+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
2008+
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
2009+
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
2010+
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
2011+
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" },
2012+
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
2013+
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
2014+
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
2015+
},
2016+
},
19822017
{
19832018
LLM_ARCH_DOTS1,
19842019
{

src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ enum llm_arch {
9090
LLM_ARCH_WAVTOKENIZER_DEC,
9191
LLM_ARCH_PLM,
9292
LLM_ARCH_BAILINGMOE,
93+
LLM_ARCH_BAILINGMOE2,
9394
LLM_ARCH_DOTS1,
9495
LLM_ARCH_ARCEE,
9596
LLM_ARCH_ERNIE4_5,
@@ -140,6 +141,8 @@ enum llm_kv {
140141
LLM_KV_EXPERT_COUNT,
141142
LLM_KV_EXPERT_USED_COUNT,
142143
LLM_KV_EXPERT_SHARED_COUNT,
144+
LLM_KV_EXPERT_GROUP_COUNT,
145+
LLM_KV_EXPERT_GROUP_USED_COUNT,
143146
LLM_KV_EXPERT_WEIGHTS_SCALE,
144147
LLM_KV_EXPERT_WEIGHTS_NORM,
145148
LLM_KV_EXPERT_GATING_FUNC,

0 commit comments

Comments
 (0)