Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Nov 11, 2021
2 parents 3276d57 + a127b24 commit 82cd508
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 77 deletions.
68 changes: 41 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,21 @@ Please use the following citation if you use Disent in your own research:

## Architecture

The disent directory structure:

- `disent/dataset`: dataset wrappers, datasets & sampling strategies
+ `disent/dataset/data`: raw datasets
+ `disent/dataset/sampling`: sampling strategies for `DisentDataset`
- `disent/framework`: frameworks, including Auto-Encoders and VAEs
- `disent/metric`: metrics for evaluating disentanglement using ground truth datasets
- `disent/model`: common encoder and decoder models used for VAE research
- `disent/nn`: torch components for building models including layers, transforms, losses and general maths
- `disent/schedule`: annealing schedules that can be registered to a framework
- `disent/util`: helper classes, functions, callbacks, anything unrelated to a pytorch system/model/framework.
The disent module structure:

- `disent.dataset`: dataset wrappers, datasets & sampling strategies
+ `disent.dataset.data`: raw datasets
+ `disent.dataset.sampling`: sampling strategies for `DisentDataset` when multiple elements are required by frameworks, eg. for triplet loss
+ `disent.dataset.transform`: common data transforms and augmentations
+ `disent.dataset.wrapper`: wrapped datasets are no longer ground-truth datasets, these may have some elements masked out. We can still unwrap these classes to obtain the original datasets for benchmarking.
- `disent.frameworks`: frameworks, including Auto-Encoders and VAEs
+ `disent.frameworks.ae`: Auto-Encoder based frameworks
+ `disent.frameworks.vae`: Variational Auto-Encoder based frameworks
- `disent.metrics`: metrics for evaluating disentanglement using ground truth datasets
- `disent.model`: common encoder and decoder models used for VAE research
- `disent.nn`: torch components for building models including layers, transforms, losses and general maths
- `disent.schedule`: annealing schedules that can be registered to a framework
- `disent.util`: helper classes, functions, callbacks, anything unrelated to a pytorch system/model/framework.

**Please Note The API Is Still Unstable ⚠️**

Expand All @@ -113,6 +117,7 @@ Easily run experiments with hydra config, these files
are not available from `pip install`.

