Skip to content

Commit

Permalink
Bug fix in LabelTensor, Dataset and DataLoader, solve #377, first att…
Browse files Browse the repository at this point in the history
…empt with PINN
  • Loading branch information
FilippoOlivo committed Nov 7, 2024
1 parent dbb5476 commit 30e2fa8
Show file tree
Hide file tree
Showing 25 changed files with 397 additions and 511 deletions.
2 changes: 1 addition & 1 deletion pina/condition/data_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class DataConditionInterface(ConditionInterface):
"""

__slots__ = ["input_points", "conditional_variables"]
condition_type = ['unsupervised']

def __init__(self, input_points, conditional_variables=None):
"""
Expand All @@ -23,7 +24,6 @@ def __init__(self, input_points, conditional_variables=None):
super().__init__()
self.input_points = input_points
self.conditional_variables = conditional_variables
self._condition_type = 'unsupervised'

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'conditional_variables'):
Expand Down
3 changes: 1 addition & 2 deletions pina/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ class DomainEquationCondition(ConditionInterface):
"""

__slots__ = ["domain", "equation"]

condition_type = ['physics']
def __init__(self, domain, equation):
"""
TODO
"""
super().__init__()
self.domain = domain
self.equation = equation
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'domain':
Expand Down
3 changes: 1 addition & 2 deletions pina/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ class InputPointsEquationCondition(ConditionInterface):
"""

__slots__ = ["input_points", "equation"]

condition_type = ['physics']
def __init__(self, input_points, equation):
"""
TODO
"""
super().__init__()
self.input_points = input_points
self.equation = equation
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'input_points':
Expand Down
3 changes: 1 addition & 2 deletions pina/condition/input_output_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ class InputOutputPointsCondition(ConditionInterface):
"""

__slots__ = ["input_points", "output_points"]

condition_type = ['supervised']
def __init__(self, input_points, output_points):
"""
TODO
"""
super().__init__()
self.input_points = input_points
self.output_points = output_points
self._condition_type = ['supervised', 'physics']

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'output_points'):
Expand Down
9 changes: 4 additions & 5 deletions pina/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging

from torch.utils.data import Dataset

from ..label_tensor import LabelTensor


Expand Down Expand Up @@ -109,14 +108,14 @@ def initialize(self):
already filled
"""
logging.debug(f'Initialize dataset {self.__class__.__name__}')

