From 7b3fea2efd612a668f1eafa74e62bf518bcb1bc0 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 25 Oct 2024 11:30:02 -0400 Subject: [PATCH] Remove dead code, add more example docs (#61) * WIP: Remove dead code, add more example docs Signed-off-by: Fabrice Normandin * Move documentation pages around a bit Signed-off-by: Fabrice Normandin * Remove unused `networks.layers` module Signed-off-by: Fabrice Normandin * Remove unused typing utils Signed-off-by: Fabrice Normandin * Fix issue in autoref plugin Signed-off-by: Fabrice Normandin * Improve the installation instructions docs Signed-off-by: Fabrice Normandin * Fix pre-commit errors Signed-off-by: Fabrice Normandin * Update conftest.py fixture dependency graph Signed-off-by: Fabrice Normandin * [WIP] Move docs around quite a bit Signed-off-by: Fabrice Normandin * Fix broken links in docs Signed-off-by: Fabrice Normandin * Add a WIP page on the remote submitit launcher Signed-off-by: Fabrice Normandin * Add a better description of the remote launcher Signed-off-by: Fabrice Normandin * Add examples Signed-off-by: Fabrice Normandin * Fix import error in algorithm_tests.py Signed-off-by: Fabrice Normandin * Improve docs / docstrings / examples Signed-off-by: Fabrice Normandin * Simplify main.py, add emojis Signed-off-by: Fabrice Normandin --------- Signed-off-by: Fabrice Normandin --- .pre-commit-config.yaml | 3 +- docs/SUMMARY.md | 19 +- docs/examples/{examples.md => index.md} | 19 +- docs/examples/jax_rl_example.md | 20 +- docs/examples/jax_sl_example.md | 39 ++++ docs/examples/nlp.md | 42 ++++ docs/examples/supervised_learning.md | 17 ++ docs/extra.css | 2 +- docs/features/jax.md | 60 ++--- docs/features/remote_slurm_launcher.md | 65 ++++++ docs/features/testing.md | 2 +- docs/index.md | 118 +++++++--- docs/install.md | 64 ------ docs/intro.md | 38 +++- mkdocs.yml | 23 +- project/algorithms/example.py | 11 +- project/algorithms/testsuites/algorithm.py | 51 ----- .../algorithms/testsuites/algorithm_tests.py | 3 +- ...{albert-cola-glue.yaml => hf_example.yaml} | 0 project/conftest.py | 14 +- project/datamodules/datamodules_test.py | 4 +- .../image_classification/imagenet.py | 21 +- project/main.py | 88 ++++---- project/networks/fcnet.py | 4 +- project/networks/layers/__init__.py | 10 - project/networks/layers/layers.py | 212 ------------------ project/networks/layers/sequential.py | 97 -------- project/utils/__init__.py | 6 - project/utils/autoref_plugin.py | 2 + project/utils/autoref_plugin_test.py | 1 + project/utils/testutils.py | 140 ++---------- project/utils/typing_utils/__init__.py | 22 +- project/utils/typing_utils/protocols.py | 14 +- pyproject.toml | 2 +- uv.lock | 8 +- 35 files changed, 479 insertions(+), 762 deletions(-) rename docs/examples/{examples.md => index.md} (56%) create mode 100644 docs/examples/jax_sl_example.md create mode 100644 docs/examples/nlp.md create mode 100644 docs/examples/supervised_learning.md create mode 100644 docs/features/remote_slurm_launcher.md delete mode 100644 docs/install.md delete mode 100644 project/algorithms/testsuites/algorithm.py rename project/configs/experiment/{albert-cola-glue.yaml => hf_example.yaml} (100%) delete mode 100644 project/networks/layers/__init__.py delete mode 100644 project/networks/layers/layers.py delete mode 100644 project/networks/layers/sequential.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f658459d..0c524b33 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -67,7 +67,7 @@ repos: rev: 0.7.17 hooks: - id: mdformat - exclude: "SUMMARY.md|testing.md|jax.md" + exclude: "docs/" # terrible, I know, but it's messing up everything with mkdocs fences! args: ["--number"] additional_dependencies: - mdformat-gfm @@ -77,6 +77,7 @@ repos: - mdformat-config - mdformat-black # see https://github.com/KyleKing/mdformat-mkdocs + # Doesn't seem to work! - mdformat-mkdocs[recommended]>=2.1.0 require_serial: true diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 37616f04..e0143d93 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -1,13 +1,20 @@ * [Home](index.md) - * [Intro](intro.md) - * [Getting Started](install.md) - * Features +* [Intro](intro.md) + * Features πŸ”₯ * [Magic Config Schemas](features/auto_schema.md) + * [Jax and Torch support with Lightning ⚑](features/jax.md) + * [Launching Jobs on Remote Clusters](features/remote_slurm_launcher.md) + * [Thorough automated testing on SLURM clusters](features/testing.md) * features/*.md - * Reference + * Reference πŸ€“ * reference/* - * Examples - * examples/* + * Examples πŸ§ͺ + * [Image Classification (⚑)](examples/supervised_learning.md) + * [Image Classification ([jax](+⚑)](examples/jax_sl_example.md) + * [NLP (πŸ€—+⚑)](examples/nlp.md) + * [RL (jax)](examples/jax_rl_example.md) + * [Running sweeps](examples/sweeps.md) + * [Profiling your codeπŸ“Ž](examples/profiling.md) * [Related projects](related.md) * [Getting Help](help.md) * [Contributing](contributing.md) diff --git a/docs/examples/examples.md b/docs/examples/index.md similarity index 56% rename from docs/examples/examples.md rename to docs/examples/index.md index 04c4ee19..5f934629 100644 --- a/docs/examples/examples.md +++ b/docs/examples/index.md @@ -1,22 +1,11 @@ # Examples -TODOs: +Here are some examples to get started with the template. + diff --git a/docs/examples/jax_rl_example.md b/docs/examples/jax_rl_example.md index 7d130243..5feb6d1c 100644 --- a/docs/examples/jax_rl_example.md +++ b/docs/examples/jax_rl_example.md @@ -4,8 +4,26 @@ additional_python_references: - project.trainers.jax_trainer --- + # Reinforcement Learning (Jax) +This example follows the same structure as the other examples: +- An "algorithm" (in this case `JaxRLExample`) is trained with a "trainer"; + + +However, there are some very important differences: +- There is not "datamodule". +- The "Trainer" is a `JaxTrainer`, instead of a `lightning.Trainer`. + - The full training loop is written in Jax; + - Some (but not all) PyTorch-Lightning callbacks can still be used with the JaxTrainer; +- The `JaxRLExample` class is an algorithm based on rejax.PPO. + + +## JaxRLExample + +The `JaxRLExample` class is a + ## JaxTrainer -The `JaxTrainer` is +The `JaxTrainer` follows a roughly similar structure as the `lightning.Trainer`: +- `JaxTrainer.fit` is called with a `JaxModule` to train the algorithm. diff --git a/docs/examples/jax_sl_example.md b/docs/examples/jax_sl_example.md new file mode 100644 index 00000000..4a78650f --- /dev/null +++ b/docs/examples/jax_sl_example.md @@ -0,0 +1,39 @@ +# Jax + PyTorch-Lightning ⚑ + +## `JaxExample`: a LightningModule that trains a Jax network + +The [JaxExample][project.algorithms.jax_example.JaxExample] algorithm uses a network which is a [flax.linen.Module](https://flax.readthedocs.io/en/latest/). +The network is wrapped with `torch_jax_interop.JaxFunction`, so that it can accept torch tensors as inputs, produces torch tensors as outputs, and the parameters are saved as as `torch.nn.Parameter`s (which use the same underlying memory as the jax arrays). +In this example, the loss function and optimizers are in PyTorch, while the network forward and backward passes are written in Jax. + +The loss that is returned in the training step is used by Lightning in the usual way. The backward +pass uses Jax to calculate the gradients, and the weights are updated by a PyTorch optimizer. + +!!! note + You could also very well do both the forward **and** backward passes in Jax! To do this, [use the 'manual optimization' mode of PyTorch-Lightning](https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html) and perform the parameter updates yourself. For the rest of Lightning to work, just make sure to store the parameters as torch.nn.Parameters. An example of how to do this will be added shortly. + + + +!!! note "What about end-to-end training in Jax?" + + See the [Jax RL Example](../examples/jax_rl_example.md)! :smile: + +### Jax Network + +{{ inline('project.algorithms.jax_example.CNN') }} + +### Jax Algorithm + +{{ inline('project.algorithms.jax_example.JaxExample') }} + +### Configs + +#### JaxExample algorithm config + +{{ inline('project/configs/algorithm/jax_example.yaml') }} + +## Running the example + +```console +$ python project/main.py algorithm=jax_example network=jax_cnn datamodule=cifar10 +``` diff --git a/docs/examples/nlp.md b/docs/examples/nlp.md new file mode 100644 index 00000000..15915af4 --- /dev/null +++ b/docs/examples/nlp.md @@ -0,0 +1,42 @@ +# NLP (PyTorch) + + +## Overview + +The [HFExample][project.algorithms.hf_example.HFExample] is a [LightningModule][lightning.pytorch.core.module.LightningModule] for a simple auto-regressive text generation task. + +It accepts a [HFDataModule][project.datamodules.text.HFDataModule] as input, along with a network. + +??? note "Click to show the code for HFExample" + {{ inline('project.algorithms.hf_example.HFExample', 4) }} + +## Config files + +### Algorithm config + +??? note "Click to show the Algorithm config" + Source: project/configs/algorithm/hf_example.yaml + + {{ inline('project/configs/algorithm/hf_example.yaml', 4) }} + +### Datamodule config + +??? note "Click to show the Datamodule config" + Source: project/configs/datamodule/hf_text.yaml + + {{ inline('project/configs/datamodule/hf_text.yaml', 4) }} + +## Running the example + +Here is a configuration file that you can use to launch a simple experiment: + +??? note "Click to show the yaml config file" + Source: project/configs/experiment/hf_example.yaml + + {{ inline('project/configs/experiment/hf_example.yaml', 4) }} + +You can use it like so: + +```console +python project/main.py experiment=example +``` diff --git a/docs/examples/supervised_learning.md b/docs/examples/supervised_learning.md new file mode 100644 index 00000000..842b8cc9 --- /dev/null +++ b/docs/examples/supervised_learning.md @@ -0,0 +1,17 @@ +# Supervised Learning (PyTorch) + +The [ExampleAlgorithm][project.algorithms.ExampleAlgorithm] is a simple [LightningModule][lightning.pytorch.core.module.LightningModule] for image classification. + +??? note "Click to show the code for ExampleAlgorithm" + {{ inline('project.algorithms.example.ExampleAlgorithm', 4) }} + +Here is a configuration file that you can use to launch a simple experiment: + +??? note "Click to show the yaml config file" + {{ inline('project/configs/experiment/example.yaml', 4) }} + +You can use it like so: + +```console +python project/main.py experiment=example +``` diff --git a/docs/extra.css b/docs/extra.css index c0fb3942..74db49f5 100644 --- a/docs/extra.css +++ b/docs/extra.css @@ -1,3 +1,3 @@ .md-grid { - max-width: 100%; + /* max-width: 100%; */ } diff --git a/docs/features/jax.md b/docs/features/jax.md index 3938129d..bd7edfab 100644 --- a/docs/features/jax.md +++ b/docs/features/jax.md @@ -1,47 +1,49 @@ -# Using Jax with PyTorch-Lightning - -You can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning. +--- +additional_python_references: + - project.algorithms.jax_rl_example + - project.algorithms.example + - project.algorithms.jax_example + - project.algorithms.hf_example + - project.trainers.jax_trainer +--- -**How does this work?** -Well, we use [torch-jax-interop](https://www.github.com/lebrice/torch_jax_interop), another package developed here at Mila, which allows easy interop between torch and jax code. See the readme on that repo for more details. +# Using Jax with PyTorch-Lightning -You can use Jax in your network or learning algorithm, for example in your forward / backward passes, to update parameters, etc. but not the training loop itself, since that is handled by the [lightning.Trainer][lightning.pytorch.trainer.trainer.Trainer]. -There are lots of good reasons why you might want to let Lightning handle the training loop. -which are very well described [here](https://lightning.ai/docs/pytorch/stable/). +> πŸ”₯ NOTE: This is a feature that is entirely unique to this template! πŸ”₯ -??? note "What about end-to-end training in Jax?" +This template includes examples that use either Jax, PyTorch, or both! - See the [Jax RL Example (coming soon!)](https://github.com/mila-iqia/ResearchTemplate/pull/55) +| Example link | Reference | Framework | Lightning? | +| ------------------------------------------------- | ------------------ | ----------- | ------------ | +| [ExampleAlgorithm](../examples/jax_sl_example.md) | `ExampleAlgorithm` | Torch | yes | +| [JaxExample](../examples/jax_sl_example.md) | `JaxExample` | Torch + Jax | yes | +| [HFExample](../examples/nlp.md) | `HFExample` | Torch + πŸ€— | yes | +| [JaxRLExample](../examples/jax_rl_example.md) | `JaxRLExample` | Jax | no (almost!) | -## `JaxExample`: a LightningModule that uses Jax +In fact, here you can mix and match both Jax and Torch code. For example, you can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning. -The [JaxExample][project.algorithms.jax_example.JaxExample] algorithm uses a network which is a [flax.linen.Module](https://flax.readthedocs.io/en/latest/). -The network is wrapped with `torch_jax_interop.JaxFunction`, so that it can accept torch tensors as inputs, produces torch tensors as outputs, and the parameters are saved as as `torch.nn.Parameter`s (which use the same underlying memory as the jax arrays). -In this example, the loss function and optimizers are in PyTorch, while the network forward and backward passes are written in Jax. +??? note "**How does this work?**" + Well, we use [torch-jax-interop](https://www.github.com/lebrice/torch_jax_interop), another package developed here at Mila 😎, that allows easy interop between torch and jax code. Feel free to take a look at it if you'd like to use it as part of your own project. 😁 -The loss that is returned in the training step is used by Lightning in the usual way. The backward -pass uses Jax to calculate the gradients, and the weights are updated by a PyTorch optimizer. -!!! note - You could also very well do both the forward **and** backward passes in Jax! To do this, [use the 'manual optimization' mode of PyTorch-Lightning](https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html) and perform the parameter updates yourself. For the rest of Lightning to work, just make sure to store the parameters as torch.nn.Parameters. An example of how to do this will be added shortly. -### Jax Network +## Using PyTorch-Lightning to train a Jax network -{{ inline('project.algorithms.jax_example.CNN') }} +If you'd like to use Jax in your network or learning algorithm, while keeping the same style of +training loop as usual, you can! -### Jax Algorithm +- Use Jax for the forward / backward passes, the parameter updates, dataset preprocessing, etc. +- Leave the training loop / callbacks / logging / checkpointing / etc to Lightning -{{ inline('project.algorithms.jax_example.JaxExample') }} +The [lightning.Trainer][lightning.pytorch.trainer.trainer.Trainer] will not be able to tell that you're using Jax! -### Configs +**Take a look at [this image classification example that uses a Jax network](../examples/jax_sl_example.md).** -#### JaxExample algorithm config -{{ inline('project/configs/algorithm/jax_example.yaml') }} +## End-to-end training in Jax: the `JaxTrainer` -## Running the example +The `JaxTrainer`, used in the [Jax RL Example](../examples/jax_rl_example.md), follows a similar structure as the lightning Trainer. However, instead of training LightningModules, it trains `JaxModule`s. -```console -$ python project/main.py algorithm=jax_example network=jax_cnn datamodule=cifar10 -``` +The "algorithm" needs to match the `JaxModule` protocol: +- `JaxModule.training_step`: train using a batch of data diff --git a/docs/features/remote_slurm_launcher.md b/docs/features/remote_slurm_launcher.md new file mode 100644 index 00000000..fe56a376 --- /dev/null +++ b/docs/features/remote_slurm_launcher.md @@ -0,0 +1,65 @@ +# Remote Slurm Submitit Launcher + +> πŸ”₯ NOTE: This is a feature that is entirely unique to this template! πŸ”₯ + +This template includes a custom submitit launcher, that can be used to launch jobs on *remote* slurm clusters. +This allows you to develop code locally, and easily ship it to a different cluster. +The only prerequisite is that you must have `ssh` access to the remote cluster. + +Under the hood, this uses a [custom `remote-slurm-executor` submitit plugin](https://github.com/lebrice/remote-slurm-executor). + + +This feature allows you to launch jobs on remote slurm clusters using two config groups: + +- The `resources` config group is used to select the job resources: + - `cpu`: CPU job + - `gpu`: GPU job +- The `cluster` config group controls where to run the job: + - `current`: Run on the current cluster. Use this if you're already on a SLURM cluster (e.g. when using `mila code`). This uses the usual `submitit_slurm` launcher. + - `mila`: Launches the job on the Mila cluster. + - `narval`: Remotely launches the job on the Narval cluster + - `cedar`: Remotely launches the job on the Cedar cluster + - `beluga`: Remotely launches the job on the Beluga cluster + + +## Examples + +This assumes that you've already setup SSH access to the clusters (for example using `mila init`). + + +### Local machine -> Mila + +```bash +python project/main.py experiment=example resources=gpu cluster=mila +``` + +### Local machine -> DRAC cluster (narval) + +```bash +python project/main.py experiment=example resources=gpu cluster=narval +``` + + +### Mila -> DRAC cluster (narval) + +This assumes that you've already setup SSH access from `mila` to the DRAC clusters. + +Note that command is about the same as [above](#local-machine---drac-cluster-narval) + +```bash +python project/main.py experiment=example resources=gpu cluster=narval +``` + + +!!! warning + + If you want to launch jobs on a remote cluster, it is (currently) necessary to place the "resources" config **before** the "cluster" config on the command-line. + + +## Launching jobs on the current SLURM cluster + +If you develop on a SLURM cluster, you can use the `cluster=current`, or simply omit the `cluster` config group and only use a config from the `resources` group. + +```bash +(mila) $ python project/main.py experiment=example resources=gpu cluster=current +``` diff --git a/docs/features/testing.md b/docs/features/testing.md index d4b55b44..8e621fd1 100644 --- a/docs/features/testing.md +++ b/docs/features/testing.md @@ -25,7 +25,7 @@ This template comes with some [easy-to-use test suites](#test-suites) as well as - [ ] Describe the Github Actions workflows that come with the template, and how to setup a self-hosted runner for template forks. - [ ] Add links to relevant documentation --> -## :fire: Automated testing on SLURM clusters with GitHub CI +## Automated testing on SLURM clusters with GitHub CI > πŸ”₯ NOTE: This is a feature that is entirely unique to this template! πŸ”₯ diff --git a/docs/index.md b/docs/index.md index 2f6b8f56..a8527800 100644 --- a/docs/index.md +++ b/docs/index.md @@ -6,7 +6,7 @@ [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/mila-iqia/ResearchTemplate#license) !!! note "Work-in-Progress" - Please note: This is a Work-in-Progress. The goal is to make a first release by the end of summer 2024. + Please note: This is a Work-in-Progress. The goal is to make a first release by the end of fall 2024. This is a research project template. It is meant to be a starting point for ML researchers at [Mila](https://mila.quebec/en). @@ -18,10 +18,10 @@ For more context, see [this introduction to the project.](intro.md). --- - [Get started quickly](install.md) with [a single installation script](#) and get up + [Get started quickly](#starting-a-new-project) with [a single installation script](#) and get up and running in minutes - [:octicons-arrow-right-24: Getting started](install.md) + [:octicons-arrow-right-24: Getting started](#starting-a-new-project) - :test_tube:{ .lg .middle } __Well-tested, robust codebase__ @@ -49,7 +49,7 @@ For more context, see [this introduction to the project.](intro.md). 1. The source code for the example is available [here](https://github.com/mila-iqia/ResearchTemplate/blob/master/project/algorithms/example.py) - [:octicons-arrow-right-24: Check out the examples here](examples/examples.md) + [:octicons-arrow-right-24: Check out the examples here](examples/index.md) -## Project layout +## Developing inside a container (advanced) -``` -pyproject.toml # Project metadata and dependencies -project/ - main.py # main entry-point - algorithms/ # learning algorithms - datamodules/ # datasets, processing and loading - networks/ # Neural networks used by algorithms - configs/ # configuration files -docs/ # documentation -conftest.py # Test fixtures and utilities -``` +This repo provides a [Devcontainer](https://code.visualstudio.com/docs/remote/containers) configuration for [Visual Studio Code](https://code.visualstudio.com/) to use a Docker container as a pre-configured development environment. This avoids struggles setting up a development environment and makes them reproducible and consistent. + +If that sounds useful to you, we recommend you first make yourself familiar with the [container tutorials](https://code.visualstudio.com/docs/remote/containers-tutorial) if you want to use them. The devcontainer.json file assumes that you have a GPU locally by default. If not, you can simply comment out the "--gpus" flag in the `.devcontainer/devcontainer.json` file. + + +1. Setup Docker on your local machine + + On an Linux machine where you have root access, you can install Docker using the following commands: + + ```bash + curl -fsSL https://get.docker.com -o get-docker.sh + sudo sh get-docker.sh + ``` + + On Windows or Mac, follow [these installation instructions](https://code.visualstudio.com/docs/remote/containers#_installation) + +2. (optional) Install the [nvidia-container-toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) to use your local machine's GPU(s). + +3. Install the [Dev Containers extension](vscode:extension/ms-vscode-remote.remote-containers) for Visual Studio Code. + +4. When opening repository in Visual Studio Code, you should be prompted to reopen the repository in a container: + + ![VsCode popup image](https://github.com/mila-iqia/ResearchTemplate/assets/13387299/37d00ce7-1214-44b2-b1d6-411ee286999f) + + Alternatively, you can open the command palette (Ctrl+Shift+P) and select `Dev Containers: Rebuild and Reopen in Container`. diff --git a/docs/install.md b/docs/install.md deleted file mode 100644 index e23aa6c0..00000000 --- a/docs/install.md +++ /dev/null @@ -1,64 +0,0 @@ -# Installation instructions - -There are two ways to install this project - -1. Using [uv](https://docs.astral.sh/uv/) -2. Using a development container (recommended if you are able to install Docker on your machine) - -## Installation - -1. Clone the repository: - - ```bash - git clone https://www.github.com/mila-iqia/ResearchTemplate - cd ResearchTemplate - ``` - -2. Installing dependencies - - You can install the package using `pip install -e .`, but we recommend using [uv](https://docs.astral.sh/uv/) - package manager. This makes it easier to switch python versions and to add or change the dependencies later on. - - 1. On your machine: - - ```console - curl -LsSf https://astral.sh/uv/install.sh | sh - source ~/.cargo/env - uv sync # Creates a virtual environment and installs dependencies in it. - ``` - - 2. On the Mila cluster: - - If you're on the `mila` cluster, you can run this setup script (on a *compute* node): - - ```console - # Get a compute node to run an interactive job: - salloc --gres=gpu:1 --cpus-per-task=4 --mem=16G --time=1:00:00 - # Run the installation script. - scripts/mila_setup.sh - ``` - -## Using a development container - -This repo provides a [Devcontainer](https://code.visualstudio.com/docs/remote/containers) configuration for [Visual Studio Code](https://code.visualstudio.com/) to use a Docker container as a pre-configured development environment. This avoids struggles setting up a development environment and makes them reproducible and consistent. Make yourself familiar with the [container tutorials](https://code.visualstudio.com/docs/remote/containers-tutorial) if you want to use them. In order to use GPUs, you can enable them within the `.devcontainer/devcontainer.json` file. - -1. Setup Docker on your local machine - - On an Linux machine where you have root access, you can install Docker using the following commands: - - ```bash - curl -fsSL https://get.docker.com -o get-docker.sh - sudo sh get-docker.sh - ``` - - On Windows or Mac, follow [these installation instructions](https://code.visualstudio.com/docs/remote/containers#_installation) - -2. (optional) Install the [nvidia-container-toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) to use your local machine's GPU(s). - -3. Install the [Dev Containers extension](vscode:extension/ms-vscode-remote.remote-containers) for Visual Studio Code. - -4. When opening repository in Visual Studio Code, you should be prompted to reopen the repository in a container: - - ![VsCode popup image](https://github.com/mila-iqia/ResearchTemplate/assets/13387299/37d00ce7-1214-44b2-b1d6-411ee286999f) - - Alternatively, you can open the command palette (Ctrl+Shift+P) and select `Dev Containers: Rebuild and Reopen in Container`. diff --git a/docs/intro.md b/docs/intro.md index f0781dea..5b6b9304 100644 --- a/docs/intro.md +++ b/docs/intro.md @@ -1,8 +1,8 @@ -# Introduction +# Why use this template? -## Why should you use this template? -### Why should you use *a* template in the first place? + +## Why should you use *a* template in the first place? For many good reasons, which are very well described [here in a similar project](https://cookiecutter-data-science.drivendata.org/why/)! 😊 @@ -13,7 +13,7 @@ Other good reads: - [https://12factor.net/](https://12factor.net/) - [https://github.com/ashleve/lightning-hydra-template/tree/main?tab=readme-ov-file#main-ideas](https://github.com/ashleve/lightning-hydra-template/tree/main?tab=readme-ov-file#main-ideas) -### Why should you use *this* template (instead of another)? +## Why should you use *this* template (instead of another)? You are welcome (and encouraged) to use other similar templates which, at the time of writing this, have significantly better documentation. However, there are several advantages to using this particular template: @@ -29,10 +29,26 @@ You are welcome (and encouraged) to use other similar templates which, at the ti This template is aimed for ML researchers that run their jobs on SLURM clusters. The target audience is researchers and students at [Mila](https://mila.quebec). This template should still be useful for others outside of Mila that use PyTorch-Lightning and Hydra. -## Main concepts - -### Datamodule - -### Network - -### Algorithm +## Project layout + +``` +pyproject.toml # Project metadata and dependencies +project/ + main.py # main entry-point + algorithms/ # learning algorithms + datamodules/ # datasets, processing and loading + networks/ # Neural networks used by algorithms + configs/ # configuration files +docs/ # documentation +conftest.py # Test fixtures and utilities +``` + +## Libraries used + +This project makes use of the following libraries: + +- [Hydra](https://hydra.cc/) is used to configure the project. It allows you to define configuration files and override them from the command line. +- [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) is used to as the training framework. It provides a high-level interface to organize ML research code. + - πŸ”₯ Please note: You can also use [Jax](https://jax.readthedocs.io/en/latest/) with this repo, as described in the [Jax example](features/jax.md) πŸ”₯ +- [Weights & Biases](https://wandb.ai) is used to log metrics and visualize results. +- [pytest](https://docs.pytest.org/en/stable/) is used for testing. diff --git a/mkdocs.yml b/mkdocs.yml index 84743597..5fec699a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -12,10 +12,18 @@ extra_css: theme: name: material features: + - toc.follow + - toc.integrate - navigation.instant + # - navigation.indexes + # - navigation.path + # - navigation.expand + # - navigation.tabs + # - navigation.tabs.sticky - navigation.instant.prefetch - navigation.instant.preview - content.code.copy + - navigation.tracking palette: # Palette toggle for automatic mode - media: "(prefers-color-scheme)" @@ -61,6 +69,11 @@ markdown_extensions: - name: mermaid class: mermaid format: !!python/name:pymdownx.superfences.fence_code_format + - pymdownx.tabbed: + alternate_style: true + slugify: !!python/object/apply:pymdownx.slugs.slugify + kwds: + case: lower plugins: - search @@ -88,6 +101,7 @@ plugins: - https://docs.python.org/3/objects.inv - https://pytorch.org/docs/stable/objects.inv - https://jax.readthedocs.io/en/latest/objects.inv + - https://mit-ll-responsible-ai.github.io/hydra-zen/objects.inv options: docstring_style: google members_order: source @@ -106,15 +120,10 @@ plugins: video_controls: True css_style: width: "100%" + + # - pymdownx.details # todo: take a look at https://github.com/drivendataorg/cookiecutter-data-science/blob/master/docs/mkdocs.yml # - admonition -# - pymdownx.details -# - pymdownx.superfences -# - pymdownx.tabbed: -# alternate_style: true -# slugify: !!python/object/apply:pymdownx.slugs.slugify -# kwds: -# case: lower # - tables # - toc: # toc_depth: 2 diff --git a/project/algorithms/example.py b/project/algorithms/example.py index aa15d8a6..fc11ff0f 100644 --- a/project/algorithms/example.py +++ b/project/algorithms/example.py @@ -13,8 +13,8 @@ import torch from hydra_zen.typing import Builds, PartialBuilds -from lightning import LightningModule from lightning.pytorch.callbacks.callback import Callback +from lightning.pytorch.core import LightningModule from torch import Tensor from torch.nn import functional as F from torch.optim.optimizer import Optimizer @@ -51,7 +51,8 @@ def __init__( datamodule: Object used to load train/val/test data. See the lightning docs for [LightningDataModule][lightning.pytorch.core.datamodule.LightningDataModule] for more info. - network: The config of the network to instantiate and train. + network: + The config of the network to instantiate and train. optimizer: The config for the Optimizer. Instantiating this will return a function \ (a [functools.partial][]) that will create the Optimizer given the hyper-parameters. init_seed: The seed to use when initializing the weights of the network. @@ -89,6 +90,7 @@ def __init__( _ = self.network(self.example_input_array) def forward(self, input: Tensor) -> Tensor: + """Forward pass of the network.""" logits = self.network(input) return logits @@ -116,6 +118,10 @@ def shared_step( return {"loss": loss, "logits": logits, "y": y} def configure_optimizers(self): + """Creates the optimizers. + + See [`lightning.pytorch.core.LightningModule.configure_optimizers`][] for more information. + """ # Instantiate the optimizer config into a functools.partial object. optimizer_partial = instantiate(self.optimizer_config) # Call the functools.partial object, passing the parameters as an argument. @@ -124,6 +130,7 @@ def configure_optimizers(self): return optimizer def configure_callbacks(self) -> Sequence[Callback] | Callback: + """Creates callbacks to be used by default during training.""" return [ ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes) ] diff --git a/project/algorithms/testsuites/algorithm.py b/project/algorithms/testsuites/algorithm.py deleted file mode 100644 index e2e84a65..00000000 --- a/project/algorithms/testsuites/algorithm.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Protocol, TypedDict - -import torch -from lightning import LightningDataModule, LightningModule, Trainer -from torch import Tensor -from typing_extensions import NotRequired, TypeVar - -from project.utils.typing_utils import PyTree -from project.utils.typing_utils.protocols import DataModule, Module - - -class StepOutputDict(TypedDict, total=False): - """A dictionary that shows what an Algorithm can output from - `training/validation/test_step`.""" - - loss: NotRequired[Tensor | float] - """Optional loss tensor that can be returned by those methods.""" - - -BatchType = TypeVar("BatchType", bound=PyTree[torch.Tensor], contravariant=True) -StepOutputType = TypeVar("StepOutputType", bound=StepOutputDict, covariant=True) - - -class Algorithm(Module, Protocol[BatchType, StepOutputType]): - """Protocol that adds more type information to the `lightning.LightningModule` class. - - This adds some type information on top of the LightningModule class, namely: - - `BatchType`: The type of batch that is produced by the dataloaders of the datamodule - - `StepOutputType`, the output type created by the step methods. - - The networks themselves are created separately and passed as a constructor argument. This is - meant to make it easier to compare different learning algorithms on the same network - architecture. - """ - - datamodule: LightningDataModule | DataModule[BatchType] - network: Module - - def __init__( - self, - *, - datamodule: DataModule[BatchType], - network: Module, - ): - super().__init__() - self.datamodule = datamodule - self.network = network - self.trainer: Trainer - - training_step = LightningModule.training_step - # validation_step = LightningModule.validation_step diff --git a/project/algorithms/testsuites/algorithm_tests.py b/project/algorithms/testsuites/algorithm_tests.py index a0d45328..fea4825f 100644 --- a/project/algorithms/testsuites/algorithm_tests.py +++ b/project/algorithms/testsuites/algorithm_tests.py @@ -20,7 +20,6 @@ from project.configs.config import Config from project.experiment import instantiate_algorithm -from project.utils.testutils import ParametrizedFixture from project.utils.typing_utils import PyTree, is_sequence_of from project.utils.typing_utils.protocols import DataModule @@ -48,7 +47,7 @@ class LearningAlgorithmTests(Generic[AlgorithmType], ABC): See the [project.algorithms.example_test][] module for an example. """ - algorithm_config: ParametrizedFixture[str] + # algorithm_config: ParametrizedFixture[str] def test_initialization_is_deterministic( self, diff --git a/project/configs/experiment/albert-cola-glue.yaml b/project/configs/experiment/hf_example.yaml similarity index 100% rename from project/configs/experiment/albert-cola-glue.yaml rename to project/configs/experiment/hf_example.yaml diff --git a/project/conftest.py b/project/conftest.py index d365fc2b..c0903f65 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -27,34 +27,28 @@ datamodule_config[ datamodule_config ] -- 'datamodule=A' --> command_line_arguments -network_config[ - network_config: -] -- 'network=B' --> command_line_arguments algorithm_config[ algorithm_config -] -- 'algorithm=C' --> command_line_arguments +] -- 'algorithm=B' --> command_line_arguments overrides[ overrides ] -- 'seed=123' --> command_line_arguments command_line_arguments[ command_line_arguments -] -- load configs for 'datamodule=A network=B algorithm=C seed=123' --> experiment_dictconfig +] -- load configs for 'datamodule=A algorithm=B seed=123' --> experiment_dictconfig experiment_dictconfig[ experiment_dictconfig ] -- instantiate objects from configs --> experiment_config experiment_config[ experiment_config -] --> datamodule & network & algorithm +] --> datamodule & algorithm datamodule[ datamodule ] --> algorithm -network[ - network -] --> algorithm algorithm[ algorithm ] -- is used by --> some_test -algorithm & network & datamodule -- is used by --> some_other_test +algorithm & datamodule -- is used by --> some_other_test ``` """ diff --git a/project/datamodules/datamodules_test.py b/project/datamodules/datamodules_test.py index c0461661..9017889b 100644 --- a/project/datamodules/datamodules_test.py +++ b/project/datamodules/datamodules_test.py @@ -14,7 +14,7 @@ ImageClassificationDataModule, ) from project.datamodules.vision import VisionDataModule -from project.utils.testutils import run_for_all_datamodules +from project.utils.testutils import run_for_all_configs_in_group from project.utils.typing_utils import is_sequence_of @@ -37,7 +37,7 @@ ], ) @pytest.mark.parametrize("overrides", ["algorithm=no_op"], indirect=True) -@run_for_all_datamodules() +@run_for_all_configs_in_group(group_name="datamodule") def test_first_batch( datamodule: LightningDataModule, request: pytest.FixtureRequest, diff --git a/project/datamodules/image_classification/imagenet.py b/project/datamodules/image_classification/imagenet.py index b32debec..bcfaa3e1 100644 --- a/project/datamodules/image_classification/imagenet.py +++ b/project/datamodules/image_classification/imagenet.py @@ -83,17 +83,16 @@ def __init__( ): """Creates an ImageNet datamodule (doesn't load or prepare the dataset yet). - Parameters - ---------- - data_dir: path to the imagenet dataset file - val_split: save `val_split`% of the training data *of each class* for validation. - image_size: final image size - num_workers: how many data workers - batch_size: batch_size - shuffle: If true shuffles the data every epoch - pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before \ - returning them - drop_last: If true drops the last incomplete batch + Parameters: + data_dir: path to the imagenet dataset file + val_split: save `val_split`% of the training data *of each class* for validation. + image_size: final image size + num_workers: how many data workers + batch_size: batch_size + shuffle: If true shuffles the data every epoch + pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before \ + returning them + drop_last: If true drops the last incomplete batch """ self.image_size = image_size super().__init__( diff --git a/project/main.py b/project/main.py index b0fa985f..5ceb2650 100644 --- a/project/main.py +++ b/project/main.py @@ -8,17 +8,18 @@ from pathlib import Path import hydra +import lightning import omegaconf import rich -import wandb from lightning import LightningDataModule from omegaconf import DictConfig from project.configs import add_configs_to_hydra_store from project.configs.config import Config -from project.experiment import Experiment, setup_experiment +from project.trainers.jax_trainer import JaxTrainer from project.utils.env_vars import REPO_ROOTDIR from project.utils.hydra_utils import resolve_dictconfig +from project.utils.typing_utils.protocols import DataModule from project.utils.utils import print_config logger = get_logger(__name__) @@ -55,57 +56,72 @@ def main(dict_config: DictConfig) -> dict: config: Config = resolve_dictconfig(dict_config) - experiment: Experiment = setup_experiment(config) + from project.experiment import ( + instantiate_algorithm, + instantiate_datamodule, + instantiate_trainer, + seed_rng, + setup_logging, + ) + + experiment_config = config + setup_logging(experiment_config) + seed_rng(experiment_config) + + trainer = instantiate_trainer(experiment_config) + + datamodule = instantiate_datamodule(experiment_config.datamodule) + + algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule) + # experiment: Experiment = setup_experiment(config) + + import wandb + if wandb.run: wandb.run.config.update({k: v for k, v in os.environ.items() if k.startswith("SLURM")}) wandb.run.config.update( omegaconf.OmegaConf.to_container(dict_config, resolve=False, throw_on_missing=True) ) - metric_name, objective, _metrics = run(experiment) - assert objective is not None - return dict(name=metric_name, type="objective", value=objective) - # return {metric_name: objective} - - -def run(experiment: Experiment) -> tuple[str, float | None, dict]: - """Run the experiment: training followed by evaluation. - - Returns the metrics of the evaluation. - """ - # Train the model using the dataloaders of the datamodule: # The Algorithm gets to "wrap" the datamodule if it wants. This might be useful in the # case of RL, where we need to set the actor to use in the environment, as well as # potentially adding Wrappers on top of the environment, or having a replay buffer, etc. # TODO: Add ckpt_path argument to resume a training run. - datamodule = getattr(experiment.algorithm, "datamodule", experiment.datamodule) + datamodule = getattr(algorithm, "datamodule", datamodule) if datamodule is None: - # todo: missing `rng` argument. from project.trainers.jax_trainer import JaxTrainer - if isinstance(experiment.trainer, JaxTrainer): + if isinstance(trainer, JaxTrainer): import jax.random - experiment.trainer.fit(experiment.algorithm, rng=jax.random.key(0)) + trainer.fit(algorithm, rng=jax.random.key(0)) # type: ignore else: - experiment.trainer.fit(experiment.algorithm) + trainer.fit(algorithm) else: assert isinstance(datamodule, LightningDataModule) - experiment.trainer.fit( - experiment.algorithm, + trainer.fit( + algorithm, datamodule=datamodule, ) - metric_name, error, metrics = evaluation(experiment) + metric_name, error, _metrics = evaluation( + trainer=trainer, datamodule=datamodule, algorithm=algorithm + ) + if wandb.run: wandb.finish() - return metric_name, error, metrics + + assert error is not None + return dict(name=metric_name, type="objective", value=error) + # return {metric_name: objective} -def evaluation(experiment: Experiment) -> tuple[str, float | None, dict]: +def evaluation( + trainer: JaxTrainer | lightning.Trainer, datamodule: DataModule, algorithm +) -> tuple[str, float | None, dict]: """Return the classification error. By default, if validation is to be performed, returns the validation error. Returns the @@ -115,14 +131,14 @@ def evaluation(experiment: Experiment) -> tuple[str, float | None, dict]: # TODO Probably log the hydra config with something like this: # exp.trainer.logger.log_hyperparams() # When overfitting on a single batch or only training, we return the train error. - if (experiment.trainer.limit_val_batches == experiment.trainer.limit_test_batches == 0) or ( - experiment.trainer.overfit_batches == 1 # type: ignore + if (trainer.limit_val_batches == trainer.limit_test_batches == 0) or ( + trainer.overfit_batches == 1 # type: ignore ): # We want to report the training error. metrics = { - **experiment.trainer.logged_metrics, - **experiment.trainer.callback_metrics, - **experiment.trainer.progress_bar_metrics, + **trainer.logged_metrics, + **trainer.callback_metrics, + **trainer.progress_bar_metrics, } rich.print(metrics) if "train/accuracy" in metrics: @@ -141,18 +157,14 @@ def evaluation(experiment: Experiment) -> tuple[str, float | None, dict]: f"Here are the available metric names:\n" f"{list(metrics.keys())}" ) - assert isinstance(experiment.datamodule, LightningDataModule) + assert isinstance(datamodule, LightningDataModule) - if experiment.trainer.limit_val_batches != 0: - results = experiment.trainer.validate( - model=experiment.algorithm, datamodule=experiment.datamodule - ) + if trainer.limit_val_batches != 0: + results = trainer.validate(model=algorithm, datamodule=datamodule) results_type = "val" else: warnings.warn(RuntimeWarning("About to use the test set for evaluation!")) - results = experiment.trainer.test( - model=experiment.algorithm, datamodule=experiment.datamodule - ) + results = trainer.test(model=algorithm, datamodule=datamodule) results_type = "test" if results is None: diff --git a/project/networks/fcnet.py b/project/networks/fcnet.py index 549d8e59..9830a294 100644 --- a/project/networks/fcnet.py +++ b/project/networks/fcnet.py @@ -6,8 +6,6 @@ import pydantic from torch import nn -from project.networks.layers.layers import Flatten - class FcNet(nn.Sequential): @pydantic.dataclasses.dataclass @@ -66,7 +64,7 @@ def __init__( ): block_layers = [] if block_index == 0: - block_layers.append(Flatten()) + block_layers.append(nn.Flatten()) if in_dims is None: block_layers.append(nn.LazyLinear(out_dims, bias=self.hparams.use_bias)) diff --git a/project/networks/layers/__init__.py b/project/networks/layers/__init__.py deleted file mode 100644 index 79402cd5..00000000 --- a/project/networks/layers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .layers import Branch, Flatten, Lambda, Merge -from .sequential import Sequential - -__all__ = [ - "Flatten", - "Lambda", - "Branch", - "Merge", - "Sequential", -] diff --git a/project/networks/layers/layers.py b/project/networks/layers/layers.py deleted file mode 100644 index a82fd1a0..00000000 --- a/project/networks/layers/layers.py +++ /dev/null @@ -1,212 +0,0 @@ -"""Simple layers you might find useful when creating new networks.""" - -from __future__ import annotations - -import functools -import operator -import typing -from collections.abc import Callable -from logging import getLogger as get_logger -from typing import Any - -import torch -from torch import Tensor, nn -from typing_extensions import ParamSpec - -from project.utils.typing_utils import Module, OutT, T, is_sequence_of - -P = ParamSpec("P", default=[Tensor]) - -logger = get_logger(__name__) - - -class Flatten(nn.Flatten): - def forward(self, input: Tensor): - # NOTE: The input Should have at least 2 dimensions for `nn.Flatten` to work, but it isn't - # the case with a single observation from a single environment. - if input.ndim <= 1: - return input - if input.is_nested: - return torch.nested.as_nested_tensor( - [input_i.reshape([input_i.shape[0], -1]) for input_i in input.unbind()] - ) - return super().forward(input) - - -class Lambda(nn.Module, Module[..., OutT]): - """A simple nn.Module wrapping a function. - - Any positional or keyword arguments passed to the constructor are saved into a `args` and - `kwargs` attribute. During the forward pass, these arguments are then bound to the function `f` - using a `functools.partial`. Any additional arguments to the `forward` method are then passed - to the partial. - """ - - __match_args__ = ("f",) - - def __init__( - self, - f: Callable[..., OutT], - *args: Tensor | nn.Module | Any, - **kwargs: Tensor | nn.Module | Any, - ): - super().__init__() - self.f: Callable[..., OutT] = f - self.args: nn.ParameterList | nn.ModuleList | tuple[Any, ...] - if not args: - self.args = args - elif is_sequence_of(args, Tensor): - self.args = nn.ParameterList( - arg - if isinstance(arg, nn.Parameter) - else nn.Parameter(arg, requires_grad=arg.requires_grad) - for arg in args - ) - elif is_sequence_of(args, nn.Module): - self.args = nn.ModuleList(args) - else: - self.args = args - # raise NotImplementedError(f"Args need to either be tensors or modules, not {args}") - self.kwargs: nn.ParameterDict | nn.ModuleDict | dict[str, Any] - if not kwargs: - self.kwargs = kwargs - elif is_sequence_of(kwargs.values(), Tensor): - self.kwargs = nn.ParameterDict( - {k: nn.Parameter(v, requires_grad=v.requires_grad) for k, v in kwargs.items()} # type: ignore - ) - elif is_sequence_of(kwargs.values(), nn.Module): - self.kwargs = nn.ModuleDict(kwargs) # type: ignore - else: - self.kwargs = kwargs - # raise NotImplementedError(f"kwargs need to either be tensors or modules, not {kwargs}") - - def forward(self, *args, **kwargs) -> OutT: - f = functools.partial(self.f, *self.args, **self.kwargs) - return f(*args, **kwargs) - - def __getattr__(self, name: str) -> Any: - if name in ["f", "args", "kwargs"]: - return super().__getattr__(name) - if name in (kwargs := self.kwargs): - logger.debug(f"Getting {name} from kwargs: {kwargs}") - return kwargs[name] - return super().__getattr__(name) - - def extra_repr(self) -> str: - # TODO: Redo, now that we have `args` and `kwargs` in containers. - func_message: str = "" - - if isinstance(self.f, nn.Module): - func_message = "" - elif isinstance(self.f, functools.partial): - assert not self.args and not self.kwargs - partial_args: list[str] = [] - if not isinstance(self.f.func, nn.Module): - partial_args.append(f"{self.f.func.__module__}.{self.f.func.__name__}") - else: - partial_args.append(repr(self.f.func)) - partial_args.extend(repr(x) for x in self.f.args) - partial_args.extend(f"{k}={v!r}" for (k, v) in self.f.keywords.items()) - - partial_qualname = type(self.f).__qualname__ - if type(self.f).__module__ == "functools": - partial_qualname = f"functools.{partial_qualname}" - func_message = f"f={partial_qualname}(" + ", ".join(partial_args) + ")" - elif hasattr(self.f, "__module__") and hasattr(self.f, "__name__"): - module = self.f.__module__ - if module != "builtins": - module += "." - func_message = f"f={module}{self.f.__name__}" - else: - func_message = f"f={self.f}" - - args_message = "" - if isinstance(self.args, nn.ParameterList | nn.ModuleList): - args_message = "" - elif self.args: - args_message = ", ".join(f"{arg!r}" for (arg) in self.args) - - kwargs_message = "" - if isinstance(self.kwargs, nn.ParameterDict | nn.ModuleDict): - kwargs_message = "" - elif self.kwargs: - kwargs_message = ", ".join(f"{k}={v!r}" for (k, v) in self.kwargs.items()) - - message = "" - if func_message: - message += func_message - if args_message: - message += (", " if message else "") + args_message - if kwargs_message: - message += (", " if message else "") + kwargs_message - return message - - if typing.TYPE_CHECKING: - __call__ = forward - - -class Branch(nn.Module, Module[P, dict[str, T]]): - """Module that executes each branch and returns a dictionary with the results of each.""" - - def __init__(self, **named_branches: Module[P, T]) -> None: - super().__init__() - self.named_branches = named_branches - self.named_branches = nn.ModuleDict(self.named_branches) # type: ignore - - def forward(self, *args: P.args, **kwargs: P.kwargs) -> dict[str, T]: - outputs: dict[str, T] = {} - # note: could potentially have each branch on a different cuda device? - # or maybe we could put each branch in a different cuda stream? - for name, branch in self.named_branches.items(): - branch_output = branch(*args, **kwargs) - outputs[name] = branch_output - return outputs - - if typing.TYPE_CHECKING: - __call__ = forward - - -class Merge(nn.Module, Module[[tuple[Tensor, ...] | dict[str, Tensor]], OutT]): - """Unpacks the output of the previous block (Branch) before it is fed to the wrapped module.""" - - __match_args__ = ("f",) - - def __init__(self, f: Module[..., OutT]) -> None: - """Unpacks the output of a previous block before it is fed to `f`.""" - super().__init__() - self.f = f - - def forward(self, packed_inputs: tuple[Tensor, ...] | dict[str, Tensor]) -> OutT: - if isinstance(packed_inputs, tuple | list): - return self.f(*packed_inputs) # type: ignore - else: - return self.f(**packed_inputs) # type: ignore - - if typing.TYPE_CHECKING: - __call__ = forward - - -class Sample(Lambda, Module[[torch.distributions.Distribution], Tensor]): - """Layer that samples from a distribution.""" - - def __init__(self, differentiable: bool = False) -> None: - super().__init__(f=operator.methodcaller("rsample" if differentiable else "sample")) - self._differentiable = differentiable - - @property - def differentiable(self) -> bool: - return self._differentiable - - @differentiable.setter - def differentiable(self, value: bool) -> None: - self._differentiable = value - self.f = operator.methodcaller("rsample" if value else "sample") - - def forward(self, dist: torch.distributions.Distribution) -> Tensor: - return super().forward(dist) - - def extra_repr(self) -> str: - return f"differentiable={self.differentiable}" - - if typing.TYPE_CHECKING: - __call__ = forward diff --git a/project/networks/layers/sequential.py b/project/networks/layers/sequential.py deleted file mode 100644 index b4427257..00000000 --- a/project/networks/layers/sequential.py +++ /dev/null @@ -1,97 +0,0 @@ -from __future__ import annotations - -from collections import OrderedDict -from collections.abc import Iterator, Sequence -from typing import Any, overload - -from torch import ( - Tensor, # noqa: F401 - nn, -) -from torch._jit_internal import _copy_to_script_wrapper -from typing_extensions import TypeVar - -from project.utils.typing_utils import Module - -ModuleType = TypeVar("ModuleType", bound=Module[..., Any], default=Module[[Tensor], Tensor]) - - -class Sequential(nn.Sequential, Sequence[ModuleType]): - # Small typing fixes for torch.nn.Sequential - - _modules: dict[str, ModuleType] - - @overload - def __init__(self, *args: ModuleType) -> None: ... - - @overload - def __init__(self, **kwargs: ModuleType) -> None: ... - - @overload - def __init__(self, arg: dict[str, ModuleType]) -> None: ... - - def __init__(self, *args, **kwargs): - if args: - assert not kwargs, "can only use *args or **kwargs, not both" - if len(args) == 1 and isinstance(args[0], dict): - new_args = (OrderedDict(args[0]),) - else: - new_args = [] - for arg in args: - if not isinstance(arg, nn.Module) and callable(arg): - from project.algorithms.common.layers import Lambda - - arg = Lambda(arg) - new_args.append(arg) - args = new_args - - if kwargs: - assert not args, "can only use *args or **kwargs, not both" - - from project.algorithms.common.layers import Lambda - - new_kwargs = {} - for name, module in kwargs.items(): - if not isinstance(module, nn.Module) and callable(module): - from project.algorithms.common.layers import Lambda - - module = Lambda(module) - new_kwargs[name] = module - kwargs = new_kwargs - - args = (OrderedDict(kwargs),) - - super().__init__(*args) - self._modules - - @overload - def __getitem__(self, idx: int) -> ModuleType: ... - - @overload - def __getitem__(self, idx: slice) -> Sequential[ModuleType]: ... - - @_copy_to_script_wrapper - def __getitem__(self, idx: int | slice) -> Sequential[ModuleType] | ModuleType: - if isinstance(idx, slice): - # NOTE: Fixing this here, subclass constructors shouldn't be called on getitem with - # slice. - return Sequential(OrderedDict(list(self._modules.items())[idx])) - else: - return self._get_item_by_idx(self._modules.values(), idx) - - def __iter__(self) -> Iterator[ModuleType]: - return super().__iter__() # type: ignore - - def __setitem__(self, idx: int, module: ModuleType) -> None: - # Violates the LSP, but eh. - return super().__setitem__(idx, module) - - def forward(self, *args, **kwargs): - out = None - for i, module in enumerate(self): - if i == 0: - out = module(*args, **kwargs) # type: ignore - else: - out = module(out) # type: ignore - assert out is not None - return out diff --git a/project/utils/__init__.py b/project/utils/__init__.py index bb9141c8..e69de29b 100644 --- a/project/utils/__init__.py +++ b/project/utils/__init__.py @@ -1,6 +0,0 @@ -# Import this patch for https://github.com/mit-ll-responsible-ai/hydra-zen/issues/705 to make sure that it gets applied. -from .utils import default_device - -__all__ = [ - "default_device", -] diff --git a/project/utils/autoref_plugin.py b/project/utils/autoref_plugin.py index 0abbd948..e3d69379 100644 --- a/project/utils/autoref_plugin.py +++ b/project/utils/autoref_plugin.py @@ -114,6 +114,8 @@ def on_page_markdown( matches = re.findall(r"`([^`]+)`", line) for match in matches: thing_name = match + if any(char in thing_name for char in ["/", " ", "-"]): + continue if thing_name in known_object_names: # References like `JaxTrainer` (which are in a module that we're aware of). thing = known_objects_for_this_module[known_object_names.index(thing_name)] diff --git a/project/utils/autoref_plugin_test.py b/project/utils/autoref_plugin_test.py index 1cc4e382..a0504337 100644 --- a/project/utils/autoref_plugin_test.py +++ b/project/utils/autoref_plugin_test.py @@ -32,6 +32,7 @@ ), ("`Trainer`", "[`Trainer`][lightning.pytorch.trainer.trainer.Trainer]"), # since `Trainer` is in the `known_things` list, we add the proper ref. + ("`.devcontainer/devcontainer.json`", "`.devcontainer/devcontainer.json`") ], ) def test_autoref_plugin(input: str, expected: str): diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 01033943..3d31606b 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -5,18 +5,14 @@ import itertools import os import typing -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from logging import getLogger as get_logger -from typing import Any, Generic, TypeVar import pytest -import torch import torchvision.models -from torch import nn from project.datamodules.image_classification.fashion_mnist import FashionMNISTDataModule from project.datamodules.image_classification.mnist import MNISTDataModule -from project.datamodules.vision import VisionDataModule from project.utils.env_vars import NETWORK_DIR from project.utils.hydra_config_utils import ( get_all_configs_in_group, @@ -86,97 +82,10 @@ ) }, } +"""Dict with some default marks to add to tests when some config combinations are present. - -def parametrized_fixture(name: str, values: Sequence, ids=None, **kwargs): - """Small helper function that creates a parametrized pytest fixture for the given values. - - NOTE: When writing a fixture in a test class, use `ParametrizedFixture` instead. - """ - - @pytest.fixture(name=name, params=values, ids=ids or [f"{name}={v}" for v in values], **kwargs) - def _parametrized_fixture(request: pytest.FixtureRequest): - return request.param - - return _parametrized_fixture - - -T = TypeVar("T") - - -class ParametrizedFixture(Generic[T]): - """Small helper function that creates a parametrized pytest fixture for the given values. - - The name of the fixture will be the name that is used for this variable on a class. - - For example: - - ```python - - class TestFoo: - odd = ParametrizedFixture([True, False]) - - def test_something(self, odd: bool): - '''some kind of test that uses odd''' - - # NOTE: This fixture can also be used by other fixtures: - - @pytest.fixture - def some_number(self, odd: bool): - return 1 if odd else 2 - - def test_foo(self, some_number: int): - '''some kind of test that uses some_number''' - ``` - """ - - def __init__(self, values: list[T], name: str | None = None, **fixture_kwargs): - self.values = values - self.fixture_kwargs = fixture_kwargs - self.name = name - - def __set_name__(self, owner: Any, name: str): - self.name = name - - def __get__(self, obj, objtype=None): - assert self.name is not None - fixture_kwargs = self.fixture_kwargs.copy() - fixture_kwargs.setdefault("ids", [f"{self.name}={v}" for v in self.values]) - - @pytest.fixture(name=self.name, params=self.values, **fixture_kwargs) - def _parametrized_fixture_method(request: pytest.FixtureRequest): - return request.param - - return _parametrized_fixture_method - - -def run_for_all_datamodules( - datamodule_names: list[str] | None = None, - datamodule_name_to_marks: dict[str, pytest.MarkDecorator | list[pytest.MarkDecorator]] - | None = None, -): - """Apply this marker to a test to make it run with all available datasets (datamodules). - - The test should use the `datamodule` fixture, either as an input argument to the test - function or indirectly by using a fixture that depends on the `datamodule` fixture. - - Parameters - ---------- - datamodule_names: List of datamodule names to use for tests. \ - By default, lists out the generic datamodules (the datamodules that aren't specific to a - single algorithm, for example the InfGendatamodules of WakeSleep.) - - datamodule_to_marks: Dictionary from datamodule names to pytest marks (e.g. \ - `pytest.mark.xfail`, `pytest.mark.skip`) to use for that particular datamodule. - """ - return run_for_all_configs_in_group( - group_name="datamodule", - config_name_to_marks=datamodule_name_to_marks, - ) - - -def run_for_all_vision_datamodules(): - return run_for_all_configs_of_type("datamodule", VisionDataModule) +For example, ResNet networks can't be applied to the MNIST datasets. +""" def run_for_all_configs_of_type( @@ -233,16 +142,14 @@ def test_something_else(self): # This will cause an error! pass ``` - Parameters - ---------- - arg_name_or_fixture: The name of the argument to parametrize, or a fixture to parametrize \ - indirectly. - values: The values to be used to parametrize the test. + Parameters: + arg_name_or_fixture: The name of the argument to parametrize, or a fixture to parametrize \ + indirectly. + values: The values to be used to parametrize the test. - Returns - ------- - A `pytest.MarkDecorator` that parametrizes the test with the given values only when the argument - is used (directly or indirectly) by the test. + Returns: + A `pytest.MarkDecorator` that parametrizes the test with the given values only when the \ + argument is used (directly or indirectly) by the test. """ if indirect is None: indirect = not isinstance(arg_name_or_fixture, str) @@ -270,14 +177,13 @@ def run_for_all_configs_in_group( The test wrapped test will uses all config from that group if they are used either as an input argument to the test function or if it the input argument to a fixture function. - Parameters - ---------- - datamodule_names: List of datamodule names to use for tests. \ - By default, lists out the generic datamodules (the datamodules that aren't specific to a - single algorithm, for example the InfGendatamodules of WakeSleep.) + Parameters: + group_name: List of datamodule names to use for tests. \ + By default, lists out the generic datamodules (the datamodules that aren't specific \ + to a single algorithm, for example the InfGendatamodules of WakeSleep.) - datamodule_to_marks: Dictionary from datamodule names to pytest marks (e.g. \ - `pytest.mark.xfail`, `pytest.mark.skip`) to use for that particular datamodule. + config_name_to_marks: Dictionary from config names to pytest marks (e.g. \ + `pytest.mark.xfail`, `pytest.mark.skip`) to use for that particular config. """ if config_name_to_marks is None: config_name_to_marks = { @@ -300,15 +206,3 @@ def run_for_all_configs_in_group( ], indirect=True, ) - - -def assert_all_params_initialized(module: nn.Module): - for name, param in module.named_parameters(): - assert not isinstance(param, nn.UninitializedParameter | nn.UninitializedBuffer), name - - -def assert_no_nans_in_params_or_grads(module: nn.Module): - for name, param in module.named_parameters(): - assert not torch.isnan(param).any(), name - if param.grad is not None: - assert not torch.isnan(param.grad).any(), name diff --git a/project/utils/typing_utils/__init__.py b/project/utils/typing_utils/__init__.py index 2fdd1bb6..3070e8d5 100644 --- a/project/utils/typing_utils/__init__.py +++ b/project/utils/typing_utils/__init__.py @@ -1,36 +1,28 @@ +"""Utilities to help annotate the types of values in the project.""" + from __future__ import annotations from collections.abc import Iterable, Mapping, Sequence -from typing import Any, NewType, TypeAlias, TypeGuard +from typing import Any, NewType, TypeGuard -from torch import Tensor -from typing_extensions import TypeVar, TypeVarTuple, Unpack +from typing_extensions import TypeVar -from .protocols import Dataclass, DataModule, Module +from .protocols import DataModule, Module # These are used to show which dim is which. C = NewType("C", int) H = NewType("H", int) W = NewType("W", int) -S = NewType("S", int) -OutT = TypeVar("OutT", default=Tensor, covariant=True) -Ts = TypeVarTuple("Ts", default=Unpack[tuple[Tensor, ...]]) -T = TypeVar("T", default=Tensor) +T = TypeVar("T") K = TypeVar("K") V = TypeVar("V") -NestedDict: TypeAlias = dict[K, V | "NestedDict[K, V]"] NestedMapping = Mapping[K, V | "NestedMapping[K, V]"] PyTree = T | Iterable["PyTree[T]"] | Mapping[Any, "PyTree[T]"] -def is_list_of(object: Any, item_type: type[V] | tuple[type[V], ...]) -> TypeGuard[list[V]]: - """Used to check (and tell the type checker) that `object` is a list of items of this type.""" - return isinstance(object, list) and is_sequence_of(object, item_type) - - def is_sequence_of( object: Any, item_type: type[V] | tuple[type[V], ...] ) -> TypeGuard[Sequence[V]]: @@ -49,8 +41,6 @@ def is_mapping_of(object: Any, key_type: type[K], value_type: type[V]) -> TypeGu __all__ = [ - "HasInputOutputShapes", "Module", - "Dataclass", "DataModule", ] diff --git a/project/utils/typing_utils/protocols.py b/project/utils/typing_utils/protocols.py index ce2151d5..6a6082b4 100644 --- a/project/utils/typing_utils/protocols.py +++ b/project/utils/typing_utils/protocols.py @@ -1,23 +1,19 @@ from __future__ import annotations -import dataclasses import typing from collections.abc import Iterable -from typing import ClassVar, Literal, ParamSpec, Protocol, TypeVar, runtime_checkable +from typing import Literal, ParamSpec, Protocol, TypeVar, runtime_checkable from torch import nn - -class Dataclass(Protocol): - __dataclass_fields__: ClassVar[dict[str, dataclasses.Field]] - - P = ParamSpec("P") -OutT = TypeVar("OutT") +OutT = TypeVar("OutT", covariant=True) @runtime_checkable class Module(Protocol[P, OutT]): + """Small protocol used to help annotate the input/outputs of `torch.nn.Module`s.""" + def forward(self, *args: P.args, **kwargs: P.kwargs) -> OutT: raise NotImplementedError @@ -38,7 +34,7 @@ def __call__(self, *args: P.args, **kwagrs: P.kwargs) -> OutT: ... to = nn.Module().to -BatchType = TypeVar("BatchType") +BatchType = TypeVar("BatchType", covariant=True) @runtime_checkable diff --git a/pyproject.toml b/pyproject.toml index b78d8f6b..e76ecfeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "torchvision>=0.19.0", "rich>=13.7.1", "tqdm>=4.66.5", - "hydra-zen>=0.13.0", + "hydra-zen==0.13.1rc1", "gdown>=5.2.0", "hydra-submitit-launcher>=1.2.0", "hydra-colorlog>=1.2.0", diff --git a/uv.lock b/uv.lock index c0e2b898..f56f33cd 100644 --- a/uv.lock +++ b/uv.lock @@ -1480,16 +1480,16 @@ wheels = [ [[package]] name = "hydra-zen" -version = "0.13.0" +version = "0.13.1rc1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "hydra-core" }, { name = "omegaconf" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/83/a9/747cd74012a33f61fbe7b079e91385214d0d0dc707721cd95afb2c823a75/hydra_zen-0.13.0.tar.gz", hash = "sha256:1b53d74aa1f0baa04fafdac6aba7a94ae40929e7b0a5a5081d8740f74322052d", size = 1329558 } +sdist = { url = "https://files.pythonhosted.org/packages/f9/8e/75255e0f0b8042d84df6afe9d3ab2d8b0f45359f3f7a80c2194e7371de2a/hydra_zen-0.13.1rc1.tar.gz", hash = "sha256:37f577e2615fc5692eab73c954f4f8dc1d01386327d208057cc169fbf5d86e0b", size = 1329925 } wheels = [ - { url = "https://files.pythonhosted.org/packages/33/ca/8f984e7a8e39ed2bf1622a1ae0fe26e65011fed150548a9319217e92925f/hydra_zen-0.13.0-py3-none-any.whl", hash = "sha256:6050b62be96d2a47b2abf0e9c0ebcce1e9a4e259e173870338ab049b833f26cf", size = 103784 }, + { url = "https://files.pythonhosted.org/packages/d7/bd/0061300354aed319076b8564d2137967794f3b886065c2a532db87b51180/hydra_zen-0.13.1rc1-py3-none-any.whl", hash = "sha256:c0d5dfb3a47aaf507e7df1eca561766cdadcd1c047107010d5a746c929c2d45f", size = 103951 }, ] [[package]] @@ -3926,7 +3926,7 @@ requires-dist = [ { name = "hydra-core", specifier = ">=1.3.2" }, { name = "hydra-orion-sweeper", specifier = ">=1.6.4" }, { name = "hydra-submitit-launcher", specifier = ">=1.2.0" }, - { name = "hydra-zen", specifier = ">=0.13.0" }, + { name = "hydra-zen", specifier = "==0.13.1rc1" }, { name = "jax", specifier = "==0.4.33" }, { name = "jax", extras = ["cuda12"], marker = "extra == 'gpu'", specifier = ">=0.4.31" }, { name = "jaxlib", specifier = "==0.4.33" },