diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 4907483d1d..8a40f8d238 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -83,6 +83,27 @@ def mixed_types(self) -> bool: """ return self.descriptor.mixed_types() + def set_out_bias(self, out_bias: np.ndarray, add=False) -> None: + """ + Modify the output bias for the atomic model. + + Parameters + ---------- + out_bias : np.ndarray + The new bias to be applied. + add : bool, optional + Whether to add the new bias to the existing one. + If False, the output bias will be directly replaced by the new bias. + If True, the new bias will be added to the existing one. + """ + self.fitting["bias_atom_e"] = ( + out_bias + self.fitting["bias_atom_e"] if add else out_bias + ) + + def get_out_bias(self) -> np.ndarray: + """Return the output bias of the atomic model.""" + return self.fitting["bias_atom_e"] + def forward_atomic( self, extended_coord: np.ndarray, diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 088cf34900..93a885f3ab 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -275,6 +275,27 @@ def get_sel_type(self) -> List[int]: # join all the selected types return list(set().union(*[model.get_sel_type() for model in self.models])) + def set_out_bias(self, out_bias: np.ndarray, add=False) -> None: + """ + Modify the output bias for all the models in the linear atomic model. + + Parameters + ---------- + out_bias : torch.Tensor + The new bias to be applied. + add : bool, optional + Whether to add the new bias to the existing one. + If False, the output bias will be directly replaced by the new bias. + If True, the new bias will be added to the existing one. + """ + for model in self.models: + model.set_out_bias(out_bias, add=add) + + def get_out_bias(self) -> np.ndarray: + """Return the weighted output bias of the linear atomic model.""" + # TODO add get_out_bias for linear atomic model + raise NotImplementedError + def is_aparam_nall(self) -> bool: """Check whether the shape of atomic parameters is (nframes, nall, ndim). diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index 936c2b0943..3e02a5d076 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -95,6 +95,25 @@ def get_sel_type(self) -> List[int]: If returning an empty list, all atom types are selected. """ + @abstractmethod + def set_out_bias(self, out_bias: t_tensor, add=False) -> None: + """ + Modify the output bias for the atomic model. + + Parameters + ---------- + out_bias : t_tensor + The new bias to be applied. + add : bool, optional + Whether to add the new bias to the existing one. + If False, the output bias will be directly replaced by the new bias. + If True, the new bias will be added to the existing one. + """ + + @abstractmethod + def get_out_bias(self) -> t_tensor: + """Return the output bias of the atomic model.""" + @abstractmethod def is_aparam_nall(self) -> bool: """Check whether the shape of atomic parameters is (nframes, nall, ndim). diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index 99b8ec1eff..30ab58928b 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -126,6 +126,25 @@ def mixed_types(self) -> bool: # to match DPA1 and DPA2. return True + def set_out_bias(self, out_bias: np.ndarray, add=False) -> None: + """ + Modify the output bias for the atomic model. + + Parameters + ---------- + out_bias : torch.Tensor + The new bias to be applied. + add : bool, optional + Whether to add the new bias to the existing one. + If False, the output bias will be directly replaced by the new bias. + If True, the new bias will be added to the existing one. + """ + self.bias_atom_e = out_bias + self.bias_atom_e if add else out_bias + + def get_out_bias(self) -> np.ndarray: + """Return the output bias of the atomic model.""" + return self.bias_atom_e + def serialize(self) -> dict: dd = BaseAtomicModel.serialize(self) dd.update( diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index b9c4971116..7b1463a3b2 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -90,13 +90,13 @@ def get_trainer( dist.init_process_group(backend="nccl") ckpt = init_model if init_model is not None else restart_model - config["model"] = change_finetune_model_params( - ckpt, - finetune_model, - config["model"], - multi_task=multi_task, - model_branch=model_branch, - ) + finetune_links = None + if finetune_model is not None: + config["model"], finetune_links = change_finetune_model_params( + finetune_model, + config["model"], + model_branch=model_branch, + ) config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None) def prepare_trainer_input_single( @@ -194,6 +194,7 @@ def prepare_trainer_input_single( finetune_model=finetune_model, force_load=force_load, shared_links=shared_links, + finetune_links=finetune_links, init_frz_model=init_frz_model, ) return trainer diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index edae53a771..ccc23b690c 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -124,15 +124,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): # more_loss['test_keys'] = [] # showed when doing dp test atom_norm = 1.0 / natoms if self.has_e and "energy" in model_pred and "energy" in label: + find_energy = label.get("find_energy", 0.0) + pref_e = pref_e * find_energy if not self.use_l1_all: l2_ener_loss = torch.mean( torch.square(model_pred["energy"] - label["energy"]) ) if not self.inference: - more_loss["l2_ener_loss"] = l2_ener_loss.detach() + more_loss["l2_ener_loss"] = self.display_if_exist( + l2_ener_loss.detach(), find_energy + ) loss += atom_norm * (pref_e * l2_ener_loss) rmse_e = l2_ener_loss.sqrt() * atom_norm - more_loss["rmse_e"] = rmse_e.detach() + more_loss["rmse_e"] = self.display_if_exist( + rmse_e.detach(), find_energy + ) # more_loss['log_keys'].append('rmse_e') else: # use l1 and for all atoms l1_ener_loss = F.l1_loss( @@ -141,24 +147,31 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): reduction="sum", ) loss += pref_e * l1_ener_loss - more_loss["mae_e"] = F.l1_loss( - model_pred["energy"].reshape(-1), - label["energy"].reshape(-1), - reduction="mean", - ).detach() + more_loss["mae_e"] = self.display_if_exist( + F.l1_loss( + model_pred["energy"].reshape(-1), + label["energy"].reshape(-1), + reduction="mean", + ).detach(), + find_energy, + ) # more_loss['log_keys'].append('rmse_e') if mae: mae_e = ( torch.mean(torch.abs(model_pred["energy"] - label["energy"])) * atom_norm ) - more_loss["mae_e"] = mae_e.detach() + more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) mae_e_all = torch.mean( torch.abs(model_pred["energy"] - label["energy"]) ) - more_loss["mae_e_all"] = mae_e_all.detach() + more_loss["mae_e_all"] = self.display_if_exist( + mae_e_all.detach(), find_energy + ) if self.has_f and "force" in model_pred and "force" in label: + find_force = label.get("find_force", 0.0) + pref_f = pref_f * find_force if "force_target_mask" in model_pred: force_target_mask = model_pred["force_target_mask"] else: @@ -174,10 +187,12 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): diff_f = label["force"] - model_pred["force"] l2_force_loss = torch.mean(torch.square(diff_f)) if not self.inference: - more_loss["l2_force_loss"] = l2_force_loss.detach() + more_loss["l2_force_loss"] = self.display_if_exist( + l2_force_loss.detach(), find_force + ) loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) rmse_f = l2_force_loss.sqrt() - more_loss["rmse_f"] = rmse_f.detach() + more_loss["rmse_f"] = self.display_if_exist(rmse_f.detach(), find_force) else: l1_force_loss = F.l1_loss( label["force"], model_pred["force"], reduction="none" @@ -185,29 +200,35 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): if force_target_mask is not None: l1_force_loss *= force_target_mask force_cnt = force_target_mask.squeeze(-1).sum(-1) - more_loss["mae_f"] = ( - l1_force_loss.mean(-1).sum(-1) / force_cnt - ).mean() + more_loss["mae_f"] = self.display_if_exist( + (l1_force_loss.mean(-1).sum(-1) / force_cnt).mean(), find_force + ) l1_force_loss = (l1_force_loss.sum(-1).sum(-1) / force_cnt).sum() else: - more_loss["mae_f"] = l1_force_loss.mean().detach() + more_loss["mae_f"] = self.display_if_exist( + l1_force_loss.mean().detach(), find_force + ) l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum() loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) if mae: mae_f = torch.mean(torch.abs(diff_f)) - more_loss["mae_f"] = mae_f.detach() + more_loss["mae_f"] = self.display_if_exist(mae_f.detach(), find_force) if self.has_v and "virial" in model_pred and "virial" in label: + find_virial = label.get("find_virial", 0.0) + pref_v = pref_v * find_virial diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9) l2_virial_loss = torch.mean(torch.square(diff_v)) if not self.inference: - more_loss["l2_virial_loss"] = l2_virial_loss.detach() + more_loss["l2_virial_loss"] = self.display_if_exist( + l2_virial_loss.detach(), find_virial + ) loss += atom_norm * (pref_v * l2_virial_loss) rmse_v = l2_virial_loss.sqrt() * atom_norm - more_loss["rmse_v"] = rmse_v.detach() + more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial) if mae: mae_v = torch.mean(torch.abs(diff_v)) * atom_norm - more_loss["mae_v"] = mae_v.detach() + more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial) if not self.inference: more_loss["rmse"] = torch.sqrt(loss.detach()) return model_pred, loss, more_loss diff --git a/deepmd/pt/loss/ener_spin.py b/deepmd/pt/loss/ener_spin.py index 1f10e3cf5f..3bd81adf77 100644 --- a/deepmd/pt/loss/ener_spin.py +++ b/deepmd/pt/loss/ener_spin.py @@ -98,15 +98,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): # more_loss['test_keys'] = [] # showed when doing dp test atom_norm = 1.0 / natoms if self.has_e and "energy" in model_pred and "energy" in label: + find_energy = label.get("find_energy", 0.0) + pref_e = pref_e * find_energy if not self.use_l1_all: l2_ener_loss = torch.mean( torch.square(model_pred["energy"] - label["energy"]) ) if not self.inference: - more_loss["l2_ener_loss"] = l2_ener_loss.detach() + more_loss["l2_ener_loss"] = self.display_if_exist( + l2_ener_loss.detach(), find_energy + ) loss += atom_norm * (pref_e * l2_ener_loss) rmse_e = l2_ener_loss.sqrt() * atom_norm - more_loss["rmse_e"] = rmse_e.detach() + more_loss["rmse_e"] = self.display_if_exist( + rmse_e.detach(), find_energy + ) # more_loss['log_keys'].append('rmse_e') else: # use l1 and for all atoms l1_ener_loss = F.l1_loss( @@ -115,44 +121,61 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): reduction="sum", ) loss += pref_e * l1_ener_loss - more_loss["mae_e"] = F.l1_loss( - model_pred["energy"].reshape(-1), - label["energy"].reshape(-1), - reduction="mean", - ).detach() + more_loss["mae_e"] = self.display_if_exist( + F.l1_loss( + model_pred["energy"].reshape(-1), + label["energy"].reshape(-1), + reduction="mean", + ).detach(), + find_energy, + ) # more_loss['log_keys'].append('rmse_e') if mae: mae_e = ( torch.mean(torch.abs(model_pred["energy"] - label["energy"])) * atom_norm ) - more_loss["mae_e"] = mae_e.detach() + more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) mae_e_all = torch.mean( torch.abs(model_pred["energy"] - label["energy"]) ) - more_loss["mae_e_all"] = mae_e_all.detach() + more_loss["mae_e_all"] = self.display_if_exist( + mae_e_all.detach(), find_energy + ) if self.has_fr and "force" in model_pred and "force" in label: + find_force_r = label.get("find_force", 0.0) + pref_fr = pref_fr * find_force_r if not self.use_l1_all: diff_fr = label["force"] - model_pred["force"] l2_force_real_loss = torch.mean(torch.square(diff_fr)) if not self.inference: - more_loss["l2_force_r_loss"] = l2_force_real_loss.detach() + more_loss["l2_force_r_loss"] = self.display_if_exist( + l2_force_real_loss.detach(), find_force_r + ) loss += (pref_fr * l2_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION) rmse_fr = l2_force_real_loss.sqrt() - more_loss["rmse_fr"] = rmse_fr.detach() + more_loss["rmse_fr"] = self.display_if_exist( + rmse_fr.detach(), find_force_r + ) if mae: mae_fr = torch.mean(torch.abs(diff_fr)) - more_loss["mae_fr"] = mae_fr.detach() + more_loss["mae_fr"] = self.display_if_exist( + mae_fr.detach(), find_force_r + ) else: l1_force_real_loss = F.l1_loss( label["force"], model_pred["force"], reduction="none" ) - more_loss["mae_fr"] = l1_force_real_loss.mean().detach() + more_loss["mae_fr"] = self.display_if_exist( + l1_force_real_loss.mean().detach(), find_force_r + ) l1_force_real_loss = l1_force_real_loss.sum(-1).mean(-1).sum() loss += (pref_fr * l1_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION) if self.has_fm and "force_mag" in model_pred and "force_mag" in label: + find_force_m = label.get("find_force_mag", 0.0) + pref_fm = pref_fm * find_force_m nframes = model_pred["force_mag"].shape[0] atomic_mask = model_pred["mask_mag"].expand([-1, -1, 3]) label_force_mag = label["force_mag"][atomic_mask].view(nframes, -1, 3) @@ -163,18 +186,26 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): diff_fm = label_force_mag - model_pred_force_mag l2_force_mag_loss = torch.mean(torch.square(diff_fm)) if not self.inference: - more_loss["l2_force_m_loss"] = l2_force_mag_loss.detach() + more_loss["l2_force_m_loss"] = self.display_if_exist( + l2_force_mag_loss.detach(), find_force_m + ) loss += (pref_fm * l2_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION) rmse_fm = l2_force_mag_loss.sqrt() - more_loss["rmse_fm"] = rmse_fm.detach() + more_loss["rmse_fm"] = self.display_if_exist( + rmse_fm.detach(), find_force_m + ) if mae: mae_fm = torch.mean(torch.abs(diff_fm)) - more_loss["mae_fm"] = mae_fm.detach() + more_loss["mae_fm"] = self.display_if_exist( + mae_fm.detach(), find_force_m + ) else: l1_force_mag_loss = F.l1_loss( label_force_mag, model_pred_force_mag, reduction="none" ) - more_loss["mae_fm"] = l1_force_mag_loss.mean().detach() + more_loss["mae_fm"] = self.display_if_exist( + l1_force_mag_loss.mean().detach(), find_force_m + ) l1_force_mag_loss = l1_force_mag_loss.sum(-1).mean(-1).sum() loss += (pref_fm * l1_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION) diff --git a/deepmd/pt/loss/loss.py b/deepmd/pt/loss/loss.py index cc253424ca..7e26f6571a 100644 --- a/deepmd/pt/loss/loss.py +++ b/deepmd/pt/loss/loss.py @@ -28,3 +28,16 @@ def forward(self, input_dict, model, label, natoms, learning_rate): def label_requirement(self) -> List[DataRequirementItem]: """Return data label requirements needed for this loss calculation.""" pass + + @staticmethod + def display_if_exist(loss: torch.Tensor, find_property: float) -> torch.Tensor: + """Display NaN if labeled property is not found. + + Parameters + ---------- + loss : torch.Tensor + the loss tensor + find_property : float + whether the property is found + """ + return loss if bool(find_property) else torch.nan diff --git a/deepmd/pt/loss/tensor.py b/deepmd/pt/loss/tensor.py index 238e6a7796..3dd91d203e 100644 --- a/deepmd/pt/loss/tensor.py +++ b/deepmd/pt/loss/tensor.py @@ -95,6 +95,8 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False and self.tensor_name in model_pred and "atomic_" + self.label_name in label ): + find_local = label.get("find_" + "atomic_" + self.label_name, 0.0) + local_weight = self.local_weight * find_local local_tensor_pred = model_pred[self.tensor_name].reshape( [-1, natoms, self.tensor_size] ) @@ -108,15 +110,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False diff = diff[model_pred["mask"].reshape([-1]).bool()] l2_local_loss = torch.mean(torch.square(diff)) if not self.inference: - more_loss[f"l2_local_{self.tensor_name}_loss"] = l2_local_loss.detach() - loss += self.local_weight * l2_local_loss + more_loss[f"l2_local_{self.tensor_name}_loss"] = self.display_if_exist( + l2_local_loss.detach(), find_local + ) + loss += local_weight * l2_local_loss rmse_local = l2_local_loss.sqrt() - more_loss[f"rmse_local_{self.tensor_name}"] = rmse_local.detach() + more_loss[f"rmse_local_{self.tensor_name}"] = self.display_if_exist( + rmse_local.detach(), find_local + ) if ( self.has_global_weight and "global_" + self.tensor_name in model_pred and self.label_name in label ): + find_global = label.get("find_" + self.label_name, 0.0) + global_weight = self.global_weight * find_global global_tensor_pred = model_pred["global_" + self.tensor_name].reshape( [-1, self.tensor_size] ) @@ -132,12 +140,14 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False atom_num = natoms l2_global_loss = torch.mean(torch.square(diff)) if not self.inference: - more_loss[f"l2_global_{self.tensor_name}_loss"] = ( - l2_global_loss.detach() + more_loss[f"l2_global_{self.tensor_name}_loss"] = self.display_if_exist( + l2_global_loss.detach(), find_global ) - loss += self.global_weight * l2_global_loss + loss += global_weight * l2_global_loss rmse_global = l2_global_loss.sqrt() / atom_num - more_loss[f"rmse_global_{self.tensor_name}"] = rmse_global.detach() + more_loss[f"rmse_global_{self.tensor_name}"] = self.display_if_exist( + rmse_global.detach(), find_global + ) return model_pred, loss, more_loss @property diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index c921538203..fa30655f8a 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging from typing import ( + Callable, Dict, List, Optional, Tuple, ) +import numpy as np import torch from deepmd.dpmodel.atomic_model import ( @@ -21,10 +24,21 @@ AtomExcludeMask, PairExcludeMask, ) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) +from deepmd.pt.utils.stat import ( + compute_output_stats, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) from deepmd.utils.path import ( DPPath, ) +log = logging.getLogger(__name__) + BaseAtomicModel_ = make_base_atomic_model(torch.Tensor) @@ -176,6 +190,40 @@ def serialize(self) -> dict: "pair_exclude_types": self.pair_exclude_types, } + def get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]: + """Get a forward wrapper of the atomic model for output bias calculation.""" + model_output_type = list(self.atomic_output_def().keys()) + if "mask" in model_output_type: + model_output_type.pop(model_output_type.index("mask")) + out_name = model_output_type[0] + + def model_forward(coord, atype, box, fparam=None, aparam=None): + with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + self.get_rcut(), + self.get_sel(), + mixed_types=self.mixed_types(), + box=box, + ) + atomic_ret = self.forward_common_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + return atomic_ret[out_name].detach() + + return model_forward + def compute_or_load_stat( self, sampled_func, @@ -197,3 +245,62 @@ def compute_or_load_stat( The path to the statistics files. """ raise NotImplementedError + + def change_out_bias( + self, + merged, + origin_type_map, + full_type_map, + bias_adjust_mode="change-by-statistic", + ) -> None: + """Change the output bias according to the input data and the pretrained model. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + origin_type_map : List[str] + The original type_map in dataset, they are targets to change the output bias. + full_type_map : List[str] + The full type_map in pre-trained model + bias_adjust_mode : str + The mode for changing output bias : ['change-by-statistic', 'set-by-statistic'] + 'change-by-statistic' : perform predictions on labels of target dataset, + and do least square on the errors to obtain the target shift as bias. + 'set-by-statistic' : directly use the statistic output bias in the target dataset. + """ + sorter = np.argsort(full_type_map) + missing_types = [t for t in origin_type_map if t not in full_type_map] + assert ( + not missing_types + ), f"Some types are not in the pre-trained model: {list(missing_types)} !" + idx_type_map = sorter[ + np.searchsorted(full_type_map, origin_type_map, sorter=sorter) + ] + original_bias = self.get_out_bias() + if bias_adjust_mode == "change-by-statistic": + delta_bias = compute_output_stats( + merged, + self.get_ntypes(), + model_forward=self.get_forward_wrapper_func(), + ) + self.set_out_bias(delta_bias, add=True) + elif bias_adjust_mode == "set-by-statistic": + bias_atom = compute_output_stats( + merged, + self.get_ntypes(), + ) + self.set_out_bias(bias_atom) + else: + raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode) + bias_atom = self.get_out_bias() + log.info( + f"Change output bias of {origin_type_map!s} " + f"from {to_numpy_array(original_bias[idx_type_map]).reshape(-1)!s} " + f"to {to_numpy_array(bias_atom[idx_type_map]).reshape(-1)!s}." + ) diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 6aa8df7aee..13b8f09a79 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -223,6 +223,27 @@ def wrapped_sampler(): if self.fitting_net is not None: self.fitting_net.compute_output_stats(wrapped_sampler, stat_file_path) + def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None: + """ + Modify the output bias for the atomic model. + + Parameters + ---------- + out_bias : torch.Tensor + The new bias to be applied. + add : bool, optional + Whether to add the new bias to the existing one. + If False, the output bias will be directly replaced by the new bias. + If True, the new bias will be added to the existing one. + """ + self.fitting_net["bias_atom_e"] = ( + out_bias + self.fitting_net["bias_atom_e"] if add else out_bias + ) + + def get_out_bias(self) -> torch.Tensor: + """Return the output bias of the atomic model.""" + return self.fitting_net["bias_atom_e"] + def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return self.fitting_net.get_dim_fparam() diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index f7216f46ef..f599399e66 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -289,6 +289,27 @@ def _compute_weight( for _ in range(nmodels) ] + def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None: + """ + Modify the output bias for all the models in the linear atomic model. + + Parameters + ---------- + out_bias : torch.Tensor + The new bias to be applied. + add : bool, optional + Whether to add the new bias to the existing one. + If False, the output bias will be directly replaced by the new bias. + If True, the new bias will be added to the existing one. + """ + for model in self.models: + model.set_out_bias(out_bias, add=add) + + def get_out_bias(self) -> torch.Tensor: + """Return the weighted output bias of the linear atomic model.""" + # TODO add get_out_bias for linear atomic model + raise NotImplementedError + def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" # tricky... @@ -390,10 +411,6 @@ def compute_or_load_stat( self.models[0].compute_or_load_stat(sampled_func, stat_file_path) self.models[1].compute_or_load_stat(sampled_func, stat_file_path) - def change_energy_bias(self): - # need to implement - pass - def serialize(self) -> dict: dd = BaseAtomicModel.serialize(self) dd.update( diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 7c7c8a2969..c20abf6a12 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -234,9 +234,24 @@ def compute_or_load_stat( torch.tensor(bias_atom_e, device=env.DEVICE).view([self.ntypes, 1]) ) - def change_energy_bias(self) -> None: - # need to implement - pass + def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None: + """ + Modify the output bias for the atomic model. + + Parameters + ---------- + out_bias : torch.Tensor + The new bias to be applied. + add : bool, optional + Whether to add the new bias to the existing one. + If False, the output bias will be directly replaced by the new bias. + If True, the new bias will be added to the existing one. + """ + self.bias_atom_e = out_bias + self.bias_atom_e if add else out_bias + + def get_out_bias(self) -> torch.Tensor: + """Return the output bias of the atomic model.""" + return self.bias_atom_e def forward_atomic( self, diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 167ad81923..0e89c05b79 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -172,6 +172,41 @@ def forward_common( model_predict = self.output_type_cast(model_predict, input_prec) return model_predict + def change_out_bias( + self, + merged, + origin_type_map, + full_type_map, + bias_adjust_mode="change-by-statistic", + ) -> None: + """Change the output bias of atomic model according to the input data and the pretrained model. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + origin_type_map : List[str] + The original type_map in dataset, they are targets to change the output bias. + full_type_map : List[str] + The full type_map in pre-trained model + bias_adjust_mode : str + The mode for changing output bias : ['change-by-statistic', 'set-by-statistic'] + 'change-by-statistic' : perform predictions on labels of target dataset, + and do least square on the errors to obtain the target shift as bias. + 'set-by-statistic' : directly use the statistic output bias in the target dataset. + """ + self.atomic_model.change_out_bias( + merged, + origin_type_map, + full_type_map, + bias_adjust_mode=bias_adjust_mode, + ) + def forward_common_lower( self, extended_coord, diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index c8edee5b94..00579b957f 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy import logging -import os -import tempfile from abc import ( abstractmethod, ) @@ -15,9 +13,6 @@ import numpy as np import torch -from deepmd.infer.deep_eval import ( - DeepEval, -) from deepmd.pt.model.network.mlp import ( FittingNet, NetworkCollection, @@ -33,7 +28,6 @@ ) from deepmd.pt.utils.env import ( DEFAULT_PRECISION, - DEVICE, PRECISION_DICT, ) from deepmd.pt.utils.exclude_mask import ( @@ -43,12 +37,6 @@ to_numpy_array, to_torch_tensor, ) -from deepmd.utils.data_system import ( - DeepmdDataSystem, -) -from deepmd.utils.finetune import ( - change_energy_bias_lower, -) dtype = env.GLOBAL_PT_FLOAT_PRECISION device = env.DEVICE @@ -88,72 +76,6 @@ def share_params(self, base_class, shared_level, resume=False): else: raise NotImplementedError - def change_energy_bias( - self, - config, - model, - old_type_map: List[str], - new_type_map: List[str], - bias_shift="delta", - ntest=10, - ): - """Change the energy bias according to the input data and the pretrained model. - - Parameters - ---------- - config : Dict - The configuration. - model : EnergyModel - Energy model loaded pre-trained model. - new_type_map : List[str] - The original type_map in dataset, they are targets to change the energy bias. - old_type_map : List[str] - The full type_map in pretrained model - bias_shift : str - The mode for changing energy bias : ['delta', 'statistic'] - 'delta' : perform predictions on energies of target dataset, - and do least sqaure on the errors to obtain the target shift as bias. - 'statistic' : directly use the statistic energy bias in the target dataset. - ntest : int - The number of test samples in a system to change the energy bias. - """ - log.info( - f"Changing energy bias in pretrained model for types {new_type_map!s}... " - "(this step may take long time)" - ) - # data - systems = config["training"]["training_data"]["systems"] - finetune_data = DeepmdDataSystem( - systems=systems, - batch_size=config["training"]["training_data"].get("batch_size", "auto"), - test_size=1, - ) - finetune_data.add("energy", ndof=1, atomic=False, must=True, high_prec=True) - model = torch.jit.script(model) - if model.get_dim_fparam() > 0: - finetune_data.add("fparam", model.get_dim_fparam(), atomic=False, must=True) - if model.get_dim_aparam() > 0: - finetune_data.add("aparam", model.get_dim_aparam(), atomic=True, must=True) - tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, tmp_model.name) - dp = DeepEval(tmp_model.name) - os.unlink(tmp_model.name) - bias = change_energy_bias_lower( - finetune_data, - dp, - new_type_map, - old_type_map, - self.bias_atom_e.detach().cpu().numpy().reshape(-1), - bias_shift=bias_shift, - ntest=ntest, - ) - self.bias_atom_e = ( - torch.from_numpy(bias) - .type_as(self.bias_atom_e) - .reshape(self.bias_atom_e.shape) - .to(DEVICE) - ) - class GeneralFitting(Fitting): """Construct a general fitting net. diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index e55803954d..6ca718724b 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -31,7 +31,7 @@ TensorLoss, ) from deepmd.pt.model.model import ( - DPZBLModel, + EnergyModel, get_model, get_zbl_model, ) @@ -97,6 +97,7 @@ def __init__( finetune_model=None, force_load=False, shared_links=None, + finetune_links=None, init_frz_model=None, ): """Construct a DeePMD trainer. @@ -117,9 +118,7 @@ def __init__( model_params = config["model"] training_params = config["training"] self.multi_task = "model_dict" in model_params - self.finetune_multi_task = model_params.pop( - "finetune_multi_task", False - ) # should use pop for next finetune + self.finetune_links = finetune_links self.model_keys = ( list(model_params["model_dict"]) if self.multi_task else ["Default"] ) @@ -235,23 +234,24 @@ def single_model_stat( _training_data.add_data_requirement(_data_requirement) if _validation_data is not None: _validation_data.add_data_requirement(_data_requirement) - if not resuming and self.rank == 0: - @functools.lru_cache - def get_sample(): - sampled = make_stat_input( - _training_data.systems, - _training_data.dataloaders, - _data_stat_nbatch, - ) - return sampled + @functools.lru_cache + def get_sample(): + sampled = make_stat_input( + _training_data.systems, + _training_data.dataloaders, + _data_stat_nbatch, + ) + return sampled + if not resuming and self.rank == 0: _model.compute_or_load_stat( sampled_func=get_sample, stat_file_path=_stat_file_path, ) if isinstance(_stat_file_path, DPH5Path): _stat_file_path.root.close() + return get_sample def get_single_model( _model_params, @@ -360,7 +360,7 @@ def get_loss(loss_params, start_lr, _ntypes, _model): # Data dp_random.seed(training_params["seed"]) if not self.multi_task: - single_model_stat( + self.get_sample_func = single_model_stat( self.model, model_params.get("data_stat_nbatch", 10), training_data, @@ -390,9 +390,10 @@ def get_loss(loss_params, start_lr, _ntypes, _model): self.validation_dataloader, self.validation_data, self.valid_numb_batch, - ) = {}, {}, {}, {}, {} + self.get_sample_func, + ) = {}, {}, {}, {}, {}, {} for model_key in self.model_keys: - single_model_stat( + self.get_sample_func[model_key] = single_model_stat( self.model[model_key], model_params["model_dict"][model_key].get("data_stat_nbatch", 10), training_data[model_key], @@ -491,60 +492,116 @@ def get_loss(loss_params, start_lr, _ntypes, _model): log.warning( f"Force load mode allowed! These keys are not in ckpt and will re-init: {slim_keys}" ) - elif self.finetune_multi_task: + + if finetune_model is not None: new_state_dict = {} - model_branch_chosen = model_params.pop("model_branch_chosen") - new_fitting = model_params.pop("new_fitting", False) target_state_dict = self.wrapper.state_dict() - target_keys = [ - i for i in target_state_dict.keys() if i != "_extra_state" - ] - for item_key in target_keys: - if new_fitting and ".fitting_net." in item_key: - # print(f'Keep {item_key} in old model!') - new_state_dict[item_key] = ( - target_state_dict[item_key].clone().detach() - ) - else: - new_key = item_key.replace( - ".Default.", f".{model_branch_chosen}." - ) - # print(f'Replace {item_key} with {new_key} in pretrained_model!') - new_state_dict[item_key] = ( - state_dict[new_key].clone().detach() + + def update_single_finetune_params( + _model_key, + _model_key_from, + _new_state_dict, + _origin_state_dict, + _random_state_dict, + _new_fitting=False, + ): + target_keys = [ + i + for i in _random_state_dict.keys() + if i != "_extra_state" and f".{_model_key}." in i + ] + for item_key in target_keys: + if _new_fitting and ".fitting_net." in item_key: + # print(f'Keep {item_key} in old model!') + _new_state_dict[item_key] = ( + _random_state_dict[item_key].clone().detach() + ) + else: + new_key = item_key.replace( + f".{_model_key}.", f".{_model_key_from}." + ) + # print(f'Replace {item_key} with {new_key} in pretrained_model!') + _new_state_dict[item_key] = ( + _origin_state_dict[new_key].clone().detach() + ) + + if not self.multi_task: + model_key = "Default" + model_key_from = self.finetune_links[model_key] + new_fitting = model_params.pop("new_fitting", False) + update_single_finetune_params( + model_key, + model_key_from, + new_state_dict, + state_dict, + target_state_dict, + _new_fitting=new_fitting, + ) + else: + for model_key in self.model_keys: + if model_key in self.finetune_links: + model_key_from = self.finetune_links[model_key] + new_fitting = model_params["model_dict"][model_key].pop( + "new_fitting", False + ) + else: + model_key_from = model_key + new_fitting = False + update_single_finetune_params( + model_key, + model_key_from, + new_state_dict, + state_dict, + target_state_dict, + _new_fitting=new_fitting, ) state_dict = new_state_dict - if finetune_model is not None: state_dict["_extra_state"] = self.wrapper.state_dict()[ "_extra_state" ] - self.wrapper.load_state_dict(state_dict) - # finetune - if finetune_model is not None and model_params["fitting_net"].get( - "type", "ener" - ) in ["ener", "direct_force_ener", "atten_vec_lcc"]: + + def single_model_finetune( + _model, + _model_params, + _sample_func, + ): old_type_map, new_type_map = ( - model_params["type_map"], - model_params["new_type_map"], + _model_params["type_map"], + _model_params["new_type_map"], ) - # TODO: need an interface instead of fetching fitting_net!!!!!!!!! - if hasattr(self.model, "atomic_model") and hasattr( - self.model.atomic_model, "fitting_net" - ): - self.model.atomic_model.fitting_net.change_energy_bias( - config, - self.model, - old_type_map, - new_type_map, - ntest=ntest, - bias_shift=model_params.get("bias_shift", "delta"), + if isinstance(_model, EnergyModel): + _model.change_out_bias( + _sample_func, + bias_adjust_mode=_model_params.get( + "bias_adjust_mode", "change-by-statistic" + ), + origin_type_map=new_type_map, + full_type_map=old_type_map, ) - elif isinstance(self.model, DPZBLModel): - # need to updated - self.model.atomic_model.change_energy_bias() else: - raise NotImplementedError + # need to updated + pass + + # finetune + if not self.multi_task: + single_model_finetune( + self.model, model_params, self.get_sample_func + ) + else: + for model_key in self.model_keys: + if model_key in self.finetune_links: + log.info( + f"Model branch {model_key} will be fine-tuned. This may take a long time..." + ) + single_model_finetune( + self.model[model_key], + model_params["model_dict"][model_key], + self.get_sample_func[model_key], + ) + else: + log.info(f"Model branch {model_key} will resume training.") + if init_frz_model is not None: frz_model = torch.jit.load(init_frz_model, map_location=DEVICE) self.model.load_state_dict(frz_model.state_dict()) @@ -1022,7 +1079,7 @@ def get_data(self, is_train=True, task_key="Default"): if item_key in input_keys: input_dict[item_key] = batch_data[item_key] else: - if item_key not in ["sid", "fid"] and "find_" not in item_key: + if item_key not in ["sid", "fid"]: label_dict[item_key] = batch_data[item_key] log_dict = {} if "fid" in batch_data: @@ -1057,6 +1114,7 @@ def print_header(self, fout, train_results, valid_results): for k in sorted(train_results[model_key].keys()): print_str += prop_fmt % (k + f"_trn_{model_key}") print_str += " %8s\n" % "lr" + print_str += "# If there is no available reference data, rmse_*_{val,trn} will print nan\n" fout.write(print_str) fout.flush() diff --git a/deepmd/pt/utils/auto_batch_size.py b/deepmd/pt/utils/auto_batch_size.py index 5af7760e2a..181d56f2f4 100644 --- a/deepmd/pt/utils/auto_batch_size.py +++ b/deepmd/pt/utils/auto_batch_size.py @@ -1,4 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Tuple, + Union, +) + +import numpy as np import torch from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase @@ -24,3 +31,73 @@ def is_oom_error(self, e: Exception) -> bool: Exception """ return isinstance(e, RuntimeError) and "CUDA out of memory." in e.args[0] + + def execute_all( + self, callable: Callable, total_size: int, natoms: int, *args, **kwargs + ) -> Tuple[Union[np.ndarray, torch.Tensor]]: + """Excuate a method with all given data. + + Parameters + ---------- + callable : Callable + The method should accept *args and **kwargs as input and return the similiar array. + total_size : int + Total size + natoms : int + The number of atoms + *args + Variable length argument list. + **kwargs + If 2D np.ndarray or torch.Tensor, assume the first axis is batch; otherwise do nothing. + """ + + def execute_with_batch_size( + batch_size: int, start_index: int + ) -> Tuple[int, Tuple[torch.Tensor]]: + end_index = start_index + batch_size + end_index = min(end_index, total_size) + return (end_index - start_index), callable( + *[ + ( + vv[start_index:end_index] + if (isinstance(vv, np.ndarray) or isinstance(vv, torch.Tensor)) + and vv.ndim > 1 + else vv + ) + for vv in args + ], + **{ + kk: ( + vv[start_index:end_index] + if (isinstance(vv, np.ndarray) or isinstance(vv, torch.Tensor)) + and vv.ndim > 1 + else vv + ) + for kk, vv in kwargs.items() + }, + ) + + index = 0 + results = [] + while index < total_size: + n_batch, result = self.execute(execute_with_batch_size, index, natoms) + if not isinstance(result, tuple): + result = (result,) + index += n_batch + if n_batch: + for rr in result: + rr.reshape((n_batch, -1)) + results.append(result) + r_list = [] + for r in zip(*results): + if isinstance(r[0], np.ndarray): + r_list.append(np.concatenate(r, axis=0)) + elif isinstance(r[0], torch.Tensor): + r_list.append(torch.cat(r, dim=0)) + else: + raise RuntimeError(f"Unexpected result type {type(r[0])}") + r = tuple(r_list) + if len(r) == 1: + # avoid returning tuple if callable doesn't return tuple + r = r[0] + return r diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index d555478af4..3f76454442 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +from copy import ( + deepcopy, +) import torch @@ -10,88 +13,173 @@ log = logging.getLogger(__name__) -def change_finetune_model_params( - ckpt, finetune_model, model_config, multi_task=False, model_branch="" +def change_finetune_model_params_single( + _single_param_target, + _model_param_pretrained, + from_multitask=False, + model_branch="Default", + model_branch_from="", ): - """Load model_params according to the pretrained one. - - Args: - - ckpt & finetune_model: origin model. - - config: Read from json file. - """ - # TODO need support for multitask mode - if finetune_model is not None: - state_dict = torch.load(finetune_model, map_location=env.DEVICE) - if "model" in state_dict: - state_dict = state_dict["model"] - last_model_params = state_dict["_extra_state"]["model_params"] - finetune_multi_task = "model_dict" in last_model_params - trainable_param = { - "descriptor": True, - "fitting_net": True, - } - for net_type in trainable_param: - if net_type in model_config: - trainable_param[net_type] = model_config[net_type].get( - "trainable", True - ) - if not finetune_multi_task: - old_type_map, new_type_map = ( - last_model_params["type_map"], - model_config["type_map"], + single_config = deepcopy(_single_param_target) + trainable_param = { + "descriptor": True, + "fitting_net": True, + } + for net_type in trainable_param: + if net_type in single_config: + trainable_param[net_type] = single_config[net_type].get("trainable", True) + if not from_multitask: + old_type_map, new_type_map = ( + _model_param_pretrained["type_map"], + single_config["type_map"], + ) + assert set(new_type_map).issubset( + old_type_map + ), "Only support for smaller type map when finetuning or resuming." + single_config = deepcopy(_model_param_pretrained) + log.info( + f"Change the '{model_branch}' model configurations according to the pretrained one..." + ) + single_config["new_type_map"] = new_type_map + else: + model_dict_params = _model_param_pretrained["model_dict"] + new_fitting = False + if model_branch_from == "": + model_branch_chosen = next(iter(model_dict_params.keys())) + new_fitting = True + single_config["bias_adjust_mode"] = ( + "set-by-statistic" # fitting net re-init ) - assert set(new_type_map).issubset( - old_type_map - ), "Only support for smaller type map when finetuning or resuming." - model_config = last_model_params - log.info( - "Change the model configurations according to the pretrained one..." + log.warning( + "The fitting net will be re-init instead of using that in the pretrained model! " + "The bias_adjust_mode will be set-by-statistic!" ) - model_config["new_type_map"] = new_type_map else: - model_config["finetune_multi_task"] = finetune_multi_task - model_dict_params = last_model_params["model_dict"] - new_fitting = False - if model_branch == "": - model_branch_chosen = next(iter(model_dict_params.keys())) - new_fitting = True - model_config["bias_shift"] = "statistic" # fitting net re-init - log.warning( - "The fitting net will be re-init instead of using that in the pretrained model! " - "The bias_shift will be statistic!" + model_branch_chosen = model_branch_from + assert model_branch_chosen in model_dict_params, ( + f"No model branch named '{model_branch_chosen}'! " + f"Available ones are {list(model_dict_params.keys())}." + ) + single_config_chosen = deepcopy(model_dict_params[model_branch_chosen]) + old_type_map, new_type_map = ( + single_config_chosen["type_map"], + single_config["type_map"], + ) + assert set(new_type_map).issubset( + old_type_map + ), "Only support for smaller type map when finetuning or resuming." + for key_item in ["type_map", "descriptor"]: + if key_item in single_config_chosen: + single_config[key_item] = single_config_chosen[key_item] + if not new_fitting: + single_config["fitting_net"] = single_config_chosen["fitting_net"] + log.info( + f"Change the '{model_branch}' model configurations according to the model branch " + f"'{model_branch_chosen}' in the pretrained one..." + ) + single_config["new_type_map"] = new_type_map + single_config["model_branch_chosen"] = model_branch_chosen + single_config["new_fitting"] = new_fitting + for net_type in trainable_param: + if net_type in single_config: + single_config[net_type]["trainable"] = trainable_param[net_type] + else: + single_config[net_type] = {"trainable": trainable_param[net_type]} + return single_config + + +def change_finetune_model_params(finetune_model, model_config, model_branch=""): + """ + Load model_params according to the pretrained one. + This function modifies the fine-tuning input in different modes as follows: + 1. Single-task fine-tuning from a single-task pretrained model: + - Updates the model parameters based on the pretrained model. + 2. Single-task fine-tuning from a multi-task pretrained model: + - Updates the model parameters based on the selected branch in the pretrained model. + - The chosen branch can be defined from the command-line or `finetune_head` input parameter. + - If not defined, model parameters in the fitting network will be randomly initialized. + 3. Multi-task fine-tuning from a single-task pretrained model: + - Updates model parameters in each branch based on the single branch ('Default') in the pretrained model. + - If `finetune_head` is not set to 'Default', + model parameters in the fitting network of the branch will be randomly initialized. + 4. Multi-task fine-tuning from a multi-task pretrained model: + - Updates model parameters in each branch based on the selected branch in the pretrained model. + - The chosen branches can be defined from the `finetune_head` input parameter of each model. + - If `finetune_head` is not defined and the model_key is the same as in the pretrained model, + it will resume from the model_key branch without fine-tuning. + - If `finetune_head` is not defined and a new model_key is used, + model parameters in the fitting network of the branch will be randomly initialized. + + Parameters + ---------- + finetune_model + The pretrained model. + model_config + The fine-tuning input parameters. + model_branch + The model branch chosen in command-line mode, only for single-task fine-tuning. + + Returns + ------- + model_config: + Updated model parameters. + finetune_links: + Fine-tuning rules in a dict format, with `model_branch`: `model_branch_from` pairs. + If `model_key` is not in this dict, it will do just resuming instead of fine-tuning. + """ + multi_task = "model_dict" in model_config + state_dict = torch.load(finetune_model, map_location=env.DEVICE) + if "model" in state_dict: + state_dict = state_dict["model"] + last_model_params = state_dict["_extra_state"]["model_params"] + finetune_from_multi_task = "model_dict" in last_model_params + finetune_links = {} + if not multi_task: + # use command-line first + if model_branch == "" and "finetune_head" in model_config: + model_branch = model_config["finetune_head"] + model_config = change_finetune_model_params_single( + model_config, + last_model_params, + from_multitask=finetune_from_multi_task, + model_branch="Default", + model_branch_from=model_branch, + ) + finetune_links["Default"] = ( + model_branch if finetune_from_multi_task else "Default" + ) + else: + assert model_branch == "", ( + "Multi-task fine-tuning does not support command-line branches chosen!" + "Please define the 'finetune_head' in each model params!" + ) + target_keys = model_config["model_dict"].keys() + if not finetune_from_multi_task: + pretrained_keys = ["Default"] + else: + pretrained_keys = last_model_params["model_dict"].keys() + for model_key in target_keys: + if "finetune_head" in model_config["model_dict"][model_key]: + pretrained_key = model_config["model_dict"][model_key]["finetune_head"] + assert pretrained_key in pretrained_keys, ( + f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model!" + f"Available heads are: {list(pretrained_keys)}" ) + model_branch_from = pretrained_key + finetune_links[model_key] = model_branch_from + elif model_key in pretrained_keys: + # not do anything if not defined "finetune_head" in heads that exist in the pretrained model + # this will just do resuming + model_branch_from = model_key else: - model_branch_chosen = model_branch - assert model_branch_chosen in model_dict_params, ( - f"No model branch named '{model_branch_chosen}'! " - f"Available ones are {list(model_dict_params.keys())}." + # if not defined "finetune_head" in new heads, the fitting net will bre randomly initialized + model_branch_from = "" + finetune_links[model_key] = next(iter(pretrained_keys)) + model_config["model_dict"][model_key] = change_finetune_model_params_single( + model_config["model_dict"][model_key], + last_model_params, + from_multitask=finetune_from_multi_task, + model_branch=model_key, + model_branch_from=model_branch_from, ) - old_type_map, new_type_map = ( - model_dict_params[model_branch_chosen]["type_map"], - model_config["type_map"], - ) - assert set(new_type_map).issubset( - old_type_map - ), "Only support for smaller type map when finetuning or resuming." - for key_item in ["type_map", "descriptor"]: - if key_item in model_dict_params[model_branch_chosen]: - model_config[key_item] = model_dict_params[model_branch_chosen][ - key_item - ] - if not new_fitting: - model_config["fitting_net"] = model_dict_params[model_branch_chosen][ - "fitting_net" - ] - log.info( - f"Change the model configurations according to the model branch " - f"{model_branch_chosen} in the pretrained one..." - ) - model_config["new_type_map"] = new_type_map - model_config["model_branch_chosen"] = model_branch_chosen - model_config["new_fitting"] = new_fitting - for net_type in trainable_param: - if net_type in model_config: - model_config[net_type]["trainable"] = trainable_param[net_type] - else: - model_config[net_type] = {"trainable": trainable_param[net_type]} - return model_config + return model_config, finetune_links diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 5e631d9412..1aff4cfb37 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -12,11 +12,14 @@ from deepmd.pt.utils import ( AtomExcludeMask, - env, +) +from deepmd.pt.utils.auto_batch_size import ( + AutoBatchSize, ) from deepmd.pt.utils.utils import ( dict_to_device, to_numpy_array, + to_torch_tensor, ) from deepmd.utils.out_stat import ( compute_stats_from_redu, @@ -76,6 +79,7 @@ def compute_output_stats( stat_file_path: Optional[DPPath] = None, rcond: Optional[float] = None, atom_ener: Optional[List[float]] = None, + model_forward: Optional[Callable[..., torch.Tensor]] = None, ): """ Compute the output statistics (e.g. energy bias) for the fitting net from packed data. @@ -97,7 +101,11 @@ def compute_output_stats( The condition number for the regression of atomic energy. atom_ener : List[float], optional Specifying atomic energy contribution in vacuum. The `set_davg_zero` key in the descrptor should be set. - + model_forward : Callable[..., torch.Tensor], optional + The wrapped forward function of atomic model. + If not None, the model will be utilized to generate the original energy prediction, + which will be subtracted from the energy label of the data. + The difference will then be used to calculate the delta complement energy bias for each type. """ if stat_file_path is not None: stat_file_path = stat_file_path / "bias_atom_e" @@ -129,13 +137,66 @@ def compute_output_stats( ) else: assigned_atom_ener = None - bias_atom_e, _ = compute_stats_from_redu( - merged_energy, - merged_natoms, - assigned_bias=assigned_atom_ener, - rcond=rcond, - ) + if model_forward is None: + # only use statistics result + bias_atom_e, _ = compute_stats_from_redu( + merged_energy, + merged_natoms, + assigned_bias=assigned_atom_ener, + rcond=rcond, + ) + else: + # subtract the model bias and output the delta bias + auto_batch_size = AutoBatchSize() + energy_predict = [] + for system in sampled: + nframes = system["coord"].shape[0] + coord, atype, box, natoms = ( + system["coord"], + system["atype"], + system["box"], + system["natoms"], + ) + fparam = system.get("fparam", None) + aparam = system.get("aparam", None) + + def model_forward_auto_batch_size(*args, **kwargs): + return auto_batch_size.execute_all( + model_forward, + nframes, + system["atype"].shape[-1], + *args, + **kwargs, + ) + + energy = ( + model_forward_auto_batch_size( + coord, atype, box, fparam=fparam, aparam=aparam + ) + .reshape(nframes, -1) + .sum(-1) + ) + energy_predict.append(to_numpy_array(energy).reshape([nframes, 1])) + + energy_predict = np.concatenate(energy_predict) + bias_diff = merged_energy - energy_predict + bias_atom_e, _ = compute_stats_from_redu( + bias_diff, + merged_natoms, + assigned_bias=assigned_atom_ener, + rcond=rcond, + ) + unbias_e = energy_predict + merged_natoms @ bias_atom_e + atom_numbs = merged_natoms.sum(-1) + rmse_ae = np.sqrt( + np.mean( + np.square((unbias_e.ravel() - merged_energy.ravel()) / atom_numbs) + ) + ) + log.info( + f"RMSE of energy per atom after linear regression is: {rmse_ae} eV/atom." + ) if stat_file_path is not None: stat_file_path.save_numpy(bias_atom_e) assert all(x is not None for x in [bias_atom_e]) - return torch.tensor(bias_atom_e, device=env.DEVICE) + return to_torch_tensor(bias_atom_e) diff --git a/deepmd/tf/fit/ener.py b/deepmd/tf/fit/ener.py index b391b00052..2229c51630 100644 --- a/deepmd/tf/fit/ener.py +++ b/deepmd/tf/fit/ener.py @@ -793,11 +793,11 @@ def change_energy_bias( frozen_model, origin_type_map, full_type_map, - bias_shift="delta", + bias_adjust_mode="change-by-statistic", ntest=10, ) -> None: dp = None - if bias_shift == "delta": + if bias_adjust_mode == "change-by-statistic": # init model dp = DeepPotential(frozen_model) self.bias_atom_e = change_energy_bias_lower( @@ -806,7 +806,7 @@ def change_energy_bias( origin_type_map, full_type_map, self.bias_atom_e, - bias_shift=bias_shift, + bias_adjust_mode=bias_adjust_mode, ntest=ntest, ) diff --git a/deepmd/tf/model/ener.py b/deepmd/tf/model/ener.py index 70e0f4d2ba..a493fe0517 100644 --- a/deepmd/tf/model/ener.py +++ b/deepmd/tf/model/ener.py @@ -486,7 +486,7 @@ def change_energy_bias( frozen_model: str, origin_type_map: list, full_type_map: str, - bias_shift: str = "delta", + bias_adjust_mode: str = "change-by-statistic", ) -> None: """Change the energy bias according to the input data and the pretrained model. @@ -500,17 +500,17 @@ def change_energy_bias( The original type_map in dataset, they are targets to change the energy bias. full_type_map : str The full type_map in pretrained model - bias_shift : str - The mode for changing energy bias : ['delta', 'statistic'] - 'delta' : perform predictions on energies of target dataset, + bias_adjust_mode : str + The mode for changing energy bias : ['change-by-statistic', 'set-by-statistic'] + 'change-by-statistic' : perform predictions on energies of target dataset, and do least sqaure on the errors to obtain the target shift as bias. - 'statistic' : directly use the statistic energy bias in the target dataset. + 'set-by-statistic' : directly use the statistic energy bias in the target dataset. """ self.fitting.change_energy_bias( data, frozen_model, origin_type_map, full_type_map, - bias_shift, + bias_adjust_mode, self.data_bias_nsample, ) diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index a0e234a547..951ae09396 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -408,7 +408,7 @@ def change_energy_bias( frozen_model: str, origin_type_map: list, full_type_map: str, - bias_shift: str = "delta", + bias_adjust_mode: str = "change-by-statistic", ) -> None: """Change the energy bias according to the input data and the pretrained model. @@ -422,11 +422,11 @@ def change_energy_bias( The original type_map in dataset, they are targets to change the energy bias. full_type_map : str The full type_map in pretrained model - bias_shift : str - The mode for changing energy bias : ['delta', 'statistic'] - 'delta' : perform predictions on energies of target dataset, + bias_adjust_mode : str + The mode for changing energy bias : ['change-by-statistic', 'set-by-statistic'] + 'change-by-statistic' : perform predictions on energies of target dataset, and do least sqaure on the errors to obtain the target shift as bias. - 'statistic' : directly use the statistic energy bias in the target dataset. + 'set-by-statistic' : directly use the statistic energy bias in the target dataset. """ raise RuntimeError("Not supported") diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index 125b795d2e..931cf87246 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -1114,7 +1114,7 @@ def _init_from_ckpt(self, ckpt_meta: str): self.ckpt_meta = ckpt_meta def _init_from_pretrained_model( - self, data, origin_type_map=None, bias_shift="delta" + self, data, origin_type_map=None, bias_adjust_mode="change-by-statistic" ): """Init the embedding net variables with the given frozen model. @@ -1124,11 +1124,11 @@ def _init_from_pretrained_model( The training data. origin_type_map : list The original type_map in dataset, they are targets to change the energy bias. - bias_shift : str - The mode for changing energy bias : ['delta', 'statistic'] - 'delta' : perform predictions on energies of target dataset, + bias_adjust_mode : str + The mode for changing energy bias : ['change-by-statistic', 'set-by-statistic'] + 'change-by-statistic' : perform predictions on energies of target dataset, and do least sqaure on the errors to obtain the target shift as bias. - 'statistic' : directly use the statistic energy bias in the target dataset. + 'set-by-statistic' : directly use the statistic energy bias in the target dataset. """ try: graph, graph_def = load_graph_def(self.run_opt.finetune) @@ -1159,11 +1159,15 @@ def _init_from_pretrained_model( "(this step may take long time)" ) self._change_energy_bias( - data, self.run_opt.finetune, origin_type_map, bias_shift + data, self.run_opt.finetune, origin_type_map, bias_adjust_mode ) def _change_energy_bias( - self, data, frozen_model, origin_type_map, bias_shift="delta" + self, + data, + frozen_model, + origin_type_map, + bias_adjust_mode="change-by-statistic", ): full_type_map = data.get_type_map() self.model.change_energy_bias( @@ -1171,7 +1175,7 @@ def _change_energy_bias( frozen_model, origin_type_map, full_type_map, - bias_shift=bias_shift, + bias_adjust_mode=bias_adjust_mode, ) diff --git a/deepmd/utils/finetune.py b/deepmd/utils/finetune.py index a454ad72ea..1150fe2701 100644 --- a/deepmd/utils/finetune.py +++ b/deepmd/utils/finetune.py @@ -26,7 +26,7 @@ def change_energy_bias_lower( origin_type_map: List[str], full_type_map: List[str], bias_atom_e: np.ndarray, - bias_shift="delta", + bias_adjust_mode="change-by-statistic", ntest=10, ): """Change the energy bias according to the input data and the pretrained model. @@ -43,11 +43,11 @@ def change_energy_bias_lower( The full type_map in pretrained model bias_atom_e : np.ndarray The old energy bias in the pretrained model. - bias_shift : str - The mode for changing energy bias : ['delta', 'statistic'] - 'delta' : perform predictions on energies of target dataset, + bias_adjust_mode : str + The mode for changing energy bias : ['change-by-statistic', 'set-by-statistic'] + 'change-by-statistic' : perform predictions on energies of target dataset, and do least sqaure on the errors to obtain the target shift as bias. - 'statistic' : directly use the statistic energy bias in the target dataset. + 'set-by-statistic' : directly use the statistic energy bias in the target dataset. ntest : int The number of test samples in a system to change the energy bias. """ @@ -88,7 +88,7 @@ def change_energy_bias_lower( (numb_test, 1), ) ) - if bias_shift == "delta": + if bias_adjust_mode == "change-by-statistic": coord = test_data["coord"][:numb_test].reshape([numb_test, -1]) if sys.pbc: box = test_data["box"][:numb_test] @@ -114,7 +114,7 @@ def change_energy_bias_lower( type_numbs = np.concatenate(type_numbs) energy_ground_truth = np.concatenate(energy_ground_truth) old_bias = bias_atom_e[idx_type_map] - if bias_shift == "delta": + if bias_adjust_mode == "change-by-statistic": energy_predict = np.concatenate(energy_predict) bias_diff = energy_ground_truth - energy_predict delta_bias = np.linalg.lstsq(type_numbs, bias_diff, rcond=None)[0] @@ -129,11 +129,11 @@ def change_energy_bias_lower( log.info( f"RMSE of atomic energy after linear regression is: {rmse_ae} eV/atom." ) - elif bias_shift == "statistic": + elif bias_adjust_mode == "set-by-statistic": statistic_bias = np.linalg.lstsq(type_numbs, energy_ground_truth, rcond=None)[0] bias_atom_e[idx_type_map] = statistic_bias.reshape(-1) else: - raise RuntimeError("Unknown bias_shift mode: " + bias_shift) + raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode) log.info( f"Change energy bias of {origin_type_map!s} from {old_bias!s} to {bias_atom_e[idx_type_map]!s}." ) diff --git a/source/tests/pt/model/test_model.py b/source/tests/pt/model/test_model.py index aa1c0dd969..493d6e2cc3 100644 --- a/source/tests/pt/model/test_model.py +++ b/source/tests/pt/model/test_model.py @@ -352,7 +352,9 @@ def test_consistency(self): } label = { "energy": batch["energy"].to(env.DEVICE), + "find_energy": 1.0, "force": batch["force"].to(env.DEVICE), + "find_force": 1.0, } cur_lr = my_lr.value(self.wanted_step) model_predict, loss, _ = my_loss( diff --git a/source/tests/pt/model/water/multitask.json b/source/tests/pt/model/water/multitask.json index 6baddd672b..2f706e4cd9 100644 --- a/source/tests/pt/model/water/multitask.json +++ b/source/tests/pt/model/water/multitask.json @@ -39,7 +39,8 @@ "resnet_dt": true, "seed": 1, "_comment": " that's all" - } + }, + "data_stat_nbatch": 1 }, "model_2": { "type_map": "my_type_map", @@ -53,7 +54,8 @@ "resnet_dt": true, "seed": 1, "_comment": " that's all" - } + }, + "data_stat_nbatch": 1 } } }, diff --git a/source/tests/pt/test_finetune.py b/source/tests/pt/test_finetune.py index dd72eb4718..79f8c57cb8 100644 --- a/source/tests/pt/test_finetune.py +++ b/source/tests/pt/test_finetune.py @@ -1,6 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import os -import shutil import tempfile import unittest from copy import ( @@ -17,16 +15,16 @@ DeepEval, ) from deepmd.pt.model.model import ( - DPZBLModel, - EnergyModel, get_model, - get_zbl_model, ) -from deepmd.utils.data_system import ( - DeepmdDataSystem, +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, ) -from deepmd.utils.finetune import ( - change_energy_bias_lower, +from deepmd.pt.utils.stat import ( + make_stat_input, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, ) from .model.test_permutation import ( @@ -34,63 +32,38 @@ model_se_e2_a, model_zbl, ) +from .test_stat import ( + energy_data_requirement, +) class FinetuneTest: - def test_finetune_change_energy_bias(self): + def test_finetune_change_out_bias(self): # get model - if "use_srtab" in self.model_config: - model = get_zbl_model(self.model_config) - else: - model = get_model(self.model_config) - if isinstance(model, EnergyModel): - model.get_fitting_net().bias_atom_e = torch.rand_like( - model.get_fitting_net().bias_atom_e - ) - energy_bias_before = deepcopy( - model.get_fitting_net().bias_atom_e.detach().cpu().numpy().reshape(-1) - ) - bias_atom_e_input = deepcopy( - model.get_fitting_net().bias_atom_e.detach().cpu().numpy().reshape(-1) - ) - elif isinstance(model, DPZBLModel): - model.dp_model.get_fitting_net().bias_atom_e = torch.rand_like( - model.dp_model.get_fitting_net().bias_atom_e - ) - energy_bias_before = deepcopy( - model.dp_model.get_fitting_net() - .bias_atom_e.detach() - .cpu() - .numpy() - .reshape(-1) - ) - bias_atom_e_input = deepcopy( - model.dp_model.get_fitting_net() - .bias_atom_e.detach() - .cpu() - .numpy() - .reshape(-1) - ) - else: - bias_atom_e_input = None - - model = torch.jit.script(model) + model = get_model(self.model_config) + fitting_net = model.get_fitting_net() + fitting_net["bias_atom_e"] = torch.rand_like(fitting_net["bias_atom_e"]) + energy_bias_before = deepcopy( + to_numpy_array(fitting_net["bias_atom_e"]).reshape(-1) + ) + + # prepare original model for test + dp = torch.jit.script(model) tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, tmp_model.name) + torch.jit.save(dp, tmp_model.name) dp = DeepEval(tmp_model.name) - ntest = 10 origin_type_map = ["O", "H"] full_type_map = ["O", "H", "B"] # change energy bias - energy_bias_after = change_energy_bias_lower( - self.data, - dp, + model.atomic_model.change_out_bias( + self.sampled, + bias_adjust_mode="change-by-statistic", origin_type_map=origin_type_map, full_type_map=full_type_map, - bias_atom_e=bias_atom_e_input, - bias_shift="delta", - ntest=ntest, + ) + energy_bias_after = deepcopy( + to_numpy_array(fitting_net["bias_atom_e"]).reshape(-1) ) # get ground-truth energy bias change @@ -98,12 +71,17 @@ def test_finetune_change_energy_bias(self): idx_type_map = sorter[ np.searchsorted(full_type_map, origin_type_map, sorter=sorter) ] - test_data = self.data.get_test() - atom_nums = np.tile(np.bincount(test_data["type"][0])[idx_type_map], (ntest, 1)) + ntest = 1 + atom_nums = np.tile( + np.bincount(to_numpy_array(self.sampled[0]["atype"][0]))[idx_type_map], + (ntest, 1), + ) energy = dp.eval( - test_data["coord"][:ntest], test_data["box"][:ntest], test_data["type"][0] + to_numpy_array(self.sampled[0]["coord"][:ntest]), + to_numpy_array(self.sampled[0]["box"][:ntest]), + to_numpy_array(self.sampled[0]["atype"][0]), )[0] - energy_diff = test_data["energy"][:ntest] - energy + energy_diff = to_numpy_array(self.sampled[0]["energy"][:ntest]) - energy finetune_shift = ( energy_bias_after[idx_type_map] - energy_bias_before[idx_type_map] ) @@ -114,60 +92,57 @@ def test_finetune_change_energy_bias(self): # check values np.testing.assert_almost_equal(finetune_shift, ground_truth_shift, decimal=10) - def tearDown(self): - for f in os.listdir("."): - if f.startswith("model") and f.endswith(".pt"): - os.remove(f) - if f in ["lcurve.out"]: - os.remove(f) - if f in ["stat_files"]: - shutil.rmtree(f) - class TestEnergyModelSeA(unittest.TestCase, FinetuneTest): def setUp(self): self.data_file = [str(Path(__file__).parent / "water/data/data_0")] - self.data = DeepmdDataSystem( + self.model_config = model_se_e2_a + self.data = DpLoaderSet( self.data_file, batch_size=1, - test_size=1, + type_map=self.model_config["type_map"], + ) + self.data.add_data_requirement(energy_data_requirement) + self.sampled = make_stat_input( + self.data.systems, + self.data.dataloaders, + nbatches=1, ) - self.data.add("energy", ndof=1, atomic=False, must=True, high_prec=True) - self.model_config = model_se_e2_a - - def tearDown(self) -> None: - FinetuneTest.tearDown(self) @unittest.skip("change bias not implemented yet.") class TestEnergyZBLModelSeA(unittest.TestCase, FinetuneTest): def setUp(self): self.data_file = [str(Path(__file__).parent / "water/data/data_0")] - self.data = DeepmdDataSystem( + self.model_config = model_zbl + self.data = DpLoaderSet( self.data_file, batch_size=1, - test_size=1, + type_map=self.model_config["type_map"], + ) + self.data.add_data_requirement(energy_data_requirement) + self.sampled = make_stat_input( + self.data.systems, + self.data.dataloaders, + nbatches=1, ) - self.data.add("energy", ndof=1, atomic=False, must=True, high_prec=True) - self.model_config = model_zbl - - def tearDown(self) -> None: - FinetuneTest.tearDown(self) class TestEnergyModelDPA2(unittest.TestCase, FinetuneTest): def setUp(self): self.data_file = [str(Path(__file__).parent / "water/data/data_0")] - self.data = DeepmdDataSystem( + self.model_config = model_dpa2 + self.data = DpLoaderSet( self.data_file, batch_size=1, - test_size=1, + type_map=self.model_config["type_map"], + ) + self.data.add_data_requirement(energy_data_requirement) + self.sampled = make_stat_input( + self.data.systems, + self.data.dataloaders, + nbatches=1, ) - self.data.add("energy", ndof=1, atomic=False, must=True, high_prec=True) - self.model_config = model_dpa2 - - def tearDown(self) -> None: - FinetuneTest.tearDown(self) if __name__ == "__main__": diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index 2abb22c2a9..17b05dadc6 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -147,6 +147,14 @@ def setUp(self): "virial": torch.from_numpy(p_virial), } self.label = { + "energy": torch.from_numpy(l_energy), + "find_energy": 1.0, + "force": torch.from_numpy(l_force), + "find_force": 1.0, + "virial": torch.from_numpy(l_virial), + "find_virial": 1.0, + } + self.label_absent = { "energy": torch.from_numpy(l_energy), "force": torch.from_numpy(l_force), "virial": torch.from_numpy(l_virial), @@ -182,14 +190,24 @@ def fake_model(): self.nloc, self.cur_lr, ) + _, my_loss_absent, my_more_loss_absent = mine( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) my_loss = my_loss.detach().cpu() + my_loss_absent = my_loss_absent.detach().cpu() self.assertTrue(np.allclose(base_loss, my_loss.numpy())) + self.assertTrue(np.allclose(0.0, my_loss_absent.numpy())) for key in ["ener", "force", "virial"]: self.assertTrue( np.allclose( base_more_loss["l2_%s_loss" % key], my_more_loss["l2_%s_loss" % key] ) ) + self.assertTrue(np.isnan(my_more_loss_absent["l2_%s_loss" % key])) class TestEnerSpinLoss(unittest.TestCase): @@ -326,6 +344,14 @@ def setUp(self): ), } self.label = { + "energy": torch.from_numpy(l_energy), + "find_energy": 1.0, + "force": torch.from_numpy(l_force_real).reshape(nframes, self.nloc, 3), + "find_force": 1.0, + "force_mag": torch.from_numpy(l_force_mag).reshape(nframes, self.nloc, 3), + "find_force_mag": 1.0, + } + self.label_absent = { "energy": torch.from_numpy(l_energy), "force": torch.from_numpy(l_force_real).reshape(nframes, self.nloc, 3), "force_mag": torch.from_numpy(l_force_mag).reshape(nframes, self.nloc, 3), @@ -361,14 +387,24 @@ def fake_model(): self.nloc_tf, # use tf natoms pref self.cur_lr, ) + _, my_loss_absent, my_more_loss_absent = mine( + {}, + fake_model, + self.label_absent, + self.nloc_tf, # use tf natoms pref + self.cur_lr, + ) my_loss = my_loss.detach().cpu() + my_loss_absent = my_loss_absent.detach().cpu() self.assertTrue(np.allclose(base_loss, my_loss.numpy())) + self.assertTrue(np.allclose(0.0, my_loss_absent.numpy())) for key in ["ener", "force_r", "force_m"]: self.assertTrue( np.allclose( base_more_loss["l2_%s_loss" % key], my_more_loss["l2_%s_loss" % key] ) ) + self.assertTrue(np.isnan(my_more_loss_absent["l2_%s_loss" % key])) if __name__ == "__main__": diff --git a/source/tests/pt/test_multitask.py b/source/tests/pt/test_multitask.py index e959e9a128..8bdb42df52 100644 --- a/source/tests/pt/test_multitask.py +++ b/source/tests/pt/test_multitask.py @@ -32,6 +32,7 @@ class MultiTaskTrainTest: def test_multitask_train(self): + # test multitask training trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links) trainer.run() # check model keys @@ -51,6 +52,100 @@ def test_multitask_train(self): multi_state_dict[state_key], multi_state_dict[state_key.replace("model_1", "model_2")], ) + + # test multitask fine-tuning + # add model_3 + self.origin_config["model"]["model_dict"]["model_3"] = deepcopy( + self.origin_config["model"]["model_dict"]["model_2"] + ) + self.origin_config["loss_dict"]["model_3"] = deepcopy( + self.origin_config["loss_dict"]["model_2"] + ) + self.origin_config["training"]["model_prob"]["model_3"] = deepcopy( + self.origin_config["training"]["model_prob"]["model_2"] + ) + self.origin_config["training"]["data_dict"]["model_3"] = deepcopy( + self.origin_config["training"]["data_dict"]["model_2"] + ) + self.origin_config["training"]["data_dict"]["model_3"]["stat_file"] = ( + self.origin_config[ + "training" + ]["data_dict"]["model_3"]["stat_file"].replace("model_2", "model_3") + ) + + # add model_4 + self.origin_config["model"]["model_dict"]["model_4"] = deepcopy( + self.origin_config["model"]["model_dict"]["model_2"] + ) + self.origin_config["loss_dict"]["model_4"] = deepcopy( + self.origin_config["loss_dict"]["model_2"] + ) + self.origin_config["training"]["model_prob"]["model_4"] = deepcopy( + self.origin_config["training"]["model_prob"]["model_2"] + ) + self.origin_config["training"]["data_dict"]["model_4"] = deepcopy( + self.origin_config["training"]["data_dict"]["model_2"] + ) + self.origin_config["training"]["data_dict"]["model_4"]["stat_file"] = ( + self.origin_config[ + "training" + ]["data_dict"]["model_4"]["stat_file"].replace("model_2", "model_4") + ) + + # set finetune rules + # model_1 resuming from model_1 + # pass + + # model_2 fine-tuning from model_2 + self.origin_config["model"]["model_dict"]["model_2"]["finetune_head"] = ( + "model_2" + ) + + # new model_3 fine-tuning from model_2 + self.origin_config["model"]["model_dict"]["model_3"]["finetune_head"] = ( + "model_2" + ) + + # new model_4 fine-tuning with randomly initialized fitting net + # pass + + self.origin_config["model"], shared_links_finetune = preprocess_shared_params( + self.origin_config["model"] + ) + + trainer_finetune = get_trainer( + deepcopy(self.origin_config), + finetune_model=self.config["training"].get("save_ckpt", "model.ckpt") + + ".pt", + shared_links=shared_links_finetune, + ) + + # check parameters + multi_state_dict_finetuned = trainer_finetune.wrapper.model.state_dict() + for state_key in multi_state_dict_finetuned: + if "model_1" in state_key: + torch.testing.assert_close( + multi_state_dict[state_key], + multi_state_dict_finetuned[state_key], + ) + elif "model_2" in state_key and "bias_atom_e" not in state_key: + torch.testing.assert_close( + multi_state_dict[state_key], + multi_state_dict_finetuned[state_key], + ) + elif "model_3" in state_key and "bias_atom_e" not in state_key: + torch.testing.assert_close( + multi_state_dict[state_key.replace("model_3", "model_2")], + multi_state_dict_finetuned[state_key], + ) + elif "model_4" in state_key and "fitting_net" not in state_key: + torch.testing.assert_close( + multi_state_dict[state_key.replace("model_4", "model_2")], + multi_state_dict_finetuned[state_key], + ) + + # check running + trainer_finetune.run() self.tearDown() def tearDown(self): @@ -93,6 +188,7 @@ def setUp(self): ) self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 + self.origin_config = deepcopy(self.config) self.config["model"], self.shared_links = preprocess_shared_params( self.config["model"] ) @@ -131,6 +227,7 @@ def setUp(self): ) self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 + self.origin_config = deepcopy(self.config) self.config["model"], self.shared_links = preprocess_shared_params( self.config["model"] ) @@ -169,6 +266,7 @@ def setUp(self): ) self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 + self.origin_config = deepcopy(self.config) self.config["model"], self.shared_links = preprocess_shared_params( self.config["model"] )