Skip to content
Draft
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@
import json
import logging
import time
from pathlib import (
Path,
)
from typing import (
Any,
Optional,
)

import h5py

from deepmd.common import (
j_loader,
)
Expand Down Expand Up @@ -46,6 +51,9 @@
from deepmd.utils.data_system import (
get_data,
)
from deepmd.utils.path import (
DPPath,
)

__all__ = ["train"]

Expand Down Expand Up @@ -229,6 +237,19 @@ def _do_work(
# setup data modifier
modifier = get_modifier(jdata["model"].get("modifier", None))

# extract stat_file from training parameters
stat_file_path = None
if not is_compress:
stat_file_raw = jdata["training"].get("stat_file", None)
if stat_file_raw is not None and run_opt.is_chief:
if not Path(stat_file_raw).exists():
if stat_file_raw.endswith((".h5", ".hdf5")):
with h5py.File(stat_file_raw, "w") as f:
pass
else:
Path(stat_file_raw).mkdir()
stat_file_path = DPPath(stat_file_raw, "a")

# decouple the training data from the model compress process
train_data = None
valid_data = None
Expand Down Expand Up @@ -261,7 +282,12 @@ def _do_work(
origin_type_map = get_data(
jdata["training"]["training_data"], rcut, None, modifier
).get_type_map()
model.build(train_data, stop_batch, origin_type_map=origin_type_map)
model.build(
train_data,
stop_batch,
origin_type_map=origin_type_map,
stat_file_path=stat_file_path,
)

if not is_compress:
# train the model with the provided systems in a cyclic way
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_numb_aparam(self) -> int:
"""Get the number of atomic parameters."""
return self.numb_aparam

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False)
m_all_stat = merge_sys_stat(all_stat)
self._compute_input_stat(
Expand Down
45 changes: 39 additions & 6 deletions deepmd/tf/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from deepmd.tf.utils.spin import (
Spin,
)
from deepmd.tf.utils.stat import (
compute_output_stats,
)
from deepmd.tf.utils.type_embed import (
TypeEmbedNet,
)
Expand Down Expand Up @@ -135,13 +138,15 @@ def get_numb_aparam(self) -> int:
"""Get the number of atomic parameters."""
return self.numb_aparam

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False)
m_all_stat = merge_sys_stat(all_stat)
self._compute_input_stat(
m_all_stat, protection=self.data_stat_protect, mixed_type=data.mixed_type
)
self._compute_output_stat(all_stat, mixed_type=data.mixed_type)
self._compute_output_stat(
all_stat, mixed_type=data.mixed_type, stat_file_path=stat_file_path
)
# self.bias_atom_e = data.compute_energy_shift(self.rcond)

def _compute_input_stat(self, all_stat, protection=1e-2, mixed_type=False) -> None:
Expand All @@ -167,11 +172,39 @@ def _compute_input_stat(self, all_stat, protection=1e-2, mixed_type=False) -> No
)
self.fitting.compute_input_stats(all_stat, protection=protection)

def _compute_output_stat(self, all_stat, mixed_type=False) -> None:
if mixed_type:
self.fitting.compute_output_stats(all_stat, mixed_type=mixed_type)
def _compute_output_stat(
self, all_stat, mixed_type=False, stat_file_path=None
) -> None:
if stat_file_path is not None:
# Use the new stat functionality with file save/load
# Add type_map subdirectory for consistency with PyTorch backend
if stat_file_path is not None and self.type_map is not None:
# descriptors and fitting net with different type_map
# should not share the same parameters
stat_file_path = stat_file_path / " ".join(self.type_map)

# Merge system stats for compatibility
m_all_stat = merge_sys_stat(all_stat)

bias_out, std_out = compute_output_stats(
m_all_stat,
self.ntypes,
keys=["energy"],
stat_file_path=stat_file_path,
rcond=getattr(self, "rcond", None),
mixed_type=mixed_type,
)

# Set the computed bias and std in the fitting object
if "energy" in bias_out:
self.fitting.bias_atom_e = bias_out["energy"]

else:
self.fitting.compute_output_stats(all_stat)
# Use the original computation method
if mixed_type:
self.fitting.compute_output_stats(all_stat, mixed_type=mixed_type)
else:
self.fitting.compute_output_stats(all_stat)

def build(
self,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def get_rcut(self):
def get_ntypes(self) -> int:
return self.model.get_ntypes()

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
pass

def init_variables(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_ntypes(self) -> int:
raise ValueError("Models have different ntypes")
return self.models[0].get_ntypes()

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
for model in self.models:
model.data_stat(data)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def get_ntypes(self) -> int:
"""Get the number of types."""

@abstractmethod
def data_stat(self, data: dict):
def data_stat(self, data: dict, stat_file_path=None):
"""Data staticis."""

def get_feed_dict(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/pairwise_dprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def get_rcut(self):
def get_ntypes(self) -> int:
return self.ntypes

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
self.qm_model.data_stat(data)
self.qmmm_model.data_stat(data)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_sel_type(self):
def get_out_size(self):
return self.fitting.get_out_size()

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False)
m_all_stat = merge_sys_stat(all_stat)
self._compute_input_stat(m_all_stat, protection=self.data_stat_protect)
Expand Down
11 changes: 9 additions & 2 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,14 @@ def get_lr_and_coef(lr_param):
self.ckpt_meta = None
self.model_type = None

def build(self, data=None, stop_batch=0, origin_type_map=None, suffix="") -> None:
def build(
self,
data=None,
stop_batch=0,
origin_type_map=None,
suffix="",
stat_file_path=None,
) -> None:
self.ntypes = self.model.get_ntypes()
self.stop_batch = stop_batch

Expand Down Expand Up @@ -209,7 +216,7 @@ def build(self, data=None, stop_batch=0, origin_type_map=None, suffix="") -> Non
# self.saver.restore (in self._init_session) will restore avg and std variables, so data_stat is useless
# init_from_frz_model will restore data_stat variables in `init_variables` method
log.info("data stating... (this step may take long time)")
self.model.data_stat(data)
self.model.data_stat(data, stat_file_path=stat_file_path)

# config the init_frz_model command
if self.run_opt.init_mode == "init_from_frz_model":
Expand Down
Loading
Loading