diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index c85e67a362..4ec9076f02 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -75,6 +75,9 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.finetune import ( + warn_configuration_mismatch_during_finetune, +) from deepmd.utils.path import ( DPH5Path, ) @@ -117,6 +120,8 @@ def __init__( training_params = config["training"] self.multi_task = "model_dict" in model_params self.finetune_links = finetune_links + # Store model params for finetune warning comparisons + self.model_params = model_params self.finetune_update_stat = False self.model_keys = ( list(model_params["model_dict"]) if self.multi_task else ["Default"] @@ -512,6 +517,37 @@ def collect_single_finetune_params( ) # collect model params from the pretrained model + # First check for configuration mismatches and warn if needed + pretrained_model_params = state_dict["_extra_state"]["model_params"] + for model_key in self.model_keys: + finetune_rule_single = self.finetune_links[model_key] + _model_key_from = finetune_rule_single.get_model_branch() + + # Get current model descriptor config + if self.multi_task: + current_descriptor = self.model_params["model_dict"][ + model_key + ].get("descriptor", {}) + else: + current_descriptor = self.model_params.get("descriptor", {}) + + # Get pretrained model descriptor config + if "model_dict" in pretrained_model_params: + pretrained_descriptor = pretrained_model_params[ + "model_dict" + ][_model_key_from].get("descriptor", {}) + else: + pretrained_descriptor = pretrained_model_params.get( + "descriptor", {} + ) + + # Warn about configuration mismatches + warn_configuration_mismatch_during_finetune( + current_descriptor, + pretrained_descriptor, + _model_key_from, + ) + for model_key in self.model_keys: finetune_rule_single = self.finetune_links[model_key] collect_single_finetune_params( diff --git a/deepmd/pd/utils/finetune.py b/deepmd/pd/utils/finetune.py index edac72d9c9..14bca445bf 100644 --- a/deepmd/pd/utils/finetune.py +++ b/deepmd/pd/utils/finetune.py @@ -8,6 +8,7 @@ from deepmd.utils.finetune import ( FinetuneRuleItem, + warn_descriptor_config_differences, ) log = logging.getLogger(__name__) @@ -61,6 +62,15 @@ def get_finetune_rule_single( "descriptor": single_config.get("descriptor", {}).get("trainable", True), "fitting_net": single_config.get("fitting_net", {}).get("trainable", True), } + + # Warn about descriptor configuration differences before overwriting + if "descriptor" in single_config and "descriptor" in single_config_chosen: + warn_descriptor_config_differences( + single_config["descriptor"], + single_config_chosen["descriptor"], + model_branch_chosen, + ) + single_config["descriptor"] = single_config_chosen["descriptor"] if not new_fitting: single_config["fitting_net"] = single_config_chosen["fitting_net"] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 8f7c763d0f..003227192d 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -81,6 +81,9 @@ DataLoader, ) +from deepmd.utils.finetune import ( + warn_configuration_mismatch_during_finetune, +) from deepmd.utils.path import ( DPH5Path, ) @@ -122,6 +125,8 @@ def __init__( training_params = config["training"] self.multi_task = "model_dict" in model_params self.finetune_links = finetune_links + # Store model params for finetune warning comparisons + self.model_params = model_params self.finetune_update_stat = False self.model_keys = ( list(model_params["model_dict"]) if self.multi_task else ["Default"] @@ -541,6 +546,37 @@ def collect_single_finetune_params( ) # collect model params from the pretrained model + # First check for configuration mismatches and warn if needed + pretrained_model_params = state_dict["_extra_state"]["model_params"] + for model_key in self.model_keys: + finetune_rule_single = self.finetune_links[model_key] + _model_key_from = finetune_rule_single.get_model_branch() + + # Get current model descriptor config + if self.multi_task: + current_descriptor = self.model_params["model_dict"][ + model_key + ].get("descriptor", {}) + else: + current_descriptor = self.model_params.get("descriptor", {}) + + # Get pretrained model descriptor config + if "model_dict" in pretrained_model_params: + pretrained_descriptor = pretrained_model_params[ + "model_dict" + ][_model_key_from].get("descriptor", {}) + else: + pretrained_descriptor = pretrained_model_params.get( + "descriptor", {} + ) + + # Warn about configuration mismatches + warn_configuration_mismatch_during_finetune( + current_descriptor, + pretrained_descriptor, + _model_key_from, + ) + for model_key in self.model_keys: finetune_rule_single = self.finetune_links[model_key] collect_single_finetune_params( diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index 96a420bf6a..113c496161 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -11,6 +11,7 @@ ) from deepmd.utils.finetune import ( FinetuneRuleItem, + warn_descriptor_config_differences, ) log = logging.getLogger(__name__) @@ -64,6 +65,15 @@ def get_finetune_rule_single( "descriptor": single_config.get("descriptor", {}).get("trainable", True), "fitting_net": single_config.get("fitting_net", {}).get("trainable", True), } + + # Warn about descriptor configuration differences before overwriting + if "descriptor" in single_config and "descriptor" in single_config_chosen: + warn_descriptor_config_differences( + single_config["descriptor"], + single_config_chosen["descriptor"], + model_branch_chosen, + ) + single_config["descriptor"] = single_config_chosen["descriptor"] if not new_fitting: single_config["fitting_net"] = single_config_chosen["fitting_net"] diff --git a/deepmd/utils/finetune.py b/deepmd/utils/finetune.py index 644da3649d..33242e3786 100644 --- a/deepmd/utils/finetune.py +++ b/deepmd/utils/finetune.py @@ -1,9 +1,188 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +from deepmd.utils.argcheck import ( + normalize, +) + log = logging.getLogger(__name__) +def warn_descriptor_config_differences( + input_descriptor: dict, + pretrained_descriptor: dict, + model_branch: str = "Default", +) -> None: + """ + Warn about differences between input descriptor config and pretrained model's descriptor config. + + This function is used when --use-pretrain-script option is used and input configuration + will be overwritten with the pretrained model's configuration. + + Parameters + ---------- + input_descriptor : dict + Descriptor configuration from input.json + pretrained_descriptor : dict + Descriptor configuration from pretrained model + model_branch : str + Model branch name for logging context + """ + # Normalize both configurations to ensure consistent comparison + # This avoids warnings for parameters that only differ due to default values + try: + # Create minimal configs for normalization with required fields + base_config = { + "model": { + "fitting_net": {"neuron": [240, 240, 240]}, + "type_map": ["H", "O"], + }, + "training": {"training_data": {"systems": ["fake"]}, "numb_steps": 100}, + } + + input_config = base_config.copy() + input_config["model"]["descriptor"] = input_descriptor.copy() + + pretrained_config = base_config.copy() + pretrained_config["model"]["descriptor"] = pretrained_descriptor.copy() + + # Normalize both configurations + normalized_input = normalize(input_config, multi_task=False)["model"][ + "descriptor" + ] + normalized_pretrained = normalize(pretrained_config, multi_task=False)["model"][ + "descriptor" + ] + + if normalized_input == normalized_pretrained: + return + + # Use normalized configs for comparison to show only meaningful differences + input_descriptor = normalized_input + pretrained_descriptor = normalized_pretrained + except Exception: + # If normalization fails, fall back to original comparison + pass + + if input_descriptor == pretrained_descriptor: + return + + # Collect differences + differences = [] + + # Check for keys that differ in values + for key in input_descriptor: + if key in pretrained_descriptor: + if input_descriptor[key] != pretrained_descriptor[key]: + differences.append( + f" {key}: {input_descriptor[key]} -> {pretrained_descriptor[key]}" + ) + else: + differences.append(f" {key}: {input_descriptor[key]} -> (removed)") + + # Check for keys only in pretrained model + for key in pretrained_descriptor: + if key not in input_descriptor: + differences.append(f" {key}: (added) -> {pretrained_descriptor[key]}") + + if differences: + log.warning( + f"Descriptor configuration in input.json differs from pretrained model " + f"(branch '{model_branch}'). The input configuration will be overwritten " + f"with the pretrained model's configuration:\n" + "\n".join(differences) + ) + + +def warn_configuration_mismatch_during_finetune( + input_descriptor: dict, + pretrained_descriptor: dict, + model_branch: str = "Default", +) -> None: + """ + Warn about configuration mismatches between input descriptor and pretrained model + when fine-tuning without --use-pretrain-script option. + + This function warns when configurations differ and state_dict initialization + will only pick relevant keys from the pretrained model (e.g., first 6 layers + from a 16-layer model). + + Parameters + ---------- + input_descriptor : dict + Descriptor configuration from input.json + pretrained_descriptor : dict + Descriptor configuration from pretrained model + model_branch : str + Model branch name for logging context + """ + # Normalize both configurations to ensure consistent comparison + # This avoids warnings for parameters that only differ due to default values + try: + # Create minimal configs for normalization with required fields + base_config = { + "model": { + "fitting_net": {"neuron": [240, 240, 240]}, + "type_map": ["H", "O"], + }, + "training": {"training_data": {"systems": ["fake"]}, "numb_steps": 100}, + } + + input_config = base_config.copy() + input_config["model"]["descriptor"] = input_descriptor.copy() + + pretrained_config = base_config.copy() + pretrained_config["model"]["descriptor"] = pretrained_descriptor.copy() + + # Normalize both configurations + normalized_input = normalize(input_config, multi_task=False)["model"][ + "descriptor" + ] + normalized_pretrained = normalize(pretrained_config, multi_task=False)["model"][ + "descriptor" + ] + + if normalized_input == normalized_pretrained: + return + + # Use normalized configs for comparison to show only meaningful differences + input_descriptor = normalized_input + pretrained_descriptor = normalized_pretrained + except Exception: + # If normalization fails, fall back to original comparison + pass + + if input_descriptor == pretrained_descriptor: + return + + # Collect differences + differences = [] + + # Check for keys that differ in values + for key in input_descriptor: + if key in pretrained_descriptor: + if input_descriptor[key] != pretrained_descriptor[key]: + differences.append( + f" {key}: {input_descriptor[key]} (input) vs {pretrained_descriptor[key]} (pretrained)" + ) + else: + differences.append(f" {key}: {input_descriptor[key]} (input only)") + + # Check for keys only in pretrained model + for key in pretrained_descriptor: + if key not in input_descriptor: + differences.append( + f" {key}: {pretrained_descriptor[key]} (pretrained only)" + ) + + if differences: + log.warning( + f"Descriptor configuration mismatch detected between input.json and pretrained model " + f"(branch '{model_branch}'). State dict initialization will only use compatible parameters " + f"from the pretrained model. Mismatched configuration:\n" + + "\n".join(differences) + ) + + class FinetuneRuleItem: def __init__( self,