Skip to content

Commit

Permalink
Merge branch 'refactor' into dataset_tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Jun 4, 2021
2 parents 029f64a + eac7c05 commit d2b92fd
Show file tree
Hide file tree
Showing 65 changed files with 1,647 additions and 1,193 deletions.
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,4 +299,46 @@ print('metrics:', metrics)

Visit the [docs](https://disent.dontpanic.sh) for more examples!


----------------------

### Hydra Experiment Example

The entrypoint for basic experiments is `experiments/run.py`.

Some configuration will be required, but basic experiments can
be adjusted by modifying the [Hydra Config 1.0](https://github.com/facebookresearch/hydra)
files in `experiment/config`.

Modifying the main `experiment/config/config.yaml` is all you
need for most basic experiments. The main config file contains
a defaults list with entries corresponding to yaml configuration
files (config options) in the subfolders (config groups) in
`experiment/config/<config_group>/<option>.yaml`.

```yaml
defaults:
# experiment
- framework: adavae
- model: conv64alt
- optimizer: adam
- dataset: xysquares
- augment: none
- sampling: full_bb
- metrics: fast
- schedule: beta_cyclic
# runtime
- run_length: long
- run_location: local
- run_callbacks: vis
- run_logging: none
```
Easily modify any of these values to adjust how the basic experiment
will be run. For example, change `framework: adavae` to `framework: betavae`, or
change the dataset from `xysquares` to `shapes3d`.

[Weights and Biases](https://docs.wandb.ai/quickstart) is supported by changing `run_logging: none` to
`run_logging: wandb`. However, you will need to login from the command line.

----------------------
226 changes: 226 additions & 0 deletions disent/data/datafile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
# MIT License
#
# Copyright (c) 2021 Nathan Juraj Michlo
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

import os
from abc import ABCMeta
from typing import Callable
from typing import Dict
from typing import final
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

import numpy as np

from disent.data.hdf5 import hdf5_resave_file
from disent.util.cache import stalefile
from disent.util.function import wrapped_partial
from disent.util.in_out import retrieve_file
from disent.util.paths import filename_from_url
from disent.util.paths import modify_file_name


# ========================================================================= #
# data objects #
# ========================================================================= #


class DataFile(object, metaclass=ABCMeta):
"""
base DataFile that does nothing, if the file does
not exist or it has the incorrect hash, then that's your problem!
"""

def __init__(self, file_name: str):
self._file_name = file_name

@final
@property
def out_name(self) -> str:
return self._file_name

def prepare(self, out_dir: str) -> str:
# TODO: maybe check that the file exists or not and raise a FileNotFoundError?
pass


class DataFileHashed(DataFile, metaclass=ABCMeta):
"""
Abstract Class
- Base DataFile class that guarantees a file to exist,
if the file does not exist, or the hash of the file is
incorrect, then the file is re-generated.
"""

def __init__(
self,
file_name: str,
file_hash: Optional[Union[str, Dict[str, str]]],
hash_type: str = 'md5',
hash_mode: str = 'fast',
):
super().__init__(file_name=file_name)
self._file_hash = file_hash
self._hash_type = hash_type
self._hash_mode = hash_mode

def prepare(self, out_dir: str) -> str:
@stalefile(file=os.path.join(out_dir, self._file_name), hash=self._file_hash, hash_type=self._hash_type, hash_mode=self._hash_mode)
def wrapped(out_file):
self._prepare(out_dir=out_dir, out_file=out_file)
return wrapped()

def _prepare(self, out_dir: str, out_file: str) -> str:
# TODO: maybe raise a FileNotFoundError or a HashError instead?
raise NotImplementedError


class DataFileHashedDl(DataFileHashed):
"""
Download a file
- uri can also be a file to perform a copy instead of download,
useful for example if you want to retrieve a file from a network drive.
"""

def __init__(
self,
uri: str,
uri_hash: Optional[Union[str, Dict[str, str]]],
uri_name: Optional[str] = None,
hash_type: str = 'md5',
hash_mode: str = 'fast',
):
super().__init__(
file_name=filename_from_url(uri) if (uri_name is None) else uri_name,
file_hash=uri_hash,
hash_type=hash_type,
hash_mode=hash_mode
)
self._uri = uri

def _prepare(self, out_dir: str, out_file: str):
retrieve_file(src_uri=self._uri, dst_path=out_file, overwrite_existing=True)


class DataFileHashedDlGen(DataFileHashed, metaclass=ABCMeta):
"""
Abstract class
- download a file and perform some processing on that file.
"""

def __init__(
self,
# download & save files
uri: str,
uri_hash: Optional[Union[str, Dict[str, str]]],
file_hash: Optional[Union[str, Dict[str, str]]],
# save paths
uri_name: Optional[str] = None,
file_name: Optional[str] = None,
# hash settings
hash_type: str = 'md5',
hash_mode: str = 'fast',
):
self._dl_obj = DataFileHashedDl(
uri=uri,
uri_hash=uri_hash,
uri_name=uri_name,
hash_type=hash_type,
hash_mode=hash_mode,
)
super().__init__(
file_name=modify_file_name(self._dl_obj.out_name, prefix='gen') if (file_name is None) else file_name,
file_hash=file_hash,
hash_type=hash_type,
hash_mode=hash_mode,
)

def _prepare(self, out_dir: str, out_file: str):
inp_file = self._dl_obj.prepare(out_dir=out_dir)
self._generate(inp_file=inp_file, out_file=out_file)

def _generate(self, inp_file: str, out_file: str):
raise NotImplementedError


class DataFileHashedDlH5(DataFileHashedDlGen):
"""
Downloads an hdf5 file and pre-processes it into the specified chunk_size.
"""

def __init__(
self,
# download & save files
uri: str,
uri_hash: Optional[Union[str, Dict[str, str]]],
file_hash: Optional[Union[str, Dict[str, str]]],
# h5 re-save settings
hdf5_dataset_name: str,
hdf5_chunk_size: Tuple[int, ...],
hdf5_compression: Optional[str] = 'gzip',
hdf5_compression_lvl: Optional[int] = 4,
hdf5_dtype: Optional[Union[np.dtype, str]] = None,
hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None,
hdf5_obs_shape: Optional[Sequence[int]] = None,
# save paths
uri_name: Optional[str] = None,
file_name: Optional[str] = None,
# hash settings
hash_type: str = 'md5',
hash_mode: str = 'fast',
):
super().__init__(
file_name=file_name,
file_hash=file_hash,
uri=uri,
uri_hash=uri_hash,
uri_name=uri_name,
hash_type=hash_type,
hash_mode=hash_mode,
)
self._hdf5_resave_file = wrapped_partial(
hdf5_resave_file,
dataset_name=hdf5_dataset_name,
chunk_size=hdf5_chunk_size,
compression=hdf5_compression,
compression_lvl=hdf5_compression_lvl,
out_dtype=hdf5_dtype,
out_mutator=hdf5_mutator,
obs_shape=hdf5_obs_shape,
)
# save the dataset name
self._dataset_name = hdf5_dataset_name

@property
def dataset_name(self) -> str:
return self._dataset_name

def _generate(self, inp_file: str, out_file: str):
self._hdf5_resave_file(inp_path=inp_file, out_path=out_file)


# ========================================================================= #
# END #
# ========================================================================= #
13 changes: 12 additions & 1 deletion disent/data/episodes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,14 @@
from typing import List, Tuple
import numpy as np


from disent.dataset.groundtruth._triplet import sample_radius
from disent.util import LengthIter
from disent.util.iters import LengthIter


# ========================================================================= #
# option episodes #
# ========================================================================= #


class BaseOptionEpisodesData(LengthIter):
Expand Down Expand Up @@ -86,3 +92,8 @@ def sample_episode_indices(episode, idx, n=1, radius=None):

def _load_episode_observations(self) -> List[np.ndarray]:
raise NotImplementedError


# ========================================================================= #
# END #
# ========================================================================= #
29 changes: 25 additions & 4 deletions disent/data/episodes/_option_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,26 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

import logging
import os
from typing import List, Tuple
from typing import List
from typing import Tuple

import numpy as np

from disent.data.episodes._base import BaseOptionEpisodesData
from disent.data.util.in_out import download_file, basename_from_url
import logging
from disent.util.in_out import download_file
from disent.util.paths import filename_from_url


log = logging.getLogger(__name__)


# ========================================================================= #
# option episodes #
# ========================================================================= #


class OptionEpisodesPickledData(BaseOptionEpisodesData):

def __init__(self, required_file: str):
Expand All @@ -40,6 +50,10 @@ def __init__(self, required_file: str):
# load data
super().__init__()

# TODO: convert this to data files?
# TODO: convert this to data files?
# TODO: convert this to data files?

def _load_episode_observations(self) -> List[np.ndarray]:
import pickle
# load the raw data!
Expand Down Expand Up @@ -107,6 +121,10 @@ def _load_episode_observations(self) -> List[np.ndarray]:

class OptionEpisodesDownloadZippedPickledData(OptionEpisodesPickledData):

# TODO: convert this to data files?
# TODO: convert this to data files?
# TODO: convert this to data files?

def __init__(self, required_file: str, download_url=None, force_download=False):
self._download_and_extract_if_needed(download_url=download_url, required_file=required_file, force_download=force_download)
super().__init__(required_file=required_file)
Expand All @@ -118,7 +136,7 @@ def _download_and_extract_if_needed(self, download_url: str, required_file: str,
if not isinstance(download_url, str):
return
# download file, but skip if file already exists
save_path = os.path.join(os.path.dirname(required_file), basename_from_url(download_url))
save_path = os.path.join(os.path.dirname(required_file), filename_from_url(download_url))
if force_download or not os.path.exists(save_path):
log.info(f'Downloading: {download_url=} to {save_path=}')
download_file(download_url, save_path=save_path)
Expand All @@ -141,3 +159,6 @@ def _download_and_extract_if_needed(self, download_url: str, required_file: str,
# ~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~


# ========================================================================= #
# END #
# ========================================================================= #
19 changes: 10 additions & 9 deletions disent/data/groundtruth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

from .base import GroundTruthData
from disent.data.groundtruth.base import GroundTruthData

# others
# from ._cars3d import Cars3dData
from ._dsprites import DSpritesData
# from ._mpi3d import Mpi3dData
# from ._norb import SmallNorbData
from ._shapes3d import Shapes3dData
from ._xyobject import XYObjectData
from ._xysquares import XYSquaresData, XYSquaresMinimalData
from ._xyblocks import XYBlocksData
from disent.data.groundtruth._cars3d import Cars3dData
from disent.data.groundtruth._dsprites import DSpritesData
from disent.data.groundtruth._mpi3d import Mpi3dData
from disent.data.groundtruth._norb import SmallNorbData
from disent.data.groundtruth._shapes3d import Shapes3dData
from disent.data.groundtruth._xyobject import XYObjectData
from disent.data.groundtruth._xysquares import XYSquaresData, XYSquaresMinimalData
from disent.data.groundtruth._xyblocks import XYBlocksData
Loading

0 comments on commit d2b92fd

Please sign in to comment.