From 0b9a002e0285cb8a6f2a0caff71d173d33b12ac1 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 8 Feb 2024 23:38:31 +0800 Subject: [PATCH] Fix Code scanning errors --- deepmd/pt/entrypoints/main.py | 20 ++++++++++++++++---- deepmd/pt/train/training.py | 11 +++++++---- deepmd/pt/utils/stat.py | 11 ++++------- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index f3530584dc..a919fb0db0 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -130,7 +130,7 @@ def prepare_trainer_input_single( # stat files hybrid_descrpt = model_params_single["descriptor"]["type"] == "hybrid" if not hybrid_descrpt: - has_stat_file_path = process_stat_path( + stat_file_path_single, has_stat_file_path = process_stat_path( data_dict_single.get("stat_file", None), data_dict_single.get("stat_file_dir", f"stat_files{suffix}"), model_params_single, @@ -180,19 +180,30 @@ def prepare_trainer_input_single( type_split=type_split, noise_settings=noise_settings, ) - return train_data_single, validation_data_single, sampled_single + return ( + train_data_single, + validation_data_single, + sampled_single, + stat_file_path_single, + ) if not multi_task: - train_data, validation_data, sampled = prepare_trainer_input_single( + ( + train_data, + validation_data, + sampled, + stat_file_path, + ) = prepare_trainer_input_single( config["model"], config["training"], config["loss"] ) else: - train_data, validation_data, sampled = {}, {}, {} + train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {} for model_key in config["model"]["model_dict"]: ( train_data[model_key], validation_data[model_key], sampled[model_key], + stat_file_path[model_key], ) = prepare_trainer_input_single( config["model"]["model_dict"][model_key], config["training"]["data_dict"][model_key], @@ -204,6 +215,7 @@ def prepare_trainer_input_single( config, train_data, sampled=sampled, + stat_file_path=stat_file_path, validation_data=validation_data, init_model=init_model, restart_model=restart_model, diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 494c29b248..b2cac5a5eb 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -68,6 +68,7 @@ def __init__( config: Dict[str, Any], training_data, sampled=None, + stat_file_path=None, validation_data=None, init_model=None, restart_model=None, @@ -180,13 +181,13 @@ def get_data_loader(_training_data, _validation_data, _training_params): valid_numb_batch, ) - def get_single_model(_model_params, _sampled): + def get_single_model(_model_params, _sampled, _stat_file_path): model = get_model(deepcopy(_model_params)).to(DEVICE) if not model_params.get("resuming", False): model.compute_or_load_stat( type_map=_model_params["type_map"], sampled=_sampled, - stat_file_path_dict=model_params.get("stat_file_path", None), + stat_file_path_dict=_stat_file_path, ) return model @@ -237,7 +238,7 @@ def get_loss(loss_params, start_lr, _ntypes): self.validation_data, self.valid_numb_batch, ) = get_data_loader(training_data, validation_data, training_params) - self.model = get_single_model(model_params, sampled) + self.model = get_single_model(model_params, sampled, stat_file_path) else: ( self.training_dataloader, @@ -260,7 +261,9 @@ def get_loss(loss_params, start_lr, _ntypes): training_params["data_dict"][model_key], ) self.model[model_key] = get_single_model( - model_params["model_dict"][model_key], sampled[model_key] + model_params["model_dict"][model_key], + sampled[model_key], + stat_file_path[model_key], ) # Learning rate diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index e7c023d427..932ba9a409 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -110,7 +110,6 @@ def compute_output_bias(energy, natoms, rcond=None): def process_stat_path( stat_file_dict, stat_file_dir, model_params_dict, descriptor_cls, fitting_cls ): - model_params_dict["stat_file_dir"] = stat_file_dir if stat_file_dict is None: stat_file_dict = {} if "descriptor" in model_params_dict: @@ -127,13 +126,11 @@ def process_stat_path( **model_params_dict["fitting_net"], ) stat_file_dict["fitting_net"] = default_stat_file_name_fitting - model_params_dict["stat_file_path"] = { - key: os.path.join(model_params_dict["stat_file_dir"], stat_file_dict[key]) - for key in stat_file_dict + stat_file_path = { + key: os.path.join(stat_file_dir, stat_file_dict[key]) for key in stat_file_dict } has_stat_file_path_list = [ - os.path.exists(model_params_dict["stat_file_path"][key]) - for key in stat_file_dict + os.path.exists(stat_file_path[key]) for key in stat_file_dict ] - return False not in has_stat_file_path_list + return stat_file_path, False not in has_stat_file_path_list