Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweak datamodule configs to use torchvision dir #37

Merged
merged 2 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions project/configs/datamodule/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,9 @@
from logging import getLogger as get_logger
from pathlib import Path

from hydra_zen import store

from project.utils.env_vars import NETWORK_DIR

logger = get_logger(__name__)

torchvision_dir: Path | None = None
"""Network directory with torchvision datasets."""
if (
NETWORK_DIR
and (_torchvision_dir := NETWORK_DIR / "datasets/torchvision").exists()
and _torchvision_dir.is_dir()
):
torchvision_dir = _torchvision_dir


# TODO: Make it possible to extend a structured base via yaml files as well as adding new fields
# (for example, ImagetNet32DataModule has a new constructor argument which can't be set atm in the
Expand Down
1 change: 1 addition & 0 deletions project/configs/datamodule/cifar10.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defaults:
- vision
_target_: project.datamodules.CIFAR10DataModule
data_dir: ${constant:torchvision_dir,DATA_DIR}
batch_size: 128
train_transforms:
_target_: project.datamodules.image_classification.cifar10.cifar10_train_transforms
1 change: 1 addition & 0 deletions project/configs/datamodule/mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defaults:
- vision
_target_: project.datamodules.MNISTDataModule
data_dir: ${constant:torchvision_dir,DATA_DIR}
normalize: True
batch_size: 128
train_transforms:
Expand Down
27 changes: 20 additions & 7 deletions project/configs/experiment/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,14 @@ defaults:
- override /network: resnet18
- override /trainer: default
- override /trainer/callbacks: default
- override /trainer/logger: wandb

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
name: example

seed: ${oc.env:SLURM_PROCID,12345}

trainer:
min_epochs: 1
max_epochs: 10
gradient_clip_val: 0.5


algorithm:
hp:
optimizer:
Expand All @@ -29,4 +25,21 @@ algorithm:
datamodule:
batch_size: 64

name: example

trainer:
min_epochs: 1
max_epochs: 10
gradient_clip_val: 0.5
logger:
wandb:
project: "ResearchTemplate"
name: ${oc.env:SLURM_JOB_ID}_${oc.env:SLURM_PROCID}
save_dir: "${hydra:runtime.output_dir}"
offline: False # set True to store all logs only locally
id: ${oc.env:SLURM_JOB_ID}_${oc.env:SLURM_PROCID} # pass correct id to resume experiment!
# entity: "" # set to name of your wandb team
log_model: False
prefix: ""
job_type: "train"
group: ${oc.env:SLURM_JOB_ID}
tags: ["${name}"]
37 changes: 35 additions & 2 deletions project/utils/env_vars.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import importlib
import os
from logging import getLogger as get_logger
from pathlib import Path

import torch

logger = get_logger(__name__)


SLURM_JOB_ID: int | None = (
int(os.environ["SLURM_JOB_ID"]) if "SLURM_JOB_ID" in os.environ else None
)
Expand Down Expand Up @@ -69,9 +74,37 @@
"""Local Directory where datasets should be extracted on this machine."""


def get_constant(name: str):
torchvision_dir: Path | None = None
"""Network directory with torchvision datasets."""
if (
NETWORK_DIR
and (_torchvision_dir := NETWORK_DIR / "datasets/torchvision").exists()
and _torchvision_dir.is_dir()
):
torchvision_dir = _torchvision_dir


def get_constant(*names: str):
"""Resolver for Hydra to get the value of a constant in this file."""
return globals()[name]
assert names
for name in names:
if name in globals():
obj = globals()[name]
if obj is None:
logger.debug(f"Value of {name} is None, moving on to the next value.")
continue
return obj
parts = name.split(".")
obj = importlib.import_module(parts[0])
for part in parts[1:]:
obj = getattr(obj, part)
if obj is not None:
return obj
logger.debug(f"Value of {name} is None, moving on to the next value.")

if len(names) == 1:
raise RuntimeError(f"Could not find non-None value for name {names[0]}")
raise RuntimeError(f"Could not find non-None value for names {names}")


NUM_WORKERS = int(
Expand Down
Loading