Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Nov 11, 2021
2 parents 9bdd81d + 95290ca commit 3276d57
Show file tree
Hide file tree
Showing 233 changed files with 5,663 additions and 2,642 deletions.
15 changes: 9 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,14 @@ from disent.frameworks.vae import BetaVae
from disent.metrics import metric_dci, metric_mig
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.nn.transform import ToStandardisedTensor
from disent.dataset.transform import ToImgTensorF32
from disent.schedule import CyclicSchedule


# create the dataset & dataloaders
# - ToStandardisedTensor transforms images from numpy arrays to tensors and performs checks
# - ToImgTensorF32 transforms images from numpy arrays to tensors and performs checks
data = XYObjectData()
dataset = DisentDataset(dataset=data, sampler=SingleSampler(), transform=ToStandardisedTensor())
dataset = DisentDataset(dataset=data, sampler=SingleSampler(), transform=ToImgTensorF32())
dataloader = DataLoader(dataset=dataset, batch_size=128, shuffle=True, num_workers=os.cpu_count())

# create the BetaVAE model
Expand Down Expand Up @@ -261,7 +262,9 @@ module.register_schedule(

# train model
# - for 2048 batches/steps
trainer = pl.Trainer(max_steps=2048, gpus=1 if torch.cuda.is_available() else None, logger=False, checkpoint_callback=False)
trainer = pl.Trainer(
max_steps=2048, gpus=1 if torch.cuda.is_available() else None, logger=False, checkpoint_callback=False
)
trainer.fit(module, dataloader)

# compute disentanglement metrics
Expand Down Expand Up @@ -304,13 +307,13 @@ files (config options) in the subfolders (config groups) in
```yaml
defaults:
# system
- framework: adavae
- framework: adavae_os
- model: vae_conv64
- optimizer: adam
- schedule: none
# data
- dataset: xyobject
- dataset_sampling: full_bb
- sampling: default__bb
- augment: none
# runtime
- metrics: fast
Expand Down
30 changes: 27 additions & 3 deletions disent/dataset/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from functools import wraps
from typing import Optional
from typing import Sequence
from typing import TypeVar
from typing import Union

import numpy as np
Expand All @@ -35,6 +36,7 @@
from disent.dataset.data import GroundTruthData
from disent.dataset.sampling import SingleSampler
from disent.dataset.wrapper import WrappedDataset
from disent.util.deprecate import deprecated
from disent.util.iters import LengthIter
from disent.util.math.random import random_choice_prng

Expand All @@ -53,7 +55,10 @@ class NotGroundTruthDataError(Exception):
"""


def groundtruth_only(func):
T = TypeVar('T')


def groundtruth_only(func: T) -> T:
@wraps(func)
def wrapper(self: 'DisentDataset', *args, **kwargs):
if not self.is_ground_truth:
Expand All @@ -76,8 +81,12 @@ def wrapper(self: 'DisentDataset', *args, **kwargs):
# ========================================================================= #


_DO_COPY = object()


class DisentDataset(Dataset, LengthIter):


def __init__(
self,
dataset: Union[Dataset, GroundTruthData],
Expand All @@ -97,6 +106,20 @@ def __init__(
if not self._sampler.is_init:
self._sampler.init(dataset)

def shallow_copy(
self,
transform=_DO_COPY,
augment=_DO_COPY,
return_indices=_DO_COPY,
) -> 'DisentDataset':
return DisentDataset(
dataset=self._dataset,
sampler=self._sampler,
transform=self._transform if (transform is _DO_COPY) else transform,
augment=self._augment if (augment is _DO_COPY) else augment,
return_indices=self._return_indices if (return_indices is _DO_COPY) else return_indices,
)

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Properties #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
Expand All @@ -114,6 +137,7 @@ def is_ground_truth(self) -> bool:
return isinstance(self._dataset, GroundTruthData)

@property
@deprecated('ground_truth_data property replaced with `gt_data`')
@groundtruth_only
def ground_truth_data(self) -> GroundTruthData:
return self._dataset
Expand Down Expand Up @@ -284,13 +308,13 @@ def dataset_sample_batch(self, num_samples: int, mode: str, replace: bool = Fals
@groundtruth_only
def dataset_batch_from_factors(self, factors: np.ndarray, mode: str):
"""Get a batch of observations X from a batch of factors Y."""
indices = self.ground_truth_data.pos_to_idx(factors)
indices = self.gt_data.pos_to_idx(factors)
return self.dataset_batch_from_indices(indices, mode=mode)

@groundtruth_only
def dataset_sample_batch_with_factors(self, num_samples: int, mode: str):
"""Sample a batch of observations X and factors Y."""
factors = self.ground_truth_data.sample_factors(num_samples)
factors = self.gt_data.sample_factors(num_samples)
batch = self.dataset_batch_from_factors(factors, mode=mode)
return batch, default_collate(factors)

Expand Down
1 change: 1 addition & 0 deletions disent/dataset/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@

# groundtruth -- impl synthetic
from disent.dataset.data._groundtruth__xyobject import XYObjectData
from disent.dataset.data._groundtruth__xyobject import XYObjectShadedData
8 changes: 4 additions & 4 deletions disent/dataset/data/_episodes__custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _load_episode_observations(self) -> List[np.ndarray]:
# check variables
option_ids_to_names = {}
ground_truth_keys = None
observation_shape = None
img_shape = None
# load data
episodes = []
for i, raw_episode in enumerate(raw_episodes):
Expand All @@ -102,11 +102,11 @@ def _load_episode_observations(self) -> List[np.ndarray]:
for gt_state in ground_truth_states:
assert ground_truth_keys == gt_state.keys()
# CHECK: observation shapes
if observation_shape is None:
observation_shape = observed_states[0].shape
if img_shape is None:
img_shape = observed_states[0].shape
else:
for observation in observed_states:
assert observation.shape == observation_shape
assert observation.shape == img_shape
# APPEND: all observations into one long episode
rollout.extend(observed_states)
# cleanup unused memory! This is not ideal, but works well.
Expand Down
103 changes: 62 additions & 41 deletions disent/dataset/data/_groundtruth.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def __init__(self, transform=None):
factor_names=self.factor_names,
)

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Overridable Defaults #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

@property
def name(self):
name = self.__class__.__name__
Expand All @@ -70,7 +74,7 @@ def name(self):
return name.lower()

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Overrides #
# State Space #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

@property
Expand All @@ -81,41 +85,33 @@ def factor_names(self) -> Tuple[str, ...]:
def factor_sizes(self) -> Tuple[int, ...]:
raise NotImplementedError()

@property
def observation_shape(self) -> Tuple[int, ...]:
# TODO: deprecate this!
# TODO: observation_shape should be called img_shape
# shape as would be for a non-batched observation
# eg. H x W x C
raise NotImplementedError()
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Properties #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

@property
def x_shape(self) -> Tuple[int, ...]:
# TODO: deprecate this!
# TODO: x_shape should be called obs_shape
# shape as would be for a single observation in a torch batch
# eg. C x H x W
shape = self.observation_shape
return shape[-1], *shape[:-1]
H, W, C = self.img_shape
return (C, H, W)

@property
def img_shape(self) -> Tuple[int, ...]:
# shape as would be for an original image
# eg. H x W x C
return self.observation_shape

@property
def obs_shape(self) -> Tuple[int, ...]:
# shape as would be for a single observation in a torch batch
# eg. C x H x W
return self.x_shape
raise NotImplementedError()

@property
def img_channels(self) -> int:
channels = self.img_shape[-1]
assert channels in (1, 3), f'invalid number of channels for dataset: {self.__class__.__name__}, got: {repr(channels)}, required: 1 or 3'
return channels

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Overrides #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

def __getitem__(self, idx):
obs = self._get_observation(idx)
if self._transform is not None:
Expand Down Expand Up @@ -149,23 +145,23 @@ def sample_random_obs_traversal(self, f_idx: int = None, base_factors=None, num:

class ArrayGroundTruthData(GroundTruthData):

def __init__(self, array, factor_names: Tuple[str, ...], factor_sizes: Tuple[int, ...], array_chn_is_last: bool = True, observation_shape: Optional[Tuple[int, ...]] = None, transform=None):
def __init__(self, array, factor_names: Tuple[str, ...], factor_sizes: Tuple[int, ...], array_chn_is_last: bool = True, x_shape: Optional[Tuple[int, ...]] = None, transform=None):
self.__factor_names = tuple(factor_names)
self.__factor_sizes = tuple(factor_sizes)
self._array = array
# get shape
if observation_shape is not None:
C, H, W = observation_shape
if x_shape is not None:
C, H, W = x_shape
elif array_chn_is_last:
H, W, C = array.shape[1:]
else:
C, H, W = array.shape[1:]
# set observation shape
self.__observation_shape = (H, W, C)
self.__img_shape = (H, W, C)
# initialize
super().__init__(transform=transform)
# check shapes -- it is up to the user to handle which method they choose
assert (array.shape[1:] == self.img_shape) or (array.shape[1:] == self.obs_shape)
assert (array.shape[1:] == self.img_shape) or (array.shape[1:] == self.x_shape)

@property
def array(self):
Expand All @@ -180,22 +176,23 @@ def factor_sizes(self) -> Tuple[int, ...]:
return self.__factor_sizes

@property
def observation_shape(self) -> Tuple[int, ...]:
return self.__observation_shape
def img_shape(self) -> Tuple[int, ...]:
return self.__img_shape

def _get_observation(self, idx):
# TODO: INVESTIGATE! I think this implements a lock,
# hindering multi-threaded environments?
return self._array[idx]

@classmethod
def new_like(cls, array, dataset: GroundTruthData, array_chn_is_last: bool = True):
def new_like(cls, array, gt_data: GroundTruthData, array_chn_is_last: bool = True):
# TODO: should this not copy the x_shape and transform?
return cls(
array=array,
factor_names=dataset.factor_names,
factor_sizes=dataset.factor_sizes,
factor_names=gt_data.factor_names,
factor_sizes=gt_data.factor_sizes,
array_chn_is_last=array_chn_is_last,
observation_shape=None, # infer from array
x_shape=None, # infer from array
transform=None,
)

Expand All @@ -207,15 +204,12 @@ def new_like(cls, array, dataset: GroundTruthData, array_chn_is_last: bool = Tru
# ========================================================================= #


class DiskGroundTruthData(GroundTruthData, metaclass=ABCMeta):
class _DiskDataMixin(object):

"""
Dataset that prepares a list DataObjects into some local directory.
- This directory can be
"""
# attr this class defines in _mixin_disk_init
_data_dir: str

def __init__(self, data_root: Optional[str] = None, prepare: bool = False, transform=None):
super().__init__(transform=transform)
def _mixin_disk_init(self, data_root: Optional[str] = None, prepare: bool = False):
# get root data folder
if data_root is None:
data_root = self.default_data_root
Expand All @@ -242,6 +236,23 @@ def default_data_root(self):
def datafiles(self) -> Sequence[DataFile]:
raise NotImplementedError

@property
def name(self) -> str:
raise NotImplementedError


class DiskGroundTruthData(_DiskDataMixin, GroundTruthData, metaclass=ABCMeta):

"""
Dataset that prepares a list DataObjects into some local directory.
- This directory can be
"""

def __init__(self, data_root: Optional[str] = None, prepare: bool = False, transform=None):
super().__init__(transform=transform)
# get root data folder
self._mixin_disk_init(data_root=data_root, prepare=prepare)


class NumpyFileGroundTruthData(DiskGroundTruthData, metaclass=ABCMeta):
"""
Expand Down Expand Up @@ -282,7 +293,7 @@ def data_key(self) -> Optional[str]:

class _Hdf5DataMixin(object):

# set attributes if _mixin_hdf5_init is called
# attrs this class defines in _mixin_hdf5_init
_in_memory: bool
_attrs: dict
_data: Union[Hdf5Dataset, np.ndarray]
Expand All @@ -303,11 +314,21 @@ def _mixin_hdf5_init(self, h5_path: str, h5_dataset_name: str = 'data', in_memor
# indexing dataset objects returns numpy array
# instantiating np.array from the dataset requires double memory.
self._data = data[:]
self._data.flags.writeable = False
data.close()
else:
# Load the dataset from the disk
self._data = data

def __len__(self):
return len(self._data)

@property
def img_shape(self):
shape = self._data.shape[1:]
assert len(shape) == 3
return shape

# override from GroundTruthData
def _get_observation(self, idx):
return self._data[idx]
Expand Down Expand Up @@ -353,7 +374,7 @@ def __init__(self, h5_path: str, in_memory=False, transform=None):
self._attr_factor_sizes = tuple(int(size) for size in self._attrs['factor_sizes'])
# set size
(B, H, W, C) = self._data.shape
self._observation_shape = (H, W, C)
self._img_shape = (H, W, C)
# initialize!
super().__init__(transform=transform)

Expand All @@ -370,8 +391,8 @@ def factor_sizes(self) -> Tuple[int, ...]:
return self._attr_factor_sizes

@property
def observation_shape(self) -> Tuple[int, ...]:
return self._observation_shape
def img_shape(self) -> Tuple[int, ...]:
return self._img_shape


# ========================================================================= #
Expand Down
2 changes: 1 addition & 1 deletion disent/dataset/data/_groundtruth__cars3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class Cars3dData(NumpyFileGroundTruthData):

factor_names = ('elevation', 'azimuth', 'object_type')
factor_sizes = (4, 24, 183) # TOTAL: 17568
observation_shape = (128, 128, 3)
img_shape = (128, 128, 3)

datafile = DataFileCars3d(
uri='http://www.scottreed.info/files/nips2015-analogy-data.tar.gz',
Expand Down
Loading

0 comments on commit 3276d57

Please sign in to comment.