Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Prob sampler #129

Merged
merged 14 commits into from
Oct 31, 2023
3 changes: 3 additions & 0 deletions deepmd_pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from deepmd_pt.utils.stat import make_stat_input
from deepmd_pt.utils.multi_task import preprocess_shared_params

from deepmd_pt import __version__


def get_trainer(config, init_model=None, restart_model=None, finetune_model=None, model_branch='', force_load=False):
# Initialize DDP
Expand Down Expand Up @@ -165,6 +167,7 @@ def main(args=None):
level=logging.WARNING if env.LOCAL_RANK else logging.INFO,
format=f"%(asctime)-15s {os.environ.get('RANK') or ''} [%(filename)s:%(lineno)d] %(levelname)s %(message)s"
)
logging.info('DeepMD version: %s', __version__)
parser = argparse.ArgumentParser(description='A tool to manager deep models of potential energy surface.')
subparsers = parser.add_subparsers(dest='command')
train_parser = subparsers.add_parser('train', help='Train a model.')
Expand Down
20 changes: 17 additions & 3 deletions deepmd_pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepmd_pt.loss import EnergyStdLoss, DenoiseLoss
from deepmd_pt.model.model import get_model
from deepmd_pt.train.wrapper import ModelWrapper
from deepmd_pt.utils.dataloader import BufferedIterator
from deepmd_pt.utils.dataloader import BufferedIterator, get_weighted_sampler
from pathlib import Path
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
Expand Down Expand Up @@ -99,9 +99,23 @@ def get_opt_param(params):
return opt_type, opt_param

def get_data_loader(_training_data, _validation_data, _training_params):
if 'auto_prob' in _training_params['training_data']:
train_sampler = get_weighted_sampler(_training_data, _training_params['training_data']['auto_prob'])
elif 'sys_probs' in _training_params['training_data']:
train_sampler = get_weighted_sampler(_training_data, _training_params['training_data']['sys_probs'],sys_prob=True)
else:
train_sampler = get_weighted_sampler(_training_data, 'prob_sys_size')


