diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 6fe8562ad6..1e25da77fb 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -128,61 +128,45 @@ def get_opt_param(params): return opt_type, opt_param def get_data_loader(_training_data, _validation_data, _training_params): - if "auto_prob" in _training_params["training_data"]: - train_sampler = get_weighted_sampler( - _training_data, _training_params["training_data"]["auto_prob"] - ) - elif "sys_probs" in _training_params["training_data"]: - train_sampler = get_weighted_sampler( - _training_data, - _training_params["training_data"]["sys_probs"], - sys_prob=True, - ) - else: - train_sampler = get_weighted_sampler(_training_data, "prob_sys_size") - - if train_sampler is None: - log.warning( - "Sampler not specified!" - ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration. - training_dataloader = DataLoader( - _training_data, - sampler=train_sampler, - batch_size=None, - num_workers=NUM_WORKERS, # setting to 0 diverges the behavior of its iterator; should be >=1 - drop_last=False, - pin_memory=True, - ) - with torch.device("cpu"): - training_data_buffered = BufferedIterator(iter(training_dataloader)) - if _validation_data is not None: - if "auto_prob" in _training_params["validation_data"]: - valid_sampler = get_weighted_sampler( - _validation_data, - _training_params["validation_data"]["auto_prob"], + def get_dataloader_and_buffer(_data, _params): + if "auto_prob" in _training_params["training_data"]: + _sampler = get_weighted_sampler( + _data, _params["training_data"]["auto_prob"] ) - elif "sys_probs" in _training_params["validation_data"]: - valid_sampler = get_weighted_sampler( - _validation_data, - _training_params["validation_data"]["sys_probs"], + elif "sys_probs" in _training_params["training_data"]: + _sampler = get_weighted_sampler( + _data, + _params["training_data"]["sys_probs"], sys_prob=True, ) else: - valid_sampler = get_weighted_sampler( - _validation_data, "prob_sys_size" - ) - validation_dataloader = DataLoader( - _validation_data, - sampler=valid_sampler, + _sampler = get_weighted_sampler(_data, "prob_sys_size") + + if _sampler is None: + log.warning( + "Sampler not specified!" + ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration. + _dataloader = DataLoader( + _data, + sampler=_sampler, batch_size=None, - num_workers=min(NUM_WORKERS, 1), + num_workers=NUM_WORKERS, # setting to 0 diverges the behavior of its iterator; should be >=1 drop_last=False, pin_memory=True, ) with torch.device("cpu"): - validation_data_buffered = BufferedIterator( - iter(validation_dataloader) - ) + _data_buffered = BufferedIterator(iter(_dataloader)) + return _dataloader, _data_buffered + + training_dataloader, training_data_buffered = get_dataloader_and_buffer( + _training_data, _training_params + ) + + if _validation_data is not None: + ( + validation_dataloader, + validation_data_buffered, + ) = get_dataloader_and_buffer(_validation_data, _training_params) if _training_params.get("validation_data", None) is not None: valid_numb_batch = _training_params["validation_data"].get( "numb_btch", 1