diff --git a/project/configs/datamodule/__init__.py b/project/configs/datamodule/__init__.py index b8bd32a5..d9b68bc5 100644 --- a/project/configs/datamodule/__init__.py +++ b/project/configs/datamodule/__init__.py @@ -1,21 +1,9 @@ from logging import getLogger as get_logger -from pathlib import Path from hydra_zen import store -from project.utils.env_vars import NETWORK_DIR - logger = get_logger(__name__) -torchvision_dir: Path | None = None -"""Network directory with torchvision datasets.""" -if ( - NETWORK_DIR - and (_torchvision_dir := NETWORK_DIR / "datasets/torchvision").exists() - and _torchvision_dir.is_dir() -): - torchvision_dir = _torchvision_dir - # TODO: Make it possible to extend a structured base via yaml files as well as adding new fields # (for example, ImagetNet32DataModule has a new constructor argument which can't be set atm in the diff --git a/project/configs/datamodule/cifar10.yaml b/project/configs/datamodule/cifar10.yaml index e8d3fb78..0ca045ea 100644 --- a/project/configs/datamodule/cifar10.yaml +++ b/project/configs/datamodule/cifar10.yaml @@ -1,6 +1,7 @@ defaults: - vision _target_: project.datamodules.CIFAR10DataModule +data_dir: ${constant:torchvision_dir,DATA_DIR} batch_size: 128 train_transforms: _target_: project.datamodules.image_classification.cifar10.cifar10_train_transforms diff --git a/project/configs/datamodule/mnist.yaml b/project/configs/datamodule/mnist.yaml index c9a16639..a7554ec3 100644 --- a/project/configs/datamodule/mnist.yaml +++ b/project/configs/datamodule/mnist.yaml @@ -1,6 +1,7 @@ defaults: - vision _target_: project.datamodules.MNISTDataModule +data_dir: ${constant:torchvision_dir,DATA_DIR} normalize: True batch_size: 128 train_transforms: diff --git a/project/configs/experiment/example.yaml b/project/configs/experiment/example.yaml index 591dfbf7..e11f37c3 100644 --- a/project/configs/experiment/example.yaml +++ b/project/configs/experiment/example.yaml @@ -9,18 +9,14 @@ defaults: - override /network: resnet18 - override /trainer: default - override /trainer/callbacks: default + - override /trainer/logger: wandb # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters +name: example seed: ${oc.env:SLURM_PROCID,12345} -trainer: - min_epochs: 1 - max_epochs: 10 - gradient_clip_val: 0.5 - - algorithm: hp: optimizer: @@ -29,4 +25,21 @@ algorithm: datamodule: batch_size: 64 -name: example + +trainer: + min_epochs: 1 + max_epochs: 10 + gradient_clip_val: 0.5 + logger: + wandb: + project: "ResearchTemplate" + name: ${oc.env:SLURM_JOB_ID}_${oc.env:SLURM_PROCID} + save_dir: "${hydra:runtime.output_dir}" + offline: False # set True to store all logs only locally + id: ${oc.env:SLURM_JOB_ID}_${oc.env:SLURM_PROCID} # pass correct id to resume experiment! + # entity: "" # set to name of your wandb team + log_model: False + prefix: "" + job_type: "train" + group: ${oc.env:SLURM_JOB_ID} + tags: ["${name}"] diff --git a/project/utils/env_vars.py b/project/utils/env_vars.py index 1246b35f..2c8336eb 100644 --- a/project/utils/env_vars.py +++ b/project/utils/env_vars.py @@ -1,8 +1,13 @@ +import importlib import os +from logging import getLogger as get_logger from pathlib import Path import torch +logger = get_logger(__name__) + + SLURM_JOB_ID: int | None = ( int(os.environ["SLURM_JOB_ID"]) if "SLURM_JOB_ID" in os.environ else None ) @@ -69,9 +74,37 @@ """Local Directory where datasets should be extracted on this machine.""" -def get_constant(name: str): +torchvision_dir: Path | None = None +"""Network directory with torchvision datasets.""" +if ( + NETWORK_DIR + and (_torchvision_dir := NETWORK_DIR / "datasets/torchvision").exists() + and _torchvision_dir.is_dir() +): + torchvision_dir = _torchvision_dir + + +def get_constant(*names: str): """Resolver for Hydra to get the value of a constant in this file.""" - return globals()[name] + assert names + for name in names: + if name in globals(): + obj = globals()[name] + if obj is None: + logger.debug(f"Value of {name} is None, moving on to the next value.") + continue + return obj + parts = name.split(".") + obj = importlib.import_module(parts[0]) + for part in parts[1:]: + obj = getattr(obj, part) + if obj is not None: + return obj + logger.debug(f"Value of {name} is None, moving on to the next value.") + + if len(names) == 1: + raise RuntimeError(f"Could not find non-None value for name {names[0]}") + raise RuntimeError(f"Could not find non-None value for names {names}") NUM_WORKERS = int(