From 1debf4f7a00c3dbe9e60237be470ff4a04a5db91 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 1 Mar 2024 01:23:12 +0800 Subject: [PATCH] Fix h5py files in multitask DDP --- deepmd/pt/entrypoints/main.py | 13 ++++++++++--- deepmd/pt/train/training.py | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 12a3a01187..023bc5305e 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -101,7 +101,7 @@ def get_trainer( config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None) def prepare_trainer_input_single( - model_params_single, data_dict_single, loss_dict_single, suffix="" + model_params_single, data_dict_single, loss_dict_single, suffix="", rank=0 ): training_dataset_params = data_dict_single["training_data"] type_split = False @@ -115,7 +115,9 @@ def prepare_trainer_input_single( # stat files stat_file_path_single = data_dict_single.get("stat_file", None) - if stat_file_path_single is not None: + if rank != 0: + stat_file_path_single = None + elif stat_file_path_single is not None: if Path(stat_file_path_single).is_dir(): raise ValueError( f"stat_file should be a file, not a directory: {stat_file_path_single}" @@ -153,13 +155,17 @@ def prepare_trainer_input_single( stat_file_path_single, ) + rank = dist.get_rank() if dist.is_initialized() else 0 if not multi_task: ( train_data, validation_data, stat_file_path, ) = prepare_trainer_input_single( - config["model"], config["training"], config["loss"] + config["model"], + config["training"], + config["loss"], + rank=rank, ) else: train_data, validation_data, stat_file_path = {}, {}, {} @@ -173,6 +179,7 @@ def prepare_trainer_input_single( config["training"]["data_dict"][model_key], config["loss_dict"][model_key], suffix=f"_{model_key}", + rank=rank, ) trainer = training.Trainer( diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index b8d13e6f25..1003b499d6 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -193,7 +193,7 @@ def get_single_model( _training_data.add_data_requirement(_data_requirement) if _validation_data is not None: _validation_data.add_data_requirement(_data_requirement) - if not resuming: + if not resuming and self.rank == 0: @functools.lru_cache def get_sample(): @@ -429,7 +429,7 @@ def get_loss(loss_params, start_lr, _ntypes): # Multi-task share params if shared_links is not None: - self.wrapper.share_params(shared_links, resume=resuming) + self.wrapper.share_params(shared_links, resume=resuming or self.rank != 0) if dist.is_initialized(): torch.cuda.set_device(LOCAL_RANK)