From 4346dd2989cf447aa808c8c8bb8151ccbb6c7cb4 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 4 Dec 2024 08:52:10 -0500 Subject: [PATCH] Minor changes to docstrings for docs Signed-off-by: Fabrice Normandin --- docs/macros.py | 14 ++++++++++++++ project/datamodules/__init__.py | 2 +- .../datamodules/image_classification/imagenet.py | 2 +- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/docs/macros.py b/docs/macros.py index 23368429..ae712eb7 100644 --- a/docs/macros.py +++ b/docs/macros.py @@ -8,9 +8,23 @@ from pathlib import Path from typing import Any +import torch + if typing.TYPE_CHECKING: from mkdocs_macros.plugin import MacrosPlugin +import lightning +from mkdocs_autoref_plugin.autoref_plugin import default_reference_sources + +default_reference_sources.extend( + [ + lightning.Trainer, + lightning.LightningModule, + lightning.LightningDataModule, + torch.nn.Module, + ] +) + logger = logging.getLogger(__name__) diff --git a/project/datamodules/__init__.py b/project/datamodules/__init__.py index 65bb8580..bc1b12d3 100644 --- a/project/datamodules/__init__.py +++ b/project/datamodules/__init__.py @@ -1,6 +1,6 @@ """Datamodules (datasets + preprocessing + dataloading) -See the :ref:`lightning.LightningDataModule` class for more information. +See the `lightning.LightningDataModule` class for more information. """ from .image_classification import ImageClassificationDataModule diff --git a/project/datamodules/image_classification/imagenet.py b/project/datamodules/image_classification/imagenet.py index 9c774262..769dc9c3 100644 --- a/project/datamodules/image_classification/imagenet.py +++ b/project/datamodules/image_classification/imagenet.py @@ -46,10 +46,10 @@ class ImageNetDataModule(ImageClassificationDataModule): - Made this a subclass of VisionDataModule Notes: + - train_dataloader uses the train split of imagenet2012 and puts away a portion of it for the validation split. - val_dataloader uses the part of the train split of imagenet2012 that was not used for training via `num_imgs_per_val_class` - - TODO: needs to pass split='val' to UnlabeledImagenet. - test_dataloader uses the validation split of imagenet2012 for testing. - TODO: need to pass num_imgs_per_class=-1 for test dataset and split="test". """