Skip to content

Commit

Permalink
Fix other tiny issues in test code
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jun 13, 2024
1 parent 61e204f commit 9a220a1
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
7 changes: 5 additions & 2 deletions project/algorithms/bases/algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,11 @@ def _hydra_config(
All overrides should have already been applied.
"""
# todo: remove this hard-coded check somehow.
if "resnet" in network_name and datamodule_name in ["mnist", "fashion_mnist"]:
pytest.skip(reason="ResNet's can't be used on MNIST datasets.")

# todo: Get the name of the algorithm from the hydra config?
algorithm_name = self.algorithm_name
with setup_hydra_for_tests_and_compose(
all_overrides=[
Expand Down Expand Up @@ -388,8 +390,9 @@ def network(
f"type {type(network)}"
)
)
assert isinstance(network, nn.Module)
return network.to(device=device)
if isinstance(network, nn.Module):
network = network.to(device=device)
return network

@pytest.fixture(scope="class")
def hp(self, experiment_config: Config) -> Algorithm.HParams: # type: ignore
Expand Down
6 changes: 3 additions & 3 deletions project/datamodules/image_classification/imagenet32.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
from torchvision.datasets import VisionDataset
from torchvision.transforms import v2 as transforms

from project.datamodules.image_classification.base import ImageClassificationDataModule
from project.utils.types import C, H, StageStr, W

from ..vision.base import VisionDataModule
Expand All @@ -27,7 +27,7 @@ def imagenet32_normalization():
return transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))


class ImageNet32Dataset(ImageClassificationDataModule):
class ImageNet32Dataset(VisionDataset):
"""Downsampled ImageNet 32x32 Dataset."""

url: ClassVar[str] = "https://drive.google.com/uc?id=1XAlD_wshHhGNzaqy8ML-Jk0ZhAm8J5J_"
Expand Down
34 changes: 26 additions & 8 deletions project/utils/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, TypeVar

import hydra.errors
import hydra_zen
import pytest
import torch
import yaml
Expand Down Expand Up @@ -124,6 +125,17 @@ def _parametrized_fixture_method(request: pytest.FixtureRequest):
return _parametrized_fixture_method


def get_config_loader():
from hydra._internal.config_loader_impl import ConfigLoaderImpl
from hydra._internal.utils import create_automatic_config_search_path

search_path = create_automatic_config_search_path(
calling_file=None, calling_module=None, config_path="pkg://project.configs"
)
config_loader = ConfigLoaderImpl(config_search_path=search_path)
return config_loader


def get_all_configs_in_group(group_name: str) -> list[str]:
# note: here we're copying a bit of the internal code from Hydra so that we also get the
# configs that are just yaml files, in addition to the configs we added programmatically to the
Expand All @@ -135,14 +147,7 @@ def get_all_configs_in_group(group_name: str) -> list[str]:
# names.remove("base")
# return names

from hydra._internal.config_loader_impl import ConfigLoaderImpl
from hydra._internal.utils import create_automatic_config_search_path

search_path = create_automatic_config_search_path(
calling_file=None, calling_module=None, config_path="pkg://project.configs"
)
config_loader = ConfigLoaderImpl(config_search_path=search_path)
return config_loader.get_group_options(group_name)
return get_config_loader().get_group_options(group_name)


def get_all_algorithm_names() -> list[str]:
Expand All @@ -155,7 +160,20 @@ def get_type_for_config_name(config_group: str, config_name: str, _cs: ConfigSto
In the case of inner dataclasses (e.g. Model.HParams), this returns the outer class (Model).
"""

config_loader = get_config_loader()
_, caching_repo = config_loader._parse_overrides_and_create_caching_repo(
config_name=None, overrides=[]
)
config_result = caching_repo.load_config(f"{config_group}/{config_name}.yaml")
if config_result is not None:
try:
return hydra_zen.get_target(config_result.config) # type: ignore
except TypeError:
pass

config_node = _cs._load(f"{config_group}/{config_name}.yaml")

if "_target_" in config_node.node:
target: str = config_node.node["_target_"]
module_name, _, class_name = target.rpartition(".")
Expand Down

0 comments on commit 9a220a1

Please sign in to comment.