Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Oct 4, 2021
2 parents b2b3378 + 89a8422 commit 9bdd81d
Show file tree
Hide file tree
Showing 53 changed files with 1,915 additions and 236 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
python3 -m pip install --upgrade pip
python3 -m pip install -r requirements.txt
python3 -m pip install -r requirements-test.txt
python3 -m pip install -r requirements-exp.txt
python3 -m pip install -r requirements-experiment.txt
- name: Test with pytest
run: |
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ are not available from `pip install`.
- `experiment/config`: root folder for [hydra](https://github.com/facebookresearch/hydra) config files
- `experiment/util`: various helper code for experiments


----------------------

## Features
Expand Down
80 changes: 61 additions & 19 deletions disent/dataset/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
from disent.dataset.sampling import BaseDisentSampler
from disent.dataset.data import GroundTruthData
from disent.dataset.sampling import SingleSampler
from disent.dataset.wrapper import WrappedDataset
from disent.util.iters import LengthIter
from disent.util.math.random import random_choice_prng


# ========================================================================= #
Expand All @@ -60,6 +62,15 @@ def wrapper(self: 'DisentDataset', *args, **kwargs):
return wrapper


def wrapped_only(func):
@wraps(func)
def wrapper(self: 'DisentDataset', *args, **kwargs):
if not self.is_wrapped_data:
raise NotGroundTruthDataError(f'Check `is_data_wrapped` first before calling `{func.__name__}`, the dataset wrapped by {repr(self.__class__.__name__)} is not a {repr(WrappedDataset.__name__)}, instead got: {repr(self._dataset)}.')
return func(self, *args, **kwargs)
return wrapper


# ========================================================================= #
# Dataset Wrapper #
# ========================================================================= #
Expand Down Expand Up @@ -107,6 +118,51 @@ def is_ground_truth(self) -> bool:
def ground_truth_data(self) -> GroundTruthData:
return self._dataset

@property
@groundtruth_only
def gt_data(self) -> GroundTruthData:
# TODO: deprecate this or the long version
return self._dataset

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Wrapped Dataset #
# -- TODO: this is a bit hacky #
# -- Allows us to compute disentanglement metrics over datasets #
# derived from ground truth data #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

@property
def is_wrapped_data(self):
return isinstance(self._dataset, WrappedDataset)

@property
def is_wrapped_gt_data(self):
return isinstance(self._dataset, WrappedDataset) and isinstance(self._dataset.data, GroundTruthData)

@property
@wrapped_only
def wrapped_data(self):
self._dataset: WrappedDataset
return self._dataset.data

@property
@wrapped_only
def wrapped_gt_data(self):
self._dataset: WrappedDataset
return self._dataset.gt_data

@wrapped_only
def unwrapped_disent_dataset(self) -> 'DisentDataset':
sampler = self._sampler.uninit_copy()
assert type(sampler) is type(self._sampler)
return DisentDataset(
dataset=self.wrapped_data,
sampler=sampler,
transform=self._transform,
augment=self._augment,
return_indices=self._return_indices,
)

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# Dataset #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
Expand Down Expand Up @@ -202,19 +258,17 @@ def _dataset_get_observation(self, *idxs):
# Batches #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

# TODO: default_collate should be replaced with a function
# that can handle tensors and nd.arrays, and return accordingly

def dataset_batch_from_indices(self, indices: Sequence[int], mode: str):
"""Get a batch of observations X from a batch of factors Y."""
return default_collate([self.dataset_get(idx, mode=mode) for idx in indices])

def dataset_sample_batch(self, num_samples: int, mode: str, replace: bool = False, return_indices: bool = False):
"""Sample a batch of observations X."""
# create seeded pseudo random number generator
# - built in np.random.choice cannot handle large values: https://github.com/numpy/numpy/issues/5299#issuecomment-497915672
# - PCG64 is the default: https://numpy.org/doc/stable/reference/random/bit_generators/index.html
# - PCG64 has good statistical properties and is fast: https://numpy.org/doc/stable/reference/random/performance.html
g = np.random.Generator(np.random.PCG64(seed=np.random.randint(0, 2**32)))
# sample indices
indices = g.choice(len(self), num_samples, replace=replace)
# built in np.random.choice cannot handle large values: https://github.com/numpy/numpy/issues/5299#issuecomment-497915672
indices = random_choice_prng(len(self), size=num_samples, replace=replace)
# return batch
batch = self.dataset_batch_from_indices(indices, mode=mode)
# return values
Expand Down Expand Up @@ -256,18 +310,6 @@ def _batch_to_observation(batch, obs_shape):
return batch


# ========================================================================= #
# EXTRA #
# ========================================================================= #

# TODO fix references to this!
# class GroundTruthDatasetAndFactors(GroundTruthDataset):
# def dataset_get_observation(self, *idxs):
# return {
# **super().dataset_get_observation(*idxs),
# 'factors': tuple(self.idx_to_pos(idxs))
# }

# ========================================================================= #
# END #
# ========================================================================= #
18 changes: 8 additions & 10 deletions disent/dataset/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,17 @@

# custom episodes -- base
from disent.dataset.data._episodes import BaseEpisodesData

# custom episodes -- impl
from disent.dataset.data._episodes__custom import EpisodesPickledData
from disent.dataset.data._episodes__custom import EpisodesDownloadZippedPickledData

# raw -- groundtruth
from disent.dataset.data._groundtruth import ArrayGroundTruthData
from disent.dataset.data._groundtruth import SelfContainedHdf5GroundTruthData

# raw
from disent.dataset.data._raw import ArrayDataset
from disent.dataset.data._raw import Hdf5Dataset

# groundtruth -- base
from disent.dataset.data._groundtruth import GroundTruthData
from disent.dataset.data._groundtruth import DiskGroundTruthData
Expand All @@ -44,11 +50,3 @@

# groundtruth -- impl synthetic
from disent.dataset.data._groundtruth__xyobject import XYObjectData

# raw -- groundtruth
# TODO: hdf5 version
from disent.dataset.data._groundtruth import ArrayGroundTruthData

# raw
from disent.dataset.data._raw import ArrayDataset
from disent.dataset.data._raw import Hdf5Dataset
122 changes: 108 additions & 14 deletions disent/dataset/data/_groundtruth.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,39 @@ def factor_sizes(self) -> Tuple[int, ...]:

@property
def observation_shape(self) -> Tuple[int, ...]:
# TODO: deprecate this!
# TODO: observation_shape should be called img_shape
# shape as would be for a non-batched observation
# eg. H x W x C
raise NotImplementedError()

@property
def x_shape(self) -> Tuple[int, ...]:
# TODO: deprecate this!
# TODO: x_shape should be called obs_shape
# shape as would be for a single observation in a torch batch
# eg. C x H x W
shape = self.observation_shape
return shape[-1], *shape[:-1]

@property
def img_shape(self) -> Tuple[int, ...]:
# shape as would be for an original image
# eg. H x W x C
return self.observation_shape

@property
def obs_shape(self) -> Tuple[int, ...]:
# shape as would be for a single observation in a torch batch
# eg. C x H x W
return self.x_shape

@property
def img_channels(self) -> int:
channels = self.img_shape[-1]
assert channels in (1, 3), f'invalid number of channels for dataset: {self.__class__.__name__}, got: {repr(channels)}, required: 1 or 3'
return channels

def __getitem__(self, idx):
obs = self._get_observation(idx)
if self._transform is not None:
Expand Down Expand Up @@ -127,14 +149,27 @@ def sample_random_obs_traversal(self, f_idx: int = None, base_factors=None, num:

class ArrayGroundTruthData(GroundTruthData):

def __init__(self, array, factor_names: Tuple[str, ...], factor_sizes: Tuple[int, ...], observation_shape: Optional[Tuple[int, ...]] = None, transform=None):
def __init__(self, array, factor_names: Tuple[str, ...], factor_sizes: Tuple[int, ...], array_chn_is_last: bool = True, observation_shape: Optional[Tuple[int, ...]] = None, transform=None):
self.__factor_names = tuple(factor_names)
self.__factor_sizes = tuple(factor_sizes)
print(array.shape)
self.__observation_shape = tuple(observation_shape if (observation_shape is not None) else array.shape[1:])
self._array = array
# get shape
if observation_shape is not None:
C, H, W = observation_shape
elif array_chn_is_last:
H, W, C = array.shape[1:]
else:
C, H, W = array.shape[1:]
# set observation shape
self.__observation_shape = (H, W, C)
# initialize
super().__init__(transform=transform)
# check shapes -- it is up to the user to handle which method they choose
assert (array.shape[1:] == self.img_shape) or (array.shape[1:] == self.obs_shape)

@property
def array(self):
return self._array

@property
def factor_names(self) -> Tuple[str, ...]:
Expand All @@ -149,18 +184,22 @@ def observation_shape(self) -> Tuple[int, ...]:
return self.__observation_shape

def _get_observation(self, idx):
# TODO: INVESTIGATE! I think this implements a lock,
# hindering multi-threaded environments?
return self._array[idx]

@classmethod
def new_like(cls, array, dataset: GroundTruthData):
def new_like(cls, array, dataset: GroundTruthData, array_chn_is_last: bool = True):
return cls(
array=array,
factor_names=dataset.factor_names,
factor_sizes=dataset.factor_sizes,
array_chn_is_last=array_chn_is_last,
observation_shape=None, # infer from array
transform=None,
)


# ========================================================================= #
# disk ground truth data #
# TODO: data & datafile preparation should be split out from #
Expand Down Expand Up @@ -241,22 +280,23 @@ def data_key(self) -> Optional[str]:
return None


class Hdf5GroundTruthData(DiskGroundTruthData, metaclass=ABCMeta):
"""
Dataset that loads an Hdf5 file from a DataObject
- requires that the data object has the `out_dataset_name` attribute
that points to the hdf5 dataset in the file to load.
"""
class _Hdf5DataMixin(object):

def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_memory=False, transform=None):
super().__init__(data_root=data_root, prepare=prepare, transform=transform)
# set attributes if _mixin_hdf5_init is called
_in_memory: bool
_attrs: dict
_data: Union[Hdf5Dataset, np.ndarray]

def _mixin_hdf5_init(self, h5_path: str, h5_dataset_name: str = 'data', in_memory: bool = False):
# variables
self._in_memory = in_memory
# load the h5py dataset
data = Hdf5Dataset(
h5_path=os.path.join(self.data_dir, self.datafile.out_name),
h5_dataset_name=self.datafile.dataset_name,
h5_path=h5_path,
h5_dataset_name=h5_dataset_name,
)
# load attributes
self._attrs = data.get_attrs()
# handle different memory modes
if self._in_memory:
# Load the entire dataset into memory if required
Expand All @@ -268,9 +308,27 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_me
# Load the dataset from the disk
self._data = data

# override from GroundTruthData
def _get_observation(self, idx):
return self._data[idx]


class Hdf5GroundTruthData(_Hdf5DataMixin, DiskGroundTruthData, metaclass=ABCMeta):
"""
Dataset that loads an Hdf5 file from a DataObject
- requires that the data object has the `out_dataset_name` attribute
that points to the hdf5 dataset in the file to load.
"""

def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_memory=False, transform=None):
super().__init__(data_root=data_root, prepare=prepare, transform=transform)
# initialize mixin
self._mixin_hdf5_init(
h5_path=os.path.join(self.data_dir, self.datafile.out_name),
h5_dataset_name=self.datafile.dataset_name,
in_memory=in_memory,
)