- `experiment/run.py`: entrypoint for running basic experiments with [hydra](https://github.com/facebookresearch/hydra) config
- `experiment/config/config.yaml`: main configuration file, this is probably what you want to edit!
- `experiment/config`: root folder for [hydra](https://github.com/facebookresearch/hydra) config files
- `experiment/util`: various helper code for experiments

Expand Down Expand Up @@ -182,6 +187,8 @@ low-memory disk-based access.

- **Ground Truth Synthetic**:
+ 🧵 XYObject: *A simplistic version of dSprites with a single square.*
+ 🧵 XYObjectShaded: *Exact same dataset as XYObject, but ground truth factors have a different representation*
+ 🧵 DSpritesImagenet: *Version of DSprite with foreground or background deterministically masked out with tiny-imagenet data*

<p align="center">
<img width="384" src="docs/img/xy-object-traversal.png" alt="XYObject Dataset Factor Traversals">
Expand Down Expand Up @@ -211,26 +218,27 @@ The currently implemented schedules include:
The following is a basic working example of disent that trains a BetaVAE with a cyclic
beta schedule and evaluates the trained model with various metrics.

<details><summary><b>Basic Example</b></summary>
<details><summary><b>💾 Basic Example</b></summary>
<p>

```python3
import os
import pytorch_lightning as pl
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader

from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import SingleSampler
from disent.dataset.transform import ToImgTensorF32
from disent.frameworks.vae import BetaVae
from disent.metrics import metric_dci, metric_mig
from disent.metrics import metric_dci
from disent.metrics import metric_mig
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.dataset.transform import ToImgTensorF32
from disent.model.ae import DecoderConv64
from disent.model.ae import EncoderConv64
from disent.schedule import CyclicSchedule


# create the dataset & dataloaders
# - ToImgTensorF32 transforms images from numpy arrays to tensors and performs checks
data = XYObjectData()
Expand All @@ -246,8 +254,10 @@ module = BetaVae(
decoder=DecoderConv64(x_shape=data.x_shape, z_size=10),
),
cfg=BetaVae.cfg(
optimizer='adam', optimizer_kwargs=dict(lr=1e-3),
loss_reduction='mean_sum', beta=4,
optimizer='adam',
optimizer_kwargs=dict(lr=1e-3),
loss_reduction='mean_sum',
beta=4,
)
)

Expand Down Expand Up @@ -301,26 +311,30 @@ a defaults list with entries corresponding to yaml configuration
files (config options) in the subfolders (config groups) in
`experiment/config/<config_group>/<option>.yaml`.

<details><summary><b>Config Defaults Example</b></summary>
<details><summary><b>💾 Config Defaults Example</b></summary>
<p>

```yaml
defaults:
# data
- sampling: default__bb
- dataset: xyobject
- augment: none
# system
- framework: adavae_os
- model: vae_conv64
# training
- optimizer: adam
- schedule: none
# data
- dataset: xyobject
- sampling: default__bb
- augment: none
# runtime
- schedule: beta_cyclic
- metrics: fast
- run_length: short
- run_location: local
# logs
- run_callbacks: vis
- run_logging: wandb
# runtime
- run_location: local
- run_launcher: local
- run_action: train

# <rest of config.yaml left out>
...
Expand Down
67 changes: 67 additions & 0 deletions docs/examples/readme_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader

from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import SingleSampler
from disent.dataset.transform import ToImgTensorF32
from disent.frameworks.vae import BetaVae
from disent.metrics import metric_dci
from disent.metrics import metric_mig
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64
from disent.model.ae import EncoderConv64
from disent.schedule import CyclicSchedule

# create the dataset & dataloaders
# - ToImgTensorF32 transforms images from numpy arrays to tensors and performs checks
data = XYObjectData()
dataset = DisentDataset(dataset=data, sampler=SingleSampler(), transform=ToImgTensorF32())
dataloader = DataLoader(dataset=dataset, batch_size=128, shuffle=True, num_workers=os.cpu_count())

# create the BetaVAE model
# - adjusting the beta, learning rate, and representation size.
module = BetaVae(
model=AutoEncoder(
# z_multiplier is needed to output mu & logvar when parameterising normal distribution
encoder=EncoderConv64(x_shape=data.x_shape, z_size=10, z_multiplier=2),
decoder=DecoderConv64(x_shape=data.x_shape, z_size=10),
),
cfg=BetaVae.cfg(
optimizer='adam',
optimizer_kwargs=dict(lr=1e-3),
loss_reduction='mean_sum',
beta=4,
)
)

# cyclic schedule for target 'beta' in the config/cfg. The initial value from the
# config is saved and multiplied by the ratio from the schedule on each step.
# - based on: https://arxiv.org/abs/1903.10145
module.register_schedule(
'beta', CyclicSchedule(
period=1024, # repeat every: trainer.global_step % period
)
)

# train model
# - for 2048 batches/steps
trainer = pl.Trainer(
max_steps=2048, gpus=1 if torch.cuda.is_available() else None, logger=False, checkpoint_callback=False
)
trainer.fit(module, dataloader)

# compute disentanglement metrics
# - we cannot guarantee which device the representation is on
# - this will take a while to run
get_repr = lambda x: module.encode(x.to(module.device))

metrics = {
**metric_dci(dataset, get_repr, num_train=1000, num_test=500, show_progress=True),
**metric_mig(dataset, get_repr, num_train=2000),
}

# evaluate
print('metrics:', metrics)
77 changes: 35 additions & 42 deletions experiment/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
from omegaconf import DictConfig
from omegaconf import ListConfig
from omegaconf import OmegaConf
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.loggers import LoggerCollection
from pytorch_lightning.loggers import WandbLogger

from disent import metrics
Expand All @@ -45,12 +43,12 @@
from disent.nn.weights import init_model_weights
from disent.util.seeds import seed
from disent.util.strings.fmt import make_box_str
from disent.util.strings import colors as c
from disent.util.lightning.callbacks import LoggerProgressCallback
from disent.util.lightning.callbacks import VaeMetricLoggingCallback
from disent.util.lightning.callbacks import VaeLatentCycleLoggingCallback
from disent.util.lightning.callbacks import VaeGtDistsLoggingCallback
from experiment.util.hydra_data import HydraDataModule
from experiment.util.hydra_utils import make_non_strict
from experiment.util.run_utils import log_error_and_exit
from experiment.util.run_utils import safe_unset_debug_logger
from experiment.util.run_utils import safe_unset_debug_trainer
Expand Down Expand Up @@ -244,11 +242,9 @@ def hydra_create_and_update_framework_config(cfg) -> DisentConfigurable.cfg:
# warn if some of the cfg variables were not overridden
missing_keys = sorted(set(framework_cfg.get_keys()) - (set(cfg.framework.cfg.keys())))
if missing_keys:
log.error(f'Framework {repr(cfg.framework.name)} is missing config keys for:')
log.warning(f'{c.RED}Framework {repr(cfg.framework.name)} is missing config keys for:{c.RST}')
for k in missing_keys:
log.error(f'{repr(k)}')
# update config params in case we missed variables in the cfg
cfg.framework.cfg.update(framework_cfg.to_dict())
log.warning(f'{c.RED}{repr(k)}{c.RST}')
# return config
return framework_cfg

Expand Down Expand Up @@ -281,37 +277,33 @@ def hydra_create_framework(framework_cfg: DisentConfigurable.cfg, datamodule, cf
# ========================================================================= #


def prepare_data(cfg: DictConfig, config_path: str = None):
def action_prepare_data(cfg: DictConfig):
# get the time the run started
time_string = datetime.today().strftime('%Y-%m-%d--%H-%M-%S')
log.info(f'Starting run at time: {time_string}')
raise NotImplementedError

# # allow the cfg to be edited
# cfg = make_non_strict(cfg)
# # deterministic seed
# seed(cfg.job.setdefault('seed', None))
# # print useful info
# log.info(f"Current working directory : {os.getcwd()}")
# log.info(f"Orig working directory : {hydra.utils.get_original_cwd()}")
# # hydra config does not support variables in defaults lists, we handle this manually
# cfg = merge_specializations(cfg, config_path=CONFIG_PATH if (config_path is None) else config_path, required=['_dataset_sampler_'])
# # check data preparation
# prepare_data_per_node = cfg.trainer.setdefault('prepare_data_per_node', True)
# hydra_check_datadir(prepare_data_per_node, cfg)
# # print the config
# log.info(f'Dataset Config Is:\n{make_box_str(OmegaConf.to_yaml({"dataset": cfg.dataset}))}')
# # prepare data
# datamodule = HydraDataModule(cfg)
# datamodule.prepare_data()


def train(cfg: DictConfig, config_path: str = None):
# deterministic seed
seed(cfg.settings.job.seed)
# print useful info
log.info(f"Current working directory : {os.getcwd()}")
log.info(f"Orig working directory : {hydra.utils.get_original_cwd()}")
# check data preparation
hydra_check_data_paths(cfg)
# print the config
log.info(f'Dataset Config Is:\n{make_box_str(OmegaConf.to_yaml({"dataset": cfg.dataset}))}')
# prepare data
datamodule = HydraDataModule(cfg)
datamodule.prepare_data()


def action_train(cfg: DictConfig):

# get the time the run started
time_string = datetime.today().strftime('%Y-%m-%d--%H-%M-%S')
log.info(f'Starting run at time: {time_string}')

# print initial config
log.info(f'Initial Config For Action: {cfg.action}\n\nCONFIG:{make_box_str(OmegaConf.to_yaml(cfg), char_v=":", char_h=".")}')

# -~-~-~-~-~-~-~-~-~-~-~-~- #

# cleanup from old runs:
Expand All @@ -324,9 +316,6 @@ def train(cfg: DictConfig, config_path: str = None):

# -~-~-~-~-~-~-~-~-~-~-~-~- #

# allow the cfg to be edited
cfg = make_non_strict(cfg)

# deterministic seed
seed(cfg.settings.job.seed)

Expand Down Expand Up @@ -404,11 +393,21 @@ def train(cfg: DictConfig, config_path: str = None):

# available actions
ACTIONS = {
'prepare_data': prepare_data,
'train': train,
'prepare_data': action_prepare_data,
'train': action_train,
}


def run_action(cfg: DictConfig):
action_key = cfg.action
# get the action
if action_key not in ACTIONS:
raise KeyError(f'The given action: {repr(action_key)} is invalid, must be one of: {sorted(ACTIONS.keys())}')
action = ACTIONS[action_key]
# run the action
action(cfg)


# ========================================================================= #
# MAIN #
# ========================================================================= #
Expand Down Expand Up @@ -436,13 +435,7 @@ def _error_resolver(msg: str):
@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
def hydra_main(cfg: DictConfig):
try:
action_key = cfg.action
# get the action
if action_key not in ACTIONS:
raise KeyError(f'The given action: {repr(action_key)} is invalid, must be one of: {sorted(ACTIONS.keys())}')
action = ACTIONS[action_key]
# run the action
action(cfg)
run_action(cfg)
except Exception as e:
log_error_and_exit(err_type='experiment error', err_msg=str(e), exc_info=True)
except:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
author="Nathan Juraj Michlo",
author_email="[email protected]",

version="0.3.0",
version="0.3.1",
python_requires=">=3.8", # we make use of standard library features only in 3.8
packages=setuptools.find_packages(),

Expand Down
20 changes: 13 additions & 7 deletions tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,28 @@
import os.path

import hydra
import pytest

import experiment.run as experiment_run
from tests.util import temp_sys_args


# ========================================================================= #
# TESTS #
# ========================================================================= #


def test_experiment_run():
# used by run() internally
experiment_run.CONFIG_PATH = os.path.join(os.path.dirname(experiment_run.__file__), 'config')

@pytest.mark.parametrize('args', [
['run_action=prepare_data'],
['run_action=train'],
])
def test_experiment_run(args):
os.environ['HYDRA_FULL_ERROR'] = '1'
with temp_sys_args([experiment_run.__file__]):
# why does this not work when config is absolute?
hydra_main = hydra.main(config_path='config', config_name='config_test')(experiment_run.train)

# TODO: why does this not work when config_path is absolute?
# ie. config_path=os.path.join(os.path.dirname(experiment_run.__file__), 'config')
with temp_sys_args([experiment_run.__file__, *args]):
hydra_main = hydra.main(config_path='config', config_name='config_test')(experiment_run.run_action)
hydra_main()


Expand Down

0 comments on commit 82cd508

Please sign in to comment.