-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove dead code, add more example docs (#61)
* WIP: Remove dead code, add more example docs Signed-off-by: Fabrice Normandin <[email protected]> * Move documentation pages around a bit Signed-off-by: Fabrice Normandin <[email protected]> * Remove unused `networks.layers` module Signed-off-by: Fabrice Normandin <[email protected]> * Remove unused typing utils Signed-off-by: Fabrice Normandin <[email protected]> * Fix issue in autoref plugin Signed-off-by: Fabrice Normandin <[email protected]> * Improve the installation instructions docs Signed-off-by: Fabrice Normandin <[email protected]> * Fix pre-commit errors Signed-off-by: Fabrice Normandin <[email protected]> * Update conftest.py fixture dependency graph Signed-off-by: Fabrice Normandin <[email protected]> * [WIP] Move docs around quite a bit Signed-off-by: Fabrice Normandin <[email protected]> * Fix broken links in docs Signed-off-by: Fabrice Normandin <[email protected]> * Add a WIP page on the remote submitit launcher Signed-off-by: Fabrice Normandin <[email protected]> * Add a better description of the remote launcher Signed-off-by: Fabrice Normandin <[email protected]> * Add examples Signed-off-by: Fabrice Normandin <[email protected]> * Fix import error in algorithm_tests.py Signed-off-by: Fabrice Normandin <[email protected]> * Improve docs / docstrings / examples Signed-off-by: Fabrice Normandin <[email protected]> * Simplify main.py, add emojis Signed-off-by: Fabrice Normandin <[email protected]> --------- Signed-off-by: Fabrice Normandin <[email protected]>
- Loading branch information
Showing
35 changed files
with
479 additions
and
762 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,20 @@ | ||
* [Home](index.md) | ||
* [Intro](intro.md) | ||
* [Getting Started](install.md) | ||
* Features | ||
* [Intro](intro.md) | ||
* Features 🔥 | ||
* [Magic Config Schemas](features/auto_schema.md) | ||
* [Jax and Torch support with Lightning ⚡](features/jax.md) | ||
* [Launching Jobs on Remote Clusters](features/remote_slurm_launcher.md) | ||
* [Thorough automated testing on SLURM clusters](features/testing.md) | ||
* features/*.md | ||
* Reference | ||
* Reference 🤓 | ||
* reference/* | ||
* Examples | ||
* examples/* | ||
* Examples 🧪 | ||
* [Image Classification (⚡)](examples/supervised_learning.md) | ||
* [Image Classification ([jax](+⚡)](examples/jax_sl_example.md) | ||
* [NLP (🤗+⚡)](examples/nlp.md) | ||
* [RL (jax)](examples/jax_rl_example.md) | ||
* [Running sweeps](examples/sweeps.md) | ||
* [Profiling your code📎](examples/profiling.md) | ||
* [Related projects](related.md) | ||
* [Getting Help](help.md) | ||
* [Contributing](contributing.md) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,11 @@ | ||
# Examples | ||
|
||
TODOs: | ||
Here are some examples to get started with the template. | ||
<!-- | ||
## TODOs: | ||
- [ ] Show examples (that are also to be tested with doctest or similar) of how to add a new algo. | ||
- [ ] Show examples of how to add a new datamodule. | ||
- [ ] Add a link to the RL example once [#13](https://github.com/mila-iqia/ResearchTemplate/issues/13) is done. | ||
- [ ] Add a link to the NLP example once [#14](https://github.com/mila-iqia/ResearchTemplate/issues/14) is done. | ||
- [ ] Add an example of how to use Jax for the dataset/dataloading: | ||
- Either through an RL example, or with `tfds` in [#18](https://github.com/mila-iqia/ResearchTemplate/issues/18) | ||
|
||
## Simple run | ||
|
||
```bash | ||
python project/main.py algorithm=example datamodule=mnist network=fcnet | ||
``` | ||
|
||
## Running a Hyper-Parameter sweep on a SLURM cluster | ||
|
||
```bash | ||
python project/main.py experiment=cluster_sweep_example | ||
``` | ||
- [ ] Add an example of how to use Jax for the dataset/dataloading: --> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Jax + PyTorch-Lightning ⚡ | ||
|
||
## `JaxExample`: a LightningModule that trains a Jax network | ||
|
||
The [JaxExample][project.algorithms.jax_example.JaxExample] algorithm uses a network which is a [flax.linen.Module](https://flax.readthedocs.io/en/latest/). | ||
The network is wrapped with `torch_jax_interop.JaxFunction`, so that it can accept torch tensors as inputs, produces torch tensors as outputs, and the parameters are saved as as `torch.nn.Parameter`s (which use the same underlying memory as the jax arrays). | ||
In this example, the loss function and optimizers are in PyTorch, while the network forward and backward passes are written in Jax. | ||
|
||
The loss that is returned in the training step is used by Lightning in the usual way. The backward | ||
pass uses Jax to calculate the gradients, and the weights are updated by a PyTorch optimizer. | ||
|
||
!!! note | ||
You could also very well do both the forward **and** backward passes in Jax! To do this, [use the 'manual optimization' mode of PyTorch-Lightning](https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html) and perform the parameter updates yourself. For the rest of Lightning to work, just make sure to store the parameters as torch.nn.Parameters. An example of how to do this will be added shortly. | ||
|
||
|
||
|
||
!!! note "What about end-to-end training in Jax?" | ||
|
||
See the [Jax RL Example](../examples/jax_rl_example.md)! :smile: | ||
|
||
### Jax Network | ||
|
||
{{ inline('project.algorithms.jax_example.CNN') }} | ||
|
||
### Jax Algorithm | ||
|
||
{{ inline('project.algorithms.jax_example.JaxExample') }} | ||
|
||
### Configs | ||
|
||
#### JaxExample algorithm config | ||
|
||
{{ inline('project/configs/algorithm/jax_example.yaml') }} | ||
|
||
## Running the example | ||
|
||
```console | ||
$ python project/main.py algorithm=jax_example network=jax_cnn datamodule=cifar10 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# NLP (PyTorch) | ||
|
||
|
||
## Overview | ||
|
||
The [HFExample][project.algorithms.hf_example.HFExample] is a [LightningModule][lightning.pytorch.core.module.LightningModule] for a simple auto-regressive text generation task. | ||
|
||
It accepts a [HFDataModule][project.datamodules.text.HFDataModule] as input, along with a network. | ||
|
||
??? note "Click to show the code for HFExample" | ||
{{ inline('project.algorithms.hf_example.HFExample', 4) }} | ||
|
||
## Config files | ||
|
||
### Algorithm config | ||
|
||
??? note "Click to show the Algorithm config" | ||
Source: project/configs/algorithm/hf_example.yaml | ||
|
||
{{ inline('project/configs/algorithm/hf_example.yaml', 4) }} | ||
|
||
### Datamodule config | ||
|
||
??? note "Click to show the Datamodule config" | ||
Source: project/configs/datamodule/hf_text.yaml | ||
|
||
{{ inline('project/configs/datamodule/hf_text.yaml', 4) }} | ||
|
||
## Running the example | ||
|
||
Here is a configuration file that you can use to launch a simple experiment: | ||
|
||
??? note "Click to show the yaml config file" | ||
Source: project/configs/experiment/hf_example.yaml | ||
|
||
{{ inline('project/configs/experiment/hf_example.yaml', 4) }} | ||
|
||
You can use it like so: | ||
|
||
```console | ||
python project/main.py experiment=example | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Supervised Learning (PyTorch) | ||
|
||
The [ExampleAlgorithm][project.algorithms.ExampleAlgorithm] is a simple [LightningModule][lightning.pytorch.core.module.LightningModule] for image classification. | ||
|
||
??? note "Click to show the code for ExampleAlgorithm" | ||
{{ inline('project.algorithms.example.ExampleAlgorithm', 4) }} | ||
|
||
Here is a configuration file that you can use to launch a simple experiment: | ||
|
||
??? note "Click to show the yaml config file" | ||
{{ inline('project/configs/experiment/example.yaml', 4) }} | ||
|
||
You can use it like so: | ||
|
||
```console | ||
python project/main.py experiment=example | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
.md-grid { | ||
max-width: 100%; | ||
/* max-width: 100%; */ | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,49 @@ | ||
# Using Jax with PyTorch-Lightning | ||
|
||
You can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning. | ||
--- | ||
additional_python_references: | ||
- project.algorithms.jax_rl_example | ||
- project.algorithms.example | ||
- project.algorithms.jax_example | ||
- project.algorithms.hf_example | ||
- project.trainers.jax_trainer | ||
--- | ||
|
||
**How does this work?** | ||
Well, we use [torch-jax-interop](https://www.github.com/lebrice/torch_jax_interop), another package developed here at Mila, which allows easy interop between torch and jax code. See the readme on that repo for more details. | ||
# Using Jax with PyTorch-Lightning | ||
|
||
You can use Jax in your network or learning algorithm, for example in your forward / backward passes, to update parameters, etc. but not the training loop itself, since that is handled by the [lightning.Trainer][lightning.pytorch.trainer.trainer.Trainer]. | ||
There are lots of good reasons why you might want to let Lightning handle the training loop. | ||
which are very well described [here](https://lightning.ai/docs/pytorch/stable/). | ||
> 🔥 NOTE: This is a feature that is entirely unique to this template! 🔥 | ||
??? note "What about end-to-end training in Jax?" | ||
This template includes examples that use either Jax, PyTorch, or both! | ||
|
||
See the [Jax RL Example (coming soon!)](https://github.com/mila-iqia/ResearchTemplate/pull/55) | ||
| Example link | Reference | Framework | Lightning? | | ||
| ------------------------------------------------- | ------------------ | ----------- | ------------ | | ||
| [ExampleAlgorithm](../examples/jax_sl_example.md) | `ExampleAlgorithm` | Torch | yes | | ||
| [JaxExample](../examples/jax_sl_example.md) | `JaxExample` | Torch + Jax | yes | | ||
| [HFExample](../examples/nlp.md) | `HFExample` | Torch + 🤗 | yes | | ||
| [JaxRLExample](../examples/jax_rl_example.md) | `JaxRLExample` | Jax | no (almost!) | | ||
|
||
|
||
## `JaxExample`: a LightningModule that uses Jax | ||
In fact, here you can mix and match both Jax and Torch code. For example, you can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning. | ||
|
||
The [JaxExample][project.algorithms.jax_example.JaxExample] algorithm uses a network which is a [flax.linen.Module](https://flax.readthedocs.io/en/latest/). | ||
The network is wrapped with `torch_jax_interop.JaxFunction`, so that it can accept torch tensors as inputs, produces torch tensors as outputs, and the parameters are saved as as `torch.nn.Parameter`s (which use the same underlying memory as the jax arrays). | ||
In this example, the loss function and optimizers are in PyTorch, while the network forward and backward passes are written in Jax. | ||
??? note "**How does this work?**" | ||
Well, we use [torch-jax-interop](https://www.github.com/lebrice/torch_jax_interop), another package developed here at Mila 😎, that allows easy interop between torch and jax code. Feel free to take a look at it if you'd like to use it as part of your own project. 😁 | ||
|
||
The loss that is returned in the training step is used by Lightning in the usual way. The backward | ||
pass uses Jax to calculate the gradients, and the weights are updated by a PyTorch optimizer. | ||
|
||
!!! note | ||
You could also very well do both the forward **and** backward passes in Jax! To do this, [use the 'manual optimization' mode of PyTorch-Lightning](https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html) and perform the parameter updates yourself. For the rest of Lightning to work, just make sure to store the parameters as torch.nn.Parameters. An example of how to do this will be added shortly. | ||
|
||
### Jax Network | ||
## Using PyTorch-Lightning to train a Jax network | ||
|
||
{{ inline('project.algorithms.jax_example.CNN') }} | ||
If you'd like to use Jax in your network or learning algorithm, while keeping the same style of | ||
training loop as usual, you can! | ||
|
||
### Jax Algorithm | ||
- Use Jax for the forward / backward passes, the parameter updates, dataset preprocessing, etc. | ||
- Leave the training loop / callbacks / logging / checkpointing / etc to Lightning | ||
|
||
{{ inline('project.algorithms.jax_example.JaxExample') }} | ||
The [lightning.Trainer][lightning.pytorch.trainer.trainer.Trainer] will not be able to tell that you're using Jax! | ||
|
||
### Configs | ||
**Take a look at [this image classification example that uses a Jax network](../examples/jax_sl_example.md).** | ||
|
||
#### JaxExample algorithm config | ||
|
||
{{ inline('project/configs/algorithm/jax_example.yaml') }} | ||
## End-to-end training in Jax: the `JaxTrainer` | ||
|
||
## Running the example | ||
The `JaxTrainer`, used in the [Jax RL Example](../examples/jax_rl_example.md), follows a similar structure as the lightning Trainer. However, instead of training LightningModules, it trains `JaxModule`s. | ||
|
||
```console | ||
$ python project/main.py algorithm=jax_example network=jax_cnn datamodule=cifar10 | ||
``` | ||
The "algorithm" needs to match the `JaxModule` protocol: | ||
- `JaxModule.training_step`: train using a batch of data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Remote Slurm Submitit Launcher | ||
|
||
> 🔥 NOTE: This is a feature that is entirely unique to this template! 🔥 | ||
This template includes a custom submitit launcher, that can be used to launch jobs on *remote* slurm clusters. | ||
This allows you to develop code locally, and easily ship it to a different cluster. | ||
The only prerequisite is that you must have `ssh` access to the remote cluster. | ||
|
||
Under the hood, this uses a [custom `remote-slurm-executor` submitit plugin](https://github.com/lebrice/remote-slurm-executor). | ||
|
||
|
||
This feature allows you to launch jobs on remote slurm clusters using two config groups: | ||
|
||
- The `resources` config group is used to select the job resources: | ||
- `cpu`: CPU job | ||
- `gpu`: GPU job | ||
- The `cluster` config group controls where to run the job: | ||
- `current`: Run on the current cluster. Use this if you're already on a SLURM cluster (e.g. when using `mila code`). This uses the usual `submitit_slurm` launcher. | ||
- `mila`: Launches the job on the Mila cluster. | ||
- `narval`: Remotely launches the job on the Narval cluster | ||
- `cedar`: Remotely launches the job on the Cedar cluster | ||
- `beluga`: Remotely launches the job on the Beluga cluster | ||
|
||
|
||
## Examples | ||
|
||
This assumes that you've already setup SSH access to the clusters (for example using `mila init`). | ||
|
||
|
||
### Local machine -> Mila | ||
|
||
```bash | ||
python project/main.py experiment=example resources=gpu cluster=mila | ||
``` | ||
|
||
### Local machine -> DRAC cluster (narval) | ||
|
||
```bash | ||
python project/main.py experiment=example resources=gpu cluster=narval | ||
``` | ||
|
||
|
||
### Mila -> DRAC cluster (narval) | ||
|
||
This assumes that you've already setup SSH access from `mila` to the DRAC clusters. | ||
|
||
Note that command is about the same as [above](#local-machine---drac-cluster-narval) | ||
|
||
```bash | ||
python project/main.py experiment=example resources=gpu cluster=narval | ||
``` | ||
|
||
|
||
!!! warning | ||
|
||
If you want to launch jobs on a remote cluster, it is (currently) necessary to place the "resources" config **before** the "cluster" config on the command-line. | ||
|
||
|
||
## Launching jobs on the current SLURM cluster | ||
|
||
If you develop on a SLURM cluster, you can use the `cluster=current`, or simply omit the `cluster` config group and only use a config from the `resources` group. | ||
|
||
```bash | ||
(mila) $ python project/main.py experiment=example resources=gpu cluster=current | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.