Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Oct 4, 2021
2 parents 94cf05f + 68db787 commit 803ac3a
Show file tree
Hide file tree
Showing 47 changed files with 532 additions and 143 deletions.
9 changes: 5 additions & 4 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ name: test

on:
push:
branches: [ main, dev ]
branches: [ "main", "dev", "dev*", "feature*"]
tags: [ '*' ]
pull_request:
branches: [ main, dev ]
branches: [ "main", "dev", "dev*", "feature*"]

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest] # [ubuntu-latest, windows-latest, macos-latest]
python-version: [3.8]
python-version: ["3.8", "3.9"]

steps:
- uses: actions/checkout@v2
Expand All @@ -31,6 +31,7 @@ jobs:
python3 -m pip install --upgrade pip
python3 -m pip install -r requirements.txt
python3 -m pip install -r requirements-test.txt
python3 -m pip install -r requirements-exp.txt
- name: Test with pytest
run: |
Expand All @@ -39,6 +40,6 @@ jobs:
- uses: codecov/codecov-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true
fail_ci_if_error: false
# codecov automatically merges all generated files
# if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,15 @@ dataloader = DataLoader(dataset=dataset, batch_size=128, shuffle=True, num_worke
# create the BetaVAE model
# - adjusting the beta, learning rate, and representation size.
module = BetaVae(
make_optimizer_fn=lambda params: Adam(params, lr=1e-4),
make_model_fn=lambda: AutoEncoder(
model=AutoEncoder(
# z_multiplier is needed to output mu & logvar when parameterising normal distribution
encoder=EncoderConv64(x_shape=data.x_shape, z_size=10, z_multiplier=2),
decoder=DecoderConv64(x_shape=data.x_shape, z_size=10),
),
cfg=BetaVae.cfg(loss_reduction='mean_sum', beta=4)
cfg=BetaVae.cfg(
optimizer='adam', optimizer_kwargs=dict(lr=1e-3),
loss_reduction='mean_sum', beta=4,
)
)

# cyclic schedule for target 'beta' in the config/cfg. The initial value from the
Expand Down
2 changes: 1 addition & 1 deletion disent/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

# wrapper
from disent.dataset._wrapper import DisentDataset
from disent.dataset._base import DisentDataset
43 changes: 28 additions & 15 deletions disent/dataset/_wrapper.py → disent/dataset/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,21 @@ def wrapper(self: 'DisentDataset', *args, **kwargs):

class DisentDataset(Dataset, LengthIter):

def __init__(self, dataset: Union[Dataset, GroundTruthData], sampler: Optional[BaseDisentSampler] = None, transform=None, augment=None):
def __init__(
self,
dataset: Union[Dataset, GroundTruthData],
sampler: Optional[BaseDisentSampler] = None,
transform=None,
augment=None,
return_indices: bool = False,
):
super().__init__()
# save attributes
self._dataset = dataset
self._sampler = SingleSampler() if (sampler is None) else sampler
self._transform = transform
self._augment = augment
self._return_indices = return_indices
# initialize sampler
if not self._sampler.is_init:
self._sampler.init(dataset)
Expand Down Expand Up @@ -112,7 +120,7 @@ def __getitem__(self, idx):
else:
idxs = (idx,)
# get the observations
return self.dataset_get_observation(*idxs)
return self._dataset_get_observation(*idxs)

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Single Datapoints #
Expand Down Expand Up @@ -177,19 +185,18 @@ def dataset_get(self, idx, mode: str):
# Multiple Datapoints #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

def dataset_get_observation(self, *idxs):
def _dataset_get_observation(self, *idxs):
xs, xs_targ = zip(*(self.dataset_get(idx, mode='pair') for idx in idxs))
# handle cases
if self._augment is None:
# makes 5-10% faster
return {
'x_targ': xs_targ,
}
else:
return {
'x': xs,
'x_targ': xs_targ,
}
obs = {'x_targ': xs_targ}
# 5-10% faster
if self._augment is not None:
obs['x'] = xs
# add indices
if self._return_indices:
obs['idx'] = idxs
# done!
return obs

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Batches #
Expand All @@ -199,7 +206,7 @@ def dataset_batch_from_indices(self, indices: Sequence[int], mode: str):
"""Get a batch of observations X from a batch of factors Y."""
return default_collate([self.dataset_get(idx, mode=mode) for idx in indices])

def dataset_sample_batch(self, num_samples: int, mode: str, replace: bool = False):
def dataset_sample_batch(self, num_samples: int, mode: str, replace: bool = False, return_indices: bool = False):
"""Sample a batch of observations X."""
# create seeded pseudo random number generator
# - built in np.random.choice cannot handle large values: https://github.com/numpy/numpy/issues/5299#issuecomment-497915672
Expand All @@ -208,7 +215,13 @@ def dataset_sample_batch(self, num_samples: int, mode: str, replace: bool = Fals
g = np.random.Generator(np.random.PCG64(seed=np.random.randint(0, 2**32)))
# sample indices
indices = g.choice(len(self), num_samples, replace=replace)
return self.dataset_batch_from_indices(indices, mode=mode)
# return batch
batch = self.dataset_batch_from_indices(indices, mode=mode)
# return values
if return_indices:
return batch, default_collate(indices)
else:
return batch

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Batches -- Ground Truth Only #
Expand Down
19 changes: 19 additions & 0 deletions disent/dataset/data/_groundtruth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
import logging
import os
from abc import ABCMeta
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

import numpy as np
from torch.utils.data import Dataset
Expand Down Expand Up @@ -100,6 +103,22 @@ def __getitem__(self, idx):
def _get_observation(self, idx):
raise NotImplementedError

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# EXTRAS #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

def sample_random_obs_traversal(self, f_idx: int = None, base_factors=None, num: int = None, mode='interval', obs_collect_fn=None) -> Tuple[np.ndarray, np.ndarray, Union[List[Any], Any]]:
"""
Same API as sample_random_factor_traversal, but also
returns the corresponding indices and uncollated list of observations
"""
factors = self.sample_random_factor_traversal(f_idx=f_idx, base_factors=base_factors, num=num, mode=mode)
indices = self.pos_to_idx(factors)
obs = [self[i] for i in indices]
if obs_collect_fn is not None:
obs = obs_collect_fn(obs)
return factors, indices, obs


# ========================================================================= #
# Basic Array Ground Truth Dataset #
Expand Down
2 changes: 1 addition & 1 deletion disent/dataset/data/_groundtruth__norb.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False, is_te
self._data, _ = read_norb_dataset(dat_path=dat_path, cat_path=cat_path, info_path=info_path)

def _get_observation(self, idx):
return self._data[idx]
return self._data[idx][:, :, None] # data is missing channel dim

@property
def datafiles(self) -> Sequence[DataFileHashedDl]:
Expand Down
68 changes: 58 additions & 10 deletions disent/frameworks/_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

import logging
from dataclasses import asdict
from dataclasses import dataclass
from dataclasses import fields
Expand All @@ -31,7 +30,9 @@
from typing import Any
from typing import Dict
from typing import final
from typing import Optional
from typing import Tuple
from typing import Type
from typing import Union

import logging
Expand Down Expand Up @@ -71,6 +72,26 @@ def __init__(self, cfg: cfg = cfg()):
self.cfg = cfg


# ========================================================================= #
# optimizers #
# ========================================================================= #


def _get_optimizer_list() -> Dict[str, Type[torch.optim.Optimizer]]:
# generate list of optimizers from torch
# - optimizer names are lowercase, eg. adam & rmsprop
optimizers = {}
for k in dir(torch.optim):
optim = getattr(torch.optim, k)
if isinstance(optim, type) and issubclass(optim, torch.optim.Optimizer) and (optim != torch.optim.Optimizer):
optimizers[k.lower()] = optim
return optimizers


# list of optimizers
_OPTIMIZERS = _get_optimizer_list()


# ========================================================================= #
# framework #
# ========================================================================= #
Expand All @@ -80,24 +101,51 @@ class DisentFramework(DisentConfigurable, DisentLightningModule):

@dataclass
class cfg(DisentConfigurable.cfg):
pass
# optimizer config
optimizer: Union[str, Type[torch.optim.Optimizer]] = 'adam'
optimizer_kwargs: Optional[Dict[str, Union[str, float, int]]] = None

def __init__(self, make_optimizer_fn, batch_augment=None, cfg: cfg = None):
def __init__(
self,
cfg: cfg = None,
# apply the batch augmentations on the GPU instead
batch_augment: callable = None,
):
# save the config values to the class
super().__init__(cfg=cfg)
# optimiser
assert callable(make_optimizer_fn)
self._make_optimiser_fn = make_optimizer_fn
# batch augmentations: not implemented as dataset transforms because we want to apply these on the GPU
assert (batch_augment is None) or callable(batch_augment)
# get the optimizer
if isinstance(self.cfg.optimizer, str):
if self.cfg.optimizer not in _OPTIMIZERS:
raise KeyError(f'invalid optimizer: {repr(self.cfg.optimizer)}, valid optimizers are: {sorted(_OPTIMIZERS.keys())}, otherwise pass a torch.optim.Optimizer class instead.')
self.cfg.optimizer = _OPTIMIZERS[self.cfg.optimizer]
# check the optimizer values
assert isinstance(self.cfg.optimizer, type) and issubclass(self.cfg.optimizer, torch.optim.Optimizer) and (self.cfg.optimizer != torch.optim.Optimizer)
assert isinstance(self.cfg.optimizer_kwargs, dict) or (self.cfg.optimizer_kwargs is None), f'invalid optimizer_kwargs type, got: {type(self.cfg.optimizer_kwargs)}'
# set default values for optimizer
if self.cfg.optimizer_kwargs is None:
self.cfg.optimizer_kwargs = dict()
if 'lr' not in self.cfg.optimizer_kwargs:
self.cfg.optimizer_kwargs['lr'] = 1e-3
log.info('lr not specified in `optimizer_kwargs`, setting to default value of `1e-3`')
# batch augmentations may not be implemented as dataset
# transforms so we can apply these on the GPU instead
assert callable(batch_augment) or (batch_augment is None)
self._batch_augment = batch_augment
# schedules
# - maybe add support for schedules in the config?
self._registered_schedules = set()
self._active_schedules: Dict[str, Tuple[Any, Schedule]] = {}

@final
def configure_optimizers(self):
# return optimizers
return self._make_optimiser_fn(self.parameters())
optimizer = self.cfg.optimizer
# instantiate the optimizer!
if issubclass(optimizer, torch.optim.Optimizer):
optimizer = optimizer(self.parameters(), **self.cfg.optimizer_kwargs)
elif not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError(f'unsupported optimizer type: {type(optimizer)}')
# return the optimizer
return optimizer

@final
def training_step(self, batch, batch_idx):
Expand Down
7 changes: 3 additions & 4 deletions disent/frameworks/ae/_unsupervised__ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,10 @@ class cfg(DisentFramework.cfg):
disable_rec_loss: bool = False
disable_aug_loss: bool = False

def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None):
super().__init__(make_optimizer_fn, batch_augment=batch_augment, cfg=cfg)
def __init__(self, model: AutoEncoder, cfg: cfg = None, batch_augment=None):
super().__init__(cfg=cfg, batch_augment=batch_augment)
# vae model
assert callable(make_model_fn)
self._model: AutoEncoder = make_model_fn() # TODO: move into property
self._model = model
# check the model
assert isinstance(self._model, AutoEncoder)
assert self._model.z_multiplier == self.REQUIRED_Z_MULTIPLIER, f'model z_multiplier is {repr(self._model.z_multiplier)} but {self.__class__.__name__} requires that it is: {repr(self.REQUIRED_Z_MULTIPLIER)}'
Expand Down
2 changes: 1 addition & 1 deletion disent/frameworks/helper/reconstructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import torch
import torch.nn.functional as F
from deprecated import deprecated
from disent.util.deprecate import deprecated

from disent.frameworks.helper.util import compute_ave_loss
from disent.nn.modules import DisentModule
Expand Down
17 changes: 14 additions & 3 deletions disent/frameworks/vae/_unsupervised__betavae.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@


class BetaVae(Vae):
"""
beta-VAE: https://arxiv.org/abs/1312.6114
"""

REQUIRED_OBS = 1

Expand All @@ -55,20 +58,28 @@ class cfg(Vae.cfg):
# loss = mean_recon_loss + beta * mean_kl_loss
# -- for loss_reduction='mean_sum' we usually have:
# loss = (H*W*C) * mean_recon_loss + beta * (z_size) * mean_kl_loss
# So when switching from one mode to the other, we need to scale beta to preserve these loss ratios.
#
# So when switching from one mode to the other, we need to scale beta to
# preserve these loss ratios:
# -- 'mean_sum' to 'mean':
# beta <- beta * (z_size) / (H*W*C)
# -- 'mean' to 'mean_sum':
# beta <- beta * (H*W*C) / (z_size)
#
# We obtain an equivalent beta for 'mean_sum' to 'mean':
# -- given values: beta=4 for 'mean_sum', with (H*W*C)=(64*64*3) and z_size=9
# beta = beta * ((z_size) / (H*W*C))
# ~= 4 * 0.0007324
# ~= 0,003
#
# This is similar to appendix A.6: `INTERPRETING NORMALISED β` of the beta-Vae paper:
# - Published as a conference paper at ICLR 2017 (22 pages)
# - https://openreview.net/forum?id=Sy2fzU9gl
#
beta: float = 0.003 # approximately equal to mean_sum beta of 4

def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None):
super().__init__(make_optimizer_fn, make_model_fn, batch_augment=batch_augment, cfg=cfg)
def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None):
super().__init__(model=model, cfg=cfg, batch_augment=batch_augment)
assert self.cfg.beta >= 0, 'beta must be >= 0'

