Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make all unit / integration tests pass [RT-72] #5

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions project/algorithms/bases/algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def n_updates(self, datamodule_name: str, network_name: str) -> int:
# improved.
return 5

@pytest.mark.xfail(
raises=NotImplementedError, reason="TODO: Implement this test.", strict=True
)
@pytest.mark.skip(reason="TODO: Implement this test.")
def test_loss_is_reproducible(
self,
algorithm: AlgorithmType,
Expand Down
3 changes: 0 additions & 3 deletions project/algorithms/example_algo_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import ClassVar

from project.algorithms.bases.image_classification_test import ImageClassificationAlgorithmTests

from .example_algo import ExampleAlgorithm
Expand All @@ -8,4 +6,3 @@
class TestExampleAlgorithm(ImageClassificationAlgorithmTests[ExampleAlgorithm]):
algorithm_type = ExampleAlgorithm
algorithm_name: str = "example_algo"
unsupported_datamodule_names: ClassVar[list[str]] = ["rl"]
25 changes: 11 additions & 14 deletions project/configs/datamodule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
)
from project.datamodules.image_classification.cifar10 import cifar10_train_transforms
from project.datamodules.image_classification.imagenet32 import imagenet32_train_transforms
from project.datamodules.image_classification.inaturalist import (
from project.datamodules.image_classification.mnist import mnist_train_transforms
from project.datamodules.vision.inaturalist import (
INaturalistDataModule,
TargetType,
Version,
)
from project.datamodules.image_classification.mnist import mnist_train_transforms

logger = get_logger(__name__)

FILE = Path(__file__)
REPO_ROOTDIR = FILE.parent
Expand All @@ -31,16 +33,16 @@
break
REPO_ROOTDIR = REPO_ROOTDIR.parent


SLURM_TMPDIR: Path | None = (
Path(os.environ["SLURM_TMPDIR"]) if "SLURM_TMPDIR" in os.environ else None
)
SLURM_JOB_ID: int | None = (
int(os.environ["SLURM_JOB_ID"]) if "SLURM_JOB_ID" in os.environ else None
)

logger = get_logger(__name__)

SLURM_TMPDIR: Path | None = (
Path(os.environ["SLURM_TMPDIR"]) if "SLURM_TMPDIR" in os.environ else None
)
if not SLURM_TMPDIR and SLURM_JOB_ID is not None:
# This can happens when running the integrated VSCode terminal with `mila code`!
if (_tmp := Path("/tmp")).exists():
SLURM_TMPDIR = _tmp

TORCHVISION_DIR: Path | None = None

Expand All @@ -49,11 +51,6 @@
TORCHVISION_DIR = _torchvision_dir


if not SLURM_TMPDIR and SLURM_JOB_ID is not None:
# This can happens when running the integrated VSCode terminal with `mila code`!
_slurm_tmpdir = Path(f"/Tmp/slurm.{SLURM_JOB_ID}.0")
if _slurm_tmpdir.exists():
SLURM_TMPDIR = _slurm_tmpdir
SCRATCH = Path(os.environ["SCRATCH"]) if "SCRATCH" in os.environ else None
DATA_DIR = Path(os.environ.get("DATA_DIR", (SLURM_TMPDIR or SCRATCH or REPO_ROOTDIR) / "data"))

Expand Down
3 changes: 2 additions & 1 deletion project/datamodules/datamodules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
TensorRegressionFixture,
get_test_source_and_temp_file_paths,
)
from project.utils.testutils import run_for_all_datamodules
from project.utils.testutils import run_for_all_datamodules, skip_test_on_github_cloud_CI
from project.utils.types import is_sequence_of

from ..utils.types.protocols import DataModule


