Skip to content

Commit

Permalink
pt:support multitask finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Mar 17, 2024
1 parent 4b3a77b commit 67b838d
Show file tree
Hide file tree
Showing 13 changed files with 512 additions and 324 deletions.
15 changes: 8 additions & 7 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ def get_trainer(
dist.init_process_group(backend="nccl")

ckpt = init_model if init_model is not None else restart_model
config["model"] = change_finetune_model_params(
ckpt,
finetune_model,
config["model"],
multi_task=multi_task,
model_branch=model_branch,
)
finetune_links = None
if finetune_model is not None:
config["model"], finetune_links = change_finetune_model_params(
finetune_model,

Check warning on line 96 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L93-L96

Added lines #L93 - L96 were not covered by tests
config["model"],
model_branch=model_branch,
)
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)

def prepare_trainer_input_single(
Expand Down Expand Up @@ -194,6 +194,7 @@ def prepare_trainer_input_single(
finetune_model=finetune_model,
force_load=force_load,
shared_links=shared_links,
finetune_links=finetune_links,
init_frz_model=init_frz_model,
)
return trainer
Expand Down
4 changes: 0 additions & 4 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,6 @@ def compute_or_load_stat(
self.models[0].compute_or_load_stat(sampled_func, stat_file_path)
self.models[1].compute_or_load_stat(sampled_func, stat_file_path)

def change_energy_bias(self):
# need to implement
pass

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 0 additions & 4 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,6 @@ def compute_or_load_stat(
torch.tensor(bias_atom_e, device=env.DEVICE).view([self.ntypes, 1])
)

def change_energy_bias(self) -> None:
# need to implement
pass

def forward_atomic(
self,
extended_coord: torch.Tensor,
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,8 @@ def forward_lower(
else:
model_predict = model_ret
return model_predict

def change_out_bias(
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
raise NotImplementedError

Check warning on line 99 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L99

Added line #L99 was not covered by tests
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/dos_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,8 @@ def forward_lower(
else:
model_predict = model_ret
return model_predict

def change_out_bias(

Check warning on line 82 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L82

Added line #L82 was not covered by tests
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
raise NotImplementedError

Check warning on line 85 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L85

Added line #L85 was not covered by tests
71 changes: 71 additions & 0 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
import tempfile

Check warning on line 4 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L2-L4

Added lines #L2 - L4 were not covered by tests
from typing import (
Dict,
Optional,
)

import numpy as np

Check warning on line 10 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L10

Added line #L10 was not covered by tests
import torch

from deepmd.infer.deep_eval import (

Check warning on line 13 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L13

Added line #L13 was not covered by tests
DeepEval,
)
from deepmd.pt.utils.stat import (

Check warning on line 16 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L16

Added line #L16 was not covered by tests
compute_output_stats,
)
from deepmd.pt.utils.utils import (

Check warning on line 19 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L19

Added line #L19 was not covered by tests
to_numpy_array,
)

from .dp_model import (
DPModel,
)

log = logging.getLogger(__name__)

Check warning on line 27 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L27

Added line #L27 was not covered by tests


class EnergyModel(DPModel):
model_type = "ener"
Expand Down Expand Up @@ -97,3 +113,58 @@ def forward_lower(
else:
model_predict = model_ret
return model_predict

def change_out_bias(

Check warning on line 117 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L117

Added line #L117 was not covered by tests
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
"""Change the energy bias according to the input data and the pretrained model.
Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
origin_type_map : List[str]
The original type_map in dataset, they are targets to change the energy bias.
full_type_map : List[str]
The full type_map in pre-trained model
bias_shift : str
The mode for changing energy bias : ['delta', 'statistic']
'delta' : perform predictions on energies of target dataset,
and do least sqaure on the errors to obtain the target shift as bias.
'statistic' : directly use the statistic energy bias in the target dataset.
"""
sorter = np.argsort(full_type_map)
idx_type_map = sorter[

Check warning on line 142 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L141-L142

Added lines #L141 - L142 were not covered by tests
np.searchsorted(full_type_map, origin_type_map, sorter=sorter)
]
original_bias = self.get_fitting_net()["bias_atom_e"]
if bias_shift == "delta":
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
model = torch.jit.script(self)
torch.jit.save(model, tmp_model.name)
dp = DeepEval(tmp_model.name)
os.unlink(tmp_model.name)
delta_bias_e = compute_output_stats(

Check warning on line 152 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L145-L152

Added lines #L145 - L152 were not covered by tests
merged,
self.atomic_model.get_ntypes(),
model=dp,
)
bias_atom_e = delta_bias_e + original_bias
elif bias_shift == "statistic":
bias_atom_e = compute_output_stats(

Check warning on line 159 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L157-L159

Added lines #L157 - L159 were not covered by tests
merged,
self.atomic_model.get_ntypes(),
)
else:
raise RuntimeError("Unknown bias_shift mode: " + bias_shift)
log.info(

Check warning on line 165 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L164-L165

Added lines #L164 - L165 were not covered by tests
f"Change energy bias of {origin_type_map!s} "
f"from {to_numpy_array(original_bias[idx_type_map]).reshape(-1)!s} "
f"to {to_numpy_array(bias_atom_e[idx_type_map]).reshape(-1)!s}."
)
self.get_fitting_net()["bias_atom_e"] = bias_atom_e

Check warning on line 170 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L170

Added line #L170 was not covered by tests
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,8 @@ def forward_lower(
else:
model_predict = model_ret
return model_predict

def change_out_bias(

Check warning on line 80 in deepmd/pt/model/model/polar_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/polar_model.py#L80

Added line #L80 was not covered by tests
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
raise NotImplementedError

Check warning on line 83 in deepmd/pt/model/model/polar_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/polar_model.py#L83

Added line #L83 was not covered by tests
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,8 @@ def forward_lower(
].squeeze(-2)
# not support virial by far
return model_predict

def change_out_bias(

Check warning on line 562 in deepmd/pt/model/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/spin_model.py#L562

Added line #L562 was not covered by tests
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
raise NotImplementedError

Check warning on line 565 in deepmd/pt/model/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/spin_model.py#L565

Added line #L565 was not covered by tests
78 changes: 0 additions & 78 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
import os
import tempfile
from abc import (
abstractmethod,
)
Expand All @@ -15,9 +13,6 @@
import numpy as np
import torch

from deepmd.infer.deep_eval import (
DeepEval,
)
from deepmd.pt.model.network.mlp import (
FittingNet,
NetworkCollection,
Expand All @@ -33,7 +28,6 @@
)
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
DEVICE,
PRECISION_DICT,
)
from deepmd.pt.utils.exclude_mask import (
Expand All @@ -43,12 +37,6 @@
to_numpy_array,
to_torch_tensor,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.finetune import (
change_energy_bias_lower,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE
Expand Down Expand Up @@ -88,72 +76,6 @@ def share_params(self, base_class, shared_level, resume=False):
else:
raise NotImplementedError

def change_energy_bias(
self,
config,
model,
old_type_map: List[str],
new_type_map: List[str],
bias_shift="delta",
ntest=10,
):
"""Change the energy bias according to the input data and the pretrained model.
Parameters
----------
config : Dict
The configuration.
model : EnergyModel
Energy model loaded pre-trained model.
new_type_map : List[str]
The original type_map in dataset, they are targets to change the energy bias.
old_type_map : List[str]
The full type_map in pretrained model
bias_shift : str
The mode for changing energy bias : ['delta', 'statistic']
'delta' : perform predictions on energies of target dataset,
and do least sqaure on the errors to obtain the target shift as bias.
'statistic' : directly use the statistic energy bias in the target dataset.
ntest : int
The number of test samples in a system to change the energy bias.
"""
log.info(
f"Changing energy bias in pretrained model for types {new_type_map!s}... "
"(this step may take long time)"
)
# data
systems = config["training"]["training_data"]["systems"]
finetune_data = DeepmdDataSystem(
systems=systems,
batch_size=config["training"]["training_data"].get("batch_size", "auto"),
test_size=1,
)
finetune_data.add("energy", ndof=1, atomic=False, must=True, high_prec=True)
model = torch.jit.script(model)
if model.get_dim_fparam() > 0:
finetune_data.add("fparam", model.get_dim_fparam(), atomic=False, must=True)
if model.get_dim_aparam() > 0:
finetune_data.add("aparam", model.get_dim_aparam(), atomic=True, must=True)
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
torch.jit.save(model, tmp_model.name)
dp = DeepEval(tmp_model.name)
os.unlink(tmp_model.name)
bias = change_energy_bias_lower(
finetune_data,
dp,
new_type_map,
old_type_map,
self.bias_atom_e.detach().cpu().numpy().reshape(-1),
bias_shift=bias_shift,
ntest=ntest,
)
self.bias_atom_e = (
torch.from_numpy(bias)
.type_as(self.bias_atom_e)
.reshape(self.bias_atom_e.shape)
.to(DEVICE)
)


class GeneralFitting(Fitting):
"""Construct a general fitting net.
Expand Down
Loading

0 comments on commit 67b838d

Please sign in to comment.