Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: support multitask finetune #3480

Merged
merged 17 commits into from
Mar 22, 2024
Merged
15 changes: 8 additions & 7 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ def get_trainer(
dist.init_process_group(backend="nccl")

ckpt = init_model if init_model is not None else restart_model
config["model"] = change_finetune_model_params(
ckpt,
finetune_model,
config["model"],
multi_task=multi_task,
model_branch=model_branch,
)
finetune_links = None
if finetune_model is not None:
config["model"], finetune_links = change_finetune_model_params(
finetune_model,
config["model"],
model_branch=model_branch,
)
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)

def prepare_trainer_input_single(
Expand Down Expand Up @@ -194,6 +194,7 @@ def prepare_trainer_input_single(
finetune_model=finetune_model,
force_load=force_load,
shared_links=shared_links,
finetune_links=finetune_links,
init_frz_model=init_frz_model,
)
return trainer
Expand Down
90 changes: 90 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Optional,
)

import numpy as np
import torch

from deepmd.dpmodel import (
Expand All @@ -19,6 +20,15 @@
from deepmd.pt.model.task.base_fitting import (
BaseFitting,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -223,6 +233,86 @@
if self.fitting_net is not None:
self.fitting_net.compute_output_stats(wrapped_sampler, stat_file_path)

def change_out_bias(
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
"""Change the energy bias according to the input data and the pretrained model.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
origin_type_map : List[str]
The original type_map in dataset, they are targets to change the energy bias.
full_type_map : List[str]
The full type_map in pre-trained model
bias_shift : str
The mode for changing energy bias : ['delta', 'statistic']
'delta' : perform predictions on energies of target dataset,
and do least sqaure on the errors to obtain the target shift as bias.
'statistic' : directly use the statistic energy bias in the target dataset.
"""
sorter = np.argsort(full_type_map)
missing_types = [t for t in origin_type_map if t not in full_type_map]
assert (
not missing_types
), f"Some types are not in the pre-trained model: {list(missing_types)} !"
idx_type_map = sorter[
np.searchsorted(full_type_map, origin_type_map, sorter=sorter)
]
original_bias = self.fitting_net["bias_atom_e"]
if bias_shift == "delta":

def model_forward(coord, atype, box, fparam=None, aparam=None):
with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
mixed_types=self.mixed_types(),
box=box,
)
atomic_ret = self.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
return atomic_ret["energy"].detach()

delta_bias_e = compute_output_stats(
merged,
self.get_ntypes(),
model_forward=model_forward,
)
bias_atom_e = delta_bias_e + original_bias
elif bias_shift == "statistic":
bias_atom_e = compute_output_stats(
merged,
self.get_ntypes(),
)
else:
raise RuntimeError("Unknown bias_shift mode: " + bias_shift)

Check warning on line 308 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L308

Added line #L308 was not covered by tests
log.info(
f"Change energy bias of {origin_type_map!s} "
f"from {to_numpy_array(original_bias[idx_type_map]).reshape(-1)!s} "
f"to {to_numpy_array(bias_atom_e[idx_type_map]).reshape(-1)!s}."
)
self.fitting_net["bias_atom_e"] = bias_atom_e

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.fitting_net.get_dim_fparam()
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,9 @@ def compute_or_load_stat(
self.models[0].compute_or_load_stat(sampled_func, stat_file_path)
self.models[1].compute_or_load_stat(sampled_func, stat_file_path)

def change_energy_bias(self):
def change_out_bias(
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
# need to implement
pass

Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def compute_or_load_stat(
torch.tensor(bias_atom_e, device=env.DEVICE).view([self.ntypes, 1])
)

def change_energy_bias(self) -> None:
def change_out_bias(
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
# need to implement
pass

Expand Down
78 changes: 0 additions & 78 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
import os
import tempfile
from abc import (
abstractmethod,
)
Expand All @@ -15,9 +13,6 @@
import numpy as np
import torch

from deepmd.infer.deep_eval import (
DeepEval,
)
from deepmd.pt.model.network.mlp import (
FittingNet,
NetworkCollection,
Expand All @@ -33,7 +28,6 @@
)
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
DEVICE,
PRECISION_DICT,
)
from deepmd.pt.utils.exclude_mask import (
Expand All @@ -43,12 +37,6 @@
to_numpy_array,
to_torch_tensor,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.finetune import (
change_energy_bias_lower,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE
Expand Down Expand Up @@ -88,72 +76,6 @@ def share_params(self, base_class, shared_level, resume=False):
else:
raise NotImplementedError

def change_energy_bias(
self,
config,
model,
old_type_map: List[str],
new_type_map: List[str],
bias_shift="delta",
ntest=10,
):
"""Change the energy bias according to the input data and the pretrained model.

Parameters
----------
config : Dict
The configuration.
model : EnergyModel
Energy model loaded pre-trained model.
new_type_map : List[str]
The original type_map in dataset, they are targets to change the energy bias.
old_type_map : List[str]
The full type_map in pretrained model
bias_shift : str
The mode for changing energy bias : ['delta', 'statistic']
'delta' : perform predictions on energies of target dataset,
and do least sqaure on the errors to obtain the target shift as bias.
'statistic' : directly use the statistic energy bias in the target dataset.
ntest : int
The number of test samples in a system to change the energy bias.
"""
log.info(
f"Changing energy bias in pretrained model for types {new_type_map!s}... "
"(this step may take long time)"
)
# data
systems = config["training"]["training_data"]["systems"]
finetune_data = DeepmdDataSystem(
systems=systems,
batch_size=config["training"]["training_data"].get("batch_size", "auto"),
test_size=1,
)
finetune_data.add("energy", ndof=1, atomic=False, must=True, high_prec=True)
model = torch.jit.script(model)
if model.get_dim_fparam() > 0:
finetune_data.add("fparam", model.get_dim_fparam(), atomic=False, must=True)
if model.get_dim_aparam() > 0:
finetune_data.add("aparam", model.get_dim_aparam(), atomic=True, must=True)
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
torch.jit.save(model, tmp_model.name)
dp = DeepEval(tmp_model.name)
os.unlink(tmp_model.name)
bias = change_energy_bias_lower(
finetune_data,
dp,
new_type_map,
old_type_map,
self.bias_atom_e.detach().cpu().numpy().reshape(-1),
bias_shift=bias_shift,
ntest=ntest,
)
self.bias_atom_e = (
torch.from_numpy(bias)
.type_as(self.bias_atom_e)
.reshape(self.bias_atom_e.shape)
.to(DEVICE)
)


class GeneralFitting(Fitting):
"""Construct a general fitting net.
Expand Down
Loading