Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix neighbor-stat for multitask #31

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import json
import logging
import os
from copy import (
deepcopy,
)
from pathlib import (
Path,
)
Expand Down Expand Up @@ -72,9 +75,11 @@ def get_trainer(
model_branch="",
force_load=False,
init_frz_model=None,
shared_links=None,
):
multi_task = "model_dict" in config.get("model", {})
# argcheck
if "model_dict" not in config.get("model", {}):
if not multi_task:
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)

Expand All @@ -85,7 +90,6 @@ def get_trainer(
assert dist.is_nccl_available()
dist.init_process_group(backend="nccl")

multi_task = "model_dict" in config["model"]
ckpt = init_model if init_model is not None else restart_model
config["model"] = change_finetune_model_params(
ckpt,
Expand All @@ -94,9 +98,6 @@ def get_trainer(
multi_task=multi_task,
model_branch=model_branch,
)
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

def prepare_trainer_input_single(
model_params_single, data_dict_single, loss_dict_single, suffix=""
Expand Down Expand Up @@ -220,11 +221,33 @@ def train(FLAGS):
SummaryPrinter()()
with open(FLAGS.INPUT) as fin:
config = json.load(fin)

# update multitask config
multi_task = "model_dict" in config["model"]
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

# do neighbor stat
if not FLAGS.skip_neighbor_stat:
log.info(
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)
config["model"] = BaseModel.update_sel(config, config["model"])
if not multi_task:
config["model"] = BaseModel.update_sel(config, config["model"])
else:
training_jdata = deepcopy(config["training"])
training_jdata.pop("data_dict", {})
training_jdata.pop("model_prob", {})
for model_item in config["model"]["model_dict"]:
fake_global_jdata = {
"model": deepcopy(config["model"]["model_dict"][model_item]),
"training": deepcopy(config["training"]["data_dict"][model_item]),
}
fake_global_jdata["training"].update(training_jdata)
config["model"]["model_dict"][model_item] = BaseModel.update_sel(
fake_global_jdata, config["model"]["model_dict"][model_item]
)

trainer = get_trainer(
config,
Expand All @@ -234,6 +257,7 @@ def train(FLAGS):
FLAGS.model_branch,
FLAGS.force_load,
FLAGS.init_frz_model,
shared_links=shared_links,
)
trainer.run()

Expand Down
Loading