Skip to content

Commit

Permalink
Fix issue in cifar10, add note about protocol
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Nov 19, 2024
1 parent 965dfef commit bc3b1d2
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 19 deletions.
3 changes: 1 addition & 2 deletions project/datamodules/image_classification/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
from project.datamodules.vision import VisionDataModule
from project.utils.typing_utils import C, H, W


Expand Down Expand Up @@ -40,7 +39,7 @@ def cifar10_unnormalization(x: torch.Tensor) -> torch.Tensor:
return (x * std) + mean


class CIFAR10DataModule(ImageClassificationDataModule, VisionDataModule):
class CIFAR10DataModule(ImageClassificationDataModule):
"""
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/
Plot-of-a-Subset-of-Images-from-the-CIFAR-10-Dataset.png
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
ImageBatchType = TypeVar("ImageBatchType", bound=tuple[Image, Tensor])


# todo: this should probably be a protocol. The only issue with that is that we do `issubclass` in
# tests to determine which datamodule configs are for image classification, so we can't do that
# with a Protocol.


class ImageClassificationDataModule(
VisionDataModule[ImageBatchType], ClassificationDataModule[ImageBatchType]
):
"""Lightning data modules for image classification."""

# This just adds the `num_classes` property to `VisionDataModule`.

num_classes: int
"""Number of classes in the dataset."""

Expand Down
6 changes: 4 additions & 2 deletions project/datamodules/image_classification/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from torchvision.models.resnet import ResNet152_Weights
from torchvision.transforms import v2 as transforms

from project.datamodules.vision import VisionDataModule
from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
from project.utils.env_vars import DATA_DIR, NETWORK_DIR, NUM_WORKERS
from project.utils.typing_utils import C, H, W

Expand All @@ -36,7 +38,7 @@ def imagenet_normalization():
ImageIndex = NewType("ImageIndex", int)


class ImageNetDataModule(VisionDataModule):
class ImageNetDataModule(ImageClassificationDataModule):
"""ImageNet datamodule.
Extracted from https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/datamodules/imagenet_datamodule.py
Expand Down
6 changes: 4 additions & 2 deletions project/datamodules/image_classification/imagenet32.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from torchvision.datasets import VisionDataset
from torchvision.transforms import v2 as transforms

from project.datamodules.vision import VisionDataModule
from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
from project.utils.env_vars import DATA_DIR, SCRATCH
from project.utils.typing_utils import C, H, W

Expand Down Expand Up @@ -167,7 +169,7 @@ def _load_dataset(self):
self._data_loaded = True


class ImageNet32DataModule(VisionDataModule):
class ImageNet32DataModule(ImageClassificationDataModule):
"""TODO: Add a `val_split` argument, that supports a value of `0`."""

name: ClassVar[str] = "imagenet32"
Expand Down
6 changes: 2 additions & 4 deletions project/datamodules/image_classification/inaturalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
import torchvision.transforms as T
from torchvision.datasets import INaturalist

from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
from project.datamodules.vision import VisionDataModule
from project.utils.env_vars import DATA_DIR, NUM_WORKERS, SLURM_TMPDIR
from project.utils.typing_utils import C, H, W

Expand All @@ -34,7 +32,7 @@ def inat_dataset_dir() -> Path:
return network_dir


class INaturalistDataModule(ImageClassificationDataModule):
class INaturalistDataModule(VisionDataModule):
name: ClassVar[str] = "inaturalist"
"""Dataset name."""

Expand Down
26 changes: 17 additions & 9 deletions project/utils/hydra_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,23 @@ def get_all_configs_in_group_of_type(
)
}

return [
name
for name, object_type in names_to_types.items()
if (
issubclass(object_type, config_target_type)
if include_subclasses
else object_type in config_target_type
)
]
def _matches_protocol(object: type, protocol: type) -> bool:
return isinstance(object, protocol) # todo: weird!

compatible_config_names = []
for name, object_type in names_to_types.items():
if not include_subclasses:
if object_type in config_target_type:
compatible_config_names.append(name)
continue
for t in config_target_type:
if (
issubclass(t, typing.Protocol) and _matches_protocol(object_type, t)
) or issubclass(object_type, t):
compatible_config_names.append(name)
break

return compatible_config_names


def get_all_configs_in_group_with_target(group_name: str, some_type: type) -> list[str]:
Expand Down

0 comments on commit bc3b1d2

Please sign in to comment.