Skip to content

Commit

Permalink
DataLoader for GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
ndem0 committed Nov 6, 2023
1 parent 541b6c4 commit 32db54d
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 29 deletions.
23 changes: 16 additions & 7 deletions pina/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class SamplePointDataset(Dataset):
This class is used to create a dataset of sample points.
"""

def __init__(self, problem) -> None:
def __init__(self, problem, device) -> None:
"""
:param dict input_pts: The input points.
"""
Expand All @@ -31,14 +31,17 @@ def __init__(self, problem) -> None:
else: # if there are no sample points
self.condition_indeces = torch.tensor([])
self.pts = torch.tensor([])

self.pts = self.pts.to(device)
self.condition_indeces = self.condition_indeces.to(device)

def __len__(self):
return self.pts.shape[0]


class DataPointDataset(Dataset):

def __init__(self, problem) -> None:
def __init__(self, problem, device) -> None:
super().__init__()
input_list = []
output_list = []
Expand All @@ -63,6 +66,10 @@ def __init__(self, problem) -> None:
self.input_pts = torch.tensor([])
self.output_pts = torch.tensor([])

self.input_pts = self.input_pts.to(device)
self.output_pts = self.output_pts.to(device)
self.condition_indeces = self.condition_indeces.to(device)

def __len__(self):
return self.input_pts.shape[0]

Expand Down Expand Up @@ -142,6 +149,7 @@ def _prepare_data_dataset(self, dataset, batch_size, shuffle):

output_labels = dataset.output_pts.labels
input_labels = dataset.input_pts.labels
self.tensor_conditions = dataset.condition_indeces

if shuffle:
idx = torch.randperm(dataset.input_pts.shape[0])
Expand Down Expand Up @@ -186,10 +194,10 @@ def _prepare_sample_dataset(self, dataset, batch_size, shuffle):
self.tensor_pts = dataset.pts
self.tensor_conditions = dataset.condition_indeces

if shuffle:
idx = torch.randperm(self.tensor_pts.shape[0])
self.tensor_pts = self.tensor_pts[idx]
self.tensor_conditions = self.tensor_conditions[idx]
# if shuffle:
# idx = torch.randperm(self.tensor_pts.shape[0])
# self.tensor_pts = self.tensor_pts[idx]
# self.tensor_conditions = self.tensor_conditions[idx]

self.batch_sample_pts = torch.tensor_split(self.tensor_pts, batch_num)
for i in range(len(self.batch_sample_pts)):
Expand All @@ -214,7 +222,8 @@ def __iter__(self):
:return: An iterator over the points.
:rtype: iter
"""
for i in self.random_idx:
#for i in self.random_idx:
for i in range(len(self.batch_list)):
type_, idx_ = self.batch_list[i]

if type_ == 'sample':
Expand Down
3 changes: 2 additions & 1 deletion pina/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ def __init__(self, p=2, reduction = 'mean', relative = False):

# check consistency
check_consistency(p, (str,int,float))
self.p = p
check_consistency(relative, bool)

self.p = p
self.relative = relative

def forward(self, input, target):
Expand Down
2 changes: 1 addition & 1 deletion pina/solvers/pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def training_step(self, batch, batch_idx):
else:
raise ValueError("Batch size not supported")


loss = loss.as_subclass(torch.Tensor)
loss = loss

condition_losses.append(loss * condition.data_weight)

Expand Down
22 changes: 16 additions & 6 deletions pina/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,30 @@ def __init__(self, solver, batch_size=None, **kwargs):
'discretise_domain function before train '
'in the provided locations.')

# TODO: make a better dataloader for train
self._create_or_update_loader()

# this method is used here because is resampling is needed
# during training, there is no need to define to touch the
# trainer dataloader, just call the method.
def _create_or_update_loader(self):
dataset_phys = SamplePointDataset(self._model.problem)
dataset_data = DataPointDataset(self._model.problem)
"""
This method is used here because is resampling is needed
during training, there is no need to define to touch the
trainer dataloader, just call the method.
"""
devices = self._accelerator_connector._parallel_devices