@property
def datafiles(self) -> Sequence[DataFileHashedDlH5]:
return [self.datafile]
Expand All @@ -280,6 +338,42 @@ def datafile(self) -> DataFileHashedDlH5:
raise NotImplementedError


class SelfContainedHdf5GroundTruthData(_Hdf5DataMixin, GroundTruthData):

def __init__(self, h5_path: str, in_memory=False, transform=None):
# initialize mixin
self._mixin_hdf5_init(
h5_path=h5_path,
h5_dataset_name='data',
in_memory=in_memory,
)
# load attrs
self._attr_name = self._attrs['dataset_name'].decode("utf-8")
self._attr_factor_names = tuple(name.decode("utf-8") for name in self._attrs['factor_names'])
self._attr_factor_sizes = tuple(int(size) for size in self._attrs['factor_sizes'])
# set size
(B, H, W, C) = self._data.shape
self._observation_shape = (H, W, C)
# initialize!
super().__init__(transform=transform)

@property
def name(self) -> str:
return self._attr_name

@property
def factor_names(self) -> Tuple[str, ...]:
return self._attr_factor_names

@property
def factor_sizes(self) -> Tuple[int, ...]:
return self._attr_factor_sizes

@property
def observation_shape(self) -> Tuple[int, ...]:
return self._observation_shape


# ========================================================================= #
# END #
# ========================================================================= #
2 changes: 1 addition & 1 deletion disent/dataset/data/_groundtruth__dsprites.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class DSpritesData(Hdf5GroundTruthData):
hdf5_dataset_name='imgs',
hdf5_chunk_size=(1, 64, 64, 1),
hdf5_dtype='uint8',
hdf5_mutator=lambda x: x * 255,
hdf5_mutator=lambda x: (x * 255)[..., None], # data is of shape (-1, 64, 64), so we add the channel dimension
hdf5_obs_shape=(64, 64, 1),
)

Expand Down
3 changes: 3 additions & 0 deletions disent/dataset/data/_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def close(self):
del self._hdf5_file
del self._hdf5_data

def get_attrs(self) -> dict:
return dict(self._hdf5_data.attrs)


# ========================================================================= #
# END #
Expand Down
Loading

0 comments on commit 9bdd81d

Please sign in to comment.