Skip to content

Commit

Permalink
Rename the algorithm tests class (wip)
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jul 17, 2024
1 parent 280bbf3 commit 905d3a4
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 950 deletions.
24 changes: 14 additions & 10 deletions docs/generate_reference_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,28 @@


import textwrap
from logging import getLogger as get_logger
from pathlib import Path

import mkdocs_gen_files
import mkdocs_gen_files.nav

from project.utils.env_vars import REPO_ROOTDIR

logger = get_logger(__name__)

module = "project"
modules = [
"project/main.py",
"project/experiment.py",
]
submodules = [
"project.algorithms",
"project.configs",
"project.datamodules",
"project.networks",
"project.utils",
]
# submodules = [
# "project.algorithms",
# "project.configs",
# "project.datamodules",
# "project.networks",
# "project.utils",
# ]


def _get_import_path(module_path: Path) -> str:
Expand All @@ -42,7 +45,7 @@ def main():
def add_doc_for_module(module_path: Path, nav: mkdocs_gen_files.nav.Nav) -> None:
"""TODO."""

assert module_path.is_dir() and (module_path / "__init__.py").exists(), module_path
assert module_path.is_dir() # and (module_path / "__init__.py").exists(), module_path

children = list(
p
Expand All @@ -52,7 +55,7 @@ def add_doc_for_module(module_path: Path, nav: mkdocs_gen_files.nav.Nav) -> None
for child_module_path in children:
child_module_import_path = _get_import_path(child_module_path)
doc_file = child_module_path.relative_to(REPO_ROOTDIR).with_suffix(".md")
write_doc_file = f"reference/{doc_file}"
write_doc_file = "reference" / doc_file

nav[tuple(child_module_import_path.split("."))] = f"{doc_file}"

Expand All @@ -71,11 +74,12 @@ def add_doc_for_module(module_path: Path, nav: mkdocs_gen_files.nav.Nav) -> None
p
for p in module_path.iterdir()
if p.is_dir()
and (p / "__init__.py").exists()
and ((p / "__init__.py").exists() or len(list(p.glob("*.py"))) > 0)
and not p.name.endswith("_test")
and not p.name.startswith((".", "__"))
)
for submodule in submodules:
logger.info(f"Creating doc for {submodule}")
add_doc_for_module(submodule, nav)


Expand Down
176 changes: 0 additions & 176 deletions project/algorithms/testsuites/algo_test.py

This file was deleted.

21 changes: 5 additions & 16 deletions project/algorithms/testsuites/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from typing import Literal, NotRequired, Protocol, TypedDict

import torch
from lightning import Callback, LightningModule, Trainer
from lightning import Callback, Trainer
from torch import Tensor
from typing_extensions import TypeVar

from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
from project.utils.types import PyTree
from project.utils.types.protocols import DataModule, Module

Expand Down Expand Up @@ -39,9 +36,6 @@ class Algorithm(Module, Protocol[BatchType, StepOutputType]):
datamodule: DataModule[BatchType]
network: Module

example_input_array = LightningModule.example_input_array
_device: torch.device | None = None

def __init__(
self,
*,
Expand All @@ -51,18 +45,13 @@ def __init__(
super().__init__()
self.datamodule = datamodule
self.network = network
# fix for `self.device` property which defaults to cpu.
self._device = None

if isinstance(datamodule, ImageClassificationDataModule):
self.example_input_array = torch.zeros(
(datamodule.batch_size, *datamodule.dims), device=self.device
)

self.trainer: Trainer

def training_step(self, batch: BatchType, batch_index: int) -> StepOutputType:
"""Performs a training step."""
"""Performs a training step.
See `LightningModule.training_step` for more information.
"""
return self.shared_step(batch=batch, batch_index=batch_index, phase="train")

def validation_step(self, batch: BatchType, batch_index: int) -> StepOutputType:
Expand Down
Loading

0 comments on commit 905d3a4

Please sign in to comment.