diff --git a/diffsynth_engine/models/wan/wan_text_encoder.py b/diffsynth_engine/models/wan/wan_text_encoder.py index 496bd2c..9696532 100644 --- a/diffsynth_engine/models/wan/wan_text_encoder.py +++ b/diffsynth_engine/models/wan/wan_text_encoder.py @@ -198,22 +198,22 @@ def __init__(self, num_encoder_layers: int = 24): def _from_diffusers(self, state_dict): rename_dict = { - "enc.output_norm.weight": "norm.weight", - "token_embd.weight": "token_embedding.weight", + "shared.weight": "token_embedding.weight", + "encoder.final_layer_norm.weight": "norm.weight", } for i in range(self.num_encoder_layers): rename_dict.update( { - f"enc.blk.{i}.attn_q.weight": f"blocks.{i}.attn.q.weight", - f"enc.blk.{i}.attn_k.weight": f"blocks.{i}.attn.k.weight", - f"enc.blk.{i}.attn_v.weight": f"blocks.{i}.attn.v.weight", - f"enc.blk.{i}.attn_o.weight": f"blocks.{i}.attn.o.weight", - f"enc.blk.{i}.ffn_up.weight": f"blocks.{i}.ffn.fc1.weight", - f"enc.blk.{i}.ffn_down.weight": f"blocks.{i}.ffn.fc2.weight", - f"enc.blk.{i}.ffn_gate.weight": f"blocks.{i}.ffn.gate.0.weight", - f"enc.blk.{i}.attn_norm.weight": f"blocks.{i}.norm1.weight", - f"enc.blk.{i}.ffn_norm.weight": f"blocks.{i}.norm2.weight", - f"enc.blk.{i}.attn_rel_b.weight": f"blocks.{i}.pos_embedding.embedding.weight", + f"encoder.block.{i}.layer.0.SelfAttention.q.weight": f"blocks.{i}.attn.q.weight", + f"encoder.block.{i}.layer.0.SelfAttention.k.weight": f"blocks.{i}.attn.k.weight", + f"encoder.block.{i}.layer.0.SelfAttention.v.weight": f"blocks.{i}.attn.v.weight", + f"encoder.block.{i}.layer.0.SelfAttention.o.weight": f"blocks.{i}.attn.o.weight", + f"encoder.block.{i}.layer.0.SelfAttention.relative_attention_bias.weight": f"blocks.{i}.pos_embedding.embedding.weight", + f"encoder.block.{i}.layer.0.layer_norm.weight": f"blocks.{i}.norm1.weight", + f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight": f"blocks.{i}.ffn.gate.0.weight", + f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight": f"blocks.{i}.ffn.fc1.weight", + f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight": f"blocks.{i}.ffn.fc2.weight", + f"encoder.block.{i}.layer.1.layer_norm.weight": f"blocks.{i}.norm2.weight", } ) @@ -224,7 +224,7 @@ def _from_diffusers(self, state_dict): return new_state_dict def convert(self, state_dict): - if "enc.output_norm.weight" in state_dict: + if "encoder.final_layer_norm.weight" in state_dict: logger.info("use diffusers format state dict") return self._from_diffusers(state_dict) return state_dict