Skip to content

Commit

Permalink
Fix Code scanning errors
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 8, 2024
1 parent 993ee55 commit 0b9a002
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
20 changes: 16 additions & 4 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def prepare_trainer_input_single(
# stat files
hybrid_descrpt = model_params_single["descriptor"]["type"] == "hybrid"
if not hybrid_descrpt:
has_stat_file_path = process_stat_path(
stat_file_path_single, has_stat_file_path = process_stat_path(
data_dict_single.get("stat_file", None),
data_dict_single.get("stat_file_dir", f"stat_files{suffix}"),
model_params_single,
Expand Down Expand Up @@ -180,19 +180,30 @@ def prepare_trainer_input_single(
type_split=type_split,
noise_settings=noise_settings,
)
return train_data_single, validation_data_single, sampled_single
return (
train_data_single,
validation_data_single,
sampled_single,
stat_file_path_single,
)

if not multi_task:
train_data, validation_data, sampled = prepare_trainer_input_single(
(
train_data,
validation_data,
sampled,
stat_file_path,
) = prepare_trainer_input_single(
config["model"], config["training"], config["loss"]
)
else:
train_data, validation_data, sampled = {}, {}, {}
train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {}
for model_key in config["model"]["model_dict"]:
(
train_data[model_key],
validation_data[model_key],
sampled[model_key],
stat_file_path[model_key],
) = prepare_trainer_input_single(
config["model"]["model_dict"][model_key],
config["training"]["data_dict"][model_key],
Expand All @@ -204,6 +215,7 @@ def prepare_trainer_input_single(
config,
train_data,
sampled=sampled,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
restart_model=restart_model,
Expand Down
11 changes: 7 additions & 4 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
config: Dict[str, Any],
training_data,
sampled=None,
stat_file_path=None,
validation_data=None,
init_model=None,
restart_model=None,
Expand Down Expand Up @@ -180,13 +181,13 @@ def get_data_loader(_training_data, _validation_data, _training_params):
valid_numb_batch,
)

def get_single_model(_model_params, _sampled):
def get_single_model(_model_params, _sampled, _stat_file_path):
model = get_model(deepcopy(_model_params)).to(DEVICE)
if not model_params.get("resuming", False):
model.compute_or_load_stat(
type_map=_model_params["type_map"],
sampled=_sampled,
stat_file_path_dict=model_params.get("stat_file_path", None),
stat_file_path_dict=_stat_file_path,
)
return model

Expand Down Expand Up @@ -237,7 +238,7 @@ def get_loss(loss_params, start_lr, _ntypes):
self.validation_data,
self.valid_numb_batch,
) = get_data_loader(training_data, validation_data, training_params)
self.model = get_single_model(model_params, sampled)
self.model = get_single_model(model_params, sampled, stat_file_path)
else:
(
self.training_dataloader,
Expand All @@ -260,7 +261,9 @@ def get_loss(loss_params, start_lr, _ntypes):
training_params["data_dict"][model_key],
)
self.model[model_key] = get_single_model(
model_params["model_dict"][model_key], sampled[model_key]
model_params["model_dict"][model_key],
sampled[model_key],
stat_file_path[model_key],
)

# Learning rate
Expand Down
11 changes: 4 additions & 7 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def compute_output_bias(energy, natoms, rcond=None):
def process_stat_path(
stat_file_dict, stat_file_dir, model_params_dict, descriptor_cls, fitting_cls
):
model_params_dict["stat_file_dir"] = stat_file_dir
if stat_file_dict is None:
stat_file_dict = {}
if "descriptor" in model_params_dict:
Expand All @@ -127,13 +126,11 @@ def process_stat_path(
**model_params_dict["fitting_net"],
)
stat_file_dict["fitting_net"] = default_stat_file_name_fitting
model_params_dict["stat_file_path"] = {
key: os.path.join(model_params_dict["stat_file_dir"], stat_file_dict[key])
for key in stat_file_dict
stat_file_path = {
key: os.path.join(stat_file_dir, stat_file_dict[key]) for key in stat_file_dict
}

has_stat_file_path_list = [
os.path.exists(model_params_dict["stat_file_path"][key])
for key in stat_file_dict
os.path.exists(stat_file_path[key]) for key in stat_file_dict
]
return False not in has_stat_file_path_list
return stat_file_path, False not in has_stat_file_path_list

0 comments on commit 0b9a002

Please sign in to comment.