From 9a220a11ccf797855808eeeb7a371dcbbb091e94 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 13 Jun 2024 20:38:16 +0000 Subject: [PATCH] Fix other tiny issues in test code Signed-off-by: Fabrice Normandin --- project/algorithms/bases/algorithm_test.py | 7 ++-- .../image_classification/imagenet32.py | 6 ++-- project/utils/testutils.py | 34 ++++++++++++++----- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/project/algorithms/bases/algorithm_test.py b/project/algorithms/bases/algorithm_test.py index 87c8d06d..5df5ee95 100644 --- a/project/algorithms/bases/algorithm_test.py +++ b/project/algorithms/bases/algorithm_test.py @@ -332,9 +332,11 @@ def _hydra_config( All overrides should have already been applied. """ + # todo: remove this hard-coded check somehow. if "resnet" in network_name and datamodule_name in ["mnist", "fashion_mnist"]: pytest.skip(reason="ResNet's can't be used on MNIST datasets.") + # todo: Get the name of the algorithm from the hydra config? algorithm_name = self.algorithm_name with setup_hydra_for_tests_and_compose( all_overrides=[ @@ -388,8 +390,9 @@ def network( f"type {type(network)}" ) ) - assert isinstance(network, nn.Module) - return network.to(device=device) + if isinstance(network, nn.Module): + network = network.to(device=device) + return network @pytest.fixture(scope="class") def hp(self, experiment_config: Config) -> Algorithm.HParams: # type: ignore diff --git a/project/datamodules/image_classification/imagenet32.py b/project/datamodules/image_classification/imagenet32.py index 8143dd01..825ba493 100644 --- a/project/datamodules/image_classification/imagenet32.py +++ b/project/datamodules/image_classification/imagenet32.py @@ -13,9 +13,9 @@ import numpy as np from PIL import Image from torch.utils.data import DataLoader, Dataset, Subset -from torchvision import transforms +from torchvision.datasets import VisionDataset +from torchvision.transforms import v2 as transforms -from project.datamodules.image_classification.base import ImageClassificationDataModule from project.utils.types import C, H, StageStr, W from ..vision.base import VisionDataModule @@ -27,7 +27,7 @@ def imagenet32_normalization(): return transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) -class ImageNet32Dataset(ImageClassificationDataModule): +class ImageNet32Dataset(VisionDataset): """Downsampled ImageNet 32x32 Dataset.""" url: ClassVar[str] = "https://drive.google.com/uc?id=1XAlD_wshHhGNzaqy8ML-Jk0ZhAm8J5J_" diff --git a/project/utils/testutils.py b/project/utils/testutils.py index e4868aa7..bfe9a203 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -13,6 +13,7 @@ from typing import Any, TypeVar import hydra.errors +import hydra_zen import pytest import torch import yaml @@ -124,6 +125,17 @@ def _parametrized_fixture_method(request: pytest.FixtureRequest): return _parametrized_fixture_method +def get_config_loader(): + from hydra._internal.config_loader_impl import ConfigLoaderImpl + from hydra._internal.utils import create_automatic_config_search_path + + search_path = create_automatic_config_search_path( + calling_file=None, calling_module=None, config_path="pkg://project.configs" + ) + config_loader = ConfigLoaderImpl(config_search_path=search_path) + return config_loader + + def get_all_configs_in_group(group_name: str) -> list[str]: # note: here we're copying a bit of the internal code from Hydra so that we also get the # configs that are just yaml files, in addition to the configs we added programmatically to the @@ -135,14 +147,7 @@ def get_all_configs_in_group(group_name: str) -> list[str]: # names.remove("base") # return names - from hydra._internal.config_loader_impl import ConfigLoaderImpl - from hydra._internal.utils import create_automatic_config_search_path - - search_path = create_automatic_config_search_path( - calling_file=None, calling_module=None, config_path="pkg://project.configs" - ) - config_loader = ConfigLoaderImpl(config_search_path=search_path) - return config_loader.get_group_options(group_name) + return get_config_loader().get_group_options(group_name) def get_all_algorithm_names() -> list[str]: @@ -155,7 +160,20 @@ def get_type_for_config_name(config_group: str, config_name: str, _cs: ConfigSto In the case of inner dataclasses (e.g. Model.HParams), this returns the outer class (Model). """ + + config_loader = get_config_loader() + _, caching_repo = config_loader._parse_overrides_and_create_caching_repo( + config_name=None, overrides=[] + ) + config_result = caching_repo.load_config(f"{config_group}/{config_name}.yaml") + if config_result is not None: + try: + return hydra_zen.get_target(config_result.config) # type: ignore + except TypeError: + pass + config_node = _cs._load(f"{config_group}/{config_name}.yaml") + if "_target_" in config_node.node: target: str = config_node.node["_target_"] module_name, _, class_name = target.rpartition(".")