if len(devices) > 1:
raise RuntimeError('Parallel training is not supported yet.')

device = devices[0]
dataset_phys = SamplePointDataset(self._model.problem, device)
dataset_data = DataPointDataset(self._model.problem, device)
self._loader = SamplePointLoader(
dataset_phys, dataset_data, batch_size=self.batch_size,
shuffle=True)

def train(self, **kwargs):
"""
Train the solver.
"""
return super().fit(self._model, train_dataloaders=self._loader, **kwargs)

@property
Expand Down
20 changes: 10 additions & 10 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Poisson(SpatialProblem):
poisson.discretise_domain(10, 'grid', locations=boundaries)

def test_sample():
sample_dataset = SamplePointDataset(poisson)
sample_dataset = SamplePointDataset(poisson, device='cpu')
assert len(sample_dataset) == 140
assert sample_dataset.pts.shape == (140, 2)
assert sample_dataset.pts.labels == ['x', 'y']
Expand All @@ -65,7 +65,7 @@ def test_sample():
assert sample_dataset.condition_indeces.min() == torch.tensor(0)

def test_data():
dataset = DataPointDataset(poisson)
dataset = DataPointDataset(poisson, device='cpu')
assert len(dataset) == 61
assert dataset.input_pts.shape == (61, 2)
assert dataset.input_pts.labels == ['x', 'y']
Expand All @@ -76,16 +76,16 @@ def test_data():
assert dataset.condition_indeces.min() == torch.tensor(0)

def test_loader():
sample_dataset = SamplePointDataset(poisson)
data_dataset = DataPointDataset(poisson)
sample_dataset = SamplePointDataset(poisson, device='cpu')
data_dataset = DataPointDataset(poisson, device='cpu')
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)

for batch in loader:
assert len(batch) in [2, 3]
assert batch['pts'].shape[0] <= 10
assert batch['pts'].requires_grad == True
assert batch['pts'].labels == ['x', 'y']

loader2 = SamplePointLoader(sample_dataset, data_dataset, batch_size=None)
assert len(list(loader2)) == 2

Expand All @@ -94,8 +94,8 @@ def test_loader2():
del poisson.conditions['data2']
del poisson2.conditions['data']
poisson2.discretise_domain(10, 'grid', locations=boundaries)
sample_dataset = SamplePointDataset(poisson)
data_dataset = DataPointDataset(poisson)
sample_dataset = SamplePointDataset(poisson, device='cpu')
data_dataset = DataPointDataset(poisson, device='cpu')
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)

for batch in loader:
Expand All @@ -111,12 +111,12 @@ def test_loader3():
del poisson.conditions['gamma3']
del poisson.conditions['gamma4']
del poisson.conditions['D']
sample_dataset = SamplePointDataset(poisson)
data_dataset = DataPointDataset(poisson)
sample_dataset = SamplePointDataset(poisson, device='cpu')
data_dataset = DataPointDataset(poisson, device='cpu')
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)

for batch in loader:
assert len(batch) == 2 # only phys condtions
assert batch['pts'].shape[0] <= 10
assert batch['pts'].requires_grad == True
assert batch['pts'].labels == ['x', 'y']
assert batch['pts'].labels == ['x', 'y']
8 changes: 4 additions & 4 deletions tests/test_solvers/test_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ class Poisson(SpatialProblem):
'D': Condition(
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
equation=my_laplace),
# 'data': Condition(
# input_points=in_,
# output_points=out_),
'data': Condition(
input_points=in_,
output_points=out_),
'data2': Condition(
input_points=in2_,
output_points=out2_)
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_train_cpu():
n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
trainer = Trainer(solver=pinn, max_epochs=50, accelerator='cpu', batch_size=20)
trainer = Trainer(solver=pinn, max_epochs=1, accelerator='cuda', batch_size=20)
trainer.train()

def test_train_restore():
Expand Down

0 comments on commit 32db54d

Please sign in to comment.