Skip to content

Commit

Permalink
Revert "Fix neighbor-stat for multitask (#31)"
Browse files Browse the repository at this point in the history
This reverts commit cdcfcb2.
  • Loading branch information
iProzd committed Feb 29, 2024
1 parent cdcfcb2 commit a7d44d1
Showing 1 changed file with 6 additions and 30 deletions.
36 changes: 6 additions & 30 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import json
import logging
import os
from copy import (
deepcopy,
)
from pathlib import (
Path,
)
Expand Down Expand Up @@ -75,11 +72,9 @@ def get_trainer(
model_branch="",
force_load=False,
init_frz_model=None,
shared_links=None,
):
multi_task = "model_dict" in config.get("model", {})
# argcheck
if not multi_task:
if "model_dict" not in config.get("model", {}):
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)

Expand All @@ -90,6 +85,7 @@ def get_trainer(
assert dist.is_nccl_available()
dist.init_process_group(backend="nccl")

multi_task = "model_dict" in config["model"]
ckpt = init_model if init_model is not None else restart_model
config["model"] = change_finetune_model_params(
ckpt,
Expand All @@ -98,6 +94,9 @@ def get_trainer(
multi_task=multi_task,
model_branch=model_branch,
)
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

def prepare_trainer_input_single(
model_params_single, data_dict_single, loss_dict_single, suffix=""
Expand Down Expand Up @@ -221,33 +220,11 @@ def train(FLAGS):
SummaryPrinter()()
with open(FLAGS.INPUT) as fin:
config = json.load(fin)

# update multitask config
multi_task = "model_dict" in config["model"]
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

# do neighbor stat
if not FLAGS.skip_neighbor_stat:
log.info(
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)
if not multi_task:
config["model"] = BaseModel.update_sel(config, config["model"])
else:
training_jdata = deepcopy(config["training"])
training_jdata.pop("data_dict", {})
training_jdata.pop("model_prob", {})
for model_item in config["model"]["model_dict"]:
fake_global_jdata = {
"model": deepcopy(config["model"]["model_dict"][model_item]),
"training": deepcopy(config["training"]["data_dict"][model_item]),
}
fake_global_jdata["training"].update(training_jdata)
config["model"]["model_dict"][model_item] = BaseModel.update_sel(
fake_global_jdata, config["model"]["model_dict"][model_item]
)
config["model"] = BaseModel.update_sel(config, config["model"])

trainer = get_trainer(
config,
Expand All @@ -257,7 +234,6 @@ def train(FLAGS):
FLAGS.model_branch,
FLAGS.force_load,
FLAGS.init_frz_model,
shared_links=shared_links,
)
trainer.run()

Expand Down

0 comments on commit a7d44d1

Please sign in to comment.