Skip to content

Commit

Permalink
* Adding trainer rst
Browse files Browse the repository at this point in the history
* Adding nitpick_ignore files to be compliant with Lightining
* TODO: paramlinks sphinix not working properly/ adding missing rst
  • Loading branch information
Dario Coscia committed Nov 2, 2023
1 parent 0a649e1 commit bb6874d
Show file tree
Hide file tree
Showing 19 changed files with 334 additions and 164 deletions.
Empty file removed docs/.nojekyll
Empty file.
1 change: 1 addition & 0 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ PINA Features
LabelTensor <label_tensor.rst>
Condition <condition.rst>
Plotter <plotter.rst>
Trainer <trainer.rst>

Problem
--------------
Expand Down
3 changes: 1 addition & 2 deletions docs/source/_rst/solvers/solver_interface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@ SolverInterface
.. automodule:: pina.solvers.solver

.. autoclass:: SolverInterface
:members:
:show-inheritance:
:noindex:
:members:
9 changes: 9 additions & 0 deletions docs/source/_rst/trainer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Trainer
===========
.. currentmodule:: pina.trainer

.. automodule:: pina.trainer

.. autoclass:: Trainer
:show-inheritance:
:members:
30 changes: 18 additions & 12 deletions docs/source/_rst/tutorials/tutorial2/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ The problem definition
----------------------

The two-dimensional Poisson problem is mathematically written as:
:raw-latex:`\begin{equation}
\begin{cases}
\Delta u = \sin{(\pi x)} \sin{(\pi y)} \text{ in } D, \\
u = 0 \text{ on } \Gamma_1 \cup \Gamma_2 \cup \Gamma_3 \cup \Gamma_4,
\end{cases}
\end{equation}` where :math:`D` is a square domain :math:`[0,1]^2`, and

.. math::
\begin{equation}
\begin{cases}
\Delta u = \sin{(\pi x)} \sin{(\pi y)} \text{ in } D, \\
u = 0 \text{ on } \Gamma_1 \cup \Gamma_2 \cup \Gamma_3 \cup \Gamma_4,
\end{cases}
\end{equation}
where :math:`D` is a square domain :math:`[0,1]^2`, and
:math:`\Gamma_i`, with :math:`i=1,...,4`, are the boundaries of the
square.

Expand Down Expand Up @@ -158,9 +162,10 @@ is now defined, with an additional input variable, named extra-feature,
which coincides with the forcing term in the Laplace equation. The set
of input variables to the neural network is:

:raw-latex:`\begin{equation}
[x, y, k(x, y)], \text{ with } k(x, y)=\sin{(\pi x)}\sin{(\pi y)},
\end{equation}`
.. math::
\begin{equation}
[x, y, k(x, y)], \text{ with } k(x, y)=\sin{(\pi x)}\sin{(\pi y)},
\end{equation}
where :math:`x` and :math:`y` are the spatial coordinates and
:math:`k(x, y)` is the added feature.
Expand Down Expand Up @@ -249,9 +254,10 @@ Another way to exploit the extra features is the addition of learnable
parameter inside them. In this way, the added parameters are learned
during the training phase of the neural network. In this case, we use:

:raw-latex:`\begin{equation}
k(x, \mathbf{y}) = \beta \sin{(\alpha x)} \sin{(\alpha y)},
\end{equation}`
.. math::
\begin{equation}
k(x, \mathbf{y}) = \beta \sin{(\alpha x)} \sin{(\alpha y)},
\end{equation}
where :math:`\alpha` and :math:`\beta` are the abovementioned
parameters. Their implementation is quite trivial: by using the class
Expand Down
15 changes: 8 additions & 7 deletions docs/source/_rst/tutorials/tutorial3/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ The problem definition

The problem is written in the following form:

