Skip to content

Commit

Permalink
load data on cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
CaRoLZhangxy committed Apr 1, 2024
1 parent 8b1b280 commit 7ce216e
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ def get_dataloader_and_buffer(_data, _params):
_data_buffered = BufferedIterator(iter(_dataloader))
else:
_dataloader = DataLoader(
_data,
sampler=_sampler,
batch_size=None,
num_workers=0, # setting to 0 diverges the behavior of its iterator; should be >=1
drop_last=False,
pin_memory=True,
_data,
sampler=_sampler,
batch_size=None,
num_workers=0, # setting to 0 diverges the behavior of its iterator; should be >=1
drop_last=False,
pin_memory=True,
)
with torch.device("cpu"):
_data_buffered = _dataloader

_data_buffered = _dataloader
return _dataloader, _data_buffered

training_dataloader, training_data_buffered = get_dataloader_and_buffer(
Expand Down Expand Up @@ -1048,7 +1048,8 @@ def get_data(self, is_train=True, task_key="Default"):
else:
if is_train:
try:
batch_data = next(iter(self.training_data[task_key]))
with torch.device("cpu"):
batch_data = next(iter(self.training_data[task_key]))
except StopIteration:
# Refresh the status of the dataloader to start from a new epoch
with torch.device("cpu"):
Expand All @@ -1058,7 +1059,8 @@ def get_data(self, is_train=True, task_key="Default"):
if self.validation_data[task_key] is None:
return {}, {}, {}
try:
batch_data = next(iter(self.validation_data[task_key]))
with torch.device("cpu"):
batch_data = next(iter(self.validation_data[task_key]))
except StopIteration:
self.validation_data[task_key] = self.validation_dataloader[
task_key
Expand Down

0 comments on commit 7ce216e

Please sign in to comment.