# --------------------------------------------------------------------- #
Expand Down
4 changes: 2 additions & 2 deletions disent/frameworks/vae/_unsupervised__dfcvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class cfg(BetaVae.cfg):
feature_layers: Optional[List[Union[str, int]]] = None
feature_inputs_mode: str = 'none'

def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None):
super().__init__(make_optimizer_fn, make_model_fn, batch_augment=batch_augment, cfg=cfg)
def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None):
super().__init__(model=model, cfg=cfg, batch_augment=batch_augment)
# make dfc loss
# TODO: this should be converted to a reconstruction loss handler that wraps another handler
self._dfc_loss = DfcLossModule(feature_layers=self.cfg.feature_layers, input_mode=self.cfg.feature_inputs_mode)
Expand Down
4 changes: 2 additions & 2 deletions disent/frameworks/vae/_unsupervised__dipvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class cfg(BetaVae.cfg):
lambda_d: float = 10.
lambda_od: float = 5.

def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None):
super().__init__(make_optimizer_fn, make_model_fn, batch_augment=batch_augment, cfg=cfg)
def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None):
super().__init__(model=model, cfg=cfg, batch_augment=batch_augment)
# checks
assert self.cfg.dip_mode in {'i', 'ii'}, f'unsupported dip_mode={repr(self.cfg.dip_mode)} for {self.__class__.__name__}. Must be one of: {{"i", "ii"}}'
assert self.cfg.dip_beta >= 0, 'dip_beta must be >= 0'
Expand Down
4 changes: 2 additions & 2 deletions disent/frameworks/vae/_unsupervised__infovae.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ class cfg(Vae.cfg):
# this is optional
maintain_reg_ratio: bool = True

def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None):
super().__init__(make_optimizer_fn, make_model_fn, batch_augment=batch_augment, cfg=cfg)
def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None):
super().__init__(model=model, cfg=cfg, batch_augment=batch_augment)
# checks
assert self.cfg.info_alpha <= 0, f'cfg.info_alpha must be <= zero, current value is: {self.cfg.info_alpha}'
assert self.cfg.loss_reduction == 'mean', 'InfoVAE only supports cfg.loss_reduction == "mean"'
Expand Down
Loading

0 comments on commit 803ac3a

Please sign in to comment.