Skip to content

Commit

Permalink
Fix h5py files in multitask DDP
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 29, 2024
1 parent e17546a commit 1debf4f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
13 changes: 10 additions & 3 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_trainer(
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)

def prepare_trainer_input_single(
model_params_single, data_dict_single, loss_dict_single, suffix=""
model_params_single, data_dict_single, loss_dict_single, suffix="", rank=0
):
training_dataset_params = data_dict_single["training_data"]
type_split = False
Expand All @@ -115,7 +115,9 @@ def prepare_trainer_input_single(

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
if stat_file_path_single is not None:
if rank != 0:
stat_file_path_single = None

Check warning on line 119 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L119

Added line #L119 was not covered by tests
elif stat_file_path_single is not None:
if Path(stat_file_path_single).is_dir():
raise ValueError(
f"stat_file should be a file, not a directory: {stat_file_path_single}"
Expand Down Expand Up @@ -153,13 +155,17 @@ def prepare_trainer_input_single(
stat_file_path_single,
)

rank = dist.get_rank() if dist.is_initialized() else 0
if not multi_task:
(
train_data,
validation_data,
stat_file_path,
) = prepare_trainer_input_single(
config["model"], config["training"], config["loss"]
config["model"],
config["training"],
config["loss"],
rank=rank,
)
else:
train_data, validation_data, stat_file_path = {}, {}, {}
Expand All @@ -173,6 +179,7 @@ def prepare_trainer_input_single(
config["training"]["data_dict"][model_key],
config["loss_dict"][model_key],
suffix=f"_{model_key}",
rank=rank,
)

trainer = training.Trainer(
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def get_single_model(
_training_data.add_data_requirement(_data_requirement)
if _validation_data is not None:
_validation_data.add_data_requirement(_data_requirement)
if not resuming:
if not resuming and self.rank == 0:

@functools.lru_cache
def get_sample():
Expand Down Expand Up @@ -429,7 +429,7 @@ def get_loss(loss_params, start_lr, _ntypes):

# Multi-task share params
if shared_links is not None:
self.wrapper.share_params(shared_links, resume=resuming)
self.wrapper.share_params(shared_links, resume=resuming or self.rank != 0)

if dist.is_initialized():
torch.cuda.set_device(LOCAL_RANK)
Expand Down

0 comments on commit 1debf4f

Please sign in to comment.