if self.num_el_per_condition:
self.condition_indices = torch.cat([
torch.tensor([i] * self.num_el_per_condition[i],
dtype=torch.uint8)
torch.tensor(
[self.conditions_idx[i]] * self.num_el_per_condition[i],
dtype=torch.uint8)
for i in range(len(self.num_el_per_condition))
],
dim=0)
dim=0)
for slot in self.__slots__:
current_attribute = getattr(self, slot)
if all(isinstance(a, LabelTensor) for a in current_attribute):
Expand Down
21 changes: 11 additions & 10 deletions pina/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def __init__(self,
problem,
device,
train_size=.7,
test_size=.1,
val_size=.2,
test_size=.2,
val_size=.1,
predict_size=0.,
batch_size=None,
shuffle=True,
Expand Down Expand Up @@ -61,28 +61,30 @@ def __init__(self,
if train_size > 0:
self.split_names.append('train')
self.split_length.append(train_size)
self.loader_functions['train_dataloader'] = lambda: PinaDataLoader(
self.splits['train'], self.batch_size, self.condition_names)
self.loader_functions['train_dataloader'] = lambda \
x: PinaDataLoader(self.splits['train'], self.batch_size,
self.condition_names)
if test_size > 0:
self.split_length.append(test_size)
self.split_names.append('test')
self.loader_functions['test_dataloader'] = lambda: PinaDataLoader(
self.loader_functions['test_dataloader'] = lambda x: PinaDataLoader(
self.splits['test'], self.batch_size, self.condition_names)
if val_size > 0:
self.split_length.append(val_size)
self.split_names.append('val')
self.loader_functions['val_dataloader'] = lambda: PinaDataLoader(
self.loader_functions['val_dataloader'] = lambda x: PinaDataLoader(
self.splits['val'], self.batch_size, self.condition_names)
if predict_size > 0:
self.split_length.append(predict_size)
self.split_names.append('predict')
self.loader_functions['predict_dataloader'] = lambda: PinaDataLoader(
self.loader_functions[
'predict_dataloader'] = lambda x: PinaDataLoader(
self.splits['predict'], self.batch_size, self.condition_names)
self.splits = {k: {} for k in self.split_names}
self.shuffle = shuffle

for k, v in self.loader_functions.items():
setattr(self, k, v)
setattr(self, k, v.__get__(self, PinaDataModule))

def prepare_data(self):
if self.datasets is None:
Expand Down Expand Up @@ -140,12 +142,11 @@ def dataset_split(dataset, lengths, seed=None, shuffle=True):
indices = torch.randperm(sum(lengths))
dataset.apply_shuffle(indices)

indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist()
offsets = [
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
]
return [
PinaSubset(dataset, indices[offset:offset + length])
PinaSubset(dataset, slice(offset, offset + length))
for offset, length in zip(offsets, lengths)
]

Expand Down
16 changes: 11 additions & 5 deletions pina/data/pina_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,27 @@ def __len__(self):
:rtype: int
"""
length = 0
for dataset in dir(self):
for dataset in self.attributes:
attribute = getattr(self, dataset)
if isinstance(attribute, list):
length += len(getattr(self, dataset))
length += len(attribute)
return length

def __getattribute__(self, item):
if item in super().__getattribute__('attributes'):
dataset = super().__getattribute__(item)
index = super().__getattribute__(item + '_idx')
return PinaSubset(dataset.dataset, dataset.indices[index])
if isinstance(dataset, PinaSubset):
dataset_index = dataset.indices
if isinstance(dataset_index, slice):
index = slice(dataset_index.start + index.start,
min(dataset_index.start + index.stop,
dataset_index.stop))
return PinaSubset(dataset.dataset, index,
require_grad=self.require_grad)
return super().__getattribute__(item)

def __getattr__(self, item):
if item == 'data' and len(self.attributes) == 1:
item = self.attributes[0]
return super().__getattribute__(item)
return self.__getattribute__(item)
raise AttributeError(f"'Batch' object has no attribute '{item}'")
52 changes: 40 additions & 12 deletions pina/data/pina_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This module is used to create an iterable object used during training
"""
import math

from .pina_batch import Batch


Expand All @@ -26,6 +27,7 @@ def __init__(self, dataset_dict, batch_size, condition_names) -> None:
"""
self.condition_names = condition_names
self.dataset_dict = dataset_dict
self.batch_size = batch_size
self._init_batches(batch_size)

def _init_batches(self, batch_size=None):
Expand All @@ -36,20 +38,46 @@ def _init_batches(self, batch_size=None):
n_elements = sum(len(v) for v in self.dataset_dict.values())
if batch_size is None:
batch_size = n_elements
indexes_dict = {}
self.batch_size = n_elements
n_batches = int(math.ceil(n_elements / batch_size))
for k, v in self.dataset_dict.items():
if n_batches != 1:
indexes_dict[k] = math.floor(len(v) / (n_batches - 1))
else:
indexes_dict[k] = len(v)
for i in range(n_batches):
temp_dict = {}
for k, v in indexes_dict.items():
if i != n_batches - 1:
temp_dict[k] = slice(i * v, (i + 1) * v)
indexes_dict = {
k: math.floor(len(v) / n_batches) if n_batches != 1 else len(v) for
k, v in self.dataset_dict.items()}

dataset_names = list(self.dataset_dict.keys())
num_el_per_batch = [{i: indexes_dict[i] for i in dataset_names} for _
in range(n_batches - 1)] + [
{i: 0 for i in dataset_names}]
reminders = {
i: len(self.dataset_dict[i]) - indexes_dict[i] * (n_batches - 1) for
i in dataset_names}
dataset_names = iter(dataset_names)
name = next(dataset_names, None)
for batch in num_el_per_batch:
tot_num_el = sum(batch.values())
batch_reminder = batch_size - tot_num_el
for _ in range(batch_reminder):
if name is None:
break
if reminders[name] > 0:
batch[name] += 1
reminders[name] -= 1
else:
temp_dict[k] = slice(i * v, len(self.dataset_dict[k]))
name = next(dataset_names, None)
if name is None:
break
batch[name] += 1
reminders[name] -= 1

reminders, dataset_names, indexes_dict = None, None, None # free memory
actual_indices = {k: 0 for k in self.dataset_dict.keys()}
for batch in num_el_per_batch:
temp_dict = {}
total_length = 0
for k, v in batch.items():
temp_dict[k] = slice(actual_indices[k], actual_indices[k] + v)
actual_indices[k] = actual_indices[k] + v
total_length += v
self.batches.append(
Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict))

Expand Down
19 changes: 16 additions & 3 deletions pina/data/pina_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class PinaSubset:
"""
__slots__ = ['dataset', 'indices', 'require_grad']

def __init__(self, dataset, indices, require_grad=True):
def __init__(self, dataset, indices, require_grad=False):
"""
TODO
"""
Expand All @@ -23,14 +23,27 @@ def __len__(self):
"""
TODO
"""
if isinstance(self.indices, slice):
return self.indices.stop - self.indices.start
return len(self.indices)

def __getattr__(self, name):
tensor = self.dataset.__getattribute__(name)
if isinstance(tensor, (LabelTensor, Tensor)):
tensor = tensor[[self.indices]].to(self.dataset.device)
if isinstance(self.indices, slice):
tensor = tensor[self.indices]
if (tensor.device != self.dataset.device
and tensor.dtype == float32):
tensor = tensor.to(self.dataset.device)
elif isinstance(self.indices, list):
tensor = tensor[[self.indices]].to(self.dataset.device)
else:
raise ValueError(f"Indices type {type(self.indices)} not "
f"supported")
return tensor.requires_grad_(
self.require_grad) if tensor.dtype == float32 else tensor
if isinstance(tensor, list):
return [tensor[i] for i in self.indices]
if isinstance(self.indices, list):
return [tensor[i] for i in self.indices]
return tensor[self.indices]
raise AttributeError(f"No attribute named {name}")
6 changes: 3 additions & 3 deletions pina/domain/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,15 @@ def _single_points_sample(n, variables):
return result

if self.fixed_ and (not self.range_):
return _single_points_sample(n, variables)
return _single_points_sample(n, variables).sort_labels()

if variables == "all":
variables = list(self.range_.keys()) + list(self.fixed_.keys())

if mode in ["grid", "chebyshev"]:
return _1d_sampler(n, mode, variables)
return _1d_sampler(n, mode, variables).sort_labels()
elif mode in ["random", "lh", "latin"]:
return _Nd_sampler(n, mode, variables)
return _Nd_sampler(n, mode, variables).sort_labels()
else:
raise ValueError(f"mode={mode} is not valid.")

Expand Down
2 changes: 1 addition & 1 deletion pina/domain/operation_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _check_dimensions(self, geometries):
:type geometries: list[Location]
"""
for geometry in geometries:
if geometry.variables != geometries[0].variables:
if sorted(geometry.variables) != sorted(geometries[0].variables):
raise NotImplementedError(
f"The geometries need to have same dimensions and labels."
)
3 changes: 1 addition & 2 deletions pina/domain/union_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,4 @@ def sample(self, n, mode="random", variables="all"):
# in case number of sampled points is smaller than the number of geometries
if len(sampled_points) >= n:
break

return LabelTensor(torch.cat(sampled_points), labels=self.variables)
return LabelTensor.cat(sampled_points)
Loading

0 comments on commit 30e2fa8

Please sign in to comment.