diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index ab35e32012..844061d0ef 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -3,6 +3,9 @@ import json import logging import os +from copy import ( + deepcopy, +) from pathlib import ( Path, ) @@ -72,9 +75,11 @@ def get_trainer( model_branch="", force_load=False, init_frz_model=None, + shared_links=None, ): + multi_task = "model_dict" in config.get("model", {}) # argcheck - if "model_dict" not in config.get("model", {}): + if not multi_task: config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") config = normalize(config) @@ -85,7 +90,6 @@ def get_trainer( assert dist.is_nccl_available() dist.init_process_group(backend="nccl") - multi_task = "model_dict" in config["model"] ckpt = init_model if init_model is not None else restart_model config["model"] = change_finetune_model_params( ckpt, @@ -94,9 +98,6 @@ def get_trainer( multi_task=multi_task, model_branch=model_branch, ) - shared_links = None - if multi_task: - config["model"], shared_links = preprocess_shared_params(config["model"]) def prepare_trainer_input_single( model_params_single, data_dict_single, loss_dict_single, suffix="" @@ -220,11 +221,33 @@ def train(FLAGS): SummaryPrinter()() with open(FLAGS.INPUT) as fin: config = json.load(fin) + + # update multitask config + multi_task = "model_dict" in config["model"] + shared_links = None + if multi_task: + config["model"], shared_links = preprocess_shared_params(config["model"]) + + # do neighbor stat if not FLAGS.skip_neighbor_stat: log.info( "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" ) - config["model"] = BaseModel.update_sel(config, config["model"]) + if not multi_task: + config["model"] = BaseModel.update_sel(config, config["model"]) + else: + training_jdata = deepcopy(config["training"]) + training_jdata.pop("data_dict", {}) + training_jdata.pop("model_prob", {}) + for model_item in config["model"]["model_dict"]: + fake_global_jdata = { + "model": deepcopy(config["model"]["model_dict"][model_item]), + "training": deepcopy(config["training"]["data_dict"][model_item]), + } + fake_global_jdata["training"].update(training_jdata) + config["model"]["model_dict"][model_item] = BaseModel.update_sel( + fake_global_jdata, config["model"]["model_dict"][model_item] + ) trainer = get_trainer( config, @@ -234,6 +257,7 @@ def train(FLAGS): FLAGS.model_branch, FLAGS.force_load, FLAGS.init_frz_model, + shared_links=shared_links, ) trainer.run()