diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index f74c4769bf..077c253ac8 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -653,6 +653,12 @@ def run(self) -> None: prof.start() def step(_step_id, task_key="Default") -> None: + if self.multi_task: + model_index = dp_random.choice( + np.arange(self.num_model, dtype=np.int_), + p=self.model_prob, + ) + task_key = self.model_keys[model_index] # PyTorch Profiler if self.enable_profiler or self.profiling: prof.step() @@ -929,24 +935,8 @@ def log_loss_valid(_task_key="Default"): self.t0 = time.time() self.total_train_time = 0.0 - for step_id in range(self.num_steps): - if step_id < self.start_step: - continue - if self.multi_task: - chosen_index_list = dp_random.choice( - np.arange( - self.num_model, dtype=np.int32 - ), # int32 should be enough for # models... - p=np.array(self.model_prob), - size=self.world_size, - replace=True, - ) - assert chosen_index_list.size == self.world_size - model_index = chosen_index_list[self.rank] - model_key = self.model_keys[model_index] - else: - model_key = "Default" - step(step_id, model_key) + for step_id in range(self.start_step, self.num_steps): + step(step_id) if JIT: break diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 9920622792..4ff161b3ba 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -28,6 +28,7 @@ ) from deepmd.pt.utils import ( + dp_random, env, ) from deepmd.pt.utils.dataset import ( @@ -50,6 +51,7 @@ def setup_seed(seed) -> None: torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True + dp_random.seed(seed) class DpLoaderSet(Dataset):