Skip to content

Commit

Permalink
Merge branch 'dataset_tasks' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Jun 4, 2021
2 parents 654c3bd + d2b92fd commit e15fb4a
Show file tree
Hide file tree
Showing 128 changed files with 2,887 additions and 1,620 deletions.
47 changes: 45 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,10 @@ from disent.data.groundtruth import XYObjectData
from disent.dataset.groundtruth import GroundTruthDataset
from disent.frameworks.vae import BetaVae
from disent.metrics import metric_dci, metric_mig
from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder
from disent.model.ae import EncoderConv64, DecoderConv64
from disent.model import AutoEncoder
from disent.nn.transform import ToStandardisedTensor
from disent.schedule import CyclicSchedule
from disent.transform import ToStandardisedTensor

# We use this internally to test this script.
# You can remove all references to this in your own code.
Expand Down Expand Up @@ -298,4 +299,46 @@ print('metrics:', metrics)

Visit the [docs](https://disent.dontpanic.sh) for more examples!


----------------------

### Hydra Experiment Example

The entrypoint for basic experiments is `experiments/run.py`.

Some configuration will be required, but basic experiments can
be adjusted by modifying the [Hydra Config 1.0](https://github.com/facebookresearch/hydra)
files in `experiment/config`.

Modifying the main `experiment/config/config.yaml` is all you
need for most basic experiments. The main config file contains
a defaults list with entries corresponding to yaml configuration
files (config options) in the subfolders (config groups) in
`experiment/config/<config_group>/<option>.yaml`.

```yaml
defaults:
# experiment
- framework: adavae
- model: conv64alt
- optimizer: adam
- dataset: xysquares
- augment: none
- sampling: full_bb
- metrics: fast
- schedule: beta_cyclic
# runtime
- run_length: long
- run_location: local
- run_callbacks: vis
- run_logging: none
```
Easily modify any of these values to adjust how the basic experiment
will be run. For example, change `framework: adavae` to `framework: betavae`, or
change the dataset from `xysquares` to `shapes3d`.

[Weights and Biases](https://docs.wandb.ai/quickstart) is supported by changing `run_logging: none` to
`run_logging: wandb`. However, you will need to login from the command line.

----------------------
226 changes: 226 additions & 0 deletions disent/data/datafile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
# MIT License
#
# Copyright (c) 2021 Nathan Juraj Michlo
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

import os
from abc import ABCMeta
from typing import Callable
from typing import Dict
from typing import final
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

import numpy as np

from disent.data.hdf5 import hdf5_resave_file
from disent.util.cache import stalefile
from disent.util.function import wrapped_partial
from disent.util.in_out import retrieve_file
from disent.util.paths import filename_from_url
from disent.util.paths import modify_file_name


# ========================================================================= #
# data objects #
# ========================================================================= #


class DataFile(object, metaclass=ABCMeta):
"""
base DataFile that does nothing, if the file does
not exist or it has the incorrect hash, then that's your problem!
"""

def __init__(self, file_name: str):
self._file_name = file_name

@final
@property
def out_name(self) -> str:
return self._file_name

def prepare(self, out_dir: str) -> str:
# TODO: maybe check that the file exists or not and raise a FileNotFoundError?
pass


class DataFileHashed(DataFile, metaclass=ABCMeta):
"""
Abstract Class
- Base DataFile class that guarantees a file to exist,
if the file does not exist, or the hash of the file is
incorrect, then the file is re-generated.
"""

def __init__(
self,
file_name: str,
file_hash: Optional[Union[str, Dict[str, str]]],
hash_type: str = 'md5',
hash_mode: str = 'fast',
):
super().__init__(file_name=file_name)
self._file_hash = file_hash
self._hash_type = hash_type
self._hash_mode = hash_mode

def prepare(self, out_dir: str) -> str:
@stalefile(file=os.path.join(out_dir, self._file_name), hash=self._file_hash, hash_type=self._hash_type, hash_mode=self._hash_mode)
def wrapped(out_file):
self._prepare(out_dir=out_dir, out_file=out_file)
return wrapped()

def _prepare(self, out_dir: str, out_file: str) -> str:
# TODO: maybe raise a FileNotFoundError or a HashError instead?
raise NotImplementedError


class DataFileHashedDl(DataFileHashed):
"""
Download a file
- uri can also be a file to perform a copy instead of download,
useful for example if you want to retrieve a file from a network drive.
"""

def __init__(
self,
uri: str,
uri_hash: Optional[Union[str, Dict[str, str]]],
uri_name: Optional[str] = None,
hash_type: str = 'md5',
hash_mode: str = 'fast',
):
super().__init__(
file_name=filename_from_url(uri) if (uri_name is None) else uri_name,
file_hash=uri_hash,
hash_type=hash_type,
hash_mode=hash_mode
)
self._uri = uri

def _prepare(self, out_dir: str, out_file: str):
retrieve_file(src_uri=self._uri, dst_path=out_file, overwrite_existing=True)


class DataFileHashedDlGen(DataFileHashed, metaclass=ABCMeta):
"""
Abstract class
- download a file and perform some processing on that file.
"""

def __init__(
self,
# download & save files
uri: str,
uri_hash: Optional[Union[str, Dict[str, str]]],
file_hash: Optional[Union[str, Dict[str, str]]],
# save paths
uri_name: Optional[str] = None,
file_name: Optional[str] = None,
# hash settings
hash_type: str = 'md5',
hash_mode: str = 'fast',
):
self._dl_obj = DataFileHashedDl(
uri=uri,
uri_hash=uri_hash,
uri_name=uri_name,
hash_type=hash_type,
hash_mode=hash_mode,
)
super().__init__(
file_name=modify_file_name(self._dl_obj.out_name, prefix='gen') if (file_name is None) else file_name,
file_hash=file_hash,
hash_type=hash_type,
hash_mode=hash_mode,
)

def _prepare(self, out_dir: str, out_file: str):
inp_file = self._dl_obj.prepare(out_dir=out_dir)
self._generate(inp_file=inp_file, out_file=out_file)

def _generate(self, inp_file: str, out_file: str):
raise NotImplementedError


class DataFileHashedDlH5(DataFileHashedDlGen):
"""
Downloads an hdf5 file and pre-processes it into the specified chunk_size.
"""

def __init__(
self,
# download & save files
uri: str,
uri_hash: Optional[Union[str, Dict[str, str]]],
file_hash: Optional[Union[str, Dict[str, str]]],
# h5 re-save settings
hdf5_dataset_name: str,
hdf5_chunk_size: Tuple[int, ...],
hdf5_compression: Optional[str] = 'gzip',
hdf5_compression_lvl: Optional[int] = 4,
hdf5_dtype: Optional[Union[np.dtype, str]] = None,
hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None,
hdf5_obs_shape: Optional[Sequence[int]] = None,
# save paths
uri_name: Optional[str] = None,
file_name: Optional[str] = None,
# hash settings
hash_type: str = 'md5',
hash_mode: str = 'fast',
):
super().__init__(
file_name=file_name,
file_hash=file_hash,
uri=uri,
uri_hash=uri_hash,
uri_name=uri_name,
hash_type=hash_type,
hash_mode=hash_mode,
)
self._hdf5_resave_file = wrapped_partial(
hdf5_resave_file,
dataset_name=hdf5_dataset_name,
chunk_size=hdf5_chunk_size,
compression=hdf5_compression,
compression_lvl=hdf5_compression_lvl,
out_dtype=hdf5_dtype,
out_mutator=hdf5_mutator,
obs_shape=hdf5_obs_shape,
)
# save the dataset name
self._dataset_name = hdf5_dataset_name

@property
def dataset_name(self) -> str:
return self._dataset_name

def _generate(self, inp_file: str, out_file: str):
self._hdf5_resave_file(inp_path=inp_file, out_path=out_file)


# ========================================================================= #
# END #
# ========================================================================= #
13 changes: 12 additions & 1 deletion disent/data/episodes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,14 @@
from typing import List, Tuple
import numpy as np


from disent.dataset.groundtruth._triplet import sample_radius
from disent.util import LengthIter
from disent.util.iters import LengthIter


# ========================================================================= #
# option episodes #
# ========================================================================= #


class BaseOptionEpisodesData(LengthIter):
Expand Down Expand Up @@ -86,3 +92,8 @@ def sample_episode_indices(episode, idx, n=1, radius=None):

def _load_episode_observations(self) -> List[np.ndarray]:
raise NotImplementedError


# ========================================================================= #
# END #
# ========================================================================= #
29 changes: 25 additions & 4 deletions disent/data/episodes/_option_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,26 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

import logging
import os
from typing import List, Tuple
from typing import List
from typing import Tuple

import numpy as np

from disent.data.episodes._base import BaseOptionEpisodesData
from disent.data.util.in_out import download_file, basename_from_url
import logging
from disent.util.in_out import download_file
from disent.util.paths import filename_from_url


log = logging.getLogger(__name__)


# ========================================================================= #
# option episodes #
# ========================================================================= #


class OptionEpisodesPickledData(BaseOptionEpisodesData):

def __init__(self, required_file: str):
Expand All @@ -40,6 +50,10 @@ def __init__(self, required_file: str):
# load data
super().__init__()

# TODO: convert this to data files?
# TODO: convert this to data files?
# TODO: convert this to data files?

def _load_episode_observations(self) -> List[np.ndarray]:
import pickle
# load the raw data!
Expand Down Expand Up @@ -107,6 +121,10 @@ def _load_episode_observations(self) -> List[np.ndarray]:

class OptionEpisodesDownloadZippedPickledData(OptionEpisodesPickledData):

# TODO: convert this to data files?
# TODO: convert this to data files?
# TODO: convert this to data files?

def __init__(self, required_file: str, download_url=None, force_download=False):
self._download_and_extract_if_needed(download_url=download_url, required_file=required_file, force_download=force_download)
super().__init__(required_file=required_file)
Expand All @@ -118,7 +136,7 @@ def _download_and_extract_if_needed(self, download_url: str, required_file: str,
if not isinstance(download_url, str):
return
# download file, but skip if file already exists
save_path = os.path.join(os.path.dirname(required_file), basename_from_url(download_url))
save_path = os.path.join(os.path.dirname(required_file), filename_from_url(download_url))
if force_download or not os.path.exists(save_path):
log.info(f'Downloading: {download_url=} to {save_path=}')
download_file(download_url, save_path=save_path)
Expand All @@ -141,3 +159,6 @@ def _download_and_extract_if_needed(self, download_url: str, required_file: str,
# ~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~


# ========================================================================= #
# END #
# ========================================================================= #
Loading

0 comments on commit e15fb4a

Please sign in to comment.