-
Notifications
You must be signed in to change notification settings - Fork 575
feat(tf): add support for stat_file parameter #4926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
Changes from 4 commits
610c6fa
932223d
8e731c8
2cb3163
a878838
69dbf52
c60793c
995a1d6
ee06a1c
03a4754
17b7a9a
1e4deb2
5864cee
249367c
e8fd06a
7efbdf9
c51189a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -135,13 +135,15 @@ def get_numb_aparam(self) -> int: | |
| """Get the number of atomic parameters.""" | ||
| return self.numb_aparam | ||
|
|
||
| def data_stat(self, data) -> None: | ||
| def data_stat(self, data, stat_file_path=None) -> None: | ||
| all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys=False) | ||
| m_all_stat = merge_sys_stat(all_stat) | ||
| self._compute_input_stat( | ||
| m_all_stat, protection=self.data_stat_protect, mixed_type=data.mixed_type | ||
| ) | ||
| self._compute_output_stat(all_stat, mixed_type=data.mixed_type) | ||
| self._compute_output_stat( | ||
| all_stat, mixed_type=data.mixed_type, stat_file_path=stat_file_path | ||
| ) | ||
| # self.bias_atom_e = data.compute_energy_shift(self.rcond) | ||
|
|
||
| def _compute_input_stat(self, all_stat, protection=1e-2, mixed_type=False) -> None: | ||
|
|
@@ -167,11 +169,37 @@ def _compute_input_stat(self, all_stat, protection=1e-2, mixed_type=False) -> No | |
| ) | ||
| self.fitting.compute_input_stats(all_stat, protection=protection) | ||
|
|
||
| def _compute_output_stat(self, all_stat, mixed_type=False) -> None: | ||
| if mixed_type: | ||
| self.fitting.compute_output_stats(all_stat, mixed_type=mixed_type) | ||
| def _compute_output_stat( | ||
| self, all_stat, mixed_type=False, stat_file_path=None | ||
| ) -> None: | ||
| if stat_file_path is not None: | ||
| # Use the new stat functionality with file save/load | ||
| from deepmd.tf.utils.stat import ( | ||
| compute_output_stats, | ||
| ) | ||
|
||
|
|
||
| # Merge system stats for compatibility | ||
| m_all_stat = merge_sys_stat(all_stat) | ||
|
|
||
| bias_out, std_out = compute_output_stats( | ||
| m_all_stat, | ||
| self.ntypes, | ||
| keys=["energy"], | ||
| stat_file_path=stat_file_path, | ||
| rcond=getattr(self, "rcond", None), | ||
| mixed_type=mixed_type, | ||
| ) | ||
|
|
||
| # Set the computed bias and std in the fitting object | ||
| if "energy" in bias_out: | ||
| self.fitting.bias_atom_e = bias_out["energy"] | ||
|
|
||
| else: | ||
| self.fitting.compute_output_stats(all_stat) | ||
| # Use the original computation method | ||
| if mixed_type: | ||
| self.fitting.compute_output_stats(all_stat, mixed_type=mixed_type) | ||
| else: | ||
| self.fitting.compute_output_stats(all_stat) | ||
|
|
||
| def build( | ||
| self, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| import logging | ||
| from typing import ( | ||
| Optional, | ||
| ) | ||
|
|
||
| import numpy as np | ||
|
|
||
| from deepmd.utils.path import ( | ||
| DPPath, | ||
| ) | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _restore_from_file( | ||
| stat_file_path: DPPath, | ||
| keys: list[str] = ["energy"], | ||
| ) -> Optional[tuple[dict, dict]]: | ||
| """Restore bias and std from stat file. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| stat_file_path : DPPath | ||
| Path to the stat file directory/file | ||
| keys : list[str] | ||
| Keys to restore statistics for | ||
|
|
||
| Returns | ||
| ------- | ||
| ret_bias : dict or None | ||
| Bias values for each key | ||
| ret_std : dict or None | ||
| Standard deviation values for each key | ||
| """ | ||
| if stat_file_path is None: | ||
| return None, None | ||
| stat_files = [stat_file_path / f"bias_atom_{kk}" for kk in keys] | ||
| if all(not (ii.is_file()) for ii in stat_files): | ||
| return None, None | ||
| stat_files = [stat_file_path / f"std_atom_{kk}" for kk in keys] | ||
| if all(not (ii.is_file()) for ii in stat_files): | ||
| return None, None | ||
|
|
||
| ret_bias = {} | ||
| ret_std = {} | ||
| for kk in keys: | ||
| fp = stat_file_path / f"bias_atom_{kk}" | ||
| # only read the key that exists | ||
| if fp.is_file(): | ||
| ret_bias[kk] = fp.load_numpy() | ||
| for kk in keys: | ||
| fp = stat_file_path / f"std_atom_{kk}" | ||
| # only read the key that exists | ||
| if fp.is_file(): | ||
| ret_std[kk] = fp.load_numpy() | ||
| return ret_bias, ret_std | ||
|
|
||
|
|
||
| def _save_to_file( | ||
| stat_file_path: DPPath, | ||
| bias_out: dict, | ||
| std_out: dict, | ||
| ) -> None: | ||
| """Save bias and std to stat file. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| stat_file_path : DPPath | ||
| Path to the stat file directory/file | ||
| bias_out : dict | ||
| Bias values for each key | ||
| std_out : dict | ||
| Standard deviation values for each key | ||
| """ | ||
| assert stat_file_path is not None | ||
| stat_file_path.mkdir(exist_ok=True, parents=True) | ||
| for kk, vv in bias_out.items(): | ||
| fp = stat_file_path / f"bias_atom_{kk}" | ||
| fp.save_numpy(vv) | ||
| for kk, vv in std_out.items(): | ||
| fp = stat_file_path / f"std_atom_{kk}" | ||
| fp.save_numpy(vv) | ||
|
|
||
|
|
||
| def compute_output_stats( | ||
| all_stat: dict, | ||
| ntypes: int, | ||
| keys: list[str] = ["energy"], | ||
| stat_file_path: Optional[DPPath] = None, | ||
| rcond: Optional[float] = None, | ||
| mixed_type: bool = False, | ||
| ) -> tuple[dict, dict]: | ||
| """Compute output statistics for TensorFlow models. | ||
|
|
||
| This is a simplified version of the PyTorch compute_output_stats function | ||
| adapted for TensorFlow models. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| all_stat : dict | ||
| Dictionary containing statistical data | ||
| ntypes : int | ||
| Number of atom types | ||
| keys : list[str] | ||
| Keys to compute statistics for | ||
| stat_file_path : DPPath, optional | ||
| Path to save/load statistics | ||
| rcond : float, optional | ||
| Condition number for regression | ||
| mixed_type : bool | ||
| Whether mixed type format is used | ||
|
|
||
| Returns | ||
| ------- | ||
| bias_out : dict | ||
| Computed bias values | ||
| std_out : dict | ||
| Computed standard deviation values | ||
| """ | ||
| # Try to restore from file first | ||
| bias_out, std_out = _restore_from_file(stat_file_path, keys) | ||
|
|
||
| if bias_out is not None and std_out is not None: | ||
| log.info("Successfully restored statistics from stat file") | ||
| return bias_out, std_out | ||
|
|
||
| # If restore failed, compute from data | ||
| log.info("Computing statistics from training data") | ||
|
|
||
| from deepmd.utils.out_stat import ( | ||
| compute_stats_from_redu, | ||
| ) | ||
|
||
|
|
||
| bias_out = {} | ||
| std_out = {} | ||
|
|
||
| for key in keys: | ||
| if key in all_stat: | ||
| # Get energy and natoms data | ||
| energy_data = np.concatenate(all_stat[key]) | ||
| natoms_data = np.concatenate(all_stat["natoms_vec"])[ | ||
| :, 2: | ||
| ] # Skip first 2 elements | ||
|
|
||
| # Compute statistics using existing utility | ||
| bias, std = compute_stats_from_redu( | ||
| energy_data.reshape(-1, 1), # Reshape to column vector | ||
| natoms_data, | ||
| rcond=rcond, | ||
| ) | ||
|
|
||
| bias_out[key] = bias.reshape(-1) # Flatten to 1D | ||
| std_out[key] = std.reshape(-1) # Flatten to 1D | ||
|
|
||
| log.info( | ||
| f"Statistics computed for {key}: bias shape {bias_out[key].shape}, std shape {std_out[key].shape}" | ||
| ) | ||
|
|
||
| # Save to file if path provided | ||
| if stat_file_path is not None and bias_out: | ||
| _save_to_file(stat_file_path, bias_out, std_out) | ||
| log.info("Statistics saved to stat file") | ||
|
|
||
| return bias_out, std_out | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot This should be moved to the top of this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved the imports to the top of the file. Fixed in a878838.