@skip_test_on_github_cloud_CI
@pytest.mark.timeout(25, func_only=True)
@run_for_all_datamodules()
def test_first_batch(
Expand Down
33 changes: 15 additions & 18 deletions project/datamodules/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@
from typing_extensions import ParamSpec

from project.utils.types import C, H, StageStr, W

from ...utils.types.protocols import DataModule
from project.utils.types.protocols import DataModule

P = ParamSpec("P")

SLURM_TMPDIR: Path | None = (
Path(os.environ["SLURM_TMPDIR"]) if "SLURM_TMPDIR" in os.environ else None
)
logger = get_logger(__name__)


Expand Down Expand Up @@ -75,8 +71,9 @@ def __init__(
"""

super().__init__()
from project.configs.datamodule import DATA_DIR

self.data_dir = data_dir if data_dir is not None else os.getcwd()
self.data_dir = data_dir if data_dir is not None else DATA_DIR
self.val_split = val_split
if num_workers is None:
num_workers = num_cpus_on_node()
Expand Down Expand Up @@ -232,15 +229,10 @@ def train_dataloader(
return self._data_loader(
self.dataset_train,
_dataloader_fn=_dataloader_fn,
shuffle=self.shuffle,
*args,
**(
dict(
shuffle=self.shuffle,
generator=torch.Generator().manual_seed(self.train_dl_rng_seed),
)
| kwargs
),
persistent_workers=True,
**kwargs,
generator=torch.Generator().manual_seed(self.train_dl_rng_seed),
)

def val_dataloader(
Expand All @@ -255,8 +247,8 @@ def val_dataloader(
self.dataset_val,
_dataloader_fn=_dataloader_fn,
*args,
**(dict(generator=torch.Generator().manual_seed(self.val_dl_rng_seed)) | kwargs),
persistent_workers=True,
**kwargs,
generator=torch.Generator().manual_seed(self.val_dl_rng_seed),
)

def test_dataloader(
Expand All @@ -273,14 +265,16 @@ def test_dataloader(
self.dataset_test,
_dataloader_fn=_dataloader_fn,
*args,
**(dict(generator=torch.Generator().manual_seed(self.test_dl_rng_seed)) | kwargs),
persistent_workers=True,
**kwargs,
generator=torch.Generator().manual_seed(self.test_dl_rng_seed),
)

def _data_loader(
self,
dataset: Dataset,
_dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader,
generator: torch.Generator | None = None,
shuffle: bool | None = None,
*dataloader_args: P.args,
**dataloader_kwargs: P.kwargs,
) -> DataLoader:
Expand All @@ -291,6 +285,9 @@ def _data_loader(
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
| (dict(shuffle=shuffle) if shuffle is not None else {})
| (dict(generator=generator) if generator is not None else {})
| (dict(persistent_workers=True) if self.num_workers > 0 else {})
| dataloader_kwargs
)
return _dataloader_fn(dataset, *dataloader_args, **dataloader_kwargs)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import os
import warnings
from collections.abc import Callable
from logging import getLogger as get_logger
Expand All @@ -10,7 +9,7 @@
import torchvision.transforms as T
from torchvision.datasets import INaturalist

from project.datamodules.image_classification.base import ImageClassificationDataModule
from project.datamodules.vision.base import VisionDataModule
from project.utils.types import C, H, W

logger = get_logger(__name__)
Expand All @@ -25,31 +24,14 @@
Version = Version2017_2019 | Version2021


def get_slurm_tmpdir() -> Path:
if "SLURM_TMPDIR" in os.environ:
return Path(os.environ["SLURM_TMPDIR"])
if "SLURM_JOB_ID" not in os.environ:
raise RuntimeError(
"SLURM_JOBID environment variable isn't set. Are you running this from a SLURM "
"cluster?"
)
slurm_tmpdir = Path(f"/Tmp/slurm.{os.environ['SLURM_JOB_ID']}.0")
if not slurm_tmpdir.is_dir():
raise NotImplementedError(
f"TODO: You appear to be running this outside the Mila cluster, since SLURM_TMPDIR "
f"isn't located at {slurm_tmpdir}."
)
return slurm_tmpdir


def inat_dataset_dir() -> Path:
network_dir = Path("/network/datasets/inat")
if not network_dir.exists():
raise NotImplementedError("For now this assumes that we're running on the Mila cluster.")
return network_dir


class INaturalistDataModule(ImageClassificationDataModule):
class INaturalistDataModule(VisionDataModule):
name: ClassVar[str] = "inaturalist"
"""Dataset name."""

Expand Down Expand Up @@ -79,7 +61,10 @@ def __init__(
) -> None:
# assuming that we're on the Mila cluster atm.
self.network_dir = inat_dataset_dir()
slurm_tmpdir = get_slurm_tmpdir()
from project.configs.datamodule import SLURM_TMPDIR

slurm_tmpdir = SLURM_TMPDIR
assert slurm_tmpdir is not None
default_data_dir = slurm_tmpdir / "data"
if data_dir is None:
data_dir = default_data_dir
Expand Down
11 changes: 11 additions & 0 deletions project/utils/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dataclasses
import hashlib
import importlib
import os
from collections.abc import Mapping, Sequence
from contextlib import contextmanager
from logging import getLogger as get_logger
Expand Down Expand Up @@ -34,6 +35,16 @@
from project.utils.types.protocols import DataModule
from project.utils.utils import get_device

in_github_CI = os.environ.get("GITHUB_ACTIONS") == "true"
in_self_hosted_github_CI = in_github_CI and torch.cuda.is_available()

skip_test_on_github_CI = pytest.mark.skipif(in_github_CI, reason="Skipping test on GitHub CI.")
skip_test_on_github_cloud_CI = pytest.mark.skipif(
in_github_CI and not in_self_hosted_github_CI,
reason="Skipping test on GitHub cloud CI, but run on the self-hosted runner.",
)
needs_gpu = pytest.mark.skipif(not torch.cuda.is_available(), reason="Needs a GPU to run.")

SLOW_DATAMODULES = ["inaturalist", "imagenet32"]

default_marks_for_config_name: dict[str, list[pytest.MarkDecorator]] = {
Expand Down
Loading