Skip to content

Commit

Permalink
Fix neighbor-stat for multitask
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 29, 2024
1 parent cce52da commit abc4bb1
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import json
import logging
import os
from copy import (
deepcopy,
)
from pathlib import (
Path,
)
Expand Down Expand Up @@ -72,9 +75,11 @@ def get_trainer(
model_branch="",
force_load=False,
init_frz_model=None,
shared_links=None,
):
multi_task = "model_dict" in config.get("model", {})
# argcheck
if "model_dict" not in config.get("model", {}):
if not multi_task:
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)

Expand All @@ -85,7 +90,6 @@ def get_trainer(
assert dist.is_nccl_available()
dist.init_process_group(backend="nccl")

multi_task = "model_dict" in config["model"]
ckpt = init_model if init_model is not None else restart_model
config["model"] = change_finetune_model_params(
ckpt,
Expand All @@ -94,9 +98,6 @@ def get_trainer(
multi_task=multi_task,
model_branch=model_branch,
)
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

def prepare_trainer_input_single(
model_params_single, data_dict_single, loss_dict_single, suffix=""
Expand Down Expand Up @@ -220,11 +221,33 @@ def train(FLAGS):
SummaryPrinter()()
with open(FLAGS.INPUT) as fin:
config = json.load(fin)

# update multitask config
multi_task = "model_dict" in config["model"]
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

# do neighbor stat
if not FLAGS.skip_neighbor_stat:
log.info(
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)
config["model"] = BaseModel.update_sel(config, config["model"])
if not multi_task:
config["model"] = BaseModel.update_sel(config, config["model"])
else:
training_jdata = deepcopy(config["training"])
training_jdata.pop("data_dict", {})
training_jdata.pop("model_prob", {})
for model_item in config["model"]["model_dict"]:
fake_global_jdata = {
"model": deepcopy(config["model"]["model_dict"][model_item]),
"training": deepcopy(config["training"]["data_dict"][model_item]),
}
fake_global_jdata["training"].update(training_jdata)
config["model"]["model_dict"][model_item] = BaseModel.update_sel(
fake_global_jdata, config["model"]["model_dict"][model_item]
)

trainer = get_trainer(
config,
Expand All @@ -234,6 +257,7 @@ def train(FLAGS):
FLAGS.model_branch,
FLAGS.force_load,
FLAGS.init_frz_model,
shared_links=shared_links,
)
trainer.run()

Expand Down

0 comments on commit abc4bb1

Please sign in to comment.