Skip to content

Commit

Permalink
pt: support --init-frz-model (#3350)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Feb 28, 2024
1 parent d377ccb commit 897fcc5
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 5 deletions.
2 changes: 1 addition & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def main_parser() -> argparse.ArgumentParser:
"--init-frz-model",
type=str,
default=None,
help="(Supported backend: TensorFlow) Initialize the training from the frozen model.",
help="Initialize the training from the frozen model.",
)
parser_train_subgroup.add_argument(
"-t",
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def get_trainer(
finetune_model=None,
model_branch="",
force_load=False,
init_frz_model=None,
):
# Initialize DDP
local_rank = os.environ.get("LOCAL_RANK")
Expand Down Expand Up @@ -200,6 +201,7 @@ def prepare_trainer_input_single(
finetune_model=finetune_model,
force_load=force_load,
shared_links=shared_links,
init_frz_model=init_frz_model,
)
return trainer

Expand Down Expand Up @@ -243,6 +245,7 @@ def train(FLAGS):
FLAGS.finetune,
FLAGS.model_branch,
FLAGS.force_load,
FLAGS.init_frz_model,
)
trainer.run()

Expand Down
22 changes: 18 additions & 4 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
finetune_model=None,
force_load=False,
shared_links=None,
init_frz_model=None,
):
"""Construct a DeePMD trainer.
Expand Down Expand Up @@ -271,7 +272,7 @@ def get_loss(loss_params, start_lr, _ntypes):
self.warmup_steps = training_params.get("warmup_steps", 0)
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
assert (
self.num_steps - self.warmup_steps > 0
self.num_steps - self.warmup_steps > 0 or self.warmup_steps == 0
), "Warm up steps must be less than total training steps!"
if self.multi_task and config.get("learning_rate_dict", None) is not None:
self.lr_exp = {}
Expand Down Expand Up @@ -394,6 +395,9 @@ def get_loss(loss_params, start_lr, _ntypes):
ntest=ntest,
bias_shift=model_params.get("bias_shift", "delta"),
)
if init_frz_model is not None:
frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)
self.model.load_state_dict(frz_model.state_dict())

# Set trainable params
self.wrapper.set_trainable_params()
Expand Down Expand Up @@ -724,6 +728,15 @@ def log_loss_valid(_task_key="Default"):
if (
self.rank == 0 or dist.get_rank() == 0
): # Handle the case if rank 0 aborted and re-assigned
if self.num_steps == 0:
# when num_steps is 0, the checkpoint is never not saved
self.latest_model = Path(self.save_ckpt + "-0.pt")
self.save_model(self.latest_model, lr=0, step=0)
log.info(f"Saved model to {self.latest_model}")
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))

if JIT:
pth_model_path = (
"frozen_model.pth" # We use .pth to denote the frozen model
Expand Down Expand Up @@ -759,9 +772,10 @@ def get_data(self, is_train=True, task_key="Default"):
batch_data = next(iter(self.training_data))
except StopIteration:
# Refresh the status of the dataloader to start from a new epoch
self.training_data = BufferedIterator(
iter(self.training_dataloader)
)
with torch.device("cpu"):
self.training_data = BufferedIterator(
iter(self.training_dataloader)
)
batch_data = next(iter(self.training_data))
else:
try:
Expand Down
101 changes: 101 additions & 0 deletions source/tests/pt/test_init_frz_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import unittest
from argparse import (
Namespace,
)
from copy import (
deepcopy,
)
from pathlib import (
Path,
)

import numpy as np

from deepmd.pt.entrypoints.main import (
freeze,
get_trainer,
)
from deepmd.pt.infer.deep_eval import (
DeepPot,
)


class TestInitFrzModel(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
config = json.load(f)
config["training"]["numb_steps"] = 1
config["training"]["save_freq"] = 1
config["learning_rate"]["start_lr"] = 1.0
config["training"]["training_data"]["systems"] = [
str(Path(__file__).parent / "water/data/single")
]
config["training"]["validation_data"]["systems"] = [
str(Path(__file__).parent / "water/data/single")
]

self.models = []
for imodel in range(2):
if imodel == 1:
config["training"]["numb_steps"] = 0
trainer = get_trainer(deepcopy(config), init_frz_model=self.models[-1])
else:
trainer = get_trainer(deepcopy(config))
trainer.run()

frozen_model = f"frozen_model{imodel}.pth"
ns = Namespace(
model="model.pt",
output=frozen_model,
head=None,
)
freeze(ns)
self.models.append(frozen_model)

def test_dp_test(self):
dp1 = DeepPot(str(self.models[0]))
dp2 = DeepPot(str(self.models[1]))
cell = np.array(
[
5.122106549439247480e00,
4.016537340154059388e-01,
6.951654033828678081e-01,
4.016537340154059388e-01,
6.112136112297989143e00,
8.178091365465004481e-01,
6.951654033828678081e-01,
8.178091365465004481e-01,
6.159552512682983760e00,
]
).reshape(1, 3, 3)
coord = np.array(
[
2.978060152121375648e00,
3.588469695887098077e00,
2.792459820604495491e00,
3.895592322591093115e00,
2.712091020667753760e00,
1.366836847133650501e00,
9.955616170888935690e-01,
4.121324820711413039e00,
1.817239061889086571e00,
3.553661462345699906e00,
5.313046969500791583e00,
6.635182659098815883e00,
6.088601018589653080e00,
6.575011420004332585e00,
6.825240650611076099e00,
]
).reshape(1, -1, 3)
atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1)

e1, f1, v1, ae1, av1 = dp1.eval(coord, cell, atype, atomic=True)
e2, f2, v2, ae2, av2 = dp2.eval(coord, cell, atype, atomic=True)
np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(ae1, ae2, rtol=1e-10, atol=1e-10)
np.testing.assert_allclose(av1, av2, rtol=1e-10, atol=1e-10)

0 comments on commit 897fcc5

Please sign in to comment.