@@ -892,8 +892,8 @@ def get_vocab_base_pre(self, tokenizer) -> str:
892892 # ref: https://huggingface.co/JetBrains/Mellum-4b-base
893893 res = "mellum"
894894 if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206" :
895- # ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
896- res = "llada-moe "
895+ # ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
896+ res = "bailingmoe2 "
897897 if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e" :
898898 # ref: https://huggingface.co/ibm-granite/granite-docling-258M
899899 res = "granite-docling"
@@ -8085,6 +8085,103 @@ def prepare_tensors(self):
80858085 raise ValueError (f"Unprocessed experts: { experts } " )
80868086
80878087
8088+ @ModelBase .register ("BailingMoeV2ForCausalLM" )
8089+ class BailingMoeV2Model (TextModel ):
8090+ model_arch = gguf .MODEL_ARCH .BAILINGMOE2
8091+
8092+ def __init__ (self , * args , ** kwargs ):
8093+ super ().__init__ (* args , ** kwargs )
8094+ if nextn_layers := self .hparams .get ("num_nextn_predict_layers" , 0 ):
8095+ self .block_count = self .hparams ["num_hidden_layers" ] + nextn_layers
8096+ self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
8097+
8098+ def set_vocab (self ):
8099+ self ._set_vocab_gpt2 ()
8100+
8101+ def set_gguf_parameters (self ):
8102+ super ().set_gguf_parameters ()
8103+ hparams = self .hparams
8104+ if (rope_dim := hparams .get ("head_dim" )) is None :
8105+ rope_dim = hparams ["hidden_size" ] // hparams ["num_attention_heads" ]
8106+
8107+ self .gguf_writer .add_rope_dimension_count (int (rope_dim * self .hparams .get ("partial_rotary_factor" , 0.5 )))
8108+ rope_scaling = self .hparams .get ("rope_scaling" ) or {}
8109+ if rope_scaling .get ("rope_type" , rope_scaling .get ("type" )) == "yarn" and "factor" in rope_scaling :
8110+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .YARN )
8111+ self .gguf_writer .add_rope_scaling_factor (rope_scaling ["factor" ])
8112+ self .gguf_writer .add_rope_scaling_orig_ctx_len (rope_scaling ["original_max_position_embeddings" ])
8113+ else :
8114+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .NONE )
8115+ self .gguf_writer .add_leading_dense_block_count (hparams ["first_k_dense_replace" ])
8116+ self .gguf_writer .add_vocab_size (hparams ["vocab_size" ])
8117+ self .gguf_writer .add_expert_feed_forward_length (hparams ["moe_intermediate_size" ])
8118+ self .gguf_writer .add_expert_shared_feed_forward_length (hparams .get ("moe_shared_expert_intermediate_size" , hparams ["moe_intermediate_size" ] * hparams ["num_shared_experts" ]))
8119+ self .gguf_writer .add_expert_weights_scale (hparams ["routed_scaling_factor" ])
8120+ self .gguf_writer .add_expert_count (hparams ["num_experts" ])
8121+ self .gguf_writer .add_expert_shared_count (hparams ["num_shared_experts" ])
8122+ self .gguf_writer .add_expert_group_count (hparams ["n_group" ])
8123+ self .gguf_writer .add_expert_group_used_count (hparams ["topk_group" ])
8124+ self .gguf_writer .add_expert_weights_norm (hparams ["norm_topk_prob" ])
8125+
8126+ if hparams ["score_function" ] == "sigmoid" :
8127+ self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
8128+ elif hparams ["score_function" ] == "softmax" :
8129+ self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SOFTMAX )
8130+ else :
8131+ raise ValueError (f"Unsupported score_function value: { hparams ['score_function' ]} " )
8132+
8133+ if (nextn_layers := self .hparams .get ("num_nextn_predict_layers" )) is not None :
8134+ self .gguf_writer .add_nextn_predict_layers (nextn_layers )
8135+
8136+ _experts : list [dict [str , Tensor ]] | None = None
8137+
8138+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
8139+ if "mlp.experts" in name :
8140+ n_experts = self .hparams ["num_experts" ]
8141+ assert bid is not None
8142+
8143+ tensors : list [tuple [str , Tensor ]] = []
8144+
8145+ if self ._experts is None :
8146+ self ._experts = [{} for _ in range (self .block_count )]
8147+
8148+ self ._experts [bid ][name ] = data_torch
8149+
8150+ if len (self ._experts [bid ]) >= n_experts * 3 :
8151+ # merge the experts into a single 3d tensor
8152+ for w_name in ["down_proj" , "gate_proj" , "up_proj" ]:
8153+ datas : list [Tensor ] = []
8154+
8155+ for xid in range (n_experts ):
8156+ ename = f"model.layers.{ bid } .mlp.experts.{ xid } .{ w_name } .weight"
8157+ datas .append (self ._experts [bid ][ename ])
8158+ del self ._experts [bid ][ename ]
8159+
8160+ data_torch = torch .stack (datas , dim = 0 )
8161+
8162+ merged_name = f"model.layers.{ bid } .mlp.experts.{ w_name } .weight"
8163+
8164+ new_name = self .map_tensor_name (merged_name )
8165+
8166+ tensors .append ((new_name , data_torch ))
8167+
8168+ return tensors
8169+
8170+ if name .endswith (".expert_bias" ):
8171+ name = name .replace (".expert_bias" , ".expert_bias.bias" )
8172+
8173+ return [(self .map_tensor_name (name ), data_torch )]
8174+
8175+ def prepare_tensors (self ):
8176+ super ().prepare_tensors ()
8177+
8178+ if self ._experts is not None :
8179+ # flatten `list[dict[str, Tensor]]` into `list[str]`
8180+ experts = [k for d in self ._experts for k in d .keys ()]
8181+ if len (experts ) > 0 :
8182+ raise ValueError (f"Unprocessed experts: { experts } " )
8183+
8184+
80888185@ModelBase .register ("GroveMoeForCausalLM" , "modeling_grove_moe.GroveMoeForCausalLM" )
80898186class GroveMoeModel (TextModel ):
80908187 model_arch = gguf .MODEL_ARCH .GROVEMOE
0 commit comments