Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
Optional,
)

import h5py

from deepmd.common import (
j_loader,
)
Expand Down Expand Up @@ -229,6 +231,27 @@ def _do_work(
# setup data modifier
modifier = get_modifier(jdata["model"].get("modifier", None))

# extract stat_file from training parameters
stat_file_path = None
if not is_compress:
stat_file_raw = jdata["training"].get("stat_file", None)
if stat_file_raw is not None and run_opt.is_chief:
from pathlib import (
Path,
)

from deepmd.utils.path import (
DPPath,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot This should be moved to the top of this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the imports to the top of the file. Fixed in a878838.


if not Path(stat_file_raw).exists():
if stat_file_raw.endswith((".h5", ".hdf5")):
with h5py.File(stat_file_raw, "w") as f:
pass
else:
Path(stat_file_raw).mkdir()
stat_file_path = DPPath(stat_file_raw, "a")

# decouple the training data from the model compress process
train_data = None
valid_data = None
Expand Down Expand Up @@ -261,7 +284,12 @@ def _do_work(
origin_type_map = get_data(
jdata["training"]["training_data"], rcut, None, modifier
).get_type_map()
model.build(train_data, stop_batch, origin_type_map=origin_type_map)
model.build(
train_data,
stop_batch,
origin_type_map=origin_type_map,
stat_file_path=stat_file_path,
)

if not is_compress:
# train the model with the provided systems in a cyclic way
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_numb_aparam(self) -> int:
"""Get the number of atomic parameters."""
return self.numb_aparam

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False)
m_all_stat = merge_sys_stat(all_stat)
self._compute_input_stat(
Expand Down
40 changes: 34 additions & 6 deletions deepmd/tf/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,15 @@ def get_numb_aparam(self) -> int:
"""Get the number of atomic parameters."""
return self.numb_aparam

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False)
m_all_stat = merge_sys_stat(all_stat)
self._compute_input_stat(
m_all_stat, protection=self.data_stat_protect, mixed_type=data.mixed_type
)
self._compute_output_stat(all_stat, mixed_type=data.mixed_type)
self._compute_output_stat(
all_stat, mixed_type=data.mixed_type, stat_file_path=stat_file_path
)
# self.bias_atom_e = data.compute_energy_shift(self.rcond)

def _compute_input_stat(self, all_stat, protection=1e-2, mixed_type=False) -> None:
Expand All @@ -167,11 +169,37 @@ def _compute_input_stat(self, all_stat, protection=1e-2, mixed_type=False) -> No
)
self.fitting.compute_input_stats(all_stat, protection=protection)

def _compute_output_stat(self, all_stat, mixed_type=False) -> None:
if mixed_type:
self.fitting.compute_output_stats(all_stat, mixed_type=mixed_type)
def _compute_output_stat(
self, all_stat, mixed_type=False, stat_file_path=None
) -> None:
if stat_file_path is not None:
# Use the new stat functionality with file save/load
from deepmd.tf.utils.stat import (
compute_output_stats,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot This should be moved to the top of this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the import to the top of the file. Fixed in a878838.


# Merge system stats for compatibility
m_all_stat = merge_sys_stat(all_stat)

bias_out, std_out = compute_output_stats(
m_all_stat,
self.ntypes,
keys=["energy"],
stat_file_path=stat_file_path,
rcond=getattr(self, "rcond", None),
mixed_type=mixed_type,
)

# Set the computed bias and std in the fitting object
if "energy" in bias_out:
self.fitting.bias_atom_e = bias_out["energy"]

else:
self.fitting.compute_output_stats(all_stat)
# Use the original computation method
if mixed_type:
self.fitting.compute_output_stats(all_stat, mixed_type=mixed_type)
else:
self.fitting.compute_output_stats(all_stat)

def build(
self,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def get_rcut(self):
def get_ntypes(self) -> int:
return self.model.get_ntypes()

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
pass

def init_variables(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_ntypes(self) -> int:
raise ValueError("Models have different ntypes")
return self.models[0].get_ntypes()

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
for model in self.models:
model.data_stat(data)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def get_ntypes(self) -> int:
"""Get the number of types."""

@abstractmethod
def data_stat(self, data: dict):
def data_stat(self, data: dict, stat_file_path=None):
"""Data staticis."""

def get_feed_dict(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/pairwise_dprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def get_rcut(self):
def get_ntypes(self) -> int:
return self.ntypes

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
self.qm_model.data_stat(data)
self.qmmm_model.data_stat(data)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/tf/model/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_sel_type(self):
def get_out_size(self):
return self.fitting.get_out_size()

def data_stat(self, data) -> None:
def data_stat(self, data, stat_file_path=None) -> None:
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False)
m_all_stat = merge_sys_stat(all_stat)
self._compute_input_stat(m_all_stat, protection=self.data_stat_protect)
Expand Down
11 changes: 9 additions & 2 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,14 @@ def get_lr_and_coef(lr_param):
self.ckpt_meta = None
self.model_type = None

def build(self, data=None, stop_batch=0, origin_type_map=None, suffix="") -> None:
def build(
self,
data=None,
stop_batch=0,
origin_type_map=None,
suffix="",
stat_file_path=None,
) -> None:
self.ntypes = self.model.get_ntypes()
self.stop_batch = stop_batch

Expand Down Expand Up @@ -209,7 +216,7 @@ def build(self, data=None, stop_batch=0, origin_type_map=None, suffix="") -> Non
# self.saver.restore (in self._init_session) will restore avg and std variables, so data_stat is useless
# init_from_frz_model will restore data_stat variables in `init_variables` method
log.info("data stating... (this step may take long time)")
self.model.data_stat(data)
self.model.data_stat(data, stat_file_path=stat_file_path)

# config the init_frz_model command
if self.run_opt.init_mode == "init_from_frz_model":
Expand Down
165 changes: 165 additions & 0 deletions deepmd/tf/utils/stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
Optional,
)

import numpy as np

from deepmd.utils.path import (
DPPath,
)

log = logging.getLogger(__name__)


def _restore_from_file(
stat_file_path: DPPath,
keys: list[str] = ["energy"],
) -> Optional[tuple[dict, dict]]:
"""Restore bias and std from stat file.

Parameters
----------
stat_file_path : DPPath
Path to the stat file directory/file
keys : list[str]
Keys to restore statistics for

Returns
-------
ret_bias : dict or None
Bias values for each key
ret_std : dict or None
Standard deviation values for each key
"""
if stat_file_path is None:
return None, None
stat_files = [stat_file_path / f"bias_atom_{kk}" for kk in keys]
if all(not (ii.is_file()) for ii in stat_files):
return None, None
stat_files = [stat_file_path / f"std_atom_{kk}" for kk in keys]
if all(not (ii.is_file()) for ii in stat_files):
return None, None

ret_bias = {}
ret_std = {}
for kk in keys:
fp = stat_file_path / f"bias_atom_{kk}"
# only read the key that exists
if fp.is_file():
ret_bias[kk] = fp.load_numpy()
for kk in keys:
fp = stat_file_path / f"std_atom_{kk}"
# only read the key that exists
if fp.is_file():
ret_std[kk] = fp.load_numpy()
return ret_bias, ret_std


def _save_to_file(
stat_file_path: DPPath,
bias_out: dict,
std_out: dict,
) -> None:
"""Save bias and std to stat file.

Parameters
----------
stat_file_path : DPPath
Path to the stat file directory/file
bias_out : dict
Bias values for each key
std_out : dict
Standard deviation values for each key
"""
assert stat_file_path is not None
stat_file_path.mkdir(exist_ok=True, parents=True)
for kk, vv in bias_out.items():
fp = stat_file_path / f"bias_atom_{kk}"
fp.save_numpy(vv)
for kk, vv in std_out.items():
fp = stat_file_path / f"std_atom_{kk}"
fp.save_numpy(vv)


def compute_output_stats(
all_stat: dict,
ntypes: int,
keys: list[str] = ["energy"],
stat_file_path: Optional[DPPath] = None,
rcond: Optional[float] = None,
mixed_type: bool = False,
) -> tuple[dict, dict]:
"""Compute output statistics for TensorFlow models.

This is a simplified version of the PyTorch compute_output_stats function
adapted for TensorFlow models.

Parameters
----------
all_stat : dict
Dictionary containing statistical data
ntypes : int
Number of atom types
keys : list[str]
Keys to compute statistics for
stat_file_path : DPPath, optional
Path to save/load statistics
rcond : float, optional
Condition number for regression
mixed_type : bool
Whether mixed type format is used

Returns
-------
bias_out : dict
Computed bias values
std_out : dict
Computed standard deviation values
"""
# Try to restore from file first
bias_out, std_out = _restore_from_file(stat_file_path, keys)

if bias_out is not None and std_out is not None:
log.info("Successfully restored statistics from stat file")
return bias_out, std_out

# If restore failed, compute from data
log.info("Computing statistics from training data")

from deepmd.utils.out_stat import (
compute_stats_from_redu,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot imports should be put on the top of the file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the import to the top-level imports section. Fixed in 1e4deb2.


bias_out = {}
std_out = {}

for key in keys:
if key in all_stat:
# Get energy and natoms data
energy_data = np.concatenate(all_stat[key])
natoms_data = np.concatenate(all_stat["natoms_vec"])[
:, 2:
] # Skip first 2 elements

# Compute statistics using existing utility
bias, std = compute_stats_from_redu(
energy_data.reshape(-1, 1), # Reshape to column vector
natoms_data,
rcond=rcond,
)

bias_out[key] = bias.reshape(-1) # Flatten to 1D
std_out[key] = std.reshape(-1) # Flatten to 1D

log.info(
f"Statistics computed for {key}: bias shape {bias_out[key].shape}, std shape {std_out[key].shape}"
)

# Save to file if path provided
if stat_file_path is not None and bias_out:
_save_to_file(stat_file_path, bias_out, std_out)
log.info("Statistics saved to stat file")

return bias_out, std_out
4 changes: 1 addition & 3 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3175,9 +3175,7 @@ def training_args(
data_args = [
arg_training_data,
arg_validation_data,
Argument(
"stat_file", str, optional=True, doc=doc_only_pt_supported + doc_stat_file
),
Argument("stat_file", str, optional=True, doc=doc_stat_file),
]
args = (
data_args
Expand Down
Loading