Skip to content

Commit

Permalink
Remove dead code, add more example docs (#61)
Browse files Browse the repository at this point in the history
* WIP: Remove dead code, add more example docs

Signed-off-by: Fabrice Normandin <[email protected]>

* Move documentation pages around a bit

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove unused `networks.layers` module

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove unused typing utils

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix issue in autoref plugin

Signed-off-by: Fabrice Normandin <[email protected]>

* Improve the installation instructions docs

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix pre-commit errors

Signed-off-by: Fabrice Normandin <[email protected]>

* Update conftest.py fixture dependency graph

Signed-off-by: Fabrice Normandin <[email protected]>

* [WIP] Move docs around quite a bit

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix broken links in docs

Signed-off-by: Fabrice Normandin <[email protected]>

* Add a WIP page on the remote submitit launcher

Signed-off-by: Fabrice Normandin <[email protected]>

* Add a better description of the remote launcher

Signed-off-by: Fabrice Normandin <[email protected]>

* Add examples

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix import error in algorithm_tests.py

Signed-off-by: Fabrice Normandin <[email protected]>

* Improve docs / docstrings / examples

Signed-off-by: Fabrice Normandin <[email protected]>

* Simplify main.py, add emojis

Signed-off-by: Fabrice Normandin <[email protected]>

---------

Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice authored Oct 25, 2024
1 parent 598627b commit 7b3fea2
Show file tree
Hide file tree
Showing 35 changed files with 479 additions and 762 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ repos:
rev: 0.7.17
hooks:
- id: mdformat
exclude: "SUMMARY.md|testing.md|jax.md"
exclude: "docs/" # terrible, I know, but it's messing up everything with mkdocs fences!
args: ["--number"]
additional_dependencies:
- mdformat-gfm
Expand All @@ -77,6 +77,7 @@ repos:
- mdformat-config
- mdformat-black
# see https://github.com/KyleKing/mdformat-mkdocs
# Doesn't seem to work!
- mdformat-mkdocs[recommended]>=2.1.0
require_serial: true

Expand Down
19 changes: 13 additions & 6 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
* [Home](index.md)
* [Intro](intro.md)
* [Getting Started](install.md)
* Features
* [Intro](intro.md)
* Features 🔥
* [Magic Config Schemas](features/auto_schema.md)
* [Jax and Torch support with Lightning ⚡](features/jax.md)
* [Launching Jobs on Remote Clusters](features/remote_slurm_launcher.md)
* [Thorough automated testing on SLURM clusters](features/testing.md)
* features/*.md
* Reference
* Reference 🤓
* reference/*
* Examples
* examples/*
* Examples 🧪
* [Image Classification (⚡)](examples/supervised_learning.md)
* [Image Classification ([jax](+⚡)](examples/jax_sl_example.md)
* [NLP (🤗+⚡)](examples/nlp.md)
* [RL (jax)](examples/jax_rl_example.md)
* [Running sweeps](examples/sweeps.md)
* [Profiling your code📎](examples/profiling.md)
* [Related projects](related.md)
* [Getting Help](help.md)
* [Contributing](contributing.md)
19 changes: 4 additions & 15 deletions docs/examples/examples.md → docs/examples/index.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
# Examples

TODOs:
Here are some examples to get started with the template.
<!--
## TODOs:
- [ ] Show examples (that are also to be tested with doctest or similar) of how to add a new algo.
- [ ] Show examples of how to add a new datamodule.
- [ ] Add a link to the RL example once [#13](https://github.com/mila-iqia/ResearchTemplate/issues/13) is done.
- [ ] Add a link to the NLP example once [#14](https://github.com/mila-iqia/ResearchTemplate/issues/14) is done.
- [ ] Add an example of how to use Jax for the dataset/dataloading:
- Either through an RL example, or with `tfds` in [#18](https://github.com/mila-iqia/ResearchTemplate/issues/18)

## Simple run

```bash
python project/main.py algorithm=example datamodule=mnist network=fcnet
```

## Running a Hyper-Parameter sweep on a SLURM cluster

```bash
python project/main.py experiment=cluster_sweep_example
```
- [ ] Add an example of how to use Jax for the dataset/dataloading: -->
20 changes: 19 additions & 1 deletion docs/examples/jax_rl_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,26 @@ additional_python_references:
- project.trainers.jax_trainer
---


# Reinforcement Learning (Jax)

This example follows the same structure as the other examples:
- An "algorithm" (in this case `JaxRLExample`) is trained with a "trainer";


However, there are some very important differences:
- There is not "datamodule".
- The "Trainer" is a `JaxTrainer`, instead of a `lightning.Trainer`.
- The full training loop is written in Jax;
- Some (but not all) PyTorch-Lightning callbacks can still be used with the JaxTrainer;
- The `JaxRLExample` class is an algorithm based on rejax.PPO.


## JaxRLExample

The `JaxRLExample` class is a

## JaxTrainer

The `JaxTrainer` is
The `JaxTrainer` follows a roughly similar structure as the `lightning.Trainer`:
- `JaxTrainer.fit` is called with a `JaxModule` to train the algorithm.
39 changes: 39 additions & 0 deletions docs/examples/jax_sl_example.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Jax + PyTorch-Lightning ⚡

## `JaxExample`: a LightningModule that trains a Jax network

The [JaxExample][project.algorithms.jax_example.JaxExample] algorithm uses a network which is a [flax.linen.Module](https://flax.readthedocs.io/en/latest/).
The network is wrapped with `torch_jax_interop.JaxFunction`, so that it can accept torch tensors as inputs, produces torch tensors as outputs, and the parameters are saved as as `torch.nn.Parameter`s (which use the same underlying memory as the jax arrays).
In this example, the loss function and optimizers are in PyTorch, while the network forward and backward passes are written in Jax.

The loss that is returned in the training step is used by Lightning in the usual way. The backward
pass uses Jax to calculate the gradients, and the weights are updated by a PyTorch optimizer.

!!! note
You could also very well do both the forward **and** backward passes in Jax! To do this, [use the 'manual optimization' mode of PyTorch-Lightning](https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html) and perform the parameter updates yourself. For the rest of Lightning to work, just make sure to store the parameters as torch.nn.Parameters. An example of how to do this will be added shortly.



!!! note "What about end-to-end training in Jax?"

See the [Jax RL Example](../examples/jax_rl_example.md)! :smile:

### Jax Network

{{ inline('project.algorithms.jax_example.CNN') }}

### Jax Algorithm

{{ inline('project.algorithms.jax_example.JaxExample') }}

### Configs

#### JaxExample algorithm config

{{ inline('project/configs/algorithm/jax_example.yaml') }}

## Running the example

```console
$ python project/main.py algorithm=jax_example network=jax_cnn datamodule=cifar10
```
42 changes: 42 additions & 0 deletions docs/examples/nlp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# NLP (PyTorch)


## Overview

The [HFExample][project.algorithms.hf_example.HFExample] is a [LightningModule][lightning.pytorch.core.module.LightningModule] for a simple auto-regressive text generation task.

It accepts a [HFDataModule][project.datamodules.text.HFDataModule] as input, along with a network.

??? note "Click to show the code for HFExample"
{{ inline('project.algorithms.hf_example.HFExample', 4) }}

## Config files

### Algorithm config

??? note "Click to show the Algorithm config"
Source: project/configs/algorithm/hf_example.yaml

{{ inline('project/configs/algorithm/hf_example.yaml', 4) }}

### Datamodule config

??? note "Click to show the Datamodule config"
Source: project/configs/datamodule/hf_text.yaml

{{ inline('project/configs/datamodule/hf_text.yaml', 4) }}

## Running the example

Here is a configuration file that you can use to launch a simple experiment:

??? note "Click to show the yaml config file"
Source: project/configs/experiment/hf_example.yaml

{{ inline('project/configs/experiment/hf_example.yaml', 4) }}

You can use it like so:

```console
python project/main.py experiment=example
```
17 changes: 17 additions & 0 deletions docs/examples/supervised_learning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Supervised Learning (PyTorch)

The [ExampleAlgorithm][project.algorithms.ExampleAlgorithm] is a simple [LightningModule][lightning.pytorch.core.module.LightningModule] for image classification.

??? note "Click to show the code for ExampleAlgorithm"
{{ inline('project.algorithms.example.ExampleAlgorithm', 4) }}

Here is a configuration file that you can use to launch a simple experiment:

??? note "Click to show the yaml config file"
{{ inline('project/configs/experiment/example.yaml', 4) }}

You can use it like so:

```console
python project/main.py experiment=example
```
2 changes: 1 addition & 1 deletion docs/extra.css
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
.md-grid {
max-width: 100%;
/* max-width: 100%; */
}
60 changes: 31 additions & 29 deletions docs/features/jax.md
Original file line number Diff line number Diff line change
@@ -1,47 +1,49 @@
# Using Jax with PyTorch-Lightning