if 'auto_prob' in _training_params['validation_data']:
valid_sampler = get_weighted_sampler(_validation_data, _training_params['validation_data']['auto_prob'])
elif 'sys_probs' in _training_params['validation_data']:
valid_sampler = get_weighted_sampler(_validation_data, _training_params['validation_data']['sys_probs'],sys_prob=True)
else:
valid_sampler = get_weighted_sampler(_validation_data, 'prob_sys_size')
training_dataloader = DataLoader(
_training_data,
sampler=torch.utils.data.RandomSampler(_training_data),
sampler=train_sampler,
batch_size=None,
num_workers=8, # setting to 0 diverges the behavior of its iterator; should be >=1
drop_last=False,
Expand All @@ -110,7 +124,7 @@ def get_data_loader(_training_data, _validation_data, _training_params):
training_data_buffered = BufferedIterator(iter(training_dataloader))
validation_dataloader = DataLoader(
_validation_data,
sampler=torch.utils.data.RandomSampler(_validation_data),
sampler=valid_sampler,
batch_size=None,
num_workers=1,
drop_last=False,
Expand Down
75 changes: 68 additions & 7 deletions deepmd_pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import queue
import time
import numpy as np
from threading import Thread
from typing import Callable, Dict, List, Tuple, Type, Union
from multiprocessing.dummy import Pool
Expand All @@ -11,7 +12,7 @@
import torch.distributed as dist
from deepmd_pt.utils import env
from deepmd_pt.utils.dataset import DeepmdDataSetForLoader
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
import torch.multiprocessing
Expand Down Expand Up @@ -75,6 +76,7 @@ def construct_dataset(system):

self.sampler_list: List[DistributedSampler] = []
self.index = []
self.total_batch = 0

self.dataloaders = []
for system in self.systems:
Expand Down Expand Up @@ -105,9 +107,8 @@ def construct_dataset(system):
shuffle=(not dist.is_initialized()) and shuffle,
)
self.dataloaders.append(system_dataloader)
for _ in range(len(system_dataloader)):
self.index.append(len(self.dataloaders) - 1)

self.index.append(len(system_dataloader))
self.total_batch += len(system_dataloader)
# Initialize iterator instances for DataLoader
self.iters = []
for item in self.dataloaders:
Expand All @@ -124,11 +125,11 @@ def set_noise(self, noise_settings):
system.set_noise(noise_settings)

def __len__(self):
return len(self.index)
return len(self.dataloaders)

def __getitem__(self, idx):
# logging.warning(str(torch.distributed.get_rank())+" idx: "+str(idx)+" index: "+str(self.index[idx]))
return next(self.iters[self.index[idx]])
#logging.warning(str(torch.distributed.get_rank())+" idx: "+str(idx)+" index: "+str(self.index[idx]))
return next(self.iters[idx])


_sentinel = object()
Expand Down Expand Up @@ -257,3 +258,63 @@ def collate_batch(batch):
else:
result[key] = collate_tensor_fn([d[key] for d in batch])
return result

def get_weighted_sampler(training_data,prob_style,sys_prob=False):
if sys_prob == False:
if prob_style == "prob_uniform":
prob_v = 1.0 / float(training_data.__len__())
probs = [prob_v for ii in range(training_data.__len__())]
elif prob_style == "prob_sys_size":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this case merge with else ? It should be equivalent to "prob_sys_size;0:nsys:1.0"

probs = []
for ii in range(len(training_data.dataloaders)):
prob_v = float(training_data.index[ii]) / float(training_data.total_batch)
probs.append(prob_v)
else:#prob_sys_size;A:B:p1;C:D:p2
probs = prob_sys_size_ext(prob_style,len(training_data),training_data.index)
else:
probs = process_sys_probs(prob_style,training_data.index)
logging.info("Generated weighted sampler with prob array: "+str(probs))
#training_data.total_batch is the size of one epoch, you can increase it to avoid too many rebuilding of iteraters
sampler = WeightedRandomSampler(probs,training_data.total_batch, replacement = True)
return sampler

def prob_sys_size_ext(keywords,nsystems,nbatch):
iProzd marked this conversation as resolved.
Show resolved Hide resolved
block_str = keywords.split(";")[1:]
print(block_str)
print(nbatch)
block_stt = []
block_end = []
block_weights = []
for ii in block_str:
stt = int(ii.split(":")[0])
end = int(ii.split(":")[1])
weight = float(ii.split(":")[2])
assert weight >= 0, "the weight of a block should be no less than 0"
block_stt.append(stt)
block_end.append(end)
block_weights.append(weight)
nblocks = len(block_str)
block_probs = np.array(block_weights) / np.sum(block_weights)
sys_probs = np.zeros([nsystems])
for ii in range(nblocks):
nbatch_block = nbatch[block_stt[ii] : block_end[ii]]
tmp_prob = [float(i) for i in nbatch_block] / np.sum(nbatch_block)
sys_probs[block_stt[ii] : block_end[ii]] = tmp_prob * block_probs[ii]
return sys_probs
def process_sys_probs(sys_probs,nbatch):
sys_probs = np.array(sys_probs)
type_filter = sys_probs >= 0
assigned_sum_prob = np.sum(type_filter * sys_probs)
# 1e-8 is to handle floating point error; See #1917
assert (
assigned_sum_prob <= 1.0 + 1e-8
), "the sum of assigned probability should be less than 1"
rest_sum_prob = 1.0 - assigned_sum_prob
if not np.isclose(rest_sum_prob, 0):
rest_nbatch = (1 - type_filter) * nbatch
rest_prob = rest_sum_prob * rest_nbatch / np.sum(rest_nbatch)
ret_prob = rest_prob + type_filter * sys_probs
else:
ret_prob = sys_probs
assert np.isclose(np.sum(ret_prob), 1), "sum of probs should be 1"
return ret_prob