Skip to content

Commit f0cf2b9

Browse files
njzjz-botnjzjz
andauthored
fix(train): allow zero-step training with bias adjustment (#5477)
Problem - `numb_steps=0` is a valid no-optimization path that should save the initial checkpoint. - When `change_bias_after_training` is enabled, the post-training bias adjustment still ran after zero steps and evaluated learning-rate/checkpoint metadata at step `-1`. Change - Skip post-training bias adjustment unless at least one training step has run. - Keep the existing zero-step initial checkpoint save path for both PyTorch and Paddle backends. - Add PT/PD regression tests that run zero-step training with `change_bias_after_training=true` and verify the saved `*-0` checkpoint metadata. Notes - `python3 -m pytest ...` could not run in this workspace because pytest is not installed in the available Python environment. - `uvx ruff check deepmd/pd/train/training.py deepmd/pt/train/training.py source/tests/pd/test_training.py source/tests/pt/test_training.py` passed. - `uvx ruff format --check deepmd/pd/train/training.py deepmd/pt/train/training.py source/tests/pd/test_training.py source/tests/pt/test_training.py` passed. - Closes #4988. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Prevented unintended bias-adjustment during zero-step PyTorch training so the initial checkpoint is created and recorded correctly. * **Refactor** * Clarified the post-training bias-adjustment conditional in Paddle for readability (no behavior change). * **Tests** * Added tests for zero-step training with bias-adjustment enabled for both Paddle and PyTorch, verifying initial checkpoint creation and training metadata. <!-- review_stack_entry_start --> [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/deepmodeling/deepmd-kit/pull/5477?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack) <!-- review_stack_entry_end --> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent 4b6506d commit f0cf2b9

4 files changed

Lines changed: 68 additions & 2 deletions

File tree

deepmd/pd/train/training.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,11 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
10381038
if JIT:
10391039
break
10401040

1041-
if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0):
1041+
if (
1042+
self.change_bias_after_training
1043+
and self.num_steps > self.start_step
1044+
and (self.rank == 0 or dist.get_rank() == 0)
1045+
):
10421046
if not self.multi_task:
10431047
self.model = model_change_out_bias(
10441048
self.model,

deepmd/pt/train/training.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1846,7 +1846,11 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
18461846
if JIT:
18471847
break
18481848

1849-
if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0):
1849+
if (
1850+
self.change_bias_after_training
1851+
and self.num_steps > self.start_step
1852+
and (self.rank == 0 or dist.get_rank() == 0)
1853+
):
18501854
if not self.multi_task:
18511855
self.model = model_change_out_bias(
18521856
self.model,

source/tests/pd/test_training.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@
99
from pathlib import (
1010
Path,
1111
)
12+
from unittest.mock import (
13+
patch,
14+
)
1215

1316
import numpy as np
17+
import paddle
1418

1519
from deepmd.pd.entrypoints.main import (
1620
get_trainer,
@@ -163,6 +167,33 @@ def setUp(self) -> None:
163167
self.config["training"]["save_freq"] = 1
164168
enable_prim(True)
165169

170+
@patch("deepmd.pd.train.training.model_change_out_bias")
171+
def test_zero_step_with_change_bias_saves_initial_checkpoint(
172+
self, mocked_change_out_bias
173+
) -> None:
174+
def keep_model(model, *_args, **_kwargs):
175+
return model
176+
177+
mocked_change_out_bias.side_effect = keep_model
178+
config = deepcopy(self.config)
179+
config["training"]["numb_steps"] = 0
180+
config["training"]["change_bias_after_training"] = True
181+
trainer = get_trainer(config)
182+
trainer.run()
183+
184+
expected_model = Path(trainer.save_ckpt + "-0.pd")
185+
self.assertEqual(expected_model, trainer.latest_model)
186+
self.assertTrue(expected_model.exists())
187+
self.assertEqual(
188+
expected_model,
189+
Path(Path("checkpoint").read_text().strip()),
190+
)
191+
checkpoint = paddle.load(str(expected_model))
192+
train_infos = checkpoint["model"]["_extra_state"]["train_infos"]
193+
self.assertEqual(0, train_infos["step"])
194+
self.assertEqual(0.0, train_infos["lr"])
195+
mocked_change_out_bias.assert_not_called()
196+
166197
def tearDown(self) -> None:
167198
DPTrainTest.tearDown(self)
168199

source/tests/pt/test_training.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,33 @@ def test_yaml_input(self) -> None:
265265
)
266266
self.assertTrue(Path("out.json").exists())
267267

268+
@patch("deepmd.pt.train.training.model_change_out_bias")
269+
def test_zero_step_with_change_bias_saves_initial_checkpoint(
270+
self, mocked_change_out_bias
271+
) -> None:
272+
def keep_model(model, *_args, **_kwargs):
273+
return model
274+
275+
mocked_change_out_bias.side_effect = keep_model
276+
config = deepcopy(self.config)
277+
config["training"]["numb_steps"] = 0
278+
config["training"]["change_bias_after_training"] = True
279+
trainer = get_trainer(config)
280+
trainer.run()
281+
282+
expected_model = Path(trainer.save_ckpt + "-0.pt")
283+
self.assertEqual(expected_model, trainer.latest_model)
284+
self.assertTrue(expected_model.exists())
285+
self.assertEqual(
286+
expected_model,
287+
Path(Path("checkpoint").read_text().strip()),
288+
)
289+
checkpoint = torch.load(expected_model, map_location="cpu", weights_only=True)
290+
train_infos = checkpoint["model"]["_extra_state"]["train_infos"]
291+
self.assertEqual(0, train_infos["step"])
292+
self.assertEqual(0.0, train_infos["lr"])
293+
mocked_change_out_bias.assert_not_called()
294+
268295
def tearDown(self) -> None:
269296
DPTrainTest.tearDown(self)
270297
for ff in ["out.json", "input.yaml"]:

0 commit comments

Comments
 (0)