:raw-latex:`\begin{equation}
\begin{cases}
\Delta u(x,y,t) = \frac{\partial^2}{\partial t^2} u(x,y,t) \quad \text{in } D, \\\\
u(x, y, t=0) = \sin(\pi x)\sin(\pi y), \\\\
u(x, y, t) = 0 \quad \text{on } \Gamma_1 \cup \Gamma_2 \cup \Gamma_3 \cup \Gamma_4,
\end{cases}
\end{equation}`
.. math::
\begin{equation}
\begin{cases}
\Delta u(x,y,t) = \frac{\partial^2}{\partial t^2} u(x,y,t) \quad \text{in } D, \\\\
u(x, y, t=0) = \sin(\pi x)\sin(\pi y), \\\\
u(x, y, t) = 0 \quad \text{on } \Gamma_1 \cup \Gamma_2 \cup \Gamma_3 \cup \Gamma_4,
\end{cases}
\end{equation}
where :math:`D` is a square domain :math:`[0,1]^2`, and
:math:`\Gamma_i`, with :math:`i=1,...,4`, are the boundaries of the
Expand Down
161 changes: 155 additions & 6 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@
# ones.
extensions = [
'sphinx.ext.autodoc',
#'sphinx.ext.autosummary',
#'sphinx.ext.coverage',
'sphinx.ext.autosummary',
#'sphinx.ext.graphviz',
#'sphinx.ext.doctest',
'sphinx_paramlinks',
'sphinx.ext.doctest',
'sphinx.ext.napoleon',
'sphinx.ext.intersphinx',
# 'sphinx.ext.todo',
#'sphinx.ext.coverage',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.viewcode',
#'sphinx.ext.ifconfig',
'sphinx.ext.mathjax',
Expand All @@ -53,12 +54,160 @@
intersphinx_mapping = {
'python': ('http://docs.python.org/3', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
# 'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None),
'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None),
'matplotlib': ('http://matplotlib.sourceforge.net/', None),
'torch': ('https://pytorch.org/docs/stable/', None),
'lightning.pytorch': ("https://lightning.ai/docs/pytorch/stable/", None),
"lightning.app": ("https://lightning.ai/docs/app/stable/", None),
"PIL": ("https://pillow.readthedocs.io/en/stable/", None),
"torchmetrics": ("https://torchmetrics.readthedocs.io/en/stable/", None),
"graphcore": ("https://docs.graphcore.ai/en/latest/", None),
"lightning_habana": ("https://lightning-ai.github.io/lightning-Habana/", None),
"tensorboardX": ("https://tensorboardx.readthedocs.io/en/stable/", None),
"lightning.app": ("https://lightning.ai/docs/app/stable/", None),
"lightning.fabric": ("https://lightning.ai/docs/fabric/stable/", None),
}

nitpicky = True

nitpick_ignore = [
("py:class", "typing.Self"),
# missing in generated API
("py:exc", "MisconfigurationException"),
# TODO: generated list of all existing ATM, need to be fixed
("py:class", "AveragedModel"),
("py:class", "CometExperiment"),
("py:meth", "DataModule.__init__"),
("py:class", "HPUAccelerator"),
("py:class", "Tensor"),
("py:class", "_PATH"),
("py:func", "add_argument"),
("py:func", "add_class_arguments"),
("py:meth", "apply_to_collection"),
("py:attr", "best_model_path"),
("py:attr", "best_model_score"),
("py:attr", "checkpoint_path"),
("py:class", "comet_ml.ExistingExperiment"),
("py:class", "comet_ml.Experiment"),
("py:class", "comet_ml.OfflineExperiment"),
("py:meth", "deepspeed.DeepSpeedEngine.backward"),
("py:attr", "example_input_array"),
("py:class", "jsonargparse._core.ArgumentParser"),
("py:class", "jsonargparse._namespace.Namespace"),
("py:class", "jsonargparse.core.ArgumentParser"),
("py:class", "jsonargparse.namespace.Namespace"),
("py:class", "transformer_engine.common.recipe.DelayedScaling"),
("py:class", "lightning.fabric.accelerators.xla.XLAAccelerator"),
("py:class", "lightning.fabric.loggers.csv_logs._ExperimentWriter"),
("py:class", "lightning.fabric.loggers.logger._DummyExperiment"),
("py:class", "lightning.fabric.plugins.precision.transformer_engine.TransformerEnginePrecision"),
("py:class", "lightning.fabric.plugins.precision.bitsandbytes.BitsandbytesPrecision"),
("py:class", "lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin"),
("py:func", "lightning.fabric.utilities.seed.seed_everything"),
("py:class", "lightning.fabric.utilities.types.LRScheduler"),
("py:class", "lightning.fabric.utilities.types.ReduceLROnPlateau"),
("py:class", "lightning.fabric.utilities.types.Steppable"),
("py:class", "lightning.fabric.wrappers._FabricOptimizer"),
("py:class", "lightning.fabric.utilities.throughput.Throughput"),
("py:func", "lightning.fabric.utilities.throughput.measure_flops"),
("py:class", "lightning.fabric.utilities.spike.SpikeDetection"),
("py:meth", "lightning.pytorch.Callback.on_exception"),
("py:class", "lightning.pytorch.LightningModule"),
("py:meth", "lightning.pytorch.LightningModule.on_train_epoch_end"),
("py:meth", "lightning.pytorch.LightningModule.on_validation_epoch_end"),
("py:meth", "lightning.pytorch.LightningModule.save_hyperparameters"),
("py:meth", "lightning.pytorch.LightningModule.test_step"),
("py:meth", "lightning.pytorch.LightningModule.training_step"),
("py:meth", "lightning.pytorch.LightningModule.validation_step"),
("py:obj", "lightning.pytorch.accelerators.MPSAccelerator"),
("py:meth", "lightning.pytorch.accelerators.accelerator.Accelerator.register_accelerators"),
("py:paramref", "lightning.pytorch.callbacks.Checkpoint._sphinx_paramlinks_save_top_k"),
("py:func", "lightning.pytorch.callbacks.RichProgressBar.configure_columns"),
("py:meth", "lightning.pytorch.callbacks.callback.Callback.on_load_checkpoint"),
("py:meth", "lightning.pytorch.callbacks.callback.Callback.on_save_checkpoint"),
("py:class", "lightning.pytorch.callbacks.checkpoint.Checkpoint"),
("py:meth", "lightning.pytorch.callbacks.progress.progress_bar.ProgressBar.get_metrics"),
("py:class", "lightning.pytorch.callbacks.progress.rich_progress.RichProgressBarTheme"),
("py:class", "lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm"),
("py:class", "lightning.pytorch.cli.ReduceLROnPlateau"),
("py:meth", "lightning.pytorch.core.LightningDataModule.setup"),
("py:meth", "lightning.pytorch.core.LightningModule.configure_model"),
("py:meth", "lightning.pytorch.core.LightningModule.save_hyperparameters"),
("py:meth", "lightning.pytorch.core.LightningModule.setup"),
("py:meth", "lightning.pytorch.core.hooks.ModelHooks.on_after_batch_transfer"),
("py:meth", "lightning.pytorch.core.hooks.ModelHooks.setup"),
("py:meth", "lightning.pytorch.core.hooks.ModelHooks.transfer_batch_to_device"),
("py:meth", "lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin.save_hyperparameters"),
("py:class", "lightning.pytorch.loggers.Logger"),
("py:func", "lightning.pytorch.loggers.logger.rank_zero_experiment"),
("py:class", "lightning.pytorch.plugins.environments.cluster_environment.ClusterEnvironment"),
("py:class", "lightning.pytorch.plugins.environments.slurm_environment.SLURMEnvironment"),
("py:class", "lightning.pytorch.plugins.io.wrapper._WrappingCheckpointIO"),
("py:func", "lightning.pytorch.seed_everything"),
("py:class", "lightning.pytorch.serve.servable_module.ServableModule"),
("py:class", "lightning.pytorch.serve.servable_module_validator.ServableModuleValidator"),
("py:mod", "lightning.pytorch.strategies"),
("py:class", "lightning.pytorch.strategies.SingleXLAStrategy"),
("py:meth", "lightning.pytorch.strategies.ddp.DDPStrategy.configure_ddp"),
("py:meth", "lightning.pytorch.strategies.ddp.DDPStrategy.setup_distributed"),
("py:meth", "lightning.pytorch.trainer.trainer.Trainer.lightning_module"),
("py:class", "lightning.pytorch.tuner.lr_finder._LRFinder"),
("py:class", "lightning.pytorch.utilities.CombinedLoader"),
("py:obj", "lightning.pytorch.utilities.deepspeed.ds_checkpoint_dir"),
("py:obj", "lightning.pytorch.utilities.memory.is_cuda_out_of_memory"),
("py:obj", "lightning.pytorch.utilities.memory.is_cudnn_snafu"),
("py:obj", "lightning.pytorch.utilities.memory.is_oom_error"),
("py:obj", "lightning.pytorch.utilities.memory.is_out_of_cpu_memory"),
("py:func", "lightning.pytorch.utilities.rank_zero.rank_zero_only"),
("py:class", "lightning.pytorch.utilities.types.LRSchedulerConfig"),
("py:class", "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig"),
("py:class", "lightning_habana.pytorch.plugins.precision.HPUPrecisionPlugin"),
("py:class", "lightning_habana.pytorch.strategies.HPUParallelStrategy"),
("py:class", "lightning_habana.pytorch.strategies.SingleHPUStrategy"),
("py:obj", "logger.experiment"),
("py:class", "mlflow.tracking.MlflowClient"),
("py:attr", "model"),
("py:meth", "move_data_to_device"),
("py:class", "neptune.Run"),
("py:class", "neptune.handler.Handler"),
("py:meth", "on_after_batch_transfer"),
("py:meth", "on_before_batch_transfer"),
("py:meth", "on_save_checkpoint"),
("py:meth", "optimizer_step"),
("py:class", "out_dict"),
("py:meth", "prepare_data"),
("py:class", "lightning.pytorch.callbacks.device_stats_monitor.DeviceStatsMonitor"),
("py:meth", "setup"),
("py:meth", "test_step"),
("py:meth", "toggle_optimizer"),
("py:class", "torch.ScriptModule"),
("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload"),
("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.MixedPrecision"),
("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy"),
("py:class", "torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler"),
("py:class", "torch.distributed.fsdp.wrap.ModuleWrapPolicy"),
("py:func", "torch.inference_mode"),
("py:meth", "torch.mean"),
("py:func", "torch.nn.Module.eval"),
("py:func", "torch.no_grad"),
("py:class", "torch.optim.lr_scheduler.LRScheduler"),
("py:meth", "torch.set_default_tensor_type"),
("py:class", "torch.utils.data.DistributedSampler"),
("py:class", "torch_xla.distributed.parallel_loader.MpDeviceLoader"),
("py:func", "torch_xla.distributed.xla_multiprocessing.spawn"),
("py:class", "torch._dynamo.OptimizedModule"),
("py:mod", "tqdm"),
("py:meth", "training_step"),
("py:meth", "transfer_batch_to_device"),
("py:class", "types.FrameType"),
("py:class", "typing.TypeGuard"),
("py:meth", "untoggle_optimizer"),
("py:meth", "validation_step"),
("py:class", "wandb.Artifact"),
("py:func", "wandb.init"),
("py:class", "wandb.sdk.lib.RunDisabled"),
("py:class", "wandb.wandb_run.Run"),

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']

Expand Down
7 changes: 2 additions & 5 deletions pina/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
__all__ = [
'PINN', 'Trainer', 'LabelTensor', 'Plotter', 'Condition', 'Location',
'CartesianDomain'
'SolverInterface', 'Trainer', 'LabelTensor', 'Plotter', 'Condition'
]

from .meta import *
from .label_tensor import LabelTensor
from .solvers.pinn import PINN
from .solvers.solver import SolverInterface
from .trainer import Trainer
from .plotter import Plotter
from .condition import Condition
from .geometry import Location
from .geometry import CartesianDomain
4 changes: 2 additions & 2 deletions pina/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def dummy(a):

class Condition:
"""
The class `Condition` is used to represent the constraints (physical
The class ``Condition`` is used to represent the constraints (physical
equations, boundary conditions, etc.) that should be satisfied in the
problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.Abstract_Problem` object.
problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
Conditions can be specified in three ways:
1. By specifying the input and output points of the condition; in such a
Expand Down
2 changes: 1 addition & 1 deletion pina/geometry/ellipsoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, ellipsoid_dict, sample_surface=False):
`sample_surface=True` only samples on the ellipsoid surface
frontier are taken. If ``sample_surface=False`` only samples on
the ellipsoid interior are taken, defaults to ``False``.
:type sample_surface: bool, optional
:type sample_surface: bool
.. warning::
Sampling for dimensions greater or equal to 10 could result
Expand Down
6 changes: 3 additions & 3 deletions pina/geometry/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def is_inside(self, point, check_border=False):
Abstract method for checking if a point is inside the location. To be
implemented in the child class.
:param tensor point: A tensor point to be checked.
:param bool check_border: a boolean that determines whether the border
:param torch.Tensor point: A tensor point to be checked.
:param bool check_border: A boolean that determines whether the border
of the location is considered checked to be considered inside or
not. Defaults to False.
not. Defaults to ``False``.
"""
pass
24 changes: 12 additions & 12 deletions pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def __init__(self, x, labels):
labels. Such labels uniquely identify the columns of the tensor,
allowing for an easier manipulation.
:param torch.Tensor x: the data tensor.
:param labels: the labels of the columns.
:type labels: str or iterable(str)
:param torch.Tensor x: The data tensor.
:param labels: The labels of the columns.
:type labels: str | list(str) | tuple(str)
:Example:
>>> from pina import LabelTensor
Expand Down Expand Up @@ -80,7 +80,7 @@ def labels(self):
@labels.setter
def labels(self, labels):
if len(labels) != self.shape[self.ndim - 1]: # small check
raise ValueError('the tensor has not the same number of columns of '
raise ValueError('The tensor has not the same number of columns of '
'the passed labels.')

self._labels = labels # assign the label
Expand All @@ -92,7 +92,7 @@ def clone(self, *args, **kwargs):
Clone the LabelTensor. For more details, see
:meth:`torch.Tensor.clone`.
:return: a copy of the tensor
:return: A copy of the tensor.
:rtype: LabelTensor
"""
try:
Expand Down Expand Up @@ -123,12 +123,12 @@ def select(self, *args, **kwargs):
def extract(self, label_to_extract):
"""
Extract the subset of the original tensor by returning all the columns
corresponding to the passed `label_to_extract`.
corresponding to the passed ``label_to_extract``.
:param label_to_extract: the label(s) to extract.
:type label_to_extract: str or iterable(str)
:raises TypeError: labels are not str
:raises ValueError: label to extract is not in the labels list
:param label_to_extract: The label(s) to extract.
:type label_to_extract: str | list(str) | tuple(str)
:raises TypeError: Labels are not ``str``.
:raises ValueError: Label to extract is not in the labels ``list``.
"""

if isinstance(label_to_extract, str):
Expand Down Expand Up @@ -158,9 +158,9 @@ def append(self, lt, mode='std'):
"""
Return a copy of the merged tensors.
:param LabelTensor lt: the tensor to merge.
:param LabelTensor lt: The tensor to merge.
:param str mode: {'std', 'first', 'cross'}
:return: the merged tensors
:return: The merged tensors.
:rtype: LabelTensor
"""
if set(self.labels).intersection(lt.labels):
Expand Down
Loading

0 comments on commit bb6874d

Please sign in to comment.