Skip to content

Commit

Permalink
Update regression files for the datamodules tests
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jun 25, 2024
1 parent 314c471 commit db7d364
Show file tree
Hide file tree
Showing 16 changed files with 185 additions and 32 deletions.
10 changes: 8 additions & 2 deletions project/datamodules/datamodules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def test_first_batch(
):
# todo: skip this test if the dataset isn't already downloaded (for example on the GitHub CI).
datamodule.prepare_data()

if stage == RunningStage.TRAINING:
datamodule.setup("fit")
dataloader = datamodule.train_dataloader()
Expand Down Expand Up @@ -125,7 +124,14 @@ def test_first_batch(
# moving mnist, y isn't a label, it's another image.
axis.set_title(f"{index=}")

fig.suptitle(f"First batch of datamodule {type(datamodule).__name__}")
split = {
RunningStage.TRAINING: "training",
RunningStage.VALIDATING: "validation",
RunningStage.TESTING: "test",
RunningStage.PREDICTING: "prediction(?)",
}

fig.suptitle(f"First {split[stage]} batch of datamodule {type(datamodule).__name__}")
figure_path, _ = get_test_source_and_temp_file_paths(
extension=".png",
request=request,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
'0':
device: cpu
hash: 7631136576767235544
max: 1.0
mean: 0.468
min: 0.0
shape:
- 128
- 3
- 32
- 32
sum: 184156.109
'1':
device: cpu
hash: 8462625093735455128
max: 9
mean: 4.703
min: 0
shape:
- 128
sum: 602
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
'0':
device: cpu
hash: 4180642819611736479
max: 1.0
mean: 0.463
min: 0.0
shape:
- 128
- 3
- 32
- 32
sum: 181864.641
'1':
device: cpu
hash: -4539052997197868398
max: 9
mean: 4.258
min: 0
shape:
- 128
sum: 545
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
'0':
device: cpu
hash: -2751264324508784427
max: 1.0
mean: 0.292
min: 0.0
shape:
- 128
- 1
- 28
- 28
sum: 29317.309
'1':
device: cpu
hash: 6530176971009424370
max: 9
mean: 4.461
min: 0
shape:
- 128
sum: 571
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
'0':
device: cpu
hash: 225494219076660575
max: 1.0
mean: 0.296
min: 0.0
shape:
- 128
- 1
- 28
- 28
sum: 29740.449
'1':
device: cpu
hash: -4543745818595514203
max: 9
mean: 4.453
min: 0
shape:
- 128
sum: 570
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
'0':
device: cpu
hash: -5724309328014586573
max: 1.0
mean: 0.461
min: 0.0
shape:
- 64
- 3
- 32
- 32
sum: 90649.305
'1':
device: cpu
hash: 2830952008253455204
max: 987
mean: 543.234
min: 49
shape:
- 64
sum: 34767
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
'0':
device: cpu
hash: 4266338311425013668
max: 1.0
mean: 0.427
min: 0.0
shape:
- 64
- 3
- 32
- 32
sum: 83882.633
'1':
device: cpu
hash: 5813156328689991827
max: 973
mean: 484.469
min: 21
shape:
- 64
sum: 31006
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
'0':
device: cpu
hash: 8711678139956893479
max: 2.64
mean: -0.181
min: -2.118
shape:
- 64
- 3
- 224
- 224
sum: -1740804.5
'1':
device: cpu
hash: -3826088756534882585
max: 1
mean: 0.219
min: 0
shape:
- 64
sum: 14
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
'0':
device: cpu
hash: 976242047177418374
max: 2.64
mean: -0.118
min: -2.118
shape:
- 64
- 3
- 224
- 224
sum: -1139394.375
'1':
device: cpu
hash: -5258163774450544391
max: 0
mean: 0.0
min: 0
shape:
- 64
sum: 0
21 changes: 0 additions & 21 deletions project/datamodules/datamodules_test/test_first_batch/mnist.yaml

This file was deleted.

5 changes: 4 additions & 1 deletion project/datamodules/image_classification/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
drop_last=drop_last,
train_transforms=train_transforms or self.train_transform(),
val_transforms=val_transforms or self.val_transform(),
test_transforms=test_transforms,
test_transforms=test_transforms or self.test_transform(),
**kwargs,
)
self.dims = (C(3), H(self.image_size), W(self.image_size))
Expand Down Expand Up @@ -233,6 +233,9 @@ def val_transform(self) -> Callable:
]
)

# todo: what should be the default transformations for the test set? Same as validation, right?
test_transform = val_transform


def prepare_imagenet(
root: Path,
Expand Down
13 changes: 5 additions & 8 deletions project/datamodules/image_classification/imagenet32.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Callable, Sequence
from logging import getLogger
from pathlib import Path
from typing import ClassVar
from typing import ClassVar, Literal

import gdown
import numpy as np
Expand All @@ -19,7 +19,7 @@

from project.datamodules.vision import VisionDataModule
from project.utils.env_vars import DATA_DIR, SCRATCH
from project.utils.types import C, H, PhaseStr, W
from project.utils.types import C, H, W

logger = getLogger(__name__)

Expand Down Expand Up @@ -233,11 +233,8 @@ def prepare_data(self) -> None:
"""Saves files to data_dir."""
super().prepare_data()

def setup(self, stage: PhaseStr | None = None) -> None:
"""Creates train, val, and test dataset."""
if stage not in ["fit", "validate", "val", "test", None]:
raise ValueError(f"Invalid stage: {stage}")

def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None) -> None:
# """Creates train, val, and test dataset."""
if stage:
logger.debug(f"Setting up for stage {stage}")
else:
Expand Down Expand Up @@ -269,7 +266,7 @@ def setup(self, stage: PhaseStr | None = None) -> None:
self.dataset_train = self._split_dataset(base_dataset_train, train=True)
self.dataset_val = self._split_dataset(base_dataset_valid, train=False)

if stage in ["test", None]:
if stage in ["test", "predict", None]:
test_transforms = self.test_transforms or self.default_transforms()
self.dataset_test = self.dataset_cls(
self.data_dir, train=False, transform=test_transforms, **self.EXTRA_ARGS
Expand Down

0 comments on commit db7d364

Please sign in to comment.