Skip to content
Draft
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions checkpoint
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model_checkpoint_path: "model.ckpt-10"
all_model_checkpoint_paths: "model.ckpt-5"
all_model_checkpoint_paths: "model.ckpt-10"
28 changes: 27 additions & 1 deletion deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@
import json
import logging
import time
from pathlib import (
Path,
)
from typing import (
Any,
Optional,
)

import h5py

from deepmd.common import (
j_loader,
)
Expand Down Expand Up @@ -46,6 +51,9 @@
from deepmd.utils.data_system import (
get_data,
)
from deepmd.utils.path import (
DPPath,
)

__all__ = ["train"]

Expand Down Expand Up @@ -229,6 +237,19 @@ def _do_work(
# setup data modifier
modifier = get_modifier(jdata["model"].get("modifier", None))

# extract stat_file from training parameters
stat_file_path = None
if not is_compress:
stat_file_raw = jdata["training"].get("stat_file", None)
if stat_file_raw is not None and run_opt.is_chief:
if not Path(stat_file_raw).exists():
if stat_file_raw.endswith((".h5", ".hdf5")):
with h5py.File(stat_file_raw, "w") as f:
pass
else:
Path(stat_file_raw).mkdir()
stat_file_path = DPPath(stat_file_raw, "a")

# decouple the training data from the model compress process
train_data = None
valid_data = None
Expand Down Expand Up @@ -261,7 +282,12 @@ def _do_work(
origin_type_map = get_data(
jdata["training"]["training_data"], rcut, None, modifier
).get_type_map()
model.build(train_data, stop_batch, origin_type_map=origin_type_map)
model.build(
train_data,
stop_batch,
origin_type_map=origin_type_map,
stat_file_path=stat_file_path,
)

if not is_compress:
# train the model with the provided systems in a cyclic way
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_numb_aparam(self) -> int:
"""Get the number of atomic parameters."""
return self.numb_aparam

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False)
m_all_stat = merge_sys_stat(all_stat)
self._compute_input_stat(
Expand Down
45 changes: 39 additions & 6 deletions deepmd/tf/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from deepmd.tf.utils.spin import (
Spin,
)
from deepmd.tf.utils.stat import (
compute_output_stats,
)
from deepmd.tf.utils.type_embed import (
TypeEmbedNet,
)
Expand Down Expand Up @@ -135,13 +138,15 @@ def get_numb_aparam(self) -> int:
"""Get the number of atomic parameters."""
return self.numb_aparam

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False)
m_all_stat = merge_sys_stat(all_stat)
self._compute_input_stat(
m_all_stat, protection=self.data_stat_protect, mixed_type=data.mixed_type
)
self._compute_output_stat(all_stat, mixed_type=data.mixed_type)
self._compute_output_stat(
all_stat, mixed_type=data.mixed_type, stat_file_path=stat_file_path
)
# self.bias_atom_e = data.compute_energy_shift(self.rcond)

def _compute_input_stat(self, all_stat, protection=1e-2, mixed_type=False) -> None:
Expand All @@ -167,11 +172,39 @@ def _compute_input_stat(self, all_stat, protection=1e-2, mixed_type=False) -> No
)
self.fitting.compute_input_stats(all_stat, protection=protection)

def _compute_output_stat(self, all_stat, mixed_type=False) -> None:
if mixed_type:
self.fitting.compute_output_stats(all_stat, mixed_type=mixed_type)
def _compute_output_stat(
self, all_stat, mixed_type=False, stat_file_path=None
) -> None:
if stat_file_path is not None:
# Use the new stat functionality with file save/load
# Add type_map subdirectory for consistency with PyTorch backend
if stat_file_path is not None and self.type_map is not None:
# descriptors and fitting net with different type_map
# should not share the same parameters
stat_file_path = stat_file_path / " ".join(self.type_map)

# Merge system stats for compatibility
m_all_stat = merge_sys_stat(all_stat)

bias_out, std_out = compute_output_stats(
m_all_stat,
self.ntypes,
keys=["energy"],
stat_file_path=stat_file_path,
rcond=getattr(self, "rcond", None),
mixed_type=mixed_type,
)

# Set the computed bias and std in the fitting object
if "energy" in bias_out:
self.fitting.bias_atom_e = bias_out["energy"]

else:
self.fitting.compute_output_stats(all_stat)
# Use the original computation method
if mixed_type:
self.fitting.compute_output_stats(all_stat, mixed_type=mixed_type)
else:
self.fitting.compute_output_stats(all_stat)

def build(
self,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def get_rcut(self):
def get_ntypes(self) -> int:
return self.model.get_ntypes()

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
pass

def init_variables(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_ntypes(self) -> int:
raise ValueError("Models have different ntypes")
return self.models[0].get_ntypes()

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
for model in self.models:
model.data_stat(data)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def get_ntypes(self) -> int:
"""Get the number of types."""

@abstractmethod
def data_stat(self, data: dict):
def data_stat(self, data: dict, stat_file_path=None):
"""Data staticis."""

def get_feed_dict(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/pairwise_dprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def get_rcut(self):
def get_ntypes(self) -> int:
return self.ntypes

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
self.qm_model.data_stat(data)
self.qmmm_model.data_stat(data)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_sel_type(self):
def get_out_size(self):
return self.fitting.get_out_size()

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False)
m_all_stat = merge_sys_stat(all_stat)
self._compute_input_stat(m_all_stat, protection=self.data_stat_protect)
Expand Down
11 changes: 9 additions & 2 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,14 @@ def get_lr_and_coef(lr_param):
self.ckpt_meta = None
self.model_type = None

def build(self, data=None, stop_batch=0, origin_type_map=None, suffix="") -> None:
def build(
self,
data=None,
stop_batch=0,
origin_type_map=None,
suffix="",
stat_file_path=None,
) -> None:
self.ntypes = self.model.get_ntypes()
self.stop_batch = stop_batch

Expand Down Expand Up @@ -209,7 +216,7 @@ def build(self, data=None, stop_batch=0, origin_type_map=None, suffix="") -> Non
# self.saver.restore (in self._init_session) will restore avg and std variables, so data_stat is useless
# init_from_frz_model will restore data_stat variables in `init_variables` method
log.info("data stating... (this step may take long time)")
self.model.data_stat(data)
self.model.data_stat(data, stat_file_path=stat_file_path)

# config the init_frz_model command
if self.run_opt.init_mode == "init_from_frz_model":
Expand Down
Loading