From f2815656aa85d987420897f0f7e5cec1000e532b Mon Sep 17 00:00:00 2001 From: caic99 Date: Wed, 27 Nov 2024 14:43:47 +0800 Subject: [PATCH 1/5] chore: refactor training loop --- deepmd/pt/train/training.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index f74c4769bf..b48c4993af 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -653,6 +653,12 @@ def run(self) -> None: prof.start() def step(_step_id, task_key="Default") -> None: + if self.multi_task: + model_index = dp_random.choice( + np.arange(self.num_model), + p=self.model_prob, + ) + task_key = self.model_keys[model_index] # PyTorch Profiler if self.enable_profiler or self.profiling: prof.step() @@ -929,24 +935,8 @@ def log_loss_valid(_task_key="Default"): self.t0 = time.time() self.total_train_time = 0.0 - for step_id in range(self.num_steps): - if step_id < self.start_step: - continue - if self.multi_task: - chosen_index_list = dp_random.choice( - np.arange( - self.num_model, dtype=np.int32 - ), # int32 should be enough for # models... - p=np.array(self.model_prob), - size=self.world_size, - replace=True, - ) - assert chosen_index_list.size == self.world_size - model_index = chosen_index_list[self.rank] - model_key = self.model_keys[model_index] - else: - model_key = "Default" - step(step_id, model_key) + for step_id in range(self.start_step, self.num_steps): + step(step_id) if JIT: break From 64e310a977e8b09fd7ff47332f07ca1ad0dcb436 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Wed, 27 Nov 2024 15:43:27 +0800 Subject: [PATCH 2/5] Update deepmd/pt/train/training.py Signed-off-by: Chun Cai --- deepmd/pt/train/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index b48c4993af..077c253ac8 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -655,7 +655,7 @@ def run(self) -> None: def step(_step_id, task_key="Default") -> None: if self.multi_task: model_index = dp_random.choice( - np.arange(self.num_model), + np.arange(self.num_model, dtype=np.int_), p=self.model_prob, ) task_key = self.model_keys[model_index] From b88ab3da798f508bb6675a448babd88bff753161 Mon Sep 17 00:00:00 2001 From: caic99 Date: Thu, 28 Nov 2024 16:48:20 +0800 Subject: [PATCH 3/5] also seeding numpy --- deepmd/pt/utils/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 9920622792..203dcdc18b 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -50,7 +50,7 @@ def setup_seed(seed) -> None: torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True - + np.random.seed(seed) class DpLoaderSet(Dataset): """A dataset for storing DataLoaders to multiple Systems. From e0ace97f0addf74bff1ae81cb0d2ef660491738f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Nov 2024 08:50:03 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/dataloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 203dcdc18b..bc1dc5dec4 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -52,6 +52,7 @@ def setup_seed(seed) -> None: torch.backends.cudnn.deterministic = True np.random.seed(seed) + class DpLoaderSet(Dataset): """A dataset for storing DataLoaders to multiple Systems. From e6f81a5d530eaa5b2eff7a9543d1800e0d88b01b Mon Sep 17 00:00:00 2001 From: caic99 Date: Thu, 28 Nov 2024 17:03:39 +0800 Subject: [PATCH 5/5] fix ut --- deepmd/pt/utils/dataloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index bc1dc5dec4..4ff161b3ba 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -28,6 +28,7 @@ ) from deepmd.pt.utils import ( + dp_random, env, ) from deepmd.pt.utils.dataset import ( @@ -50,7 +51,7 @@ def setup_seed(seed) -> None: torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True - np.random.seed(seed) + dp_random.seed(seed) class DpLoaderSet(Dataset):