diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 0359071d71..361bc4b0b6 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -14,6 +14,7 @@ ) import h5py +import numpy as np import torch import torch.distributed as dist import torch.multiprocessing @@ -106,29 +107,34 @@ def construct_dataset(system): self.dataloaders = [] self.batch_sizes = [] - for system in self.systems: + if isinstance(batch_size, str): + if batch_size == "auto": + rule = 32 + elif batch_size.startswith("auto:"): + rule = int(batch_size.split(":")[1]) + else: + rule = None + log.error("Unsupported batch size type") + for ii in self.systems: + ni = ii._natoms + bsi = rule // ni + if bsi * ni < rule: + bsi += 1 + self.batch_sizes.append(bsi) + elif isinstance(batch_size, list): + self.batch_sizes = batch_size + else: + self.batch_sizes = batch_size * np.ones(len(systems), dtype=int) + assert len(self.systems) == len(self.batch_sizes) + for system, batch_size in zip(self.systems, self.batch_sizes): if dist.is_initialized(): system_sampler = DistributedSampler(system) self.sampler_list.append(system_sampler) else: system_sampler = None - if isinstance(batch_size, str): - if batch_size == "auto": - rule = 32 - elif batch_size.startswith("auto:"): - rule = int(batch_size.split(":")[1]) - else: - rule = None - log.error("Unsupported batch size type") - self.batch_size = rule // system._natoms - if self.batch_size * system._natoms < rule: - self.batch_size += 1 - else: - self.batch_size = batch_size - self.batch_sizes.append(self.batch_size) system_dataloader = DataLoader( dataset=system, - batch_size=self.batch_size, + batch_size=int(batch_size), num_workers=0, # Should be 0 to avoid too many threads forked sampler=system_sampler, collate_fn=collate_batch,