diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 46d284a395..5ee4d28166 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -90,13 +90,13 @@ def get_trainer( dist.init_process_group(backend="nccl") ckpt = init_model if init_model is not None else restart_model - config["model"] = change_finetune_model_params( - ckpt, - finetune_model, - config["model"], - multi_task=multi_task, - model_branch=model_branch, - ) + finetune_links = None + if finetune_model is not None: + config["model"], finetune_links = change_finetune_model_params( + finetune_model, + config["model"], + model_branch=model_branch, + ) config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None) def prepare_trainer_input_single( @@ -194,6 +194,7 @@ def prepare_trainer_input_single( finetune_model=finetune_model, force_load=force_load, shared_links=shared_links, + finetune_links=finetune_links, init_frz_model=init_frz_model, ) return trainer diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index f7216f46ef..482972b5d5 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -390,10 +390,6 @@ def compute_or_load_stat( self.models[0].compute_or_load_stat(sampled_func, stat_file_path) self.models[1].compute_or_load_stat(sampled_func, stat_file_path) - def change_energy_bias(self): - # need to implement - pass - def serialize(self) -> dict: dd = BaseAtomicModel.serialize(self) dd.update( diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index a4aa43ede1..cd139cdead 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -234,10 +234,6 @@ def compute_or_load_stat( torch.tensor(bias_atom_e, device=env.DEVICE).view([self.ntypes, 1]) ) - def change_energy_bias(self) -> None: - # need to implement - pass - def forward_atomic( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index 45b120771b..231ddf89a4 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -92,3 +92,8 @@ def forward_lower( else: model_predict = model_ret return model_predict + + def change_out_bias( + self, merged, origin_type_map, full_type_map, bias_shift="delta" + ) -> None: + raise NotImplementedError diff --git a/deepmd/pt/model/model/dos_model.py b/deepmd/pt/model/model/dos_model.py index 680eac41f5..147354d32e 100644 --- a/deepmd/pt/model/model/dos_model.py +++ b/deepmd/pt/model/model/dos_model.py @@ -78,3 +78,8 @@ def forward_lower( else: model_predict = model_ret return model_predict + + def change_out_bias( + self, merged, origin_type_map, full_type_map, bias_shift="delta" + ) -> None: + raise NotImplementedError diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 5217293623..4338297eac 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -1,15 +1,31 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging +import os +import tempfile from typing import ( Dict, Optional, ) +import numpy as np import torch +from deepmd.infer.deep_eval import ( + DeepEval, +) +from deepmd.pt.utils.stat import ( + compute_output_stats, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) + from .dp_model import ( DPModel, ) +log = logging.getLogger(__name__) + class EnergyModel(DPModel): model_type = "ener" @@ -97,3 +113,58 @@ def forward_lower( else: model_predict = model_ret return model_predict + + def change_out_bias( + self, merged, origin_type_map, full_type_map, bias_shift="delta" + ) -> None: + """Change the energy bias according to the input data and the pretrained model. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + origin_type_map : List[str] + The original type_map in dataset, they are targets to change the energy bias. + full_type_map : List[str] + The full type_map in pre-trained model + bias_shift : str + The mode for changing energy bias : ['delta', 'statistic'] + 'delta' : perform predictions on energies of target dataset, + and do least sqaure on the errors to obtain the target shift as bias. + 'statistic' : directly use the statistic energy bias in the target dataset. + """ + sorter = np.argsort(full_type_map) + idx_type_map = sorter[ + np.searchsorted(full_type_map, origin_type_map, sorter=sorter) + ] + original_bias = self.get_fitting_net()["bias_atom_e"] + if bias_shift == "delta": + tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") + model = torch.jit.script(self) + torch.jit.save(model, tmp_model.name) + dp = DeepEval(tmp_model.name) + os.unlink(tmp_model.name) + delta_bias_e = compute_output_stats( + merged, + self.atomic_model.get_ntypes(), + model=dp, + ) + bias_atom_e = delta_bias_e + original_bias + elif bias_shift == "statistic": + bias_atom_e = compute_output_stats( + merged, + self.atomic_model.get_ntypes(), + ) + else: + raise RuntimeError("Unknown bias_shift mode: " + bias_shift) + log.info( + f"Change energy bias of {origin_type_map!s} " + f"from {to_numpy_array(original_bias[idx_type_map]).reshape(-1)!s} " + f"to {to_numpy_array(bias_atom_e[idx_type_map]).reshape(-1)!s}." + ) + self.get_fitting_net()["bias_atom_e"] = bias_atom_e diff --git a/deepmd/pt/model/model/polar_model.py b/deepmd/pt/model/model/polar_model.py index 403058aa47..ec669583b0 100644 --- a/deepmd/pt/model/model/polar_model.py +++ b/deepmd/pt/model/model/polar_model.py @@ -76,3 +76,8 @@ def forward_lower( else: model_predict = model_ret return model_predict + + def change_out_bias( + self, merged, origin_type_map, full_type_map, bias_shift="delta" + ) -> None: + raise NotImplementedError diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index df2f48e2e4..75ecee7601 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -558,3 +558,8 @@ def forward_lower( ].squeeze(-2) # not support virial by far return model_predict + + def change_out_bias( + self, merged, origin_type_map, full_type_map, bias_shift="delta" + ) -> None: + raise NotImplementedError diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 4637178318..128c6e378c 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy import logging -import os -import tempfile from abc import ( abstractmethod, ) @@ -15,9 +13,6 @@ import numpy as np import torch -from deepmd.infer.deep_eval import ( - DeepEval, -) from deepmd.pt.model.network.mlp import ( FittingNet, NetworkCollection, @@ -33,7 +28,6 @@ ) from deepmd.pt.utils.env import ( DEFAULT_PRECISION, - DEVICE, PRECISION_DICT, ) from deepmd.pt.utils.exclude_mask import ( @@ -43,12 +37,6 @@ to_numpy_array, to_torch_tensor, ) -from deepmd.utils.data_system import ( - DeepmdDataSystem, -) -from deepmd.utils.finetune import ( - change_energy_bias_lower, -) dtype = env.GLOBAL_PT_FLOAT_PRECISION device = env.DEVICE @@ -88,72 +76,6 @@ def share_params(self, base_class, shared_level, resume=False): else: raise NotImplementedError - def change_energy_bias( - self, - config, - model, - old_type_map: List[str], - new_type_map: List[str], - bias_shift="delta", - ntest=10, - ): - """Change the energy bias according to the input data and the pretrained model. - - Parameters - ---------- - config : Dict - The configuration. - model : EnergyModel - Energy model loaded pre-trained model. - new_type_map : List[str] - The original type_map in dataset, they are targets to change the energy bias. - old_type_map : List[str] - The full type_map in pretrained model - bias_shift : str - The mode for changing energy bias : ['delta', 'statistic'] - 'delta' : perform predictions on energies of target dataset, - and do least sqaure on the errors to obtain the target shift as bias. - 'statistic' : directly use the statistic energy bias in the target dataset. - ntest : int - The number of test samples in a system to change the energy bias. - """ - log.info( - f"Changing energy bias in pretrained model for types {new_type_map!s}... " - "(this step may take long time)" - ) - # data - systems = config["training"]["training_data"]["systems"] - finetune_data = DeepmdDataSystem( - systems=systems, - batch_size=config["training"]["training_data"].get("batch_size", "auto"), - test_size=1, - ) - finetune_data.add("energy", ndof=1, atomic=False, must=True, high_prec=True) - model = torch.jit.script(model) - if model.get_dim_fparam() > 0: - finetune_data.add("fparam", model.get_dim_fparam(), atomic=False, must=True) - if model.get_dim_aparam() > 0: - finetune_data.add("aparam", model.get_dim_aparam(), atomic=True, must=True) - tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, tmp_model.name) - dp = DeepEval(tmp_model.name) - os.unlink(tmp_model.name) - bias = change_energy_bias_lower( - finetune_data, - dp, - new_type_map, - old_type_map, - self.bias_atom_e.detach().cpu().numpy().reshape(-1), - bias_shift=bias_shift, - ntest=ntest, - ) - self.bias_atom_e = ( - torch.from_numpy(bias) - .type_as(self.bias_atom_e) - .reshape(self.bias_atom_e.shape) - .to(DEVICE) - ) - class GeneralFitting(Fitting): """Construct a general fitting net. diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index fc293f70ec..032aa993b0 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -30,7 +30,7 @@ TensorLoss, ) from deepmd.pt.model.model import ( - DPZBLModel, + EnergyModel, get_model, get_zbl_model, ) @@ -96,6 +96,7 @@ def __init__( finetune_model=None, force_load=False, shared_links=None, + finetune_links=None, init_frz_model=None, ): """Construct a DeePMD trainer. @@ -116,8 +117,9 @@ def __init__( model_params = config["model"] training_params = config["training"] self.multi_task = "model_dict" in model_params - self.finetune_multi_task = model_params.pop( - "finetune_multi_task", False + self.finetune_links = finetune_links + self.finetune_from_multi_task = model_params.pop( + "finetune_from_multi_task", False ) # should use pop for next finetune self.model_keys = ( list(model_params["model_dict"]) if self.multi_task else ["Default"] @@ -234,23 +236,24 @@ def single_model_stat( _training_data.add_data_requirement(_data_requirement) if _validation_data is not None: _validation_data.add_data_requirement(_data_requirement) - if not resuming and self.rank == 0: - @functools.lru_cache - def get_sample(): - sampled = make_stat_input( - _training_data.systems, - _training_data.dataloaders, - _data_stat_nbatch, - ) - return sampled + @functools.lru_cache + def get_sample(): + sampled = make_stat_input( + _training_data.systems, + _training_data.dataloaders, + _data_stat_nbatch, + ) + return sampled + if not resuming and self.rank == 0: _model.compute_or_load_stat( sampled_func=get_sample, stat_file_path=_stat_file_path, ) if isinstance(_stat_file_path, DPH5Path): _stat_file_path.root.close() + return get_sample def get_single_model( _model_params, @@ -355,7 +358,7 @@ def get_loss(loss_params, start_lr, _ntypes, _model): # Data dp_random.seed(training_params["seed"]) if not self.multi_task: - single_model_stat( + self.get_sample_func = single_model_stat( self.model, model_params.get("data_stat_nbatch", 10), training_data, @@ -385,9 +388,10 @@ def get_loss(loss_params, start_lr, _ntypes, _model): self.validation_dataloader, self.validation_data, self.valid_numb_batch, - ) = {}, {}, {}, {}, {} + self.get_sample_func, + ) = {}, {}, {}, {}, {}, {} for model_key in self.model_keys: - single_model_stat( + self.get_sample_func[model_key] = single_model_stat( self.model[model_key], model_params["model_dict"][model_key].get("data_stat_nbatch", 10), training_data[model_key], @@ -486,60 +490,114 @@ def get_loss(loss_params, start_lr, _ntypes, _model): log.warning( f"Force load mode allowed! These keys are not in ckpt and will re-init: {slim_keys}" ) - elif self.finetune_multi_task: + + if finetune_model is not None: new_state_dict = {} - model_branch_chosen = model_params.pop("model_branch_chosen") - new_fitting = model_params.pop("new_fitting", False) target_state_dict = self.wrapper.state_dict() - target_keys = [ - i for i in target_state_dict.keys() if i != "_extra_state" - ] - for item_key in target_keys: - if new_fitting and ".fitting_net." in item_key: - # print(f'Keep {item_key} in old model!') - new_state_dict[item_key] = ( - target_state_dict[item_key].clone().detach() - ) - else: - new_key = item_key.replace( - ".Default.", f".{model_branch_chosen}." - ) - # print(f'Replace {item_key} with {new_key} in pretrained_model!') - new_state_dict[item_key] = ( - state_dict[new_key].clone().detach() + + def update_single_finetune_params( + _model_key, + _model_key_from, + _new_state_dict, + _origin_state_dict, + _random_state_dict, + _new_fitting=False, + ): + target_keys = [ + i + for i in _random_state_dict.keys() + if i != "_extra_state" and f".{_model_key}." in i + ] + for item_key in target_keys: + if _new_fitting and ".fitting_net." in item_key: + # print(f'Keep {item_key} in old model!') + _new_state_dict[item_key] = ( + _random_state_dict[item_key].clone().detach() + ) + else: + new_key = item_key.replace( + f".{_model_key}.", f".{_model_key_from}." + ) + # print(f'Replace {item_key} with {new_key} in pretrained_model!') + _new_state_dict[item_key] = ( + _origin_state_dict[new_key].clone().detach() + ) + + if not self.multi_task: + model_key = "Default" + model_key_from = self.finetune_links[model_key] + new_fitting = model_params.pop("new_fitting", False) + update_single_finetune_params( + model_key, + model_key_from, + new_state_dict, + state_dict, + target_state_dict, + _new_fitting=new_fitting, + ) + else: + for model_key in self.model_keys: + if model_key in self.finetune_links: + model_key_from = self.finetune_links[model_key] + new_fitting = model_params["model_dict"][model_key].pop( + "new_fitting", False + ) + else: + model_key_from = model_key + new_fitting = False + update_single_finetune_params( + model_key, + model_key_from, + new_state_dict, + state_dict, + target_state_dict, + _new_fitting=new_fitting, ) state_dict = new_state_dict - if finetune_model is not None: state_dict["_extra_state"] = self.wrapper.state_dict()[ "_extra_state" ] - self.wrapper.load_state_dict(state_dict) - # finetune - if finetune_model is not None and model_params["fitting_net"].get( - "type", "ener" - ) in ["ener", "direct_force_ener", "atten_vec_lcc"]: + + def single_model_finetune( + _model, + _model_params, + _sample_func, + ): old_type_map, new_type_map = ( - model_params["type_map"], - model_params["new_type_map"], + _model_params["type_map"], + _model_params["new_type_map"], ) - # TODO: need an interface instead of fetching fitting_net!!!!!!!!! - if hasattr(self.model, "atomic_model") and hasattr( - self.model.atomic_model, "fitting_net" - ): - self.model.atomic_model.fitting_net.change_energy_bias( - config, - self.model, - old_type_map, - new_type_map, - ntest=ntest, - bias_shift=model_params.get("bias_shift", "delta"), + if isinstance(_model, EnergyModel): + _model.change_out_bias( + _sample_func, + bias_shift=_model_params.get("bias_shift", "delta"), + origin_type_map=new_type_map, + full_type_map=old_type_map, ) - elif isinstance(self.model, DPZBLModel): - # need to updated - self.model.atomic_model.change_energy_bias() else: + # need to updated raise NotImplementedError + + # finetune + if not self.multi_task: + single_model_finetune( + self.model, model_params, self.get_sample_func + ) + else: + for model_key in self.model_keys: + if model_key in self.finetune_links: + log.info( + f"Model branch {model_key} will be fine-tuned. This may take a long time..." + ) + single_model_finetune( + self.model[model_key], + model_params["model_dict"][model_key], + self.get_sample_func[model_key], + ) + else: + log.info(f"Model branch {model_key} will resume training.") + if init_frz_model is not None: frz_model = torch.jit.load(init_frz_model, map_location=DEVICE) self.model.load_state_dict(frz_model.state_dict()) diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index d555478af4..05720720ea 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +from copy import ( + deepcopy, +) import torch @@ -10,88 +13,173 @@ log = logging.getLogger(__name__) -def change_finetune_model_params( - ckpt, finetune_model, model_config, multi_task=False, model_branch="" +def change_finetune_model_params_single( + _single_param_target, + _model_param_pretrained, + from_multitask=False, + model_branch="Default", + model_branch_from="", ): - """Load model_params according to the pretrained one. + single_config = deepcopy(_single_param_target) + trainable_param = { + "descriptor": True, + "fitting_net": True, + } + for net_type in trainable_param: + if net_type in single_config: + trainable_param[net_type] = single_config[net_type].get("trainable", True) + if not from_multitask: + old_type_map, new_type_map = ( + _model_param_pretrained["type_map"], + single_config["type_map"], + ) + assert set(new_type_map).issubset( + old_type_map + ), "Only support for smaller type map when finetuning or resuming." + single_config = deepcopy(_model_param_pretrained) + log.info( + f"Change the '{model_branch}' model configurations according to the pretrained one..." + ) + single_config["new_type_map"] = new_type_map + else: + single_config["finetune_from_multi_task"] = from_multitask + model_dict_params = _model_param_pretrained["model_dict"] + new_fitting = False + if model_branch_from == "": + model_branch_chosen = next(iter(model_dict_params.keys())) + new_fitting = True + single_config["bias_shift"] = "statistic" # fitting net re-init + log.warning( + "The fitting net will be re-init instead of using that in the pretrained model! " + "The bias_shift will be statistic!" + ) + else: + model_branch_chosen = model_branch_from + assert model_branch_chosen in model_dict_params, ( + f"No model branch named '{model_branch_chosen}'! " + f"Available ones are {list(model_dict_params.keys())}." + ) + single_config_chosen = deepcopy(model_dict_params[model_branch_chosen]) + old_type_map, new_type_map = ( + single_config_chosen["type_map"], + single_config["type_map"], + ) + assert set(new_type_map).issubset( + old_type_map + ), "Only support for smaller type map when finetuning or resuming." + for key_item in ["type_map", "descriptor"]: + if key_item in single_config_chosen: + single_config[key_item] = single_config_chosen[key_item] + if not new_fitting: + single_config["fitting_net"] = single_config_chosen["fitting_net"] + log.info( + f"Change the '{model_branch}' model configurations according to the model branch " + f"'{model_branch_chosen}' in the pretrained one..." + ) + single_config["new_type_map"] = new_type_map + single_config["model_branch_chosen"] = model_branch_chosen + single_config["new_fitting"] = new_fitting + for net_type in trainable_param: + if net_type in single_config: + single_config[net_type]["trainable"] = trainable_param[net_type] + else: + single_config[net_type] = {"trainable": trainable_param[net_type]} + return single_config - Args: - - ckpt & finetune_model: origin model. - - config: Read from json file. + +def change_finetune_model_params(finetune_model, model_config, model_branch=""): """ - # TODO need support for multitask mode - if finetune_model is not None: - state_dict = torch.load(finetune_model, map_location=env.DEVICE) - if "model" in state_dict: - state_dict = state_dict["model"] - last_model_params = state_dict["_extra_state"]["model_params"] - finetune_multi_task = "model_dict" in last_model_params - trainable_param = { - "descriptor": True, - "fitting_net": True, - } - for net_type in trainable_param: - if net_type in model_config: - trainable_param[net_type] = model_config[net_type].get( - "trainable", True - ) - if not finetune_multi_task: - old_type_map, new_type_map = ( - last_model_params["type_map"], - model_config["type_map"], - ) - assert set(new_type_map).issubset( - old_type_map - ), "Only support for smaller type map when finetuning or resuming." - model_config = last_model_params - log.info( - "Change the model configurations according to the pretrained one..." - ) - model_config["new_type_map"] = new_type_map + Load model_params according to the pretrained one. + + Parameters + ---------- + finetune_model: + The pretrained model. + model_config + The fine-tuning input parameters. + model_branch + The model branch chosen in command-line mode, only for single-task fine-tuning. + This function modifies the fine-tuning input in different modes as follows: + 1. Single-task fine-tuning from a single-task pretrained model: + - Updates the model parameters based on the pretrained model. + 2. Single-task fine-tuning from a multi-task pretrained model: + - Updates the model parameters based on the selected branch in the pretrained model. + - The chosen branch can be defined from the command-line or `finetune_head` input parameter. + - If not defined, model parameters in the fitting network will be randomly initialized. + 3. Multi-task fine-tuning from a single-task pretrained model: + - Updates model parameters in each branch based on the single branch ('Default') in the pretrained model. + - If `finetune_head` is not set to 'Default', + model parameters in the fitting network of the branch will be randomly initialized. + 4. Multi-task fine-tuning from a multi-task pretrained model: + - Updates model parameters in each branch based on the selected branch in the pretrained model. + - The chosen branches can be defined from the `finetune_head` input parameter of each model. + - If `finetune_head` is not defined and the model_key is the same as in the pretrained model, + it will resume from the model_key branch without fine-tuning. + - If `finetune_head` is not defined and a new model_key is used, + model parameters in the fitting network of the branch will be randomly initialized. + + Returns + ------- + model_config: + Updated model parameters. + finetune_links: + Fine-tuning rules in a dict format, with `model_branch`: `model_branch_from` pairs. + If `model_key` is not in this dict, it will do just resuming instead of fine-tuning. + """ + multi_task = "model_dict" in model_config + state_dict = torch.load(finetune_model, map_location=env.DEVICE) + if "model" in state_dict: + state_dict = state_dict["model"] + last_model_params = state_dict["_extra_state"]["model_params"] + finetune_from_multi_task = "model_dict" in last_model_params + finetune_links = {} + if not multi_task: + # use command-line first + if model_branch == "" and "finetune_head" in model_config: + model_branch = model_config["finetune_head"] + model_config = change_finetune_model_params_single( + model_config, + last_model_params, + from_multitask=finetune_from_multi_task, + model_branch="Default", + model_branch_from=model_branch, + ) + finetune_links["Default"] = ( + model_branch if finetune_from_multi_task else "Default" + ) + else: + assert model_branch == "", ( + "Multi-task fine-tuning does not support command-line branches chosen!" + "Please define the 'finetune_head' in each model params!" + ) + target_keys = model_config["model_dict"].keys() + if not finetune_from_multi_task: + pretrained_keys = ["Default"] else: - model_config["finetune_multi_task"] = finetune_multi_task - model_dict_params = last_model_params["model_dict"] - new_fitting = False - if model_branch == "": - model_branch_chosen = next(iter(model_dict_params.keys())) - new_fitting = True - model_config["bias_shift"] = "statistic" # fitting net re-init - log.warning( - "The fitting net will be re-init instead of using that in the pretrained model! " - "The bias_shift will be statistic!" + pretrained_keys = last_model_params["model_dict"].keys() + for model_key in target_keys: + if "finetune_head" in model_config["model_dict"][model_key]: + pretrained_key = model_config["model_dict"][model_key]["finetune_head"] + assert pretrained_key in pretrained_keys, ( + f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model!" + f"Available heads are: {list(pretrained_keys)}" ) + model_branch_from = pretrained_key + finetune_links[model_key] = model_branch_from + elif model_key in pretrained_keys: + # not do anything if not defined "finetune_head" in heads that exist in the pretrained model + # this will just do resuming + model_branch_from = model_key + pass else: - model_branch_chosen = model_branch - assert model_branch_chosen in model_dict_params, ( - f"No model branch named '{model_branch_chosen}'! " - f"Available ones are {list(model_dict_params.keys())}." + # if not defined "finetune_head" in new heads, the fitting net will bre randomly initialized + model_branch_from = "" + finetune_links[model_key] = next(iter(pretrained_keys)) + model_config["model_dict"][model_key] = change_finetune_model_params_single( + model_config["model_dict"][model_key], + last_model_params, + from_multitask=finetune_from_multi_task, + model_branch=model_key, + model_branch_from=model_branch_from, ) - old_type_map, new_type_map = ( - model_dict_params[model_branch_chosen]["type_map"], - model_config["type_map"], - ) - assert set(new_type_map).issubset( - old_type_map - ), "Only support for smaller type map when finetuning or resuming." - for key_item in ["type_map", "descriptor"]: - if key_item in model_dict_params[model_branch_chosen]: - model_config[key_item] = model_dict_params[model_branch_chosen][ - key_item - ] - if not new_fitting: - model_config["fitting_net"] = model_dict_params[model_branch_chosen][ - "fitting_net" - ] - log.info( - f"Change the model configurations according to the model branch " - f"{model_branch_chosen} in the pretrained one..." - ) - model_config["new_type_map"] = new_type_map - model_config["model_branch_chosen"] = model_branch_chosen - model_config["new_fitting"] = new_fitting - for net_type in trainable_param: - if net_type in model_config: - model_config[net_type]["trainable"] = trainable_param[net_type] - else: - model_config[net_type] = {"trainable": trainable_param[net_type]} - return model_config + return model_config, finetune_links diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 5e631d9412..de75740c22 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -10,13 +10,16 @@ import numpy as np import torch +from deepmd.infer.deep_eval import ( + DeepEval, +) from deepmd.pt.utils import ( AtomExcludeMask, - env, ) from deepmd.pt.utils.utils import ( dict_to_device, to_numpy_array, + to_torch_tensor, ) from deepmd.utils.out_stat import ( compute_stats_from_redu, @@ -76,6 +79,7 @@ def compute_output_stats( stat_file_path: Optional[DPPath] = None, rcond: Optional[float] = None, atom_ener: Optional[List[float]] = None, + model: Optional[DeepEval] = None, ): """ Compute the output statistics (e.g. energy bias) for the fitting net from packed data. @@ -97,7 +101,10 @@ def compute_output_stats( The condition number for the regression of atomic energy. atom_ener : List[float], optional Specifying atomic energy contribution in vacuum. The `set_davg_zero` key in the descrptor should be set. - + model : DeepEval, optional + If not None, the model will be utilized to generate the original energy prediction, + which will be subtracted from the energy label of the data. + The difference will then be used to calculate the delta complement energy bias for each type. """ if stat_file_path is not None: stat_file_path = stat_file_path / "bias_atom_e" @@ -129,13 +136,67 @@ def compute_output_stats( ) else: assigned_atom_ener = None - bias_atom_e, _ = compute_stats_from_redu( - merged_energy, - merged_natoms, - assigned_bias=assigned_atom_ener, - rcond=rcond, - ) + if model is None: + # only use statistics result + bias_atom_e, _ = compute_stats_from_redu( + merged_energy, + merged_natoms, + assigned_bias=assigned_atom_ener, + rcond=rcond, + ) + else: + # subtract the model bias and output the delta bias + energy_predict = [] + for system in sampled: + nframes = system["coord"].shape[0] + coord = to_numpy_array(system["coord"]).reshape(nframes, -1) + box = ( + to_numpy_array(system["box"]).reshape(nframes, -1) + if system["box"] is not None + else None + ) + if data_mixed_type: + atype = to_numpy_array(system["atype"]).reshape(nframes, -1) + else: + atype = to_numpy_array(system["atype"]).reshape(nframes, -1)[0] + fparam = ( + to_numpy_array(system["fparam"]) + if "fparam" in system is not None + else None + ) + aparam = ( + to_numpy_array(system["aparam"]) + if "aparam" in system is not None + else None + ) + ret = model.eval( + coord, + box, + atype, + mixed_type=data_mixed_type, + fparam=fparam, + aparam=aparam, + ) + energy_predict.append(ret[0].reshape([nframes, 1])) + energy_predict = np.concatenate(energy_predict) + bias_diff = merged_energy - energy_predict + bias_atom_e, _ = compute_stats_from_redu( + bias_diff, + merged_natoms, + assigned_bias=assigned_atom_ener, + rcond=rcond, + ) + unbias_e = energy_predict + merged_natoms @ bias_atom_e + atom_numbs = merged_natoms.sum(-1) + rmse_ae = np.sqrt( + np.mean( + np.square((unbias_e.ravel() - merged_energy.ravel()) / atom_numbs) + ) + ) + log.info( + f"RMSE of atomic energy after linear regression is: {rmse_ae} eV/atom." + ) if stat_file_path is not None: stat_file_path.save_numpy(bias_atom_e) assert all(x is not None for x in [bias_atom_e]) - return torch.tensor(bias_atom_e, device=env.DEVICE) + return to_torch_tensor(bias_atom_e) diff --git a/source/tests/pt/test_finetune.py b/source/tests/pt/test_finetune.py index dd72eb4718..e77063c2ac 100644 --- a/source/tests/pt/test_finetune.py +++ b/source/tests/pt/test_finetune.py @@ -1,6 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import os -import shutil import tempfile import unittest from copy import ( @@ -17,16 +15,16 @@ DeepEval, ) from deepmd.pt.model.model import ( - DPZBLModel, - EnergyModel, get_model, - get_zbl_model, ) -from deepmd.utils.data_system import ( - DeepmdDataSystem, +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, ) -from deepmd.utils.finetune import ( - change_energy_bias_lower, +from deepmd.pt.utils.stat import ( + make_stat_input, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, ) from .model.test_permutation import ( @@ -34,63 +32,38 @@ model_se_e2_a, model_zbl, ) +from .test_stat import ( + energy_data_requirement, +) class FinetuneTest: - def test_finetune_change_energy_bias(self): + def test_finetune_change_out_bias(self): # get model - if "use_srtab" in self.model_config: - model = get_zbl_model(self.model_config) - else: - model = get_model(self.model_config) - if isinstance(model, EnergyModel): - model.get_fitting_net().bias_atom_e = torch.rand_like( - model.get_fitting_net().bias_atom_e - ) - energy_bias_before = deepcopy( - model.get_fitting_net().bias_atom_e.detach().cpu().numpy().reshape(-1) - ) - bias_atom_e_input = deepcopy( - model.get_fitting_net().bias_atom_e.detach().cpu().numpy().reshape(-1) - ) - elif isinstance(model, DPZBLModel): - model.dp_model.get_fitting_net().bias_atom_e = torch.rand_like( - model.dp_model.get_fitting_net().bias_atom_e - ) - energy_bias_before = deepcopy( - model.dp_model.get_fitting_net() - .bias_atom_e.detach() - .cpu() - .numpy() - .reshape(-1) - ) - bias_atom_e_input = deepcopy( - model.dp_model.get_fitting_net() - .bias_atom_e.detach() - .cpu() - .numpy() - .reshape(-1) - ) - else: - bias_atom_e_input = None - - model = torch.jit.script(model) + model = get_model(self.model_config) + fitting_net = model.get_fitting_net() + fitting_net["bias_atom_e"] = torch.rand_like(fitting_net["bias_atom_e"]) + energy_bias_before = deepcopy( + to_numpy_array(fitting_net["bias_atom_e"]).reshape(-1) + ) + + # prepare original model for test + dp = torch.jit.script(model) tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, tmp_model.name) + torch.jit.save(dp, tmp_model.name) dp = DeepEval(tmp_model.name) - ntest = 10 origin_type_map = ["O", "H"] full_type_map = ["O", "H", "B"] # change energy bias - energy_bias_after = change_energy_bias_lower( - self.data, - dp, + model.change_out_bias( + self.sampled, + bias_shift="delta", origin_type_map=origin_type_map, full_type_map=full_type_map, - bias_atom_e=bias_atom_e_input, - bias_shift="delta", - ntest=ntest, + ) + energy_bias_after = deepcopy( + to_numpy_array(fitting_net["bias_atom_e"]).reshape(-1) ) # get ground-truth energy bias change @@ -98,12 +71,17 @@ def test_finetune_change_energy_bias(self): idx_type_map = sorter[ np.searchsorted(full_type_map, origin_type_map, sorter=sorter) ] - test_data = self.data.get_test() - atom_nums = np.tile(np.bincount(test_data["type"][0])[idx_type_map], (ntest, 1)) + ntest = 1 + atom_nums = np.tile( + np.bincount(to_numpy_array(self.sampled[0]["atype"][0]))[idx_type_map], + (ntest, 1), + ) energy = dp.eval( - test_data["coord"][:ntest], test_data["box"][:ntest], test_data["type"][0] + to_numpy_array(self.sampled[0]["coord"][:ntest]), + to_numpy_array(self.sampled[0]["box"][:ntest]), + to_numpy_array(self.sampled[0]["atype"][0]), )[0] - energy_diff = test_data["energy"][:ntest] - energy + energy_diff = to_numpy_array(self.sampled[0]["energy"][:ntest]) - energy finetune_shift = ( energy_bias_after[idx_type_map] - energy_bias_before[idx_type_map] ) @@ -114,60 +92,57 @@ def test_finetune_change_energy_bias(self): # check values np.testing.assert_almost_equal(finetune_shift, ground_truth_shift, decimal=10) - def tearDown(self): - for f in os.listdir("."): - if f.startswith("model") and f.endswith(".pt"): - os.remove(f) - if f in ["lcurve.out"]: - os.remove(f) - if f in ["stat_files"]: - shutil.rmtree(f) - class TestEnergyModelSeA(unittest.TestCase, FinetuneTest): def setUp(self): self.data_file = [str(Path(__file__).parent / "water/data/data_0")] - self.data = DeepmdDataSystem( + self.model_config = model_se_e2_a + self.data = DpLoaderSet( self.data_file, batch_size=1, - test_size=1, + type_map=self.model_config["type_map"], + ) + self.data.add_data_requirement(energy_data_requirement) + self.sampled = make_stat_input( + self.data.systems, + self.data.dataloaders, + nbatches=1, ) - self.data.add("energy", ndof=1, atomic=False, must=True, high_prec=True) - self.model_config = model_se_e2_a - - def tearDown(self) -> None: - FinetuneTest.tearDown(self) @unittest.skip("change bias not implemented yet.") class TestEnergyZBLModelSeA(unittest.TestCase, FinetuneTest): def setUp(self): self.data_file = [str(Path(__file__).parent / "water/data/data_0")] - self.data = DeepmdDataSystem( + self.model_config = model_zbl + self.data = DpLoaderSet( self.data_file, batch_size=1, - test_size=1, + type_map=self.model_config["type_map"], + ) + self.data.add_data_requirement(energy_data_requirement) + self.sampled = make_stat_input( + self.data.systems, + self.data.dataloaders, + nbatches=1, ) - self.data.add("energy", ndof=1, atomic=False, must=True, high_prec=True) - self.model_config = model_zbl - - def tearDown(self) -> None: - FinetuneTest.tearDown(self) class TestEnergyModelDPA2(unittest.TestCase, FinetuneTest): def setUp(self): self.data_file = [str(Path(__file__).parent / "water/data/data_0")] - self.data = DeepmdDataSystem( + self.model_config = model_dpa2 + self.data = DpLoaderSet( self.data_file, batch_size=1, - test_size=1, + type_map=self.model_config["type_map"], + ) + self.data.add_data_requirement(energy_data_requirement) + self.sampled = make_stat_input( + self.data.systems, + self.data.dataloaders, + nbatches=1, ) - self.data.add("energy", ndof=1, atomic=False, must=True, high_prec=True) - self.model_config = model_dpa2 - - def tearDown(self) -> None: - FinetuneTest.tearDown(self) if __name__ == "__main__":