From 1e2d330f9250cb82254e1912c96ecefcb019ed02 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin <normandf@mila.quebec> Date: Mon, 18 Nov 2024 19:24:34 +0000 Subject: [PATCH] Fix / rename examples in docs Signed-off-by: Fabrice Normandin <normandf@mila.quebec> --- docs/SUMMARY.md | 6 +++--- ...torch_sl_example.md => image_classification.md} | 0 docs/examples/index.md | 14 +++++++------- ...x_sl_example.md => jax_image_classification.md} | 4 ++-- docs/examples/{jax_rl_example.md => jax_rl.md} | 0 docs/features/jax.md | 4 ++-- 6 files changed, 14 insertions(+), 14 deletions(-) rename docs/examples/{torch_sl_example.md => image_classification.md} (100%) rename docs/examples/{jax_sl_example.md => jax_image_classification.md} (93%) rename docs/examples/{jax_rl_example.md => jax_rl.md} (100%) diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 5dba41f0..a65eb75e 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -7,11 +7,11 @@ * [Thorough automated testing on SLURM clusters](features/testing.md) * features/*.md * [Examples π§ͺ](examples/index.md) - * [Image Classification (β‘)](examples/torch_sl_example.md) - * [Image Classification (jax+β‘)](examples/jax_sl_example.md) + * [Image Classification (β‘)](examples/image_classification.md) + * [Image Classification (jax+β‘)](examples/jax_image_classification.md) * [Text Classification (π€+β‘)](examples/text_classification.md) * [Fine-tuning an LLM (π€+β‘)](examples/llm_finetuning.md) - * [RL (jax)](examples/jax_rl_example.md) + * [Reinforcement Learning (jax)](examples/jax_rl.md) * [Running sweeps](examples/sweeps.md) * [Profiling your codeπ](examples/profiling.md) * examples/*.md diff --git a/docs/examples/torch_sl_example.md b/docs/examples/image_classification.md similarity index 100% rename from docs/examples/torch_sl_example.md rename to docs/examples/image_classification.md diff --git a/docs/examples/index.md b/docs/examples/index.md index ab85abb3..91600c14 100644 --- a/docs/examples/index.md +++ b/docs/examples/index.md @@ -12,10 +12,10 @@ additional_python_references: This template includes examples that use either Jax, PyTorch, or both! -| Example link | Research Area | Reference link | Frameworks | -| ------------------------------------------------- | ------------------------------------------ | ---------------------- | --------------- | -| [Image Classification](torch_sl_example.md) | Supervised Learning (image classification) | `ImageClassifier` | Torch + β‘ | -| [Image Classification (Jax)](jax_sl_example.md) | Supervised Learning (image classification) | `JaxImageClassifier` | Torch + Jax + β‘ | -| [Text Classification](text_classification.md) | NLP (text classification) | `TextClassifier` | Torch + π€ + β‘ | -| [Reinforcement Learning (Jax)](jax_rl_example.md) | RL | `JaxRLExample` | Jax | -| [LLM Fine-tuning](llm_finetuning.md) | NLP (Causal language modeling) | `LLMFineTuningExample` | Torch + π€ + β‘ | +| Example link | Research Area | Reference link | Frameworks | +| --------------------------------------------------------- | ------------------------------------------ | ---------------------- | --------------- | +| [Image Classification](image_classification.md) | Supervised Learning (image classification) | `ImageClassifier` | Torch + β‘ | +| [Image Classification (Jax)](jax_image_classification.md) | Supervised Learning (image classification) | `JaxImageClassifier` | Torch + Jax + β‘ | +| [Text Classification](text_classification.md) | NLP (text classification) | `TextClassifier` | Torch + π€ + β‘ | +| [Reinforcement Learning (Jax)](jax_rl.md) | RL | `JaxRLExample` | Jax | +| [LLM Fine-tuning](llm_finetuning.md) | NLP (Causal language modeling) | `LLMFineTuningExample` | Torch + π€ + β‘ | diff --git a/docs/examples/jax_sl_example.md b/docs/examples/jax_image_classification.md similarity index 93% rename from docs/examples/jax_sl_example.md rename to docs/examples/jax_image_classification.md index 9e214988..ee1ddc99 100644 --- a/docs/examples/jax_sl_example.md +++ b/docs/examples/jax_image_classification.md @@ -22,11 +22,11 @@ pass uses Jax to calculate the gradients, and the weights are updated by a PyTor !!! question "What about end-to-end training in Jax?" - See the [Jax RL Example](../examples/jax_rl_example.md)! :smile: + See the [Jax RL Example](../examples/jax_rl.md)! :smile: ### Jax Network -{{ inline('project.algorithms.jax_image_classifier.CNN') }} +{{ inline('project.algorithms.jax_image_classifier.JaxCNN') }} ### Jax Algorithm diff --git a/docs/examples/jax_rl_example.md b/docs/examples/jax_rl.md similarity index 100% rename from docs/examples/jax_rl_example.md rename to docs/examples/jax_rl.md diff --git a/docs/features/jax.md b/docs/features/jax.md index 37d55a81..41c67fd3 100644 --- a/docs/features/jax.md +++ b/docs/features/jax.md @@ -32,12 +32,12 @@ training loop as usual, you can! The [lightning.Trainer][lightning.pytorch.trainer.trainer.Trainer] will not be able to tell that you're using Jax! -**Take a look at [this image classification example that uses a Jax network](../examples/jax_sl_example.md).** +**Take a look at [this image classification example that uses a Jax network](../examples/jax_image_classification.md).** ## End-to-end training in Jax: the `JaxTrainer` -The `JaxTrainer`, used in the [Jax RL Example](../examples/jax_rl_example.md), follows a similar structure as the lightning Trainer. However, instead of training LightningModules, it trains `JaxModule`s, which are a simplified, jax-based look-alike of `lightning.LightningModule`s. +The `JaxTrainer`, used in the [Jax RL Example](../examples/jax_rl.md), follows a similar structure as the lightning Trainer. However, instead of training LightningModules, it trains `JaxModule`s, which are a simplified, jax-based look-alike of `lightning.LightningModule`s. The "algorithm" needs to match the `JaxModule` protocol: