Skip to content

Commit 7a8d6c0

Browse files
committed
[Fix] Fix bug of clean_param_name
The previous `clean_param_name` only matches the "._checkpoint_wrapped_module" which starts with **.**, however, for layers wrapper with checkpoint wrapper, the layer name start with "_checkpoint_wrapped_module" cannot be cleaned for the missing prefix . ghstack-source-id: 220732d Pull-Request: #1452
1 parent b400515 commit 7a8d6c0

File tree

7 files changed

+35
-14
lines changed

7 files changed

+35
-14
lines changed

xtuner/v1/engine/train_engine.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,12 +411,13 @@ def step_optimizer(self, grad_norm):
411411
self.optimizer.zero_grad()
412412
return grad_norm
413413

414+
# TODO: Should be removed
414415
@staticmethod
415416
def clean_param_name(name: str) -> str:
416-
if "._checkpoint_wrapped_module." in name:
417-
name = name.replace("._checkpoint_wrapped_module.", ".")
418-
if "._orig_mod." in name:
419-
name = name.replace("._orig_mod.", ".")
417+
if "_checkpoint_wrapped_module." in name:
418+
name = name.replace("_checkpoint_wrapped_module.", "")
419+
if "_orig_mod." in name:
420+
name = name.replace("_orig_mod.", "")
420421
return name
421422

422423
def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16):

xtuner/v1/model/base.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ class XTunerBaseModelConfig(PydanticBaseModel):
8585
def hf_config(self) -> PretrainedConfig | None:
8686
raise NotImplementedError
8787

88+
def build(self):
89+
raise NotImplementedError
90+
8891

8992
DEFAULT_FLOAT8_CFG = {
9093
"xtuner.v1.float8.fsdp_utils.tensor_to_per_block_fp8_scales": TorchCompileOption(fullgraph=True),
@@ -250,7 +253,11 @@ def __init__(self, config: XTunerBaseModelConfig):
250253
def set_hf(self, hf_path: str | Path):
251254
self._hf_path = Path(hf_path)
252255

253-
def from_hf(self, hf_path: str | Path, strict: bool = True) -> tuple:
256+
def from_hf(
257+
self, hf_path: str | Path, strict: bool = True
258+
) -> tuple[
259+
Annotated[set[str], "loaded keys"], Annotated[set[str], "unloaded keys"], Annotated[set[str], "missing keys"]
260+
]:
254261
self._hf_path = Path(hf_path)
255262

256263
if isinstance(hf_path, Path):
@@ -348,7 +355,7 @@ def init_weights(self):
348355
from xtuner.v1.utils import default_init_weights
349356

350357
initialized_params = default_init_weights(self)
351-
if missing := {name for name, _ in self.named_parameters()} - initialized_params:
358+
if missing := {self._clean_param_name(name) for name, _ in self.named_parameters()} - initialized_params:
352359
raise RuntimeError(f"{missing} is not initialized")
353360

354361
def _init_load_spec(self) -> None:
@@ -797,11 +804,12 @@ def _get_same_hf_param(
797804
if buffer_tensor_list:
798805
yield buffer_name_list, buffer_tensor_list
799806

807+
# TODO: Using `xtuenr.v1.utils.misc.clean_param_name`
800808
def _clean_param_name(self, name: str) -> str:
801-
if "._checkpoint_wrapped_module." in name:
802-
name = name.replace("._checkpoint_wrapped_module.", ".")
803-
if "._orig_mod." in name:
804-
name = name.replace("._orig_mod.", ".")
809+
if "_checkpoint_wrapped_module." in name:
810+
name = name.replace("_checkpoint_wrapped_module.", "")
811+
if "_orig_mod." in name:
812+
name = name.replace("_orig_mod.", "")
805813
return name
806814

807815
def _group_param_by_load_spec(self, load_enum: LoadEnum):

xtuner/v1/model/compose/qwen3_vl/modeling_vision.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ def init_weights(self):
276276

277277
for layer_idx, layer in enumerate(self.blocks):
278278
for name, module in layer.named_modules():
279+
name = self._clean_param_name(name)
279280
if isinstance(module, nn.Linear):
280281
init_params(module.weight,
281282
partial(torch.nn.init.normal_, mean=0.0, std=self.config.initializer_range))

xtuner/v1/train/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from xtuner.v1.engine import LossLog, OtherLog, TrainEngine
3434
from xtuner.v1.engine.vision_compose_train_engine import VisionComposeTrainEngine
3535
from xtuner.v1.loss import CELossConfig, CELossContext
36-
from xtuner.v1.model.base import ModelItem, TransformerConfig, XTunerBaseModelConfig
36+
from xtuner.v1.model.base import ModelItem, XTunerBaseModelConfig
3737
from xtuner.v1.model.compose.base import BaseComposeConfig
3838
from xtuner.v1.model.moe.moe import MoEConfig
3939
from xtuner.v1.patch import patch_default_save_plan

xtuner/v1/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
XTUNER_DETERMINISTIC,
1212
FunctionEnum,
1313
SharedMemory,
14+
clean_param_name,
1415
get_function_type,
1516
get_padding_length,
1617
is_hf_model_path,
@@ -57,4 +58,5 @@
5758
"monkey_unpatch_torch_reductions",
5859
"ray_method",
5960
"profile_time",
61+
"clean_param_name",
6062
]

xtuner/v1/utils/init_weight.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import torch.nn as nn
66
from torch.distributed.tensor import DTensor, distribute_tensor
77

8-
from xtuner.v1.utils import get_device
8+
from .device import get_device
9+
from .misc import clean_param_name
910

1011

1112
DEVICE = get_device()
@@ -51,15 +52,15 @@ def _default_init_atom(name: str, module: nn.Module):
5152
if hasattr(module, "bias") and module.bias is not None:
5253
bias = cast(torch.Tensor, module.bias)
5354
init_params(bias, nn.init.zeros_)
54-
initialized_params.add(f"{name}.bias")
55+
initialized_params.add(clean_param_name(f"{name}.bias"))
5556

5657
if hasattr(module, "weight") and module.weight is not None:
5758
weight = cast(torch.Tensor, module.weight)
5859
if "norm" in name:
5960
init_params(weight, nn.init.ones_)
6061
else:
6162
init_params(weight, partial(nn.init.normal_, mean=0.0, std=0.02))
62-
initialized_params.add(f"{name}.weight")
63+
initialized_params.add(clean_param_name(f"{name}.weight"))
6364

6465
_init_weights_recursive("", module)
6566
return initialized_params

xtuner/v1/utils/misc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,11 @@ def get_function_full_qualname(function: FunctionType) -> str:
192192

193193
full_qualname = f"{module_name}.{qualname}"
194194
return full_qualname
195+
196+
197+
def clean_param_name(name: str) -> str:
198+
if "_checkpoint_wrapped_module." in name:
199+
name = name.replace("_checkpoint_wrapped_module.", "")
200+
if "_orig_mod." in name:
201+
name = name.replace("_orig_mod.", "")
202+
return name

0 commit comments

Comments
 (0)