Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in DataLoader and tutorial5 notebook #376

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pina/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
__all__ = [
"Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset",
"PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph'
"Trainer", "LabelTensor", "Plotter", "Condition",
"PinaDataModule", 'TorchOptimizer', 'Graph', 'LabelParameter'
]

from .meta import *
from .label_tensor import LabelTensor
from .label_tensor import LabelTensor, LabelParameter
from .solvers.solver import SolverInterface
from .trainer import Trainer
from .plotter import Plotter
from .condition.condition import Condition
from .data import SamplePointDataset

from .data import PinaDataModule
from .data import PinaDataLoader

from .optim import TorchOptimizer
from .optim import TorchScheduler
from .graph import Graph
4 changes: 2 additions & 2 deletions pina/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def store_sample_domains(self, n, mode, variables, sample_locations):
condition = self.problem.conditions[loc]
keys = ["input_points", "equation"]
# if the condition is not ready, we get and store the data
if (not self._is_conditions_ready[loc]):
if not self._is_conditions_ready[loc]:
# if it is the first time we sample
if not self.data_collections[loc]:
already_sampled = []
Expand All @@ -87,7 +87,7 @@ def store_sample_domains(self, n, mode, variables, sample_locations):
condition.domain.sample(n=n, mode=mode, variables=variables)
] + already_sampled
pts = merge_tensors(samples)
if (set(pts.labels).issubset(sorted(self.problem.input_variables))):
if set(pts.labels).issubset(sorted(self.problem.input_variables)):
pts = pts.sort_labels()
if sorted(pts.labels) == sorted(self.problem.input_variables):
self._is_conditions_ready[loc] = True
Expand Down
4 changes: 2 additions & 2 deletions pina/condition/data_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ class DataConditionInterface(ConditionInterface):
"""

__slots__ = ["input_points", "conditional_variables"]
condition_type = ['unsupervised']

def __init__(self, input_points, conditional_variables=None):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
self.conditional_variables = conditional_variables
self._condition_type = 'unsupervised'

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'conditional_variables'):
Expand Down
5 changes: 2 additions & 3 deletions pina/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ class DomainEquationCondition(ConditionInterface):
"""

__slots__ = ["domain", "equation"]

condition_type = ['physics']
def __init__(self, domain, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.domain = domain
self.equation = equation
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'domain':
Expand Down
5 changes: 2 additions & 3 deletions pina/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ class InputPointsEquationCondition(ConditionInterface):
"""

__slots__ = ["input_points", "equation"]

condition_type = ['physics']
def __init__(self, input_points, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
self.equation = equation
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'input_points':
Expand Down
5 changes: 2 additions & 3 deletions pina/condition/input_output_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ class InputOutputPointsCondition(ConditionInterface):
"""

__slots__ = ["input_points", "output_points"]

condition_type = ['supervised']
def __init__(self, input_points, output_points):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
self.output_points = output_points
self._condition_type = ['supervised', 'physics']

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'output_points'):
Expand Down
13 changes: 5 additions & 8 deletions pina/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
Import data classes
"""
__all__ = [
'PinaDataLoader', 'SupervisedDataset', 'SamplePointDataset',
'UnsupervisedDataset', 'Batch', 'PinaDataModule', 'BaseDataset'
'PinaDataModule',
'PinaDataset'
]

from .pina_dataloader import PinaDataLoader
from .supervised_dataset import SupervisedDataset
from .sample_dataset import SamplePointDataset
from .unsupervised_dataset import UnsupervisedDataset
from .pina_batch import Batch


from .data_module import PinaDataModule
from .base_dataset import BaseDataset
from .data_module import PinaDataset
157 changes: 0 additions & 157 deletions pina/data/base_dataset.py

This file was deleted.

Loading
Loading