Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Jun 4, 2021
2 parents 0f45fef + aae5453 commit b2663a1
Show file tree
Hide file tree
Showing 136 changed files with 2,965 additions and 1,648 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,5 @@ logs/

# custom - root folder only
/data/dataset
/docs/examples/data
*.pkl
98 changes: 78 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@

----------------------

### Table Of Contents

- [Overview](#overview)
- [Getting Started](#getting-started)
- [Features](#features)
* [Frameworks](#frameworks)
* [Metrics](#metrics)
* [Datasets](#datasets)
* [Schedules & Annealing](#schedules--annealing)
- [Examples](#examples)
* [Python Example](#python-example)
* [Hydra Config Example](#hydra-config-example)
- [Why?](#why)

----------------------

### Overview

Disent is a modular disentangled representation learning framework for auto-encoders, built upon pytorch-lightning. This framework consists of various composable components that can be used to build and benchmark disentanglement pipelines.
Expand Down Expand Up @@ -159,7 +175,7 @@ add your own, or you have a request.

</p></details>

#### Datasets:
#### Datasets

Various common datasets used in disentanglement research are implemented, as well as new sythetic datasets that are generated programatically on the fly. These are convenient and lightweight, not requiring storage space.

Expand All @@ -181,7 +197,7 @@ Various common datasets used in disentanglement research are implemented, as wel
- Input based transforms are supported.
- Input and Target CPU and GPU based augmentations are supported.

#### Schedules/Annealing:
#### Schedules & Annealing

Hyper-parameter annealing is supported through the use of schedules. The currently implemented schedules include:

Expand All @@ -192,21 +208,6 @@ Hyper-parameter annealing is supported through the use of schedules. The current

----------------------

### Why?

- Created as part of my Computer Science MSc scheduled for completion in 2021.

- I needed custom high quality implementations of various VAE's.

- A pytorch version of [disentanglement_lib](https://github.com/google-research/disentanglement_lib).

- I didn't have time to wait for [Weakly-Supervised Disentanglement Without Compromises](https://arxiv.org/abs/2002.02886) to release
their code as part of disentanglement_lib. (As of September 2020 it has been released, but has unresolved [discrepencies](https://github.com/google-research/disentanglement_lib/issues/31)).

- disentanglement_lib still uses outdated Tensorflow 1.0, and the flow of data is unintuitive because of its use of [Gin Config](https://github.com/google/gin-config).

----------------------

### Architecture

**disent**
Expand All @@ -225,7 +226,9 @@ Hyper-parameter annealing is supported through the use of schedules. The current

----------------------

### Example Code
### Examples

#### Python Example

The following is a basic working example of disent that trains a BetaVAE with a cyclic
beta schedule and evaluates the trained model with various metrics.
Expand All @@ -241,9 +244,10 @@ from disent.data.groundtruth import XYObjectData
from disent.dataset.groundtruth import GroundTruthDataset
from disent.frameworks.vae import BetaVae
from disent.metrics import metric_dci, metric_mig
from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder
from disent.model.ae import EncoderConv64, DecoderConv64
from disent.model import AutoEncoder
from disent.nn.transform import ToStandardisedTensor
from disent.schedule import CyclicSchedule
from disent.transform import ToStandardisedTensor

# We use this internally to test this script.
# You can remove all references to this in your own code.
Expand Down Expand Up @@ -298,4 +302,58 @@ print('metrics:', metrics)

Visit the [docs](https://disent.dontpanic.sh) for more examples!

#### Hydra Config Example

The entrypoint for basic experiments is `experiments/run.py`.

Some configuration will be required, but basic experiments can
be adjusted by modifying the [Hydra Config 1.0](https://github.com/facebookresearch/hydra)
files in `experiment/config`.

Modifying the main `experiment/config/config.yaml` is all you
need for most basic experiments. The main config file contains
a defaults list with entries corresponding to yaml configuration
files (config options) in the subfolders (config groups) in
`experiment/config/<config_group>/<option>.yaml`.

```yaml
defaults:
# experiment
- framework: adavae
- model: conv64alt
- optimizer: adam
- dataset: xysquares
- augment: none
- sampling: full_bb
- metrics: fast
- schedule: beta_cyclic
# runtime
- run_length: long
- run_location: local
- run_callbacks: vis
- run_logging: none
```
Easily modify any of these values to adjust how the basic experiment
will be run. For example, change `framework: adavae` to `framework: betavae`, or
change the dataset from `xysquares` to `shapes3d`.

[Weights and Biases](https://docs.wandb.ai/quickstart) is supported by changing `run_logging: none` to
`run_logging: wandb`. However, you will need to login from the command line.

----------------------

### Why?

- Created as part of my Computer Science MSc scheduled for completion in 2021.

- I needed custom high quality implementations of various VAE's.

- A pytorch version of [disentanglement_lib](https://github.com/google-research/disentanglement_lib).

- I didn't have time to wait for [Weakly-Supervised Disentanglement Without Compromises](https://arxiv.org/abs/2002.02886) to release
their code as part of disentanglement_lib. (As of September 2020 it has been released, but has unresolved [discrepencies](https://github.com/google-research/disentanglement_lib/issues/31)).

- disentanglement_lib still uses outdated Tensorflow 1.0, and the flow of data is unintuitive because of its use of [Gin Config](https://github.com/google/gin-config).

----------------------
226 changes: 226 additions & 0 deletions disent/data/datafile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
# MIT License
#
# Copyright (c) 2021 Nathan Juraj Michlo
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

import os
from abc import ABCMeta
from typing import Callable
from typing import Dict
from typing import final
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

import numpy as np

from disent.data.hdf5 import hdf5_resave_file
from disent.util.cache import stalefile
from disent.util.function import wrapped_partial
from disent.util.in_out import retrieve_file
from disent.util.paths import filename_from_url
from disent.util.paths import modify_file_name


# ========================================================================= #
# data objects #
# ========================================================================= #


class DataFile(object, metaclass=ABCMeta):
"""
base DataFile that does nothing, if the file does
not exist or it has the incorrect hash, then that's your problem!
"""

def __init__(self, file_name: str):
self._file_name = file_name

@final
@property
def out_name(self) -> str:
return self._file_name

def prepare(self, out_dir: str) -> str:
# TODO: maybe check that the file exists or not and raise a FileNotFoundError?
pass


class DataFileHashed(DataFile, metaclass=ABCMeta):
"""
Abstract Class
- Base DataFile class that guarantees a file to exist,
if the file does not exist, or the hash of the file is
incorrect, then the file is re-generated.
"""

def __init__(
self,
file_name: str,
file_hash: Optional[Union[str, Dict[str, str]]],
hash_type: str = 'md5',
hash_mode: str = 'fast',
):
super().__init__(file_name=file_name)
self._file_hash = file_hash
self._hash_type = hash_type
self._hash_mode = hash_mode

def prepare(self, out_dir: str) -> str:
@stalefile(file=os.path.join(out_dir, self._file_name), hash=self._file_hash, hash_type=self._hash_type, hash_mode=self._hash_mode)
def wrapped(out_file):
self._prepare(out_dir=out_dir, out_file=out_file)
return wrapped()

def _prepare(self, out_dir: str, out_file: str) -> str:
# TODO: maybe raise a FileNotFoundError or a HashError instead?
raise NotImplementedError


class DataFileHashedDl(DataFileHashed):
"""
Download a file
- uri can also be a file to perform a copy instead of download,
useful for example if you want to retrieve a file from a network drive.
"""

def __init__(
self,
uri: str,
uri_hash: Optional[Union[str, Dict[str, str]]],
uri_name: Optional[str] = None,
hash_type: str = 'md5',
hash_mode: str = 'fast',
):
super().__init__(
file_name=filename_from_url(uri) if (uri_name is None) else uri_name,
file_hash=uri_hash,
hash_type=hash_type,
hash_mode=hash_mode
)
self._uri = uri

def _prepare(self, out_dir: str, out_file: str):
retrieve_file(src_uri=self._uri, dst_path=out_file, overwrite_existing=True)


class DataFileHashedDlGen(DataFileHashed, metaclass=ABCMeta):
"""
Abstract class
- download a file and perform some processing on that file.
"""

def __init__(
self,
# download & save files
uri: str,
uri_hash: Optional[Union[str, Dict[str, str]]],
file_hash: Optional[Union[str, Dict[str, str]]],
# save paths
uri_name: Optional[str] = None,
file_name: Optional[str] = None,
# hash settings
hash_type: str = 'md5',
hash_mode: str = 'fast',
):
self._dl_obj = DataFileHashedDl(
uri=uri,
uri_hash=uri_hash,
uri_name=uri_name,
hash_type=hash_type,
hash_mode=hash_mode,
)
super().__init__(
file_name=modify_file_name(self._dl_obj.out_name, prefix='gen') if (file_name is None) else file_name,
file_hash=file_hash,
hash_type=hash_type,
hash_mode=hash_mode,
)

def _prepare(self, out_dir: str, out_file: str):
inp_file = self._dl_obj.prepare(out_dir=out_dir)
self._generate(inp_file=inp_file, out_file=out_file)

def _generate(self, inp_file: str, out_file: str):
raise NotImplementedError


class DataFileHashedDlH5(DataFileHashedDlGen):
"""
Downloads an hdf5 file and pre-processes it into the specified chunk_size.
"""

def __init__(
self,
# download & save files
uri: str,
uri_hash: Optional[Union[str, Dict[str, str]]],
file_hash: Optional[Union[str, Dict[str, str]]],
# h5 re-save settings
hdf5_dataset_name: str,
hdf5_chunk_size: Tuple[int, ...],
hdf5_compression: Optional[str] = 'gzip',
hdf5_compression_lvl: Optional[int] = 4,
hdf5_dtype: Optional[Union[np.dtype, str]] = None,
hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None,
hdf5_obs_shape: Optional[Sequence[int]] = None,
# save paths
uri_name: Optional[str] = None,
file_name: Optional[str] = None,
# hash settings
hash_type: str = 'md5',
hash_mode: str = 'fast',
):
super().__init__(
file_name=file_name,
file_hash=file_hash,
uri=uri,
uri_hash=uri_hash,
uri_name=uri_name,
hash_type=hash_type,
hash_mode=hash_mode,
)
self._hdf5_resave_file = wrapped_partial(
hdf5_resave_file,
dataset_name=hdf5_dataset_name,
chunk_size=hdf5_chunk_size,
compression=hdf5_compression,
compression_lvl=hdf5_compression_lvl,
out_dtype=hdf5_dtype,
out_mutator=hdf5_mutator,
obs_shape=hdf5_obs_shape,
)
# save the dataset name
self._dataset_name = hdf5_dataset_name

@property
def dataset_name(self) -> str:
return self._dataset_name

def _generate(self, inp_file: str, out_file: str):
self._hdf5_resave_file(inp_path=inp_file, out_path=out_file)


# ========================================================================= #
# END #
# ========================================================================= #
Loading

0 comments on commit b2663a1

Please sign in to comment.