You can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning.
---
additional_python_references:
- project.algorithms.jax_rl_example
- project.algorithms.example
- project.algorithms.jax_example
- project.algorithms.hf_example
- project.trainers.jax_trainer
---

**How does this work?**
Well, we use [torch-jax-interop](https://www.github.com/lebrice/torch_jax_interop), another package developed here at Mila, which allows easy interop between torch and jax code. See the readme on that repo for more details.
# Using Jax with PyTorch-Lightning

You can use Jax in your network or learning algorithm, for example in your forward / backward passes, to update parameters, etc. but not the training loop itself, since that is handled by the [lightning.Trainer][lightning.pytorch.trainer.trainer.Trainer].
There are lots of good reasons why you might want to let Lightning handle the training loop.
which are very well described [here](https://lightning.ai/docs/pytorch/stable/).
> 🔥 NOTE: This is a feature that is entirely unique to this template! 🔥
??? note "What about end-to-end training in Jax?"
This template includes examples that use either Jax, PyTorch, or both!

See the [Jax RL Example (coming soon!)](https://github.com/mila-iqia/ResearchTemplate/pull/55)
| Example link | Reference | Framework | Lightning? |
| ------------------------------------------------- | ------------------ | ----------- | ------------ |
| [ExampleAlgorithm](../examples/jax_sl_example.md) | `ExampleAlgorithm` | Torch | yes |
| [JaxExample](../examples/jax_sl_example.md) | `JaxExample` | Torch + Jax | yes |
| [HFExample](../examples/nlp.md) | `HFExample` | Torch + 🤗 | yes |
| [JaxRLExample](../examples/jax_rl_example.md) | `JaxRLExample` | Jax | no (almost!) |


## `JaxExample`: a LightningModule that uses Jax
In fact, here you can mix and match both Jax and Torch code. For example, you can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning.

The [JaxExample][project.algorithms.jax_example.JaxExample] algorithm uses a network which is a [flax.linen.Module](https://flax.readthedocs.io/en/latest/).
The network is wrapped with `torch_jax_interop.JaxFunction`, so that it can accept torch tensors as inputs, produces torch tensors as outputs, and the parameters are saved as as `torch.nn.Parameter`s (which use the same underlying memory as the jax arrays).
In this example, the loss function and optimizers are in PyTorch, while the network forward and backward passes are written in Jax.
??? note "**How does this work?**"
Well, we use [torch-jax-interop](https://www.github.com/lebrice/torch_jax_interop), another package developed here at Mila 😎, that allows easy interop between torch and jax code. Feel free to take a look at it if you'd like to use it as part of your own project. 😁

The loss that is returned in the training step is used by Lightning in the usual way. The backward
pass uses Jax to calculate the gradients, and the weights are updated by a PyTorch optimizer.

!!! note
You could also very well do both the forward **and** backward passes in Jax! To do this, [use the 'manual optimization' mode of PyTorch-Lightning](https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html) and perform the parameter updates yourself. For the rest of Lightning to work, just make sure to store the parameters as torch.nn.Parameters. An example of how to do this will be added shortly.

### Jax Network
## Using PyTorch-Lightning to train a Jax network

{{ inline('project.algorithms.jax_example.CNN') }}
If you'd like to use Jax in your network or learning algorithm, while keeping the same style of
training loop as usual, you can!

### Jax Algorithm
- Use Jax for the forward / backward passes, the parameter updates, dataset preprocessing, etc.
- Leave the training loop / callbacks / logging / checkpointing / etc to Lightning

{{ inline('project.algorithms.jax_example.JaxExample') }}
The [lightning.Trainer][lightning.pytorch.trainer.trainer.Trainer] will not be able to tell that you're using Jax!

### Configs
**Take a look at [this image classification example that uses a Jax network](../examples/jax_sl_example.md).**

#### JaxExample algorithm config

{{ inline('project/configs/algorithm/jax_example.yaml') }}
## End-to-end training in Jax: the `JaxTrainer`

## Running the example
The `JaxTrainer`, used in the [Jax RL Example](../examples/jax_rl_example.md), follows a similar structure as the lightning Trainer. However, instead of training LightningModules, it trains `JaxModule`s.

```console
$ python project/main.py algorithm=jax_example network=jax_cnn datamodule=cifar10
```
The "algorithm" needs to match the `JaxModule` protocol:
- `JaxModule.training_step`: train using a batch of data
65 changes: 65 additions & 0 deletions docs/features/remote_slurm_launcher.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Remote Slurm Submitit Launcher

> 🔥 NOTE: This is a feature that is entirely unique to this template! 🔥
This template includes a custom submitit launcher, that can be used to launch jobs on *remote* slurm clusters.
This allows you to develop code locally, and easily ship it to a different cluster.
The only prerequisite is that you must have `ssh` access to the remote cluster.

Under the hood, this uses a [custom `remote-slurm-executor` submitit plugin](https://github.com/lebrice/remote-slurm-executor).


This feature allows you to launch jobs on remote slurm clusters using two config groups:

- The `resources` config group is used to select the job resources:
- `cpu`: CPU job
- `gpu`: GPU job
- The `cluster` config group controls where to run the job:
- `current`: Run on the current cluster. Use this if you're already on a SLURM cluster (e.g. when using `mila code`). This uses the usual `submitit_slurm` launcher.
- `mila`: Launches the job on the Mila cluster.
- `narval`: Remotely launches the job on the Narval cluster
- `cedar`: Remotely launches the job on the Cedar cluster
- `beluga`: Remotely launches the job on the Beluga cluster


## Examples

This assumes that you've already setup SSH access to the clusters (for example using `mila init`).


### Local machine -> Mila

```bash
python project/main.py experiment=example resources=gpu cluster=mila
```

### Local machine -> DRAC cluster (narval)

```bash
python project/main.py experiment=example resources=gpu cluster=narval
```


### Mila -> DRAC cluster (narval)

This assumes that you've already setup SSH access from `mila` to the DRAC clusters.

Note that command is about the same as [above](#local-machine---drac-cluster-narval)

```bash
python project/main.py experiment=example resources=gpu cluster=narval
```


!!! warning

If you want to launch jobs on a remote cluster, it is (currently) necessary to place the "resources" config **before** the "cluster" config on the command-line.


## Launching jobs on the current SLURM cluster

If you develop on a SLURM cluster, you can use the `cluster=current`, or simply omit the `cluster` config group and only use a config from the `resources` group.

```bash
(mila) $ python project/main.py experiment=example resources=gpu cluster=current
```
2 changes: 1 addition & 1 deletion docs/features/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ This template comes with some [easy-to-use test suites](#test-suites) as well as
- [ ] Describe the Github Actions workflows that come with the template, and how to setup a self-hosted runner for template forks.
- [ ] Add links to relevant documentation -->

## :fire: Automated testing on SLURM clusters with GitHub CI
## Automated testing on SLURM clusters with GitHub CI

> 🔥 NOTE: This is a feature that is entirely unique to this template! 🔥
Expand Down
Loading

0 comments on commit 7b3fea2

Please sign in to comment.