diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 844061d0ef..ab35e32012 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -3,9 +3,6 @@ import json import logging import os -from copy import ( - deepcopy, -) from pathlib import ( Path, ) @@ -75,11 +72,9 @@ def get_trainer( model_branch="", force_load=False, init_frz_model=None, - shared_links=None, ): - multi_task = "model_dict" in config.get("model", {}) # argcheck - if not multi_task: + if "model_dict" not in config.get("model", {}): config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") config = normalize(config) @@ -90,6 +85,7 @@ def get_trainer( assert dist.is_nccl_available() dist.init_process_group(backend="nccl") + multi_task = "model_dict" in config["model"] ckpt = init_model if init_model is not None else restart_model config["model"] = change_finetune_model_params( ckpt, @@ -98,6 +94,9 @@ def get_trainer( multi_task=multi_task, model_branch=model_branch, ) + shared_links = None + if multi_task: + config["model"], shared_links = preprocess_shared_params(config["model"]) def prepare_trainer_input_single( model_params_single, data_dict_single, loss_dict_single, suffix="" @@ -221,33 +220,11 @@ def train(FLAGS): SummaryPrinter()() with open(FLAGS.INPUT) as fin: config = json.load(fin) - - # update multitask config - multi_task = "model_dict" in config["model"] - shared_links = None - if multi_task: - config["model"], shared_links = preprocess_shared_params(config["model"]) - - # do neighbor stat if not FLAGS.skip_neighbor_stat: log.info( "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" ) - if not multi_task: - config["model"] = BaseModel.update_sel(config, config["model"]) - else: - training_jdata = deepcopy(config["training"]) - training_jdata.pop("data_dict", {}) - training_jdata.pop("model_prob", {}) - for model_item in config["model"]["model_dict"]: - fake_global_jdata = { - "model": deepcopy(config["model"]["model_dict"][model_item]), - "training": deepcopy(config["training"]["data_dict"][model_item]), - } - fake_global_jdata["training"].update(training_jdata) - config["model"]["model_dict"][model_item] = BaseModel.update_sel( - fake_global_jdata, config["model"]["model_dict"][model_item] - ) + config["model"] = BaseModel.update_sel(config, config["model"]) trainer = get_trainer( config, @@ -257,7 +234,6 @@ def train(FLAGS): FLAGS.model_branch, FLAGS.force_load, FLAGS.init_frz_model, - shared_links=shared_links, ) trainer.run()