-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add dataset and dataloader for sample points (#195)
* add dataset and dataloader for sample points * unittests
- Loading branch information
Showing
19 changed files
with
582 additions
and
197 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,78 +1,240 @@ | ||
from torch.utils.data import Dataset, DataLoader | ||
import functools | ||
from torch.utils.data import Dataset | ||
import torch | ||
from pina import LabelTensor | ||
|
||
|
||
class PinaDataset(): | ||
class SamplePointDataset(Dataset): | ||
""" | ||
This class is used to create a dataset of sample points. | ||
""" | ||
|
||
def __init__(self, pinn) -> None: | ||
self.pinn = pinn | ||
def __init__(self, problem, device) -> None: | ||
""" | ||
:param dict input_pts: The input points. | ||
""" | ||
super().__init__() | ||
pts_list = [] | ||
self.condition_names = [] | ||
|
||
for name, condition in problem.conditions.items(): | ||
if not hasattr(condition, 'output_points'): | ||
pts_list.append(problem.input_pts[name]) | ||
self.condition_names.append(name) | ||
|
||
self.pts = LabelTensor.vstack(pts_list) | ||
|
||
if self.pts != []: | ||
self.condition_indeces = torch.cat([ | ||
torch.tensor([i]*len(pts_list[i])) | ||
for i in range(len(self.condition_names)) | ||
], dim=0) | ||
else: # if there are no sample points | ||
self.condition_indeces = torch.tensor([]) | ||
self.pts = torch.tensor([]) | ||
|
||
self.pts = self.pts.to(device) | ||
self.condition_indeces = self.condition_indeces.to(device) | ||
|
||
def __len__(self): | ||
return self.pts.shape[0] | ||
|
||
|
||
class DataPointDataset(Dataset): | ||
|
||
def __init__(self, problem, device) -> None: | ||
super().__init__() | ||
input_list = [] | ||
output_list = [] | ||
self.condition_names = [] | ||
|
||
for name, condition in problem.conditions.items(): | ||
if hasattr(condition, 'output_points'): | ||
input_list.append(problem.conditions[name].input_points) | ||
output_list.append(problem.conditions[name].output_points) | ||
self.condition_names.append(name) | ||
|
||
self.input_pts = LabelTensor.vstack(input_list) | ||
self.output_pts = LabelTensor.vstack(output_list) | ||
|
||
if self.input_pts != []: | ||
self.condition_indeces = torch.cat([ | ||
torch.tensor([i]*len(input_list[i])) | ||
for i in range(len(self.condition_names)) | ||
], dim=0) | ||
else: # if there are no data points | ||
self.condition_indeces = torch.tensor([]) | ||
self.input_pts = torch.tensor([]) | ||
self.output_pts = torch.tensor([]) | ||
|
||
self.input_pts = self.input_pts.to(device) | ||
self.output_pts = self.output_pts.to(device) | ||
self.condition_indeces = self.condition_indeces.to(device) | ||
|
||
def __len__(self): | ||
return self.input_pts.shape[0] | ||
|
||
|
||
class SamplePointLoader: | ||
""" | ||
This class is used to create a dataloader to use during the training. | ||
:var condition_names: The names of the conditions. The order is consistent | ||
with the condition indeces in the batches. | ||
:vartype condition_names: list[str] | ||
""" | ||
|
||
def __init__(self, sample_dataset, data_dataset, batch_size=None, shuffle=True) -> None: | ||
""" | ||
Constructor. | ||
:param SamplePointDataset sample_pts: The sample points dataset. | ||
:param int batch_size: The batch size. If ``None``, the batch size is | ||
set to the number of sample points. Default is ``None``. | ||
:param bool shuffle: If ``True``, the sample points are shuffled. | ||
Default is ``True``. | ||
""" | ||
if not isinstance(sample_dataset, SamplePointDataset): | ||
raise TypeError(f'Expected SamplePointDataset, got {type(sample_dataset)}') | ||
if not isinstance(data_dataset, DataPointDataset): | ||
raise TypeError(f'Expected DataPointDataset, got {type(data_dataset)}') | ||
|
||
self.n_data_conditions = len(data_dataset.condition_names) | ||
self.n_phys_conditions = len(sample_dataset.condition_names) | ||
data_dataset.condition_indeces += self.n_phys_conditions | ||
|
||
self._prepare_sample_dataset(sample_dataset, batch_size, shuffle) | ||
self._prepare_data_dataset(data_dataset, batch_size, shuffle) | ||
|
||
self.condition_names = ( | ||
sample_dataset.condition_names + data_dataset.condition_names) | ||
|
||
self.batch_list = [] | ||
for i in range(len(self.batch_sample_pts)): | ||
self.batch_list.append( | ||
('sample', i) | ||
) | ||
|
||
@property | ||
def dataloader(self): | ||
return self._create_dataloader() | ||
for i in range(len(self.batch_input_pts)): | ||
self.batch_list.append( | ||
('data', i) | ||
) | ||
|
||
@property | ||
def dataset(self): | ||
return [self.SampleDataset(key, val) | ||
for key, val in self.input_pts.items()] | ||
if shuffle: | ||
self.random_idx = torch.randperm(len(self.batch_list)) | ||
else: | ||
self.random_idx = torch.arange(len(self.batch_list)) | ||
|
||
def _create_dataloader(self): | ||
"""Private method for creating dataloader | ||
|
||
:return: dataloader | ||
:rtype: torch.utils.data.DataLoader | ||
def _prepare_data_dataset(self, dataset, batch_size, shuffle): | ||
""" | ||
if self.pinn.batch_size is None: | ||
return {key: [{key: val}] for key, val in self.pinn.input_pts.items()} | ||
|
||
def custom_collate(batch): | ||
# extracting pts labels | ||
_, pts = list(batch[0].items())[0] | ||
labels = pts.labels | ||
# calling default torch collate | ||
collate_res = default_collate(batch) | ||
# save collate result in dict | ||
res = {} | ||
for key, val in collate_res.items(): | ||
val.labels = labels | ||
res[key] = val | ||
def __getitem__(self, index): | ||
tensor = self._tensor.select(0, index) | ||
return {self._location: tensor} | ||
|
||
def __len__(self): | ||
return self._len | ||
|
||
|
||
|
||
# TODO: working also for datapoints | ||
class DummyLoader: | ||
|
||
def __init__(self, data, device) -> None: | ||
|
||
# TODO: We need to make a dataset somehow | ||
# and the PINADataset needs to have a method | ||
# to send points to device | ||
# now we simply do it here | ||
# send data to device | ||
def convert_tensors(pts, device): | ||
pts = pts.to(device) | ||
pts.requires_grad_(True) | ||
pts.retain_grad() | ||
return pts | ||
|
||
for location, pts in data.items(): | ||
if isinstance(pts, (tuple, list)): | ||
pts = tuple(map(functools.partial(convert_tensors, device=device),pts)) | ||
else: | ||
pts = pts.to(device) | ||
pts = pts.requires_grad_(True) | ||
pts.retain_grad() | ||
|
||
data[location] = pts | ||
Prepare the dataset for data points. | ||
# iterator | ||
self.data = [data] | ||
:param SamplePointDataset dataset: The dataset. | ||
:param int batch_size: The batch size. | ||
:param bool shuffle: If ``True``, the sample points are shuffled. | ||
""" | ||
self.sample_dataset = dataset | ||
|
||
if len(dataset) == 0: | ||
self.batch_data_conditions = [] | ||
self.batch_input_pts = [] | ||
self.batch_output_pts = [] | ||
return | ||
|
||
if batch_size is None: | ||
batch_size = len(dataset) | ||
batch_num = len(dataset) // batch_size | ||
if len(dataset) % batch_size != 0: | ||
batch_num += 1 | ||
|
||
output_labels = dataset.output_pts.labels | ||
input_labels = dataset.input_pts.labels | ||
self.tensor_conditions = dataset.condition_indeces | ||
|
||
if shuffle: | ||
idx = torch.randperm(dataset.input_pts.shape[0]) | ||
self.input_pts = dataset.input_pts[idx] | ||
self.output_pts = dataset.output_pts[idx] | ||
self.tensor_conditions = dataset.condition_indeces[idx] | ||
|
||
self.batch_input_pts = torch.tensor_split( | ||
dataset.input_pts, batch_num) | ||
self.batch_output_pts = torch.tensor_split( | ||
dataset.output_pts, batch_num) | ||
|
||
for i in range(len(self.batch_input_pts)): | ||
self.batch_input_pts[i].labels = input_labels | ||
self.batch_output_pts[i].labels = output_labels | ||
|
||
self.batch_data_conditions = torch.tensor_split( | ||
self.tensor_conditions, batch_num) | ||
|
||
def _prepare_sample_dataset(self, dataset, batch_size, shuffle): | ||
""" | ||
Prepare the dataset for sample points. | ||
:param DataPointDataset dataset: The dataset. | ||
:param int batch_size: The batch size. | ||
:param bool shuffle: If ``True``, the sample points are shuffled. | ||
""" | ||
|
||
self.sample_dataset = dataset | ||
if len(dataset) == 0: | ||
self.batch_sample_conditions = [] | ||
self.batch_sample_pts = [] | ||
return | ||
|
||
if batch_size is None: | ||
batch_size = len(dataset) | ||
|
||
batch_num = len(dataset) // batch_size | ||
if len(dataset) % batch_size != 0: | ||
batch_num += 1 | ||
|
||
self.tensor_pts = dataset.pts | ||
self.tensor_conditions = dataset.condition_indeces | ||
|
||
# if shuffle: | ||
# idx = torch.randperm(self.tensor_pts.shape[0]) | ||
# self.tensor_pts = self.tensor_pts[idx] | ||
# self.tensor_conditions = self.tensor_conditions[idx] | ||
|
||
self.batch_sample_pts = torch.tensor_split(self.tensor_pts, batch_num) | ||
for i in range(len(self.batch_sample_pts)): | ||
self.batch_sample_pts[i].labels = dataset.pts.labels | ||
|
||
self.batch_sample_conditions = torch.tensor_split( | ||
self.tensor_conditions, batch_num) | ||
|
||
def __iter__(self): | ||
return iter(self.data) | ||
""" | ||
Return an iterator over the points. Any element of the iterator is a | ||
dictionary with the following keys: | ||
- ``pts``: The input sample points. It is a LabelTensor with the | ||
shape ``(batch_size, input_dimension)``. | ||
- ``output``: The output sample points. This key is present only | ||
if data conditions are present. It is a LabelTensor with the | ||
shape ``(batch_size, output_dimension)``. | ||
- ``condition``: The integer condition indeces. It is a tensor | ||
with the shape ``(batch_size, )`` of type ``torch.int64`` and | ||
indicates for any ``pts`` the corresponding problem condition. | ||
:return: An iterator over the points. | ||
:rtype: iter | ||
""" | ||
#for i in self.random_idx: | ||
for i in range(len(self.batch_list)): | ||
type_, idx_ = self.batch_list[i] | ||
|
||
if type_ == 'sample': | ||
d = { | ||
'pts': self.batch_sample_pts[idx_].requires_grad_(True), | ||
'condition': self.batch_sample_conditions[idx_], | ||
} | ||
else: | ||
d = { | ||
'pts': self.batch_input_pts[idx_].requires_grad_(True), | ||
'output': self.batch_output_pts[idx_], | ||
'condition': self.batch_data_conditions[idx_], | ||
} | ||
yield d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.