From 669f65e4d9817a90c61259bdbda6f4b52895a620 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Wed, 26 May 2021 00:33:51 +0200 Subject: [PATCH 01/34] all triplet_p values are now 1 --- experiment/config/framework/X--augpos_tvae_os.yaml | 2 +- experiment/config/framework/X--tbadavae.yaml | 2 +- experiment/config/framework/X--tgadavae.yaml | 2 +- experiment/config/framework/tae.yaml | 2 +- experiment/config/framework/tvae.yaml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/experiment/config/framework/X--augpos_tvae_os.yaml b/experiment/config/framework/X--augpos_tvae_os.yaml index 048f7959..375dd157 100644 --- a/experiment/config/framework/X--augpos_tvae_os.yaml +++ b/experiment/config/framework/X--augpos_tvae_os.yaml @@ -17,7 +17,7 @@ module: triplet_margin_min: 0.001 triplet_margin_max: 1 triplet_scale: 0.1 - triplet_p: 2 + triplet_p: 1 # settings used elsewhere data_wrap_mode: weak_pair diff --git a/experiment/config/framework/X--tbadavae.yaml b/experiment/config/framework/X--tbadavae.yaml index e57bfa8a..f2b44a3c 100644 --- a/experiment/config/framework/X--tbadavae.yaml +++ b/experiment/config/framework/X--tbadavae.yaml @@ -21,7 +21,7 @@ module: triplet_margin_min: 0.001 triplet_margin_max: 1 triplet_scale: 0.1 - triplet_p: 2 + triplet_p: 1 # settings used elsewhere data_wrap_mode: triplet diff --git a/experiment/config/framework/X--tgadavae.yaml b/experiment/config/framework/X--tgadavae.yaml index dbac246f..1d3e2e27 100644 --- a/experiment/config/framework/X--tgadavae.yaml +++ b/experiment/config/framework/X--tgadavae.yaml @@ -23,7 +23,7 @@ module: triplet_margin_min: 0.001 triplet_margin_max: 1 triplet_scale: 0.1 - triplet_p: 2 + triplet_p: 1 # settings used elsewhere data_wrap_mode: triplet diff --git a/experiment/config/framework/tae.yaml b/experiment/config/framework/tae.yaml index 391c64a5..61d57566 100644 --- a/experiment/config/framework/tae.yaml +++ b/experiment/config/framework/tae.yaml @@ -11,7 +11,7 @@ module: triplet_margin_min: 0.001 triplet_margin_max: 1 triplet_scale: 0.1 - triplet_p: 2 + triplet_p: 1 # settings used elsewhere data_wrap_mode: triplet diff --git a/experiment/config/framework/tvae.yaml b/experiment/config/framework/tvae.yaml index e40df573..085bcc79 100644 --- a/experiment/config/framework/tvae.yaml +++ b/experiment/config/framework/tvae.yaml @@ -17,7 +17,7 @@ module: triplet_margin_min: 0.001 triplet_margin_max: 1 triplet_scale: 0.1 - triplet_p: 2 + triplet_p: 1 # settings used elsewhere data_wrap_mode: triplet From 45ff2167906c270039ad39713506890b644bd65d Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Wed, 26 May 2021 14:53:51 +0200 Subject: [PATCH 02/34] BREAKING-CHANGES for experiments --- experiment/config/config.yaml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/experiment/config/config.yaml b/experiment/config/config.yaml index c4355ede..7517dc33 100644 --- a/experiment/config/config.yaml +++ b/experiment/config/config.yaml @@ -1,12 +1,12 @@ defaults: # experiment - framework: betavae - - model: conv64 + - model: conv64alt - optimizer: radam - dataset: xysquares # allow framework to override settings here, but placing dataset before framework in defaults - augment: none - sampling: full_bb - - metrics: all + - metrics: fast - schedule: none # runtime - run_length: long @@ -35,7 +35,7 @@ framework: overlap_loss: NULL model: - z_size: 9 + z_size: 25 optimizer: lr: 5e-4 @@ -44,4 +44,5 @@ optimizer: # - This key is deleted on load and the correct key on the root config is set similar to defaults. # - Unfortunately this hack needs to exists as hydra does not yet support this kinda of variable interpolation in defaults. specializations: - data_wrapper: ${dataset.data_type}_${framework.data_wrap_mode} +# data_wrapper: ${dataset.data_type}_${framework.data_wrap_mode} + data_wrapper: gt_dist_${framework.data_wrap_mode} From 384017014545f0c0f537617f578fb149aef810fd Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Wed, 26 May 2021 15:26:12 +0200 Subject: [PATCH 03/34] experiment 7 --- experiment/config/config.yaml | 5 ++-- experiment/exp/07_autoencoders/run.sh | 34 +++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) create mode 100644 experiment/exp/07_autoencoders/run.sh diff --git a/experiment/config/config.yaml b/experiment/config/config.yaml index 7517dc33..155f327b 100644 --- a/experiment/config/config.yaml +++ b/experiment/config/config.yaml @@ -29,7 +29,6 @@ framework: module: recon_loss: mse4 loss_reduction: mean - latent_distribution: cauchy optional: latent_distribution: normal # only used by VAEs overlap_loss: NULL @@ -44,5 +43,5 @@ optimizer: # - This key is deleted on load and the correct key on the root config is set similar to defaults. # - Unfortunately this hack needs to exists as hydra does not yet support this kinda of variable interpolation in defaults. specializations: -# data_wrapper: ${dataset.data_type}_${framework.data_wrap_mode} - data_wrapper: gt_dist_${framework.data_wrap_mode} + data_wrapper: ${dataset.data_type}_${framework.data_wrap_mode} +# data_wrapper: gt_dist_${framework.data_wrap_mode} diff --git a/experiment/exp/07_autoencoders/run.sh b/experiment/exp/07_autoencoders/run.sh new file mode 100644 index 00000000..e0642aa4 --- /dev/null +++ b/experiment/exp/07_autoencoders/run.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# ========================================================================= # +# Settings # +# ========================================================================= # + +export PROJECT="exp-autoencoder-versions" +export PARTITION="batch" +export PARALLELISM=24 + +# source the helper file +source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" + +# ========================================================================= # +# Experiment # +# ========================================================================= # + +clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours + +# 1 * (2*3*3*8) == 144 +submit_sweep \ + +DUMMY.repeat=1 \ + +EXTRA.tags='various-auto-encoders' \ + \ + run_length=short,long \ + schedule=adavae_up_ratio_full,adavae_up_all_full,none \ + \ + dataset=xysquares,cars3d,shapes3d \ + framework=ae,tae,X--adaae,X--adanegtae,vae,tvae,adavae,X--adanegtvae \ + model=conv64alt \ + model.z_size=25 \ + \ + specializations.data_wrapper='gt_dist_${framework.data_wrap_mode}' \ + sampling=gt_dist_manhat From 654c3bda3bc86d085bd28b8ae6c4e43adca11271 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Tue, 1 Jun 2021 00:51:42 +0200 Subject: [PATCH 04/34] experiment update update experiment --- experiment/exp/07_autoencoders/run.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/experiment/exp/07_autoencoders/run.sh b/experiment/exp/07_autoencoders/run.sh index e0642aa4..9c8f6d3c 100644 --- a/experiment/exp/07_autoencoders/run.sh +++ b/experiment/exp/07_autoencoders/run.sh @@ -5,8 +5,8 @@ # ========================================================================= # export PROJECT="exp-autoencoder-versions" -export PARTITION="batch" -export PARALLELISM=24 +export PARTITION="stampede" +export PARALLELISM=32 # source the helper file source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" @@ -17,7 +17,7 @@ source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours -# 1 * (2*3*3*8) == 144 +# 1 * (2*2*3*3*8) == 288 submit_sweep \ +DUMMY.repeat=1 \ +EXTRA.tags='various-auto-encoders' \ @@ -31,4 +31,4 @@ submit_sweep \ model.z_size=25 \ \ specializations.data_wrapper='gt_dist_${framework.data_wrap_mode}' \ - sampling=gt_dist_manhat + sampling=gt_dist_manhat,gt_dist_manhat_scaled From 6461912b88960e2774890432dd49a352556e84cd Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 28 May 2021 00:49:39 +0200 Subject: [PATCH 05/34] renamed constructor variables of XYSquaresData + XYSquaresMinimalData minimal implementation --- disent/data/groundtruth/__init__.py | 2 +- disent/data/groundtruth/_xysquares.py | 120 +++++++++++++++--- disent/metrics/_flatness.py | 4 +- disent/metrics/_flatness_components.py | 4 +- disent/util/__init__.py | 5 +- docs/examples/overview_data.py | 2 +- docs/examples/overview_dataset_loader.py | 2 +- docs/examples/overview_dataset_pair.py | 2 +- .../examples/overview_dataset_pair_augment.py | 2 +- docs/examples/overview_dataset_single.py | 2 +- experiment/config/dataset/xysquares.yaml | 10 +- experiment/config/dataset/xysquares_rgb.yaml | 15 +++ experiment/exp/00_data_traversal/run.py | 2 +- experiment/exp/01_visual_overlap/run.py | 16 +-- tests/test_data.py | 52 ++++++++ 15 files changed, 196 insertions(+), 44 deletions(-) create mode 100644 experiment/config/dataset/xysquares_rgb.yaml create mode 100644 tests/test_data.py diff --git a/disent/data/groundtruth/__init__.py b/disent/data/groundtruth/__init__.py index 0822769e..02a742ea 100644 --- a/disent/data/groundtruth/__init__.py +++ b/disent/data/groundtruth/__init__.py @@ -30,5 +30,5 @@ from ._norb import SmallNorbData from ._shapes3d import Shapes3dData from ._xyobject import XYObjectData -from ._xysquares import XYSquaresData +from ._xysquares import XYSquaresData, XYSquaresMinimalData from ._xyblocks import XYBlocksData diff --git a/disent/data/groundtruth/_xysquares.py b/disent/data/groundtruth/_xysquares.py index e4f91745..f59391bf 100644 --- a/disent/data/groundtruth/_xysquares.py +++ b/disent/data/groundtruth/_xysquares.py @@ -23,10 +23,13 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging +from typing import Optional from typing import Tuple -from disent.data.groundtruth.base import GroundTruthData +from typing import Union + import numpy as np +from disent.data.groundtruth.base import GroundTruthData from disent.util import iter_chunks @@ -38,14 +41,62 @@ # ========================================================================= # +class XYSquaresMinimalData(GroundTruthData): + """ + Dataset that generates all possible permutations of 3 (R, G, B) coloured + squares placed on a square grid. This dataset is designed to not overlap + in the reconstruction loss space. + + If you use this in your work, please cite: https://github.com/nmichlo/disent + + NOTE: Unlike XYSquaresData, XYSquaresMinimalData is the bare-minimum class + to generate the same results as the default values for XYSquaresData, + this class is a fair bit faster (~0.8x)! + - All 3 squares are returned, in RGB, each square is size 8, with + non-overlapping grid spacing set to 8 pixels, in total leaving + 8*8*8*8*8*8 factors. Images are uint8 with fill values 0 (bg) + and 255 (fg). + """ + + @property + def factor_names(self) -> Tuple[str, ...]: + return 'x_R', 'y_R', 'x_G', 'y_G', 'x_B', 'y_B' + + @property + def factor_sizes(self) -> Tuple[int, ...]: + return 8, 8, 8, 8, 8, 8 # R, G, B squares + + @property + def observation_shape(self) -> Tuple[int, ...]: + return 64, 64, 3 + + def __getitem__(self, idx): + # get factors + factors = np.reshape(np.unravel_index(idx, self.factor_sizes), (-1, 2)) + # GENERATE + obs = np.zeros(self.observation_shape, dtype=np.uint8) + for i, (fx, fy) in enumerate(factors): + x, y = 8 * fx, 8 * fy + obs[y:y+8, x:x+8, i] = 255 + return obs + + +# ========================================================================= # +# xy multi grid data # +# ========================================================================= # + + class XYSquaresData(GroundTruthData): """ Dataset that generates all possible permutations of 3 (R, G, B) coloured - squares placed on a square grid. - - This dataset is designed to not overlap in the reconstruction loss space. - (if the spacing is set correctly.) + squares placed on a square grid. This dataset is designed to not overlap + in the reconstruction loss space. (if the spacing is set correctly.) + + If you use this in your work, please cite: https://github.com/nmichlo/disent + + NOTE: Unlike XYSquaresMinimalData, XYSquaresData allows adjusting various aspects + of the data that is generated, but the generation process is slower (~1.25x). """ @property @@ -60,18 +111,52 @@ def factor_sizes(self) -> Tuple[int, ...]: def observation_shape(self) -> Tuple[int, ...]: return self._width, self._width, (3 if self._rgb else 1) - def __init__(self, square_size=8, grid_size=64, grid_spacing=None, num_squares=3, rgb=True, no_warnings=False, fill_value=None, max_placements=None): + def __init__( + self, + square_size: int = 8, + image_size: int = 64, + grid_size: Optional[int] = None, + grid_spacing: Optional[int] = None, + num_squares: int = 3, + rgb: bool = True, + fill_value: Optional[Union[float, int]] = None, + dtype: Union[np.dtype, str] = np.uint8, + no_warnings: bool = False, + ): + """ + :param square_size: the size of the individual squares in pixels + :param image_size: the image size in pixels + :param grid_spacing: the step size between square positions on the grid. By + default this is set to square_size which results in non-overlapping + data if `grid_spacing >= square_size` Reducing this value such that + `grid_spacing < square_size` results in overlapping data. + :param num_squares: The number of squares drawn. `1 <= num_squares <= 3` + :param rgb: Image has 3 channels if True, otherwise it is greyscale with 1 channel. + :param no_warnings: If warnings should be disabled if overlapping. + :param fill_value: The foreground value to use for filling squares, the default background value is 0. + :param grid_size: The number of grid positions available for the square to be placed in. The square is centered if this is less than + :param dtype: + """ if grid_spacing is None: grid_spacing = square_size - if grid_spacing < square_size and not no_warnings: + if (grid_spacing < square_size) and not no_warnings: log.warning(f'overlap between squares for reconstruction loss, {grid_spacing} < {square_size}') # color self._rgb = rgb - self._fill_value = fill_value if (fill_value is not None) else 255 - assert isinstance(self._fill_value, int) - assert 0 < self._fill_value <= 255, f'0 < {self._fill_value} <= 255' + self._dtype = np.dtype(dtype) + # check fill values + if self._dtype.kind == 'u': + self._fill_value = 255 if (fill_value is None) else fill_value + assert isinstance(self._fill_value, int) + assert 0 < self._fill_value <= 255, f'0 < {self._fill_value} <= 255' + elif self._dtype.kind == 'f': + self._fill_value = 1.0 if (fill_value is None) else fill_value + assert isinstance(self._fill_value, (int, float)) + assert 0.0 < self._fill_value <= 1.0, f'0.0 < {self._fill_value} <= 1.0' + else: + raise TypeError(f'invalid dtype: {self._dtype}, must be float or unsigned integer') # image sizes - self._width = grid_size + self._width = image_size # number of squares self._num_squares = num_squares assert 1 <= num_squares <= 3, 'Only 1, 2 or 3 squares are supported!' @@ -81,12 +166,15 @@ def __init__(self, square_size=8, grid_size=64, grid_spacing=None, num_squares=3 self._spacing = grid_spacing self._placements = (self._width - self._square_size) // grid_spacing + 1 # maximum placements - if max_placements is not None: - assert isinstance(max_placements, int) - assert max_placements > 0 - self._placements = min(self._placements, max_placements) + if grid_size is not None: + assert isinstance(grid_size, int) + assert grid_size > 0 + if (grid_size > self._placements) and not no_warnings: + log.warning(f'number of possible placements: {self._placements} is less than the given grid size: {grid_size}, reduced grid size from: {grid_size} -> {self._placements}') + self._placements = min(self._placements, grid_size) # center elements self._offset = (self._width - (self._square_size + (self._placements-1)*self._spacing)) // 2 + # initialise parents -- they depend on self.factors super().__init__() def __getitem__(self, idx): @@ -94,7 +182,7 @@ def __getitem__(self, idx): factors = self.idx_to_pos(idx) offset, space, size = self._offset, self._spacing, self._square_size # GENERATE - obs = np.zeros(self.observation_shape, dtype=np.uint8) + obs = np.zeros(self.observation_shape, dtype=self._dtype) for i, (fx, fy) in enumerate(iter_chunks(factors, 2)): x, y = offset + space * fx, offset + space * fy if self._rgb: diff --git a/disent/metrics/_flatness.py b/disent/metrics/_flatness.py index af9aeb06..cbb3daa6 100644 --- a/disent/metrics/_flatness.py +++ b/disent/metrics/_flatness.py @@ -301,10 +301,10 @@ def angles_between(a, b, dim=-1, nan_to_angle=None): # return r # # class XYOverlapData(XYSquaresData): -# def __init__(self, square_size=8, grid_size=64, grid_spacing=None, num_squares=3, rgb=True): +# def __init__(self, square_size=8, image_size=64, grid_spacing=None, num_squares=3, rgb=True): # if grid_spacing is None: # grid_spacing = (square_size+1) // 2 -# super().__init__(square_size=square_size, grid_size=grid_size, grid_spacing=grid_spacing, num_squares=num_squares, rgb=rgb) +# super().__init__(square_size=square_size, image_size=image_size, grid_spacing=grid_spacing, num_squares=num_squares, rgb=rgb) # # # datasets = [XYObjectData(rgb=False, palette='white'), XYSquaresData(), XYOverlapData(), XYObjectData()] # datasets = [XYObjectData()] diff --git a/disent/metrics/_flatness_components.py b/disent/metrics/_flatness_components.py index 27109f45..eddf327c 100644 --- a/disent/metrics/_flatness_components.py +++ b/disent/metrics/_flatness_components.py @@ -333,10 +333,10 @@ def aggregate_measure_distances_along_factor( # return r # # class XYOverlapData(XYSquaresData): -# def __init__(self, square_size=8, grid_size=64, grid_spacing=None, num_squares=3, rgb=True): +# def __init__(self, square_size=8, image_size=64, grid_spacing=None, num_squares=3, rgb=True): # if grid_spacing is None: # grid_spacing = (square_size+1) // 2 -# super().__init__(square_size=square_size, grid_size=grid_size, grid_spacing=grid_spacing, num_squares=num_squares, rgb=rgb) +# super().__init__(square_size=square_size, image_size=image_size, grid_spacing=grid_spacing, num_squares=num_squares, rgb=rgb) # # # datasets = [XYObjectData(rgb=False, palette='white'), XYSquaresData(), XYOverlapData(), XYObjectData()] # datasets = [XYObjectData()] diff --git a/disent/util/__init__.py b/disent/util/__init__.py index 0bb4452b..4214dc88 100644 --- a/disent/util/__init__.py +++ b/disent/util/__init__.py @@ -353,7 +353,10 @@ def __enter__(self): def __exit__(self, *args, **kwargs): self._end_time = time.time_ns() if self._print_name: - log.log(self._log_level, f'{self._print_name}: {self.pretty}') + if self._log_level is None: + print(f'{self._print_name}: {self.pretty}') + else: + log.log(self._log_level, f'{self._print_name}: {self.pretty}') @property def elapsed_ns(self) -> int: diff --git a/docs/examples/overview_data.py b/docs/examples/overview_data.py index 36a3cce4..9393d2ee 100644 --- a/docs/examples/overview_data.py +++ b/docs/examples/overview_data.py @@ -1,6 +1,6 @@ from disent.data.groundtruth import XYSquaresData -data = XYSquaresData(square_size=1, grid_size=2, num_squares=2) +data = XYSquaresData(square_size=1, image_size=2, num_squares=2) print(f'Number of observations: {len(data)} == {data.size}') print(f'Observation shape: {data.observation_shape}') diff --git a/docs/examples/overview_dataset_loader.py b/docs/examples/overview_dataset_loader.py index a8f33ce1..793e2a54 100644 --- a/docs/examples/overview_dataset_loader.py +++ b/docs/examples/overview_dataset_loader.py @@ -3,7 +3,7 @@ from disent.dataset.groundtruth import GroundTruthDatasetPairs from disent.transform import ToStandardisedTensor -data: GroundTruthData = XYSquaresData(square_size=1, grid_size=2, num_squares=2) +data: GroundTruthData = XYSquaresData(square_size=1, image_size=2, num_squares=2) dataset: Dataset = GroundTruthDatasetPairs(data, transform=ToStandardisedTensor(), augment=None) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) diff --git a/docs/examples/overview_dataset_pair.py b/docs/examples/overview_dataset_pair.py index d0fffdc4..b18a57b1 100644 --- a/docs/examples/overview_dataset_pair.py +++ b/docs/examples/overview_dataset_pair.py @@ -3,7 +3,7 @@ from disent.dataset.groundtruth import GroundTruthDatasetPairs from disent.transform import ToStandardisedTensor -data: GroundTruthData = XYSquaresData(square_size=1, grid_size=2, num_squares=2) +data: GroundTruthData = XYSquaresData(square_size=1, image_size=2, num_squares=2) dataset: Dataset = GroundTruthDatasetPairs(data, transform=ToStandardisedTensor(), augment=None) for obs in dataset: diff --git a/docs/examples/overview_dataset_pair_augment.py b/docs/examples/overview_dataset_pair_augment.py index a759eb1b..17080512 100644 --- a/docs/examples/overview_dataset_pair_augment.py +++ b/docs/examples/overview_dataset_pair_augment.py @@ -3,7 +3,7 @@ from disent.dataset.groundtruth import GroundTruthDatasetPairs from disent.transform import ToStandardisedTensor, FftBoxBlur -data: GroundTruthData = XYSquaresData(square_size=1, grid_size=2, num_squares=2) +data: GroundTruthData = XYSquaresData(square_size=1, image_size=2, num_squares=2) dataset: Dataset = GroundTruthDatasetPairs(data, transform=ToStandardisedTensor(), augment=FftBoxBlur(radius=1, p=1.0)) for obs in dataset: diff --git a/docs/examples/overview_dataset_single.py b/docs/examples/overview_dataset_single.py index d9274bbb..a400620a 100644 --- a/docs/examples/overview_dataset_single.py +++ b/docs/examples/overview_dataset_single.py @@ -2,7 +2,7 @@ from disent.data.groundtruth import XYSquaresData, GroundTruthData from disent.dataset.groundtruth import GroundTruthDataset -data: GroundTruthData = XYSquaresData(square_size=1, grid_size=2, num_squares=2) +data: GroundTruthData = XYSquaresData(square_size=1, image_size=2, num_squares=2) dataset: Dataset = GroundTruthDataset(data, transform=None, augment=None) for obs in dataset: diff --git a/experiment/config/dataset/xysquares.yaml b/experiment/config/dataset/xysquares.yaml index a22cc85f..d2c183e3 100644 --- a/experiment/config/dataset/xysquares.yaml +++ b/experiment/config/dataset/xysquares.yaml @@ -1,13 +1,7 @@ # @package _group_ -name: xysquares +name: xysquares_minimal data: - _target_: disent.data.groundtruth.XYSquaresData - square_size: 8 - grid_size: 64 - grid_spacing: 8 - num_squares: 3 - rgb: TRUE - max_placements: 8 + _target_: disent.data.groundtruth.XYSquaresMinimalData transform: _target_: disent.transform.ToStandardisedTensor x_shape: [3, 64, 64] diff --git a/experiment/config/dataset/xysquares_rgb.yaml b/experiment/config/dataset/xysquares_rgb.yaml new file mode 100644 index 00000000..93bb2eac --- /dev/null +++ b/experiment/config/dataset/xysquares_rgb.yaml @@ -0,0 +1,15 @@ +# @package _group_ +name: xysquares_rgb +data: + _target_: disent.data.groundtruth.XYSquaresData + square_size: 8 + grid_size: 64 + grid_spacing: 8 + num_squares: 3 + rgb: TRUE + max_placements: 8 +transform: + _target_: disent.transform.ToStandardisedTensor +x_shape: [3, 64, 64] + +data_type: ground_truth diff --git a/experiment/exp/00_data_traversal/run.py b/experiment/exp/00_data_traversal/run.py index 5411f488..e3467c4b 100644 --- a/experiment/exp/00_data_traversal/run.py +++ b/experiment/exp/00_data_traversal/run.py @@ -119,7 +119,7 @@ def plot_dataset_traversals( # save image for i in ([1, 2, 3, 4, 5, 6, 7, 8] if all_squares else [1, 8]): plot_dataset_traversals( - XYSquaresData(grid_spacing=i, max_placements=8, no_warnings=True), + XYSquaresData(grid_spacing=i, grid_size=8, no_warnings=True), rel_path=f'plots/xy-squares-traversal-spacing{i}', seed=7, add_random_traversal=add_random_traversal, num_cols=num_cols ) diff --git a/experiment/exp/01_visual_overlap/run.py b/experiment/exp/01_visual_overlap/run.py index b0b40477..8f73a684 100644 --- a/experiment/exp/01_visual_overlap/run.py +++ b/experiment/exp/01_visual_overlap/run.py @@ -423,14 +423,14 @@ def plot_unique_count(dfs, save_name: str = None, show_plt: bool = True, fig_l_p dfs = plot_all( exp_name='increasing-overlap-fixed', datas={ - 'XYSquares-1-8': lambda: XYSquaresData(square_size=8, grid_spacing=1, max_placements=8), - 'XYSquares-2-8': lambda: XYSquaresData(square_size=8, grid_spacing=2, max_placements=8), - 'XYSquares-3-8': lambda: XYSquaresData(square_size=8, grid_spacing=3, max_placements=8), - 'XYSquares-4-8': lambda: XYSquaresData(square_size=8, grid_spacing=4, max_placements=8), - 'XYSquares-5-8': lambda: XYSquaresData(square_size=8, grid_spacing=5, max_placements=8), - 'XYSquares-6-8': lambda: XYSquaresData(square_size=8, grid_spacing=6, max_placements=8), - 'XYSquares-7-8': lambda: XYSquaresData(square_size=8, grid_spacing=7, max_placements=8), - 'XYSquares-8-8': lambda: XYSquaresData(square_size=8, grid_spacing=8, max_placements=8), + 'XYSquares-1-8': lambda: XYSquaresData(square_size=8, grid_spacing=1, grid_size=8), + 'XYSquares-2-8': lambda: XYSquaresData(square_size=8, grid_spacing=2, grid_size=8), + 'XYSquares-3-8': lambda: XYSquaresData(square_size=8, grid_spacing=3, grid_size=8), + 'XYSquares-4-8': lambda: XYSquaresData(square_size=8, grid_spacing=4, grid_size=8), + 'XYSquares-5-8': lambda: XYSquaresData(square_size=8, grid_spacing=5, grid_size=8), + 'XYSquares-6-8': lambda: XYSquaresData(square_size=8, grid_spacing=6, grid_size=8), + 'XYSquares-7-8': lambda: XYSquaresData(square_size=8, grid_spacing=7, grid_size=8), + 'XYSquares-8-8': lambda: XYSquaresData(square_size=8, grid_spacing=8, grid_size=8), }, hide_extra_legends=True, **SHARED_SETTINGS diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 00000000..6031accb --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,52 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import numpy as np + +from disent.data.groundtruth import XYSquaresData +from disent.data.groundtruth._xysquares import XYSquaresMinimalData + + +# ========================================================================= # +# TESTS # +# ========================================================================= # + + +def test_xysquares_similarity(): + data_org = XYSquaresData() + data_min = XYSquaresMinimalData() + # check lengths + assert len(data_org) == len(data_min) + n = len(data_min) + # check items + for i in np.random.randint(0, n, size=100): + assert np.allclose(data_org[i], data_min[i]) + # check bounds + assert np.allclose(data_org[0], data_min[0]) + assert np.allclose(data_org[n-1], data_min[n-1]) + + +# ========================================================================= # +# END # +# ========================================================================= # + From 0c75dd12174e92b238d04093230d2b6f98fee8f9 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 28 May 2021 12:43:00 +0200 Subject: [PATCH 06/34] data is no longer stored on classes, but instances of those classes --- disent/data/groundtruth/_cars3d.py | 8 ++------ disent/data/groundtruth/_mpi3d.py | 6 +----- disent/data/groundtruth/_norb.py | 13 +++++-------- experiment/util/hydra_utils.py | 2 +- 4 files changed, 9 insertions(+), 20 deletions(-) diff --git a/disent/data/groundtruth/_cars3d.py b/disent/data/groundtruth/_cars3d.py index d28884d1..24c43fa8 100644 --- a/disent/data/groundtruth/_cars3d.py +++ b/disent/data/groundtruth/_cars3d.py @@ -56,15 +56,11 @@ class Cars3dData(DownloadableGroundTruthData): def __init__(self, data_dir='data/dataset/cars3d', force_download=False): super().__init__(data_dir=data_dir, force_download=force_download) - converted_file = self._make_converted_file(data_dir, force_download) - - if not hasattr(self.__class__, '_DATA'): - # store data on class - self.__class__._DATA = np.load(converted_file)['images'] + self._data = np.load(converted_file)['images'] def __getitem__(self, idx): - return self.__class__._DATA[idx] + return self._data[idx] def _make_converted_file(self, data_dir, force_download): # get files & folders diff --git a/disent/data/groundtruth/_mpi3d.py b/disent/data/groundtruth/_mpi3d.py index 9ad65250..626576e9 100644 --- a/disent/data/groundtruth/_mpi3d.py +++ b/disent/data/groundtruth/_mpi3d.py @@ -74,11 +74,7 @@ def __init__(self, data_dir='data/dataset/mpi3d', force_download=False, subset=' super().__init__(data_dir=data_dir, force_download=force_download) # load data - if not hasattr(self.__class__, '_DATA'): - self.__class__._DATA = {} - if subset not in self.__class__._DATA: - self.__class__._DATA[subset] = np.load(self.dataset_paths[0]) - self._data = self.__class__._DATA[subset] + self._data = np.load(self.dataset_paths[0]) def __getitem__(self, idx): return self._data[idx] diff --git a/disent/data/groundtruth/_norb.py b/disent/data/groundtruth/_norb.py index 5d417a65..bfe4e96e 100644 --- a/disent/data/groundtruth/_norb.py +++ b/disent/data/groundtruth/_norb.py @@ -76,16 +76,13 @@ class SmallNorbData(DownloadableGroundTruthData): def __init__(self, data_dir='data/dataset/smallnorb', force_download=False, is_test=False): super().__init__(data_dir=data_dir, force_download=force_download) assert not is_test, 'Test set not yet supported' - - if not hasattr(self.__class__, '_DATA'): - images, features = self._read_norb_set(is_test) - # sort by features - indices = np.lexsort(features[:, [4, 3, 2, 1, 0]].T) - # store data on class - self.__class__._DATA = images[indices] + # read dataset and sort by features + images, features = self._read_norb_set(is_test) + indices = np.lexsort(features[:, [4, 3, 2, 1, 0]].T) + self._data = images[indices] def __getitem__(self, idx): - return self.__class__._DATA[idx] + return self._data[idx] def _read_norb_set(self, is_test): # get file data corresponding to urls diff --git a/experiment/util/hydra_utils.py b/experiment/util/hydra_utils.py index 603f5841..f2afef8a 100644 --- a/experiment/util/hydra_utils.py +++ b/experiment/util/hydra_utils.py @@ -85,7 +85,7 @@ def merge_specializations(cfg: DictConfig, config_path: str, main_fn: callable, # set and update specializations for group, specialization in cfg.specializations.items(): - assert group not in cfg, f'{group=} already exists on cfg, specialization merging is not supported!' + assert group not in cfg, f'group={repr(group)} already exists on cfg, specialization merging is not supported!' log.info(f'merging specialization: {repr(specialization)}') # load specialization config specialization_cfg = OmegaConf.load(os.path.join(config_root, group, f'{specialization}.yaml')) From 6296e2867d2f495081106ac3e722e01cd6b80d45 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Sun, 30 May 2021 22:18:34 +0200 Subject: [PATCH 07/34] basic tasks -- split out logic from dataset updated tasks update jobs update cached jobs simplified jobs & base datasets + added fixed shapes3d --- disent/data/groundtruth/_shapes3d.py | 49 ++--- disent/data/groundtruth/base.py | 314 ++++++++++++++------------- disent/data/util/hdf5.py | 127 +++++++---- disent/data/util/in_out.py | 264 ++++++++++++++++++---- disent/data/util/jobs.py | 156 +++++++++++++ disent/util/__init__.py | 59 ++++- experiment/exp/util/_tasks.py | 1 + 7 files changed, 695 insertions(+), 275 deletions(-) create mode 100644 disent/data/util/jobs.py diff --git a/disent/data/groundtruth/_shapes3d.py b/disent/data/groundtruth/_shapes3d.py index f68abc06..c198f724 100644 --- a/disent/data/groundtruth/_shapes3d.py +++ b/disent/data/groundtruth/_shapes3d.py @@ -22,7 +22,9 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -from disent.data.groundtruth.base import Hdf5PreprocessedGroundTruthData + +from disent.data.groundtruth.base import DlH5DataObject +from disent.data.groundtruth.base import Hdf5GroundTruthData # ========================================================================= # @@ -30,7 +32,7 @@ # ========================================================================= # -class Shapes3dData(Hdf5PreprocessedGroundTruthData): +class Shapes3dData(Hdf5GroundTruthData): """ 3D Shapes Dataset: - https://github.com/deepmind/3d-shapes @@ -39,41 +41,32 @@ class Shapes3dData(Hdf5PreprocessedGroundTruthData): - direct: https://storage.googleapis.com/3d-shapes/3dshapes.h5 redirect: https://storage.cloud.google.com/3d-shapes/3dshapes.h5 info: https://console.cloud.google.com/storage/browser/_details/3d-shapes/3dshapes.h5 - - reference implementation: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/data/ground_truth/shapes3d.py """ - dataset_url = 'https://storage.googleapis.com/3d-shapes/3dshapes.h5' + name = '3dshapes' factor_names = ('floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 'orientation') factor_sizes = (10, 10, 10, 8, 4, 15) # TOTAL: 480000 observation_shape = (64, 64, 3) - hdf5_name = 'images' - # minimum chunk size, no compression but good for random accesses - hdf5_chunk_size = (1, 64, 64, 3) - - def __init__(self, data_dir='data/dataset/3dshapes', in_memory=False, force_download=False, force_preprocess=False): - super().__init__(data_dir=data_dir, in_memory=in_memory, force_download=force_download, force_preprocess=force_preprocess) + data_object = DlH5DataObject( + # processed dataset file + file_name='3dshapes.h5', + file_hashes={'fast': 'e3a1a449b95293d4b2c25edbfcb8e804', 'full': 'b5187ee0d8b519bb33281c5ca549658c'}, + # download file/link + uri='https://storage.googleapis.com/3d-shapes/3dshapes.h5', + uri_hashes={'fast': '85b20ed7cc8dc1f939f7031698d2d2ab', 'full': '099a2078d58cec4daad0702c55d06868'}, + # hash settings + hash_mode='fast', + hash_type='md5', + # h5 re-save settings + hdf5_dataset_name='images', + hdf5_chunk_size=(1, 64, 64, 3), + hdf5_compression='gzip', + hdf5_compression_lvl=4, + ) # ========================================================================= # # END # # ========================================================================= # - - -# if __name__ == '__main__': - # dataset = RandomDataset(Shapes3dData()) - # dataloader = DataLoader(dataset, num_workers=os.cpu_count(), batch_size=256) - # - # for batch in tqdm(dataloader): - # pass - - # # test that dimensions are resampled correctly, and only differ by a certain number of factors, not all. - # for i in range(10): - # idx = np.random.randint(len(dataset)) - # a, b = pair_dataset.sample_pair_factors(idx) - # print(all(dataset.idx_to_pos(idx) == a), '|', a, '&', b, ':', [int(v) for v in (a == b)]) - # a, b = dataset.pos_to_idx([a, b]) - # print(a, b) - # dataset[a], dataset[b] diff --git a/disent/data/groundtruth/base.py b/disent/data/groundtruth/base.py index 9892e69f..11f80c88 100644 --- a/disent/data/groundtruth/base.py +++ b/disent/data/groundtruth/base.py @@ -22,13 +22,22 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import dataclasses import logging import os from abc import ABCMeta -from typing import List, Tuple +from typing import Any +from typing import Dict +from typing import Optional +from typing import Sequence +from typing import Tuple + import h5py -from disent.data.util.in_out import basename_from_url, download_file, ensure_dir_exists +from disent.data.util.hdf5 import hdf5_resave_file +from disent.data.util.in_out import ensure_dir_exists +from disent.data.util.in_out import retrieve_file +from disent.data.util.jobs import CachedJobFile from disent.data.util.state_space import StateSpace @@ -46,6 +55,13 @@ def __init__(self): assert len(self.factor_names) == len(self.factor_sizes), 'Dimensionality mismatch of FACTOR_NAMES and FACTOR_DIMS' super().__init__(factor_sizes=self.factor_sizes) + @property + def name(self): + name = self.__class__.__name__ + if name.endswith('Data'): + name = name[:-len('Data')] + return name.lower() + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Overrides # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @@ -76,143 +92,59 @@ def __getitem__(self, idx): # ========================================================================= # -# dataset helpers # +# disk ground truth data # # ========================================================================= # -class DownloadableGroundTruthData(GroundTruthData, metaclass=ABCMeta): +class DiskGroundTruthData(GroundTruthData, metaclass=ABCMeta): - def __init__(self, data_dir='data/dataset', force_download=False): + def __init__(self, data_root: Optional[str] = None, prepare: bool = False): super().__init__() - # paths - self._data_dir = ensure_dir_exists(data_dir) - self._data_paths = [os.path.join(self._data_dir, basename_from_url(url)) for url in self.dataset_urls] - # meta - self._force_download = force_download - # DOWNLOAD - self._do_download_dataset() - - def _do_download_dataset(self): - for path, url in zip(self.dataset_paths, self.dataset_urls): - no_data = not os.path.exists(path) - # download data - if self._force_download or no_data: - download_file(url, path, overwrite_existing=True) - - @property - def dataset_paths(self) -> List[str]: - """path that the data should be loaded from in the child class""" - return self._data_paths - - @property - def dataset_urls(self) -> List[str]: - raise NotImplementedError() - - -class PreprocessedDownloadableGroundTruthData(DownloadableGroundTruthData, metaclass=ABCMeta): - - def __init__(self, data_dir='data/dataset', force_download=False, force_preprocess=False): - super().__init__(data_dir=data_dir, force_download=force_download) - # paths - self._proc_path = f'{self._data_path}.processed' - self._force_preprocess = force_preprocess - # PROCESS - self._do_download_and_process_dataset() - - def _do_download_dataset(self): - # we skip this in favour of our new method, - # so that we can lazily download the data. - pass - - def _do_download_and_process_dataset(self): - no_data = not os.path.exists(self._data_path) - no_proc = not os.path.exists(self._proc_path) - - # preprocess only if required - do_proc = self._force_preprocess or no_proc - # lazily download if required for preprocessing - do_data = self._force_download or (no_data and do_proc) - - if do_data: - download_file(self.dataset_url, self._data_path, overwrite_existing=True) - - if do_proc: - # TODO: also used in io save file, convert to with syntax. - # save to a temporary location in case there is an error, we then know one occured. - path_dir, path_base = os.path.split(self._proc_path) - ensure_dir_exists(path_dir) - temp_proc_path = os.path.join(path_dir, f'.{path_base}.temp') - - # process stuff - self._preprocess_dataset(path_src=self._data_path, path_dst=temp_proc_path) - - # delete existing file if needed - if os.path.isfile(self._proc_path): - os.remove(self._proc_path) - # move processed file to correct place - os.rename(temp_proc_path, self._proc_path) - - assert os.path.exists(self._proc_path), f'Overridden _preprocess_dataset method did not initialise the required dataset file: dataset_path="{self._proc_path}"' - - @property - def _data_path(self): - assert len(self.dataset_paths) == 1 - return self.dataset_paths[0] - - @property - def dataset_urls(self): - return [self.dataset_url] - - @property - def dataset_url(self): - raise NotImplementedError() + # get root data folder + if data_root is None: + data_root = os.path.abspath(os.environ.get('DISENT_DATA_ROOT', 'data/dataset')) + else: + data_root = os.path.abspath(data_root) + # get class data folder + self._data_dir = ensure_dir_exists(os.path.join(data_root, self.name)) + log.info(f'{self.name}: data_dir_share={repr(self._data_dir)}') + # prepare everything + if prepare: + for data_object in self.data_objects: + data_object.prepare(self.data_dir) @property - def dataset_path(self): - """path that the dataset should be loaded from in the child class""" - return self._proc_path + def data_dir(self): + return self._data_dir @property - def dataset_path_unprocessed(self): - return self._data_path - - def _preprocess_dataset(self, path_src, path_dst): - raise NotImplementedError() - - -class Hdf5PreprocessedGroundTruthData(PreprocessedDownloadableGroundTruthData, metaclass=ABCMeta): - """ - Automatically download and pre-process an hdf5 dataset - into the specific chunk sizes. + def data_objects(self) -> Sequence['DataObject']: + raise NotImplementedError - Often the (non-chunked) dataset will be optimized for random accesses, - while the unprocessed (chunked) dataset will be better for sequential reads. - - The chunk size specifies the region of data to be loaded when accessing a - single element of the dataset, if the chunk size is not correctly set, - unneeded data will be loaded when accessing observations. - - override `hdf5_chunk_size` to set the chunk size, for random access - optimized data this should be set to the minimum observation shape that can - be broadcast across the shape of the dataset. Eg. with observations of shape - (64, 64, 3), set the chunk size to (1, 64, 64, 3). - TODO: Only supports one dataset from the hdf5 file - itself, labels etc need a custom implementation. - """ +class Hdf5GroundTruthData(DiskGroundTruthData, metaclass=ABCMeta): - def __init__(self, data_dir='data/dataset', in_memory=False, force_download=False, force_preprocess=False): - super().__init__(data_dir=data_dir, force_download=force_download, force_preprocess=force_preprocess) + def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_memory=False): + super().__init__(data_root=data_root, prepare=prepare) + # variables + self._h5_path = self.data_object.get_file_path(self.data_dir) + self._h5_dataset_name = self.data_object.hdf5_dataset_name self._in_memory = in_memory - # Load the entire dataset into memory if required if self._in_memory: - with h5py.File(self.dataset_path, 'r', libver='latest', swmr=True) as db: + with h5py.File(self._h5_path, 'r', libver='latest', swmr=True) as db: # indexing dataset objects returns numpy array # instantiating np.array from the dataset requires double memory. - self._memory_data = db[self.hdf5_name][:] + self._memory_data = db[self._h5_dataset_name][:] else: - # is this thread safe? - self._hdf5_file = h5py.File(self.dataset_path, 'r', libver='latest', swmr=True) - self._hdf5_data = self._hdf5_file[self.hdf5_name] + self._hdf5_file, self._hdf5_data = self._make_hdf5() + + def _make_hdf5(self): + # is this thread safe? + # TODO: this may cause a memory leak, it is never closed? + hdf5_file = h5py.File(self._h5_path, 'r', libver='latest', swmr=True) + hdf5_data = hdf5_file[self._h5_dataset_name] + return hdf5_file, hdf5_data def __getitem__(self, idx): if self._in_memory: @@ -220,44 +152,122 @@ def __getitem__(self, idx): else: return self._hdf5_data[idx] - def __del__(self): - # do we need to do this? - if not self._in_memory: - self._hdf5_file.close() - - def _preprocess_dataset(self, path_src, path_dst): - import os - from disent.data.util.hdf5 import hdf5_resave_dataset, hdf5_test_entries_per_second, bytes_to_human - - # resave datasets - with h5py.File(path_src, 'r') as inp_data: - with h5py.File(path_dst, 'w') as out_data: - hdf5_resave_dataset(inp_data, out_data, self.hdf5_name, self.hdf5_chunk_size, self.hdf5_compression, self.hdf5_compression_lvl) - # File Size: - log.info(f'[FILE SIZES] IN: {bytes_to_human(os.path.getsize(path_src))} OUT: {bytes_to_human(os.path.getsize(path_dst))}\n') - # Test Speed: - log.info('[TESTING] Access Speed...') - log.info(f'Random Accesses Per Second: {hdf5_test_entries_per_second(out_data, self.hdf5_name, access_method="random"):.3f}') - @property - def hdf5_compression(self) -> 'str': - return 'gzip' + def data_objects(self) -> Sequence['DlH5DataObject']: + return [self.data_object] @property - def hdf5_compression_lvl(self) -> int: - # default is 4, max of 9 doesnt seem to add much cpu usage on read, but its not worth it data wise? - return 4 + def data_object(self) -> 'DlH5DataObject': + raise NotImplementedError - @property - def hdf5_name(self) -> str: - raise NotImplementedError() + # CUSTOM PICKLE HANDLING -- h5py files are not supported! + # https://docs.python.org/3/library/pickle.html#pickle-state + # https://docs.python.org/3/library/pickle.html#object.__getstate__ + # https://docs.python.org/3/library/pickle.html#object.__setstate__ + # TODO: this might duplicate in-memory stuffs. - @property - def hdf5_chunk_size(self) -> Tuple[int]: - # dramatically affects access speed, but also compression ratio. - raise NotImplementedError() + def __getstate__(self): + state = self.__dict__.copy() + state.pop('_hdf5_file', None) + state.pop('_hdf5_data', None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + if not self._in_memory: + self._hdf5_file, self._hdf5_data = self._make_hdf5() + + +# ========================================================================= # +# data objects # +# ========================================================================= # + + +@dataclasses.dataclass +class DataObject(object): + file_name: str + + def prepare(self, data_dir: str): + pass + + def get_file_path(self, data_dir: str, variant: Optional[str] = None): + suffix = '' if (variant is None) else f'.{variant}' + return os.path.join(data_dir, self.file_name + suffix) + + +@dataclasses.dataclass +class DlDataObject(DataObject): + file_name: str + # download file/link + uri: str + uri_hashes: Dict[str, str] + # hash settings + hash_mode: str + hash_type: str + + def _make_dl_job(self, save_path: str): + return CachedJobFile( + make_file_fn=lambda path: retrieve_file( + src_uri=self.uri, + dst_path=path, + overwrite_existing=True, + ), + path=save_path, + hash=self.uri_hashes[self.hash_mode], + hash_type=self.hash_type, + hash_mode=self.hash_mode, + ) + + def prepare(self, data_dir: str): + dl_job = self._make_dl_job(save_path=self.get_file_path(data_dir=data_dir)) + dl_job.run() + + +@dataclasses.dataclass +class DlH5DataObject(DlDataObject): + file_name: str + file_hashes: Dict[str, str] + # download file/link + uri: str + uri_hashes: Dict[str, str] + # hash settings + hash_mode: str + hash_type: str + # h5 re-save settings + hdf5_dataset_name: str + hdf5_chunk_size: Tuple[int, ...] + hdf5_compression: Optional[str] + hdf5_compression_lvl: Optional[int] + + def _make_h5_job(self, load_path: str, save_path: str): + return CachedJobFile( + make_file_fn=lambda path: hdf5_resave_file( + inp_path=load_path, + out_path=path, + dataset_name=self.hdf5_dataset_name, + chunk_size=self.hdf5_chunk_size, + compression=self.hdf5_compression, + compression_lvl=self.hdf5_compression_lvl, + batch_size=None, + print_mode='minimal', + ), + path=save_path, + hash=self.file_hashes[self.hash_mode], + hash_type=self.hash_type, + hash_mode=self.hash_mode, + ) + + def prepare(self, data_dir: str): + dl_path = self.get_file_path(data_dir=data_dir, variant='ORIG') + h5_path = self.get_file_path(data_dir=data_dir) + dl_job = self._make_dl_job(save_path=dl_path) + h5_job = self._make_h5_job(load_path=dl_path, save_path=h5_path) + h5_job.set_parent(dl_job).run() # ========================================================================= # # END # # ========================================================================= # + + + diff --git a/disent/data/util/hdf5.py b/disent/data/util/hdf5.py index d5a23490..1782750f 100644 --- a/disent/data/util/hdf5.py +++ b/disent/data/util/hdf5.py @@ -21,15 +21,24 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +""" +Utilities for converting and testing different chunk sizes of hdf5 files +""" +import logging import math -import time +import os + +import h5py import numpy as np from tqdm import tqdm -""" -Utilities for converting and testing different chunk sizes of hdf5 files -""" +from disent.data.util.in_out import AtomicFileContext +from disent.util import iter_chunks +from disent.util import Timer + + +log = logging.getLogger(__name__) # ========================================================================= # @@ -37,6 +46,7 @@ # ========================================================================= # +# TODO: cleanup def bytes_to_human(size_bytes, decimals=3, color=True): if size_bytes == 0: return "0B" @@ -48,6 +58,7 @@ def bytes_to_human(size_bytes, decimals=3, color=True): return f"{s:{4+decimals}.{decimals}f} {name}" +# TODO: cleanup def hdf5_print_entry_data_stats(data, dataset, label='STATISTICS', print_mode='all'): dtype = data[dataset].dtype itemsize = data[dataset].dtype.itemsize @@ -75,6 +86,8 @@ def hdf5_print_entry_data_stats(data, dataset, label='STATISTICS', print_mode='a f'[{label:3s}] entry: {str(list(shape)):18s} ({str(dtype):8s}) \033[93m{bytes_to_human(data_per_entry)}\033[0m chunk: {str(list(chunks)):18s} \033[93m{bytes_to_human(data_per_chunk)}\033[0m chunks per entry: {str(list(chunks_per_dim)):18s} \033[93m{bytes_to_human(read_data_per_entry)}\033[0m (\033[91m{chunks_per_entry}\033[0m)' ) + +# TODO: cleanup def hd5f_print_dataset_info(data, dataset, label='DATASET'): if label: tqdm.write(f'[{label}]: \033[92m{dataset}\033[0m') @@ -86,59 +99,87 @@ def hd5f_print_dataset_info(data, dataset, label='DATASET'): ) -def hdf5_resave_dataset(inp_data, out_data, dataset, chunks=None, compression=None, compression_opts=None, batch_size=None, max_entries=None, dry_run=False, print_mode='minimal'): - # print_dataset_info(inp_data, dataset, label='INPUT') +def hdf5_resave_dataset(inp_h5: h5py.File, out_h5: h5py.File, dataset_name, chunk_size=None, compression=None, compression_lvl=None, batch_size=None, print_mode='minimal'): # create new dataset - out_data.create_dataset( - name=dataset, - shape=inp_data[dataset].shape, - dtype=inp_data[dataset].dtype, - chunks=chunks, + out_h5.create_dataset( + name=dataset_name, + shape=inp_h5[dataset_name].shape, + dtype=inp_h5[dataset_name].dtype, + chunks=chunk_size, compression=compression, - compression_opts=compression_opts + compression_opts=compression_lvl, + # non-deterministic time stamps are added to the file if this is not + # disabled, resulting in different hash sums when the file is re-generated! + # https://github.com/h5py/h5py/issues/225 + # https://stackoverflow.com/questions/16019656 + track_times=False, ) - - hdf5_print_entry_data_stats(inp_data, dataset, label=f'IN', print_mode=print_mode) - hdf5_print_entry_data_stats(out_data, dataset, label=f'OUT', print_mode=print_mode) - tqdm.write('') - - if not dry_run: - # choose chunk size - if batch_size is None: - batch_size = inp_data[dataset].chunks[0] - # batched copy - entries = len(inp_data[dataset]) - with tqdm(total=entries) as progress: - for i in range(0, max_entries if max_entries else entries, batch_size): - out_data[dataset][i:i + batch_size] = inp_data[dataset][i:i + batch_size] - progress.update(batch_size) - tqdm.write('') - - -def hdf5_test_entries_per_second(data, dataset, access_method='random', max_entries=48000, timeout=10): - # num entries to test - n = min(len(data[dataset]), max_entries) - + # print stats + hdf5_print_entry_data_stats(inp_h5, dataset_name, label=f'IN', print_mode=print_mode) + hdf5_print_entry_data_stats(out_h5, dataset_name, label=f'OUT', print_mode=print_mode) + # choose batch size for copying data + if batch_size is None: + batch_size = inp_h5[dataset_name].chunks[0] + log.debug(f're-saving h5 dataset using automatic batch size of: {batch_size}') + # batched copy | loads could be parallelized! + entries = len(inp_h5[dataset_name]) + with tqdm(total=entries) as progress: + for i in range(0, entries, batch_size): + out_h5[dataset_name][i:i + batch_size] = inp_h5[dataset_name][i:i + batch_size] + progress.update(batch_size) + + +def hdf5_resave_file(inp_path: str, out_path: str, dataset_name, chunk_size=None, compression=None, compression_lvl=None, batch_size=None, print_mode='minimal'): + # re-save datasets + with h5py.File(inp_path, 'r') as inp_h5: + with AtomicFileContext(out_path, open_mode=None, overwrite=True) as tmp_h5_path: + with h5py.File(tmp_h5_path, 'w') as out_h5: + hdf5_resave_dataset( + inp_h5=inp_h5, + out_h5=out_h5, + dataset_name=dataset_name, + chunk_size=chunk_size, + compression=compression, + compression_lvl=compression_lvl, + batch_size=batch_size, + print_mode=print_mode, + ) + # file size: + log.info(f'[FILE SIZES] IN: {bytes_to_human(os.path.getsize(inp_path))} OUT: {bytes_to_human(os.path.getsize(out_path))}') + + +def hdf5_test_speed(h5_path: str, dataset_name: str, access_method: str = 'random'): + with h5py.File(h5_path, 'r') as out_h5: + log.info('[TESTING] Access Speed...') + log.info(f'Random Accesses Per Second: {hdf5_test_entries_per_second(out_h5, dataset_name, access_method=access_method, max_entries=5_000):.3f}') + + +def hdf5_test_entries_per_second(h5_data: h5py.File, dataset_name, access_method='random', max_entries=48000, timeout=10, batch_size: int = 256): + data = h5_data[dataset_name] # get access method if access_method == 'sequential': - indices = np.arange(n) + indices = np.arange(len(data)) elif access_method == 'random': - indices = np.arange(n) + indices = np.arange(len(data)) np.random.shuffle(indices) else: raise KeyError('Invalid access method') - + # num entries to test + n = min(len(data), max_entries) + indices = indices[:n] # iterate through dataset, exit on timeout or max_entries - start_time = time.time() - for i, idx in enumerate(indices): - entry = data[dataset][idx] - if time.time() - start_time > timeout or i >= max_entries: + t = Timer() + for chunk in iter_chunks(enumerate(indices), chunk_size=batch_size): + with t: + for i, idx in chunk: + entry = data[idx] + if t.elapsed > timeout: break - # calculate score - entries_per_sec = (i + 1) / (time.time() - start_time) + entries_per_sec = (i + 1) / t.elapsed return entries_per_sec + # ========================================================================= # # END # # ========================================================================= # diff --git a/disent/data/util/in_out.py b/disent/data/util/in_out.py index 9708bd40..28bbf0cb 100644 --- a/disent/data/util/in_out.py +++ b/disent/data/util/in_out.py @@ -22,14 +22,175 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import os import logging -import warnings +from typing import Optional +from typing import Tuple log = logging.getLogger(__name__) + +# ========================================================================= # +# file hashing # +# ========================================================================= # + + +def yield_file_bytes(file: str, chunk_size=16384): + with open(file, 'rb') as f: + bytes = True + while bytes: + bytes = f.read(chunk_size) + yield bytes + + +def yield_fast_hash_bytes(file: str, chunk_size=16384, num_chunks=3): + assert num_chunks >= 2 + # return the size in bytes + size = os.path.getsize(file) + yield size.to_bytes(length=64//8, byteorder='big', signed=False) + # return file bytes chunks + if size < chunk_size * num_chunks: + # we cant return chunks because the file is too small, return everything! + yield from yield_file_bytes(file, chunk_size=chunk_size) + else: + # includes evenly spaced start, middle and end chunks + with open(file, 'rb') as f: + for i in range(num_chunks): + pos = (i * (size - chunk_size)) // (num_chunks - 1) + f.seek(pos) + yield f.read(chunk_size) + + +def hash_file(file: str, hash_type='md5', hash_mode='full') -> str: + """ + :param file: the path to the file + :param hash_type: the kind of hash to compute, default is "md5" + :param hash_mode: "full" uses all the bytes in the file to compute the hash, "fast" uses the start, middle, end bytes as well as the size of the file in the hash. + :param chunk_size: number of bytes to read at a time + :return: the hexdigest of the hash + """ + import hashlib + # get file bytes iterator + if hash_mode == 'full': + byte_iter = yield_file_bytes(file=file) + elif hash_mode == 'fast': + byte_iter = yield_fast_hash_bytes(file=file) + else: + raise KeyError(f'invalid hash_mode: {repr(hash_mode)}') + # generate hash + hash = hashlib.new(hash_type) + for bytes in byte_iter: + hash.update(bytes) + hash = hash.hexdigest() + # done + return hash + + +class HashError(Exception): + """ + Raised if the hash of a file was invalid. + """ + + +def validate_file_hash(file: str, hash: str, hash_type='md5', hash_mode='full'): + fhash = hash_file(file=file, hash_type=hash_type, hash_mode=hash_mode) + if fhash != hash: + raise HashError(f'computed {hash_mode} {hash_type} hash: {repr(fhash)} does not match expected hash: {repr(hash)} for file: {repr(file)}') + + +# ========================================================================= # +# Atomic file saving # +# ========================================================================= # + + +class AtomicFileContext(object): + """ + Within the context, data must be written to a temporary file. + Once data has been successfully written, the temporary file + is moved to the location of the given file. + + ``` + with AtomicFileHandler('file.txt') as tmp_file: + with open(tmp_file, 'w') as f: + f.write("hello world!\n") + ``` + """ + + def __init__(self, file: str, open_mode: Optional[str] = None, overwrite: bool = False, makedirs: bool = True, tmp_file: Optional[str] = None): + from pathlib import Path + # check files + if not file: + raise ValueError(f'file must not be empty: {repr(file)}') + if not tmp_file and (tmp_file is not None): + raise ValueError(f'tmp_file must not be empty: {repr(tmp_file)}') + # get files + self.trg_file = Path(file).absolute() + self.tmp_file = Path(f'{self.trg_file}.TEMP' if (tmp_file is None) else tmp_file) + # check that the files are different + if self.trg_file == self.tmp_file: + raise ValueError(f'temporary and target files are the same: {self.tmp_file} == {self.trg_file}') + # other settings + self._makedirs = makedirs + self._overwrite = overwrite + self._open_mode = open_mode + self._resource = None + + def __enter__(self): + # check files exist or not + if self.tmp_file.exists(): + if not self.tmp_file.is_file(): + raise FileExistsError(f'the temporary file exists but is not a file: {self.tmp_file}') + if self.trg_file.exists(): + if not self._overwrite: + raise FileExistsError('the target file already exists, set overwrite=True to ignore this error.') + if not self.trg_file.is_file(): + raise FileExistsError(f'the target file exists but is not a file: {self.trg_file}') + # create the missing directories if needed + if self._makedirs: + self.tmp_file.parent.mkdir(parents=True, exist_ok=True) + # delete any existing temporary files + if self.tmp_file.exists(): + log.debug(f'deleting existing temporary file: {self.tmp_file}') + self.tmp_file.unlink() + # handle the different modes, deleting any existing tmp files + if self._open_mode is not None: + log.debug(f'created new temporary file: {self.tmp_file}') + self._resource = open(self.tmp_file, self._open_mode) + return str(self.tmp_file), self._resource + else: + return str(self.tmp_file) + + def __exit__(self, error_type, error, traceback): + # close the temp file + if self._resource is not None: + self._resource.close() + self._resource = None + # cleanup if there was an error, and exit early + if error_type is not None: + if self.tmp_file.exists(): + self.tmp_file.unlink(missing_ok=True) + log.error(f'An error occurred in {self.__class__.__name__}, cleaned up temporary file: {self.tmp_file}') + else: + log.error(f'An error occurred in {self.__class__.__name__}') + return + # the temp file must have been created! + if not self.tmp_file.exists(): + raise FileNotFoundError(f'the temporary file was not created: {self.tmp_file}') + # delete the target file if it exists and overwrite is enabled: + if self._overwrite: + log.warning(f'overwriting file: {self.trg_file}') + self.trg_file.unlink(missing_ok=True) + # create the missing directories if needed + if self._makedirs: + self.trg_file.parent.mkdir(parents=True, exist_ok=True) + # move the temp file to the target file + log.info(f'moved temporary file to final location: {self.tmp_file} -> {self.trg_file}') + os.rename(self.tmp_file, self.trg_file) + + # ========================================================================= # -# io # +# files/dirs exist # # ========================================================================= # @@ -57,61 +218,80 @@ def ensure_parent_dir_exists(*path): return ensure_dir_exists(*path, is_file=True, absolute=True) -def basename_from_url(url): - import os - from urllib.parse import urlparse - return os.path.basename(urlparse(url).path) +# ========================================================================= # +# files/dirs exist # +# ========================================================================= # -def download_file(url, save_path=None, overwrite_existing=False, chunk_size=4096): +def download_file(url: str, save_path: str, overwrite_existing: bool = False, chunk_size: int = 16384): import requests - import os from tqdm import tqdm - - if save_path is None: - save_path = basename_from_url(url) - log.info(f'inferred save_path="{save_path}"') - - # split path - # TODO: also used in base.py for processing, convert to with syntax. - path_dir, path_base = os.path.split(save_path) - ensure_dir_exists(path_dir) - - if not path_base: - raise Exception('Invalid save path: "{save_path}"') - - # check save path isnt there - if os.path.isfile(save_path): - if overwrite_existing: - warnings.warn(f'Overwriting existing file: "{save_path}"') - else: - raise Exception(f'File already exists: "{save_path}" set overwrite_existing=True to overwrite.') - - # we download to a temporary file in case there is an error - temp_download_path = os.path.join(path_dir, f'.{path_base}.download.temp') - - # open the file for saving - with open(temp_download_path, "wb") as file: + # write the file + with AtomicFileContext(file=save_path, open_mode='wb', overwrite=overwrite_existing) as (_, file): response = requests.get(url, stream=True) total_length = response.headers.get('content-length') - # cast to integer if content-length exists on response if total_length is not None: total_length = int(total_length) - # download with progress bar - with tqdm(total=total_length, desc=f'Downloading "{path_base}"', unit='B', unit_scale=True, unit_divisor=1024) as progress: + log.info(f'Downloading: {url} to: {save_path}') + with tqdm(total=total_length, desc=f'Downloading', unit='B', unit_scale=True, unit_divisor=1024) as progress: for data in response.iter_content(chunk_size=chunk_size): file.write(data) progress.update(chunk_size) - # remove if we can overwrite - if overwrite_existing and os.path.isfile(save_path): - # TODO: is this necessary? - os.remove(save_path) - # rename temp download - os.rename(temp_download_path, save_path) +def copy_file(src: str, dst: str, overwrite_existing: bool = False): + # copy the file + if os.path.abspath(src) == os.path.abspath(dst): + raise FileExistsError(f'input and output paths for copy are the same, skipping: {repr(dst)}') + else: + with AtomicFileContext(file=dst, overwrite=overwrite_existing) as path: + import shutil + shutil.copyfile(src, path) + + +def retrieve_file(src_uri: str, dst_path: str, overwrite_existing: bool = True): + uri, is_url = _uri_parse_file_or_url(src_uri) + if is_url: + download_file(url=uri, save_path=dst_path, overwrite_existing=overwrite_existing) + else: + copy_file(src=uri, dst=dst_path, overwrite_existing=overwrite_existing) + + +# ========================================================================= # +# path utils # +# ========================================================================= # + + +def basename_from_url(url): + import os + from urllib.parse import urlparse + return os.path.basename(urlparse(url).path) + + +def _uri_parse_file_or_url(inp_uri) -> Tuple[str, bool]: + from urllib.parse import urlparse + result = urlparse(inp_uri) + # parse different cases + if result.scheme in ('http', 'https'): + is_url = True + uri = result.geturl() + elif result.scheme in ('file', ''): + is_url = False + if result.scheme == 'file': + if result.netloc: + raise KeyError(f'file uri format is invalid: "{result.geturl()}" two slashes specifies host as: "{result.netloc}" eg. instead of "file://hostname/root_folder/file.txt", please use: "file:/root_folder/file.txt" (no hostname) or "file:///root_folder/file.txt" (empty hostname).') + if not os.path.isabs(result.path): + raise RuntimeError(f'path: {repr(result.path)} obtained from file URI: {repr(inp_uri)} should always be absolute') + uri = result.path + else: + uri = result.geturl() + uri = os.path.abspath(uri) + else: + raise ValueError(f'invalid file or url: {repr(inp_uri)}') + # done + return uri, is_url # ========================================================================= # diff --git a/disent/data/util/jobs.py b/disent/data/util/jobs.py new file mode 100644 index 00000000..594cb5d3 --- /dev/null +++ b/disent/data/util/jobs.py @@ -0,0 +1,156 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +import os +from abc import ABCMeta +from typing import Callable +from typing import NoReturn + +from disent.data.util.in_out import hash_file + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# Base Job # +# ========================================================================= # + + +class CachedJob(object): + """ + Base class for cached jobs. A job is some arbitrary directed chains + of computations where child jobs depend on parent jobs, and jobs + can be skipped if it has already been run and the cache is valid. + + Jobs are always deterministic, and if run and cached should never go out of date. + + NOTE: if needed it would be easy to add support directed acyclic graphs, and sub-graphs + NOTE: this is probably overkill, but it makes the code to write a new dataset nice and clean... + """ + + def __init__(self, job_fn: Callable[[], NoReturn], is_cached_fn: Callable[[], bool]): + self._parent = None + self._child = None + self._job_fn = job_fn + self._is_cached_fn = is_cached_fn + + def __repr__(self): + return f'{self.__class__.__name__}' + + def set_parent(self, parent: 'CachedJob'): + if not isinstance(parent, CachedJob): + raise TypeError(f'{self}: parent job was not an instance of: {CachedJob.__class__}') + if self._parent is not None: + raise RuntimeError(f'{self}: parent has already been set') + if parent._child is not None: + raise RuntimeError(f'{parent}: child has already been set') + self._parent = parent + parent._child = self + return parent + + def set_child(self, child: 'CachedJob'): + child.set_parent(self) + return child + + def run(self, force=False, recursive=False) -> 'CachedJob': + # visit parents always + if recursive: + if self._parent is not None: + self._parent.run(force=force, recursive=recursive) + # skip if fresh + if not force: + if not self._is_cached_fn(): + log.debug(f'{self}: skipped non-stale job') + return self + # don't visit parents if fresh + if not recursive: + if self._parent is not None: + self._parent.run(force=force, recursive=recursive) + # run nodes + log.debug(f'{self}: run stale job') + self._job_fn() + return self + + +# ========================================================================= # +# Base File Job # +# ========================================================================= # + + +class CachedJobFile(CachedJob): + + """ + An abstract cached job that only runs if a file does not exist, + or the files hash sum does not match the given value. + """ + + def __init__( + self, + make_file_fn: Callable[[str], NoReturn], + path: str, + hash: str, + hash_type: str = 'md5', + hash_mode: str = 'full', + ): + # set attributes + self.path = path + self.hash = hash + self.hash_type = hash_type + self.hash_mode = hash_mode + # generate + self._make_file_fn = make_file_fn + # check hash + super().__init__(job_fn=self.__job_fn, is_cached_fn=self.__is_cached_fn) + + def __compute_hash(self) -> str: + return hash_file(self.path, hash_type=self.hash_type, hash_mode=self.hash_mode) + + def __is_cached_fn(self) -> bool: + # stale if the file does not exist + if not os.path.exists(self.path): + log.warning(f'{self}: stale because file does not exist: {repr(self.path)}') + return True + # stale if the hash does not match + fhash = self.__compute_hash() + if self.hash != fhash: + log.warning(f'{self}: stale because computed {self.hash_mode} {self.hash_type} hash: {repr(fhash)} does not match expected hash: {repr(self.hash)} for: {repr(self.path)}') + return True + # not stale, we don't need to do anything! + return False + + def __job_fn(self): + self._make_file_fn(self.path) + # check the hash + fhash = self.__compute_hash() + if self.hash != fhash: + raise RuntimeError(f'{self}: computed {self.hash_mode} {self.hash_type} hash: {repr(fhash)} for newly generated file {repr(self.path)} does not match expected hash: {repr(self.hash)}') + else: + log.debug(f'{self}: successfully generated file: {repr(self.path)} with correct {self.hash_mode} {self.hash_type} hash: {fhash}') + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/__init__.py b/disent/util/__init__.py index 4214dc88..aca606b9 100644 --- a/disent/util/__init__.py +++ b/disent/util/__init__.py @@ -32,6 +32,7 @@ from dataclasses import fields from itertools import islice from pprint import pformat +from random import random from typing import List import numpy as np @@ -340,10 +341,43 @@ def __getitem__(self, item): class Timer: - def __init__(self, print_name=None, log_level=logging.INFO): + + """ + Timer class, can be used with a with statement to + measure the execution time of a block of code! + + Examples: + + 1. get the runtime + ``` + with Timer() as t: + time.sleep(1) + print(t.pretty) + ``` + + 2. automatic print + ``` + with Timer(name="example") as t: + time.sleep(1) + ``` + + 3. reuse timer to measure multiple instances + ``` + t = Timer() + for i in range(100): + with t: + time.sleep(0.95) + if t.elapsed > 3: + break + print(t) + ``` + """ + + def __init__(self, name: str = None, log_level: int = logging.INFO): self._start_time: int = None self._end_time: int = None - self._print_name = print_name + self._total_time = 0 + self.name = name self._log_level = log_level def __enter__(self): @@ -352,19 +386,24 @@ def __enter__(self): def __exit__(self, *args, **kwargs): self._end_time = time.time_ns() - if self._print_name: + # add elapsed time to total time, and reset the timer! + self._total_time += (self._end_time - self._start_time) + self._start_time = None + self._end_time = None + # print results + if self.name: if self._log_level is None: - print(f'{self._print_name}: {self.pretty}') + print(f'{self.name}: {self.pretty}') else: - log.log(self._log_level, f'{self._print_name}: {self.pretty}') + log.log(self._log_level, f'{self.name}: {self.pretty}') @property def elapsed_ns(self) -> int: - if self._start_time is None: - return 0 - if self._end_time is None: - return time.time_ns() - self._start_time - return self._end_time - self._start_time + if self._start_time is not None: + # running + return self._total_time + (time.time_ns() - self._start_time) + # finished running + return self._total_time @property def elapsed_ms(self) -> float: diff --git a/experiment/exp/util/_tasks.py b/experiment/exp/util/_tasks.py index fe7ac624..0794e60e 100644 --- a/experiment/exp/util/_tasks.py +++ b/experiment/exp/util/_tasks.py @@ -34,6 +34,7 @@ # ========================================================================= # # Task Builders # +# - I know this is overkill... but i was having fun... # # ========================================================================= # From 782ff850bcb07eb0f1467723e81a7f9c9cba807f Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 31 May 2021 00:13:13 +0200 Subject: [PATCH 08/34] pickle hdf5 dataset + hdf5 util cleanup --- disent/data/groundtruth/base.py | 62 +++++-------- disent/data/util/hdf5.py | 160 ++++++++++++++++++++------------ disent/data/util/in_out.py | 26 +++++- 3 files changed, 149 insertions(+), 99 deletions(-) diff --git a/disent/data/groundtruth/base.py b/disent/data/groundtruth/base.py index 11f80c88..1c6e92ce 100644 --- a/disent/data/groundtruth/base.py +++ b/disent/data/groundtruth/base.py @@ -26,15 +26,16 @@ import logging import os from abc import ABCMeta -from typing import Any +from typing import Callable from typing import Dict from typing import Optional from typing import Sequence from typing import Tuple -import h5py +import numpy as np from disent.data.util.hdf5 import hdf5_resave_file +from disent.data.util.hdf5 import PickleH5pyDataset from disent.data.util.in_out import ensure_dir_exists from disent.data.util.in_out import retrieve_file from disent.data.util.jobs import CachedJobFile @@ -127,30 +128,25 @@ class Hdf5GroundTruthData(DiskGroundTruthData, metaclass=ABCMeta): def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_memory=False): super().__init__(data_root=data_root, prepare=prepare) # variables - self._h5_path = self.data_object.get_file_path(self.data_dir) - self._h5_dataset_name = self.data_object.hdf5_dataset_name self._in_memory = in_memory - # Load the entire dataset into memory if required + # load the h5py dataset + data = PickleH5pyDataset( + h5_path=self.data_object.get_file_path(self.data_dir), + h5_dataset_name=self.data_object.hdf5_dataset_name, + ) + # handle different memroy modes if self._in_memory: - with h5py.File(self._h5_path, 'r', libver='latest', swmr=True) as db: - # indexing dataset objects returns numpy array - # instantiating np.array from the dataset requires double memory. - self._memory_data = db[self._h5_dataset_name][:] + # Load the entire dataset into memory if required + # indexing dataset objects returns numpy array + # instantiating np.array from the dataset requires double memory. + self._data = data[:] + data.close() else: - self._hdf5_file, self._hdf5_data = self._make_hdf5() - - def _make_hdf5(self): - # is this thread safe? - # TODO: this may cause a memory leak, it is never closed? - hdf5_file = h5py.File(self._h5_path, 'r', libver='latest', swmr=True) - hdf5_data = hdf5_file[self._h5_dataset_name] - return hdf5_file, hdf5_data + # Load the dataset from the disk + self._data = data def __getitem__(self, idx): - if self._in_memory: - return self._memory_data[idx] - else: - return self._hdf5_data[idx] + return self._data[idx] @property def data_objects(self) -> Sequence['DlH5DataObject']: @@ -160,23 +156,6 @@ def data_objects(self) -> Sequence['DlH5DataObject']: def data_object(self) -> 'DlH5DataObject': raise NotImplementedError - # CUSTOM PICKLE HANDLING -- h5py files are not supported! - # https://docs.python.org/3/library/pickle.html#pickle-state - # https://docs.python.org/3/library/pickle.html#object.__getstate__ - # https://docs.python.org/3/library/pickle.html#object.__setstate__ - # TODO: this might duplicate in-memory stuffs. - - def __getstate__(self): - state = self.__dict__.copy() - state.pop('_hdf5_file', None) - state.pop('_hdf5_data', None) - return state - - def __setstate__(self, state): - self.__dict__.update(state) - if not self._in_memory: - self._hdf5_file, self._hdf5_data = self._make_hdf5() - # ========================================================================= # # data objects # @@ -238,6 +217,8 @@ class DlH5DataObject(DlDataObject): hdf5_chunk_size: Tuple[int, ...] hdf5_compression: Optional[str] hdf5_compression_lvl: Optional[int] + hdf5_dtype: np.dtype = None + hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None def _make_h5_job(self, load_path: str, save_path: str): return CachedJobFile( @@ -249,7 +230,8 @@ def _make_h5_job(self, load_path: str, save_path: str): compression=self.hdf5_compression, compression_lvl=self.hdf5_compression_lvl, batch_size=None, - print_mode='minimal', + out_dtype=self.hdf5_dtype, + out_mutator=self.hdf5_mutator, ), path=save_path, hash=self.file_hashes[self.hash_mode], @@ -262,7 +244,7 @@ def prepare(self, data_dir: str): h5_path = self.get_file_path(data_dir=data_dir) dl_job = self._make_dl_job(save_path=dl_path) h5_job = self._make_h5_job(load_path=dl_path, save_path=h5_path) - h5_job.set_parent(dl_job).run() + dl_job.set_child(h5_job).run() # ========================================================================= # diff --git a/disent/data/util/hdf5.py b/disent/data/util/hdf5.py index 1782750f..941b38d3 100644 --- a/disent/data/util/hdf5.py +++ b/disent/data/util/hdf5.py @@ -26,7 +26,6 @@ """ import logging -import math import os import h5py @@ -34,7 +33,10 @@ from tqdm import tqdm from disent.data.util.in_out import AtomicFileContext +from disent.data.util.in_out import bytes_to_human +from disent.util import colors as c from disent.util import iter_chunks +from disent.util import LengthIter from disent.util import Timer @@ -42,69 +44,102 @@ # ========================================================================= # -# hdf5 # +# hdf5 pickle dataset # # ========================================================================= # -# TODO: cleanup -def bytes_to_human(size_bytes, decimals=3, color=True): - if size_bytes == 0: - return "0B" - size_name = ("B ", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB") - size_color = (None, 92, 93, 91, 91, 91, 91, 91, 91) - i = int(math.floor(math.log(size_bytes, 1024))) - s = round(size_bytes / math.pow(1024, i), decimals) - name = f'\033[{size_color[i]}m{size_name[i]}\033[0m' if color else size_name[i] - return f"{s:{4+decimals}.{decimals}f} {name}" +class PickleH5pyDataset(LengthIter): + """ + This class supports pickling and unpickling of a read-only + SWMR h5py file and corresponding dataset. + """ + + def __init__(self, h5_path: str, h5_dataset_name: str): + self._h5_path = h5_path + self._h5_dataset_name = h5_dataset_name + self._hdf5_file, self._hdf5_data = self._make_hdf5() + + def _make_hdf5(self): + # TODO: can this cause a memory leak if it is never closed? + hdf5_file = h5py.File(self._h5_path, 'r', libver='latest', swmr=True) + hdf5_data = hdf5_file[self._h5_dataset_name] + return hdf5_file, hdf5_data + + def __len__(self): + return self._hdf5_data.shape[0] + + def __getitem__(self, item): + return self._hdf5_data[item] + + def __enter__(self): + return self + + def __exit__(self, error_type, error, traceback): + self.close() + + # CUSTOM PICKLE HANDLING -- h5py files are not supported! + # https://docs.python.org/3/library/pickle.html#pickle-state + # https://docs.python.org/3/library/pickle.html#object.__getstate__ + # https://docs.python.org/3/library/pickle.html#object.__setstate__ + # TODO: this might duplicate in-memory stuffs. + + def __getstate__(self): + state = self.__dict__.copy() + state.pop('_hdf5_file', None) + state.pop('_hdf5_data', None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._hdf5_file, self._hdf5_data = self._make_hdf5() + + def close(self): + self._hdf5_file.close() + del self._hdf5_file + del self._hdf5_data + + +# ========================================================================= # +# hdf5 # +# ========================================================================= # # TODO: cleanup -def hdf5_print_entry_data_stats(data, dataset, label='STATISTICS', print_mode='all'): - dtype = data[dataset].dtype - itemsize = data[dataset].dtype.itemsize +def hdf5_print_entry_data_stats(h5_dataset: h5py.Dataset, label='STATISTICS'): + dtype = h5_dataset.dtype + itemsize = h5_dataset.dtype.itemsize # chunk - chunks = np.array(data[dataset].chunks) + chunks = np.array(h5_dataset.chunks) data_per_chunk = np.prod(chunks) * itemsize # entry - shape = np.array([1, *data[dataset].shape[1:]]) + shape = np.array([1, *h5_dataset.shape[1:]]) data_per_entry = np.prod(shape) * itemsize # chunks per entry chunks_per_dim = np.ceil(shape / chunks).astype('int') chunks_per_entry = np.prod(chunks_per_dim) read_data_per_entry = data_per_chunk * chunks_per_entry # print info - if print_mode == 'all': - if label: - tqdm.write(f'[{label}]: \033[92m{dataset}\033[0m') - tqdm.write( - f'\t\033[90mentry shape:\033[0m {str(list(shape)):18s} \033[93m{bytes_to_human(data_per_entry)}\033[0m' - f'\n\t\033[90mchunk shape:\033[0m {str(list(chunks)):18s} \033[93m{bytes_to_human(data_per_chunk)}\033[0m' - f'\n\t\033[90mchunks per entry:\033[0m {str(list(chunks_per_dim)):18s} \033[93m{bytes_to_human(read_data_per_entry)}\033[0m (\033[91m{chunks_per_entry}\033[0m)' - ) - elif print_mode == 'minimal': - tqdm.write( - f'[{label:3s}] entry: {str(list(shape)):18s} ({str(dtype):8s}) \033[93m{bytes_to_human(data_per_entry)}\033[0m chunk: {str(list(chunks)):18s} \033[93m{bytes_to_human(data_per_chunk)}\033[0m chunks per entry: {str(list(chunks_per_dim)):18s} \033[93m{bytes_to_human(read_data_per_entry)}\033[0m (\033[91m{chunks_per_entry}\033[0m)' - ) - - -# TODO: cleanup -def hd5f_print_dataset_info(data, dataset, label='DATASET'): - if label: - tqdm.write(f'[{label}]: \033[92m{dataset}\033[0m') tqdm.write( - f'\t\033[90mraw:\033[0m {data[dataset]}' - f'\n\t\033[90mchunks:\033[0m {data[dataset].chunks}' - f'\n\t\033[90mcompression:\033[0m {data[dataset].compression}' - f'\n\t\033[90mcompression lvl:\033[0m {data[dataset].compression_opts}' + f'[{label:3s}] ' + f'entry: {str(list(shape)):18s} ({str(dtype):8s}) {c.lYLW}{bytes_to_human(data_per_entry)}{c.RST} ' + f'chunk: {str(list(chunks)):18s} {c.YLW}{bytes_to_human(data_per_chunk)}{c.RST} ' + f'chunks per entry: {str(list(chunks_per_dim)):18s} {c.YLW}{bytes_to_human(read_data_per_entry)}{c.RST} ({c.RED}{chunks_per_entry:5d}{c.RST}) | ' + f'compression: {repr(h5_dataset.compression)} compression lvl: {repr(h5_dataset.compression_opts)}' ) -def hdf5_resave_dataset(inp_h5: h5py.File, out_h5: h5py.File, dataset_name, chunk_size=None, compression=None, compression_lvl=None, batch_size=None, print_mode='minimal'): +# ========================================================================= # +# hdf5 - resave # +# ========================================================================= # + + +def hdf5_resave_dataset(inp_h5: h5py.File, out_h5: h5py.File, dataset_name, chunk_size=None, compression=None, compression_lvl=None, batch_size=None, out_dtype=None, out_mutator=None): # create new dataset - out_h5.create_dataset( + inp_data = inp_h5[dataset_name] + out_data = out_h5.create_dataset( name=dataset_name, - shape=inp_h5[dataset_name].shape, - dtype=inp_h5[dataset_name].dtype, + shape=inp_data.shape, + dtype=out_dtype if (out_dtype is not None) else inp_data.dtype, chunks=chunk_size, compression=compression, compression_opts=compression_lvl, @@ -115,25 +150,28 @@ def hdf5_resave_dataset(inp_h5: h5py.File, out_h5: h5py.File, dataset_name, chun track_times=False, ) # print stats - hdf5_print_entry_data_stats(inp_h5, dataset_name, label=f'IN', print_mode=print_mode) - hdf5_print_entry_data_stats(out_h5, dataset_name, label=f'OUT', print_mode=print_mode) + tqdm.write('') + hdf5_print_entry_data_stats(inp_data, label=f'IN') + hdf5_print_entry_data_stats(out_data, label=f'OUT') # choose batch size for copying data if batch_size is None: - batch_size = inp_h5[dataset_name].chunks[0] + batch_size = inp_data.chunks[0] log.debug(f're-saving h5 dataset using automatic batch size of: {batch_size}') - # batched copy | loads could be parallelized! - entries = len(inp_h5[dataset_name]) - with tqdm(total=entries) as progress: - for i in range(0, entries, batch_size): - out_h5[dataset_name][i:i + batch_size] = inp_h5[dataset_name][i:i + batch_size] + # get default + if out_mutator is None: + out_mutator = lambda x: x + # save data + with tqdm(total=len(inp_data)) as progress: + for i in range(0, len(inp_data), batch_size): + out_data[i:i + batch_size] = out_mutator(inp_data[i:i + batch_size]) progress.update(batch_size) -def hdf5_resave_file(inp_path: str, out_path: str, dataset_name, chunk_size=None, compression=None, compression_lvl=None, batch_size=None, print_mode='minimal'): +def hdf5_resave_file(inp_path: str, out_path: str, dataset_name, chunk_size=None, compression=None, compression_lvl=None, batch_size=None, out_dtype=None, out_mutator=None): # re-save datasets with h5py.File(inp_path, 'r') as inp_h5: with AtomicFileContext(out_path, open_mode=None, overwrite=True) as tmp_h5_path: - with h5py.File(tmp_h5_path, 'w') as out_h5: + with h5py.File(tmp_h5_path, 'w', libver='latest') as out_h5: hdf5_resave_dataset( inp_h5=inp_h5, out_h5=out_h5, @@ -142,16 +180,16 @@ def hdf5_resave_file(inp_path: str, out_path: str, dataset_name, chunk_size=None compression=compression, compression_lvl=compression_lvl, batch_size=batch_size, - print_mode=print_mode, + out_dtype=out_dtype, + out_mutator=out_mutator, ) # file size: log.info(f'[FILE SIZES] IN: {bytes_to_human(os.path.getsize(inp_path))} OUT: {bytes_to_human(os.path.getsize(out_path))}') -def hdf5_test_speed(h5_path: str, dataset_name: str, access_method: str = 'random'): - with h5py.File(h5_path, 'r') as out_h5: - log.info('[TESTING] Access Speed...') - log.info(f'Random Accesses Per Second: {hdf5_test_entries_per_second(out_h5, dataset_name, access_method=access_method, max_entries=5_000):.3f}') +# ========================================================================= # +# hdf5 - speed tests # +# ========================================================================= # def hdf5_test_entries_per_second(h5_data: h5py.File, dataset_name, access_method='random', max_entries=48000, timeout=10, batch_size: int = 256): @@ -180,6 +218,12 @@ def hdf5_test_entries_per_second(h5_data: h5py.File, dataset_name, access_method return entries_per_sec +def hdf5_test_speed(h5_path: str, dataset_name: str, access_method: str = 'random'): + with h5py.File(h5_path, 'r') as out_h5: + log.info('[TESTING] Access Speed...') + log.info(f'Random Accesses Per Second: {hdf5_test_entries_per_second(out_h5, dataset_name, access_method=access_method, max_entries=5_000):.3f}') + + # ========================================================================= # # END # # ========================================================================= # diff --git a/disent/data/util/in_out.py b/disent/data/util/in_out.py index 28bbf0cb..eb480bef 100644 --- a/disent/data/util/in_out.py +++ b/disent/data/util/in_out.py @@ -22,15 +22,39 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import os import logging +import math +import os from typing import Optional from typing import Tuple +from disent.util import colors as c + log = logging.getLogger(__name__) +# ========================================================================= # +# Formatting # +# ========================================================================= # + + +_BYTES_POW_NAME = ("B ", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB") +_BYTES_POW_COLR = (c.WHT, c.lGRN, c.lYLW, c.lRED, c.lRED, c.lRED, c.lRED, c.lRED, c.lRED) + + +def bytes_to_human(size_bytes, decimals=3, color=True): + if size_bytes == 0: + return "0B" + # round correctly + i = int(math.floor(math.log(size_bytes, 1024))) + s = round(size_bytes / math.pow(1024, i), decimals) + # generate string + name = f'{_BYTES_POW_COLR[i]}{_BYTES_POW_NAME[i]}{c.RST}' if color else f'{_BYTES_POW_NAME[i]}' + # format string + return f"{s:{4+decimals}.{decimals}f} {name}" + + # ========================================================================= # # file hashing # # ========================================================================= # From b033345be2d7f86af0a7e294c020ec747c28e862 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 31 May 2021 00:49:05 +0200 Subject: [PATCH 09/34] fix hdf5 determinism for saving + dsprites --- disent/data/groundtruth/_dsprites.py | 50 ++++++++++++------- disent/data/groundtruth/base.py | 3 +- disent/data/util/hdf5.py | 7 ++- disent/data/util/jobs.py | 2 +- .../run_04_gen_adversarial.py | 8 +-- 5 files changed, 44 insertions(+), 26 deletions(-) diff --git a/disent/data/groundtruth/_dsprites.py b/disent/data/groundtruth/_dsprites.py index 8f6f6093..6ce9d828 100644 --- a/disent/data/groundtruth/_dsprites.py +++ b/disent/data/groundtruth/_dsprites.py @@ -22,7 +22,8 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -from disent.data.groundtruth.base import Hdf5PreprocessedGroundTruthData +from disent.data.groundtruth.base import DlH5DataObject +from disent.data.groundtruth.base import Hdf5GroundTruthData # ========================================================================= # @@ -30,7 +31,7 @@ # ========================================================================= # -class DSpritesData(Hdf5PreprocessedGroundTruthData): +class DSpritesData(Hdf5GroundTruthData): """ DSprites Dataset - beta-VAE: Learning Basic Visual Concepts with a Constrained Variational BaseFramework @@ -45,21 +46,34 @@ class DSpritesData(Hdf5PreprocessedGroundTruthData): # reference implementation: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/data/ground_truth/dsprites.py """ + name = 'dsprites' + + # TODO: reference implementation has colour variants factor_names = ('shape', 'scale', 'orientation', 'position_x', 'position_y') factor_sizes = (3, 6, 40, 32, 32) # TOTAL: 737280 - observation_shape = (64, 64, 1) # TODO: reference implementation has colour variants - - dataset_url = 'https://raw.githubusercontent.com/deepmind/dsprites-dataset/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.hdf5' - - hdf5_name = 'imgs' - # minimum chunk size, no compression but good for random accesses - hdf5_chunk_size = (1, 64, 64) - - def __init__(self, data_dir='data/dataset/dsprites', in_memory=False, force_download=False, force_preprocess=False): - super().__init__(data_dir=data_dir, in_memory=in_memory, force_download=force_download, force_preprocess=force_preprocess) - - def __getitem__(self, idx): - return super().__getitem__(idx) * 255 # for some reason uint8 is used as datatype, but only in range 0-1 + observation_shape = (64, 64, 1) + + # 8de0faa39af431a2dc7828df01121fe6 + # 4e142be8960e05b5da4563be70281e8a + + data_object = DlH5DataObject( + # processed dataset file + file_name='dsprites_ndarray_co1sh3sc6or40x32y32_64x64.hdf5', + file_hashes={'fast': '6d6d43d5f4d5c08c4b99a406289b8ecd', 'full': '1473ac1e1af7fdbc910766b3f9157f7b'}, + # download file/link + uri='https://raw.githubusercontent.com/deepmind/dsprites-dataset/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.hdf5', + uri_hashes={'fast': 'd6ee1e43db715c2f0de3c41e38863347', 'full': 'b331c4447a651c44bf5e8ae09022e230'}, + # hash settings + hash_mode='fast', + hash_type='md5', + # h5 re-save settings + hdf5_dataset_name='imgs', + hdf5_chunk_size=(1, 64, 64), + hdf5_compression='gzip', + hdf5_compression_lvl=4, + hdf5_dtype='uint8', + hdf5_mutator=lambda x: x # lambda batch: batch * 255 + ) # ========================================================================= # @@ -68,7 +82,7 @@ def __getitem__(self, idx): if __name__ == '__main__': - from tqdm import tqdm - for dat in tqdm(DSpritesData(in_memory=True, force_preprocess=True)): - pass + data = DSpritesData(in_memory=False, prepare=True) + for dat in data: + print(dat) diff --git a/disent/data/groundtruth/base.py b/disent/data/groundtruth/base.py index 1c6e92ce..7e1bef26 100644 --- a/disent/data/groundtruth/base.py +++ b/disent/data/groundtruth/base.py @@ -31,6 +31,7 @@ from typing import Optional from typing import Sequence from typing import Tuple +from typing import Union import numpy as np @@ -217,7 +218,7 @@ class DlH5DataObject(DlDataObject): hdf5_chunk_size: Tuple[int, ...] hdf5_compression: Optional[str] hdf5_compression_lvl: Optional[int] - hdf5_dtype: np.dtype = None + hdf5_dtype: Optional[Union[np.dtype, str]] = None hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None def _make_h5_job(self, load_path: str, save_path: str): diff --git a/disent/data/util/hdf5.py b/disent/data/util/hdf5.py index 941b38d3..e761f5d4 100644 --- a/disent/data/util/hdf5.py +++ b/disent/data/util/hdf5.py @@ -61,7 +61,7 @@ def __init__(self, h5_path: str, h5_dataset_name: str): def _make_hdf5(self): # TODO: can this cause a memory leak if it is never closed? - hdf5_file = h5py.File(self._h5_path, 'r', libver='latest', swmr=True) + hdf5_file = h5py.File(self._h5_path, 'r', swmr=True) hdf5_data = hdf5_file[self._h5_dataset_name] return hdf5_file, hdf5_data @@ -134,6 +134,9 @@ def hdf5_print_entry_data_stats(h5_dataset: h5py.Dataset, label='STATISTICS'): def hdf5_resave_dataset(inp_h5: h5py.File, out_h5: h5py.File, dataset_name, chunk_size=None, compression=None, compression_lvl=None, batch_size=None, out_dtype=None, out_mutator=None): + # check out_h5 version compatibility + if (isinstance(out_h5.libver, str) and out_h5.libver != 'earliest') or (out_h5.libver[0] != 'earliest'): + raise RuntimeError(f'hdf5 out file has an incompatible libver: {repr(out_h5.libver)} libver should be set to: "earliest"') # create new dataset inp_data = inp_h5[dataset_name] out_data = out_h5.create_dataset( @@ -171,7 +174,7 @@ def hdf5_resave_file(inp_path: str, out_path: str, dataset_name, chunk_size=None # re-save datasets with h5py.File(inp_path, 'r') as inp_h5: with AtomicFileContext(out_path, open_mode=None, overwrite=True) as tmp_h5_path: - with h5py.File(tmp_h5_path, 'w', libver='latest') as out_h5: + with h5py.File(tmp_h5_path, 'w', libver='earliest') as out_h5: # TODO: libver='latest' is not deterministic, even with track_times=False hdf5_resave_dataset( inp_h5=inp_h5, out_h5=out_h5, diff --git a/disent/data/util/jobs.py b/disent/data/util/jobs.py index 594cb5d3..d8905ba9 100644 --- a/disent/data/util/jobs.py +++ b/disent/data/util/jobs.py @@ -146,7 +146,7 @@ def __job_fn(self): # check the hash fhash = self.__compute_hash() if self.hash != fhash: - raise RuntimeError(f'{self}: computed {self.hash_mode} {self.hash_type} hash: {repr(fhash)} for newly generated file {repr(self.path)} does not match expected hash: {repr(self.hash)}') + raise RuntimeError(f'{self}: error because computed {self.hash_mode} {self.hash_type} hash: {repr(fhash)} does not match expected hash: {repr(self.hash)} for: {repr(self.path)}') else: log.debug(f'{self}: successfully generated file: {repr(self.path)} with correct {self.hash_mode} {self.hash_type} hash: {fhash}') diff --git a/experiment/exp/05_adversarial_data/run_04_gen_adversarial.py b/experiment/exp/05_adversarial_data/run_04_gen_adversarial.py index 3e074b40..86d20453 100644 --- a/experiment/exp/05_adversarial_data/run_04_gen_adversarial.py +++ b/experiment/exp/05_adversarial_data/run_04_gen_adversarial.py @@ -111,7 +111,7 @@ def _make_hdf5_dataset(path, dataset, overwrite_mode: str = 'continue') -> str: raise KeyError(f'invalid overwrite_mode={repr(overwrite_mode)}') # open in read write mode log.info(f'Opening hdf5 dataset: overwrite_mode={repr(overwrite_mode)} exists={repr(os.path.exists(path))} path={repr(path)}') - with h5py.File(path, rw_mode, libver='latest') as f: + with h5py.File(path, rw_mode, libver='earliest') as f: # get data num_obs = len(dataset) obs_shape = dataset[0][NAME_OBS][0].shape @@ -136,7 +136,7 @@ def _make_hdf5_dataset(path, dataset, overwrite_mode: str = 'continue') -> str: def _read_hdf5_batch(h5py_path: str, idxs, return_visits=False): batch, visits = [], [] - with h5py.File(h5py_path, 'r', libver='latest', swmr=True) as f: + with h5py.File(h5py_path, 'r', swmr=True) as f: for i in idxs: visits.append(f[NAME_VISITS][i]) batch.append(torch.as_tensor(f[NAME_DATA][i], dtype=torch.float32) / 255) @@ -155,7 +155,7 @@ def _load_hdf5_batch(dataset, h5py_path: str, idxs, initial_noise: Optional[floa observation has not been saved into the hdf5 dataset yet. """ batch, visits = [], [] - with h5py.File(h5py_path, 'r', libver='latest', swmr=True) as f: + with h5py.File(h5py_path, 'r', swmr=True) as f: for i in idxs: v = f[NAME_VISITS][i] if v > 0: @@ -178,7 +178,7 @@ def _save_hdf5_batch(h5py_path: str, batch, idxs): Save a batch to disk. - Can only be used by one thread at a time! """ - with h5py.File(h5py_path, 'r+', libver='latest') as f: + with h5py.File(h5py_path, 'r+', libver='earliest') as f: for obs, idx in zip(batch, idxs): f[NAME_DATA][idx] = torch.clamp(torch.round(obs * 255), 0, 255).to(torch.uint8) f[NAME_VISITS][idx] += 1 From 3ee2eedf745281490ea961f739dbdce271bff17b Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Tue, 1 Jun 2021 23:02:01 +0200 Subject: [PATCH 10/34] temp commit --- disent/data/groundtruth/__init__.py | 6 +- disent/data/groundtruth/_dsprites.py | 5 +- disent/data/groundtruth/_norb.py | 127 +++++++++++++++++++++------ disent/dataset/__init__.py | 20 ----- disent/dataset/_augment_util.py | 15 ++++ disent/frameworks/framework.py | 31 ++++++- disent/metrics/__init__.py | 19 ++-- disent/metrics/_dci.py | 2 +- disent/{ => nn}/model/common.py | 35 ++------ disent/schedule/lerp.py | 31 +++---- disent/util/__init__.py | 52 +---------- disent/util/colors.py | 2 + pytest.ini | 9 ++ tests/test_data.py | 51 +++++++++++ 14 files changed, 247 insertions(+), 158 deletions(-) rename disent/{ => nn}/model/common.py (70%) create mode 100644 pytest.ini diff --git a/disent/data/groundtruth/__init__.py b/disent/data/groundtruth/__init__.py index 02a742ea..5bb81f32 100644 --- a/disent/data/groundtruth/__init__.py +++ b/disent/data/groundtruth/__init__.py @@ -24,10 +24,10 @@ from .base import GroundTruthData # others -from ._cars3d import Cars3dData +# from ._cars3d import Cars3dData from ._dsprites import DSpritesData -from ._mpi3d import Mpi3dData -from ._norb import SmallNorbData +# from ._mpi3d import Mpi3dData +# from ._norb import SmallNorbData from ._shapes3d import Shapes3dData from ._xyobject import XYObjectData from ._xysquares import XYSquaresData, XYSquaresMinimalData diff --git a/disent/data/groundtruth/_dsprites.py b/disent/data/groundtruth/_dsprites.py index 6ce9d828..126fa050 100644 --- a/disent/data/groundtruth/_dsprites.py +++ b/disent/data/groundtruth/_dsprites.py @@ -34,7 +34,7 @@ class DSpritesData(Hdf5GroundTruthData): """ DSprites Dataset - - beta-VAE: Learning Basic Visual Concepts with a Constrained Variational BaseFramework + - beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework (https://github.com/deepmind/dsprites-dataset) Files: @@ -53,9 +53,6 @@ class DSpritesData(Hdf5GroundTruthData): factor_sizes = (3, 6, 40, 32, 32) # TOTAL: 737280 observation_shape = (64, 64, 1) - # 8de0faa39af431a2dc7828df01121fe6 - # 4e142be8960e05b5da4563be70281e8a - data_object = DlH5DataObject( # processed dataset file file_name='dsprites_ndarray_co1sh3sc6or40x32y32_64x64.hdf5', diff --git a/disent/data/groundtruth/_norb.py b/disent/data/groundtruth/_norb.py index bfe4e96e..2e5fadac 100644 --- a/disent/data/groundtruth/_norb.py +++ b/disent/data/groundtruth/_norb.py @@ -21,36 +21,119 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - +import dataclasses import gzip +from typing import Dict + import numpy as np -from disent.data.groundtruth.base import DownloadableGroundTruthData + +from disent.data.groundtruth import GroundTruthData +from disent.data.groundtruth.base import DlDataObject + + +# ========================================================================= # +# Binary Matrix Helper Functions # +# - https://cs.nyu.edu/~ylclab/data/norb-v1.0-small # +# ========================================================================= # + + +_BINARY_MATRIX_TYPES = { + 0x1E3D4C55: 'uint8', # byte matrix + 0x1E3D4C54: 'int32', # integer matrix + 0x1E3D4C56: 'int16', # short matrix + 0x1E3D4C51: 'float32', # single precision matrix + 0x1E3D4C53: 'float64', # double precision matrix + # 0x1E3D4C52: '???', # packed matrix -- not sure what this is? +} + + +def read_binary_matrix_buffer(buffer): + """ + Read the binary matrix data + - modified from disentanglement_lib + + Binary Matrix File Format Specification + * The Header: + - dtype: 4 bytes + - ndim: 4 bytes, little endian + - dim_sizes: (4 * min(3, ndim)) bytes + * Handling the number of dimensions: + - If there are less than 3 dimensions, then dim[1] and dim[2] are both: 1 + - Elif there are 3 or more dimensions, then the header will contain further size information. + * Handling Matrix Data: + - Little endian matrix data comes after the header, + the index of the last dimension changes the fastest. + """ + dtype = int(np.frombuffer(buffer, "int32", 1, 0)) # bytes [0, 4) + ndim = int(np.frombuffer(buffer, "int32", 1, 4)) # bytes [4, 8) + eff_dim = max(3, ndim) # stores minimum of 3 dimensions even for 1D array + dims = np.frombuffer(buffer, "int32", eff_dim, 8)[0:ndim] # bytes [8, 8 + eff_dim * 4) + data = np.frombuffer(buffer, _BINARY_MATRIX_TYPES[dtype], offset=8 + eff_dim * 4) + data = data.reshape(tuple(dims)) + return data + + +def read_binary_matrix_file(file, gzipped: bool = True): + with (gzip.open if gzipped else open)(file, "rb") as f: + return read_binary_matrix_buffer(buffer=f) + + +def resave_binary_matrix_file(inp_path, out_path, gzipped: bool = True): + with AtomicFileContext(out_path, open_mode=None) as temp_out_path: + data = read_binary_matrix_file(file=inp_path, gzipped=gzipped) + np.savez(temp_out_path, data=data) + + +# ========================================================================= # +# Norb Data Tasks # +# ========================================================================= # + + +@dataclasses.dataclass +class BinaryMatrixDataObject(DlDataObject): + file_name: str + file_hashes: Dict[str, str] + # download file/link + uri: str + uri_hashes: Dict[str, str] + # hash settings + hash_mode: str + hash_type: str + + def _make_h5_job(self, load_path: str, save_path: str): + return CachedJobFile( + make_file_fn=lambda path: resave_binary_matrix_file( + inp_path=load_path, + out_path=path, + gzipped=True, + ), + path=save_path, + hash=self.file_hashes[self.hash_mode], + hash_type=self.hash_type, + hash_mode=self.hash_mode, + ) + + def prepare(self, data_dir: str): + dl_path = self.get_file_path(data_dir=data_dir, variant='ORIG') + h5_path = self.get_file_path(data_dir=data_dir) + dl_job = self._make_dl_job(save_path=dl_path) + h5_job = self._make_h5_job(load_path=dl_path, save_path=h5_path) + dl_job.set_child(h5_job).run() + # ========================================================================= # # dataset_norb # # ========================================================================= # -class SmallNorbData(DownloadableGroundTruthData): +class SmallNorbData(GroundTruthData): """ Small NORB Dataset - https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/ - Files: - - direct hdf5: https://raw.githubusercontent.com/deepmind/dsprites-dataset/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.hdf5 - - direct npz: https://raw.githubusercontent.com/deepmind/dsprites-dataset/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz - # reference implementation: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/data/ground_truth/norb.py """ - NORB_TYPES = { - 0x1E3D4C55: 'uint8', # byte matrix - 0x1E3D4C54: 'int32', # integer matrix - # 0x1E3D4C56: 'int16', # short matrix - # 0x1E3D4C51: 'float32', # single precision matrix - # 0x1E3D4C53: 'float64', # double precision matrix - } - # ordered training data (dat, cat, info) NORB_TRAIN_URLS = [ 'https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz', @@ -95,20 +178,8 @@ def _read_norb_set(self, is_test): images = dat[:, 0] # images are in pairs, we only extract the first one of each return images, features - @staticmethod - def _read_norb_file(filename): - """Read the norb data from the compressed file - modified from disentanglement_lib""" - with gzip.open(filename, "rb") as f: - s = f.read() - magic = int(np.frombuffer(s, "int32", 1, 0)) - ndim = int(np.frombuffer(s, "int32", 1, 4)) - eff_dim = max(3, ndim) # stores minimum of 3 dimensions even for 1D array - dims = np.frombuffer(s, "int32", eff_dim, 8)[0:ndim] - data = np.frombuffer(s, SmallNorbData.NORB_TYPES[magic], offset=8 + eff_dim * 4) - data = data.reshape(tuple(dims)) - return data - # ========================================================================= # # END # # ========================================================================= # + diff --git a/disent/dataset/__init__.py b/disent/dataset/__init__.py index 9cb087d0..78634246 100644 --- a/disent/dataset/__init__.py +++ b/disent/dataset/__init__.py @@ -23,23 +23,3 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -# ========================================================================= # -# util # -# ========================================================================= # - - -def split_dataset(dataset, train_ratio=0.8): - """ - splits a dataset randomly into a training (train_ratio) and test set (1-train_ratio). - """ - import torch.utils.data - train_size = int(train_ratio * len(dataset)) - test_size = len(dataset) - train_size - return torch.utils.data.random_split(dataset, [train_size, test_size]) - - -# ========================================================================= # -# END # -# ========================================================================= # - diff --git a/disent/dataset/_augment_util.py b/disent/dataset/_augment_util.py index af8f5b9c..6740dfe9 100644 --- a/disent/dataset/_augment_util.py +++ b/disent/dataset/_augment_util.py @@ -29,6 +29,11 @@ from torch.utils.data.dataloader import default_collate +# ========================================================================= # +# util # +# ========================================================================= # + + class AugmentableDataset(object): @property @@ -140,6 +145,11 @@ def dataset_sample_batch(self, num_samples: int, mode: str): return self.dataset_batch_from_indices(sorted(indices), mode=mode) +# ========================================================================= # +# util # +# ========================================================================= # + + def _batch_to_observation(batch, obs_shape): """ Convert a batch of size 1, to a single observation. @@ -148,3 +158,8 @@ def _batch_to_observation(batch, obs_shape): assert batch.shape == (1, *obs_shape), f'batch.shape={repr(batch.shape)} does not correspond to obs_shape={repr(obs_shape)} with batch dimension added' return batch.reshape(obs_shape) return batch + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/frameworks/framework.py b/disent/frameworks/framework.py index 9d0be107..c88d9af1 100644 --- a/disent/frameworks/framework.py +++ b/disent/frameworks/framework.py @@ -22,10 +22,12 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import warnings +import logging +from dataclasses import asdict from dataclasses import dataclass from dataclasses import fields from numbers import Number +from pprint import pformat from typing import Any from typing import Dict from typing import final @@ -42,6 +44,33 @@ log = logging.getLogger(__name__) +# ========================================================================= # +# framework config # +# ========================================================================= # + + +class DisentConfigurable(object): + + @dataclass + class cfg(object): + def get_keys(self) -> list: + return list(self.to_dict().keys()) + + def to_dict(self) -> dict: + return asdict(self) + + def __str__(self): + return pformat(self.to_dict(), sort_dicts=False) + + def __init__(self, cfg: cfg = cfg()): + if cfg is None: + cfg = self.__class__.cfg() + log.info(f'Initialised default config {cfg=} for {self.__class__.__name__}') + super().__init__() + assert isinstance(cfg, self.__class__.cfg), f'{cfg=} ({type(cfg)}) is not an instance of {self.__class__.cfg}' + self.cfg = cfg + + # ========================================================================= # # framework # # ========================================================================= # diff --git a/disent/metrics/__init__.py b/disent/metrics/__init__.py index 6ebab63d..f0aa9449 100644 --- a/disent/metrics/__init__.py +++ b/disent/metrics/__init__.py @@ -33,23 +33,24 @@ from ._flatness import metric_flatness from ._flatness_components import metric_flatness_components -# helper imports -from disent.util import wrapped_partial as _wrapped_partial - # ========================================================================= # # Fast Metric Settings # # ========================================================================= # +# helper imports +from disent.util import wrapped_partial as _wrapped_partial + + FAST_METRICS = { - 'dci': _wrapped_partial(metric_dci, num_train=1000, num_test=500, boost_mode='sklearn'), # takes - 'factor_vae': _wrapped_partial(metric_factor_vae, num_train=700, num_eval=350, num_variance_estimate=1000), # may not be accurate, but it just takes waay too long otherwise 20+ seconds - 'flatness': _wrapped_partial(metric_flatness, factor_repeats=128), + 'dci': _wrapped_partial(metric_dci, num_train=1000, num_test=500, boost_mode='sklearn'), + 'factor_vae': _wrapped_partial(metric_factor_vae, num_train=700, num_eval=350, num_variance_estimate=1000), # may not be accurate, but it just takes waay too long otherwise 20+ seconds + 'flatness': _wrapped_partial(metric_flatness, factor_repeats=128), 'flatness_components': _wrapped_partial(metric_flatness_components, factor_repeats=128), - 'mig': _wrapped_partial(metric_mig, num_train=2000), - 'sap': _wrapped_partial(metric_sap, num_train=2000, num_test=1000), - 'unsupervised': _wrapped_partial(metric_unsupervised, num_train=2000), + 'mig': _wrapped_partial(metric_mig, num_train=2000), + 'sap': _wrapped_partial(metric_sap, num_train=2000, num_test=1000), + 'unsupervised': _wrapped_partial(metric_unsupervised, num_train=2000), } DEFAULT_METRICS = { diff --git a/disent/metrics/_dci.py b/disent/metrics/_dci.py index 8fde0047..cc70a190 100644 --- a/disent/metrics/_dci.py +++ b/disent/metrics/_dci.py @@ -22,7 +22,7 @@ """ Implementation of Disentanglement, Completeness and Informativeness. -Based on "A BaseFramework for the Quantitative Evaluation of Disentangled +Based on "A Framework for the Quantitative Evaluation of Disentangled Representations" (https://openreview.net/forum?id=By-7dz-AZ). """ diff --git a/disent/model/common.py b/disent/nn/model/common.py similarity index 70% rename from disent/model/common.py rename to disent/nn/model/common.py index 54ddef1b..e1cc0c28 100644 --- a/disent/model/common.py +++ b/disent/nn/model/common.py @@ -23,8 +23,7 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging - -from disent.util import DisentModule +from disent.util.base import DisentModule log = logging.getLogger(__name__) @@ -35,42 +34,26 @@ # ========================================================================= # -class Print(DisentModule): - """From: https://github.com/1Konny/Beta-VAE/blob/master/model.py""" - def __init__(self, layer): - super().__init__() - self.layer = layer - - def forward(self, tensor): - log.debug(self.layer, '|', tensor.shape, '->') - output = self.layer.forward(tensor) - log.debug(output.shape) - return output - - class BatchView(DisentModule): - """From: https://github.com/1Konny/Beta-VAE/blob/master/model.py""" def __init__(self, size): super().__init__() - self.size = (-1, *size) + self._size = (-1, *size) - def forward(self, tensor): - return tensor.view(*self.size) + def forward(self, x): + return x.view(*self._size) class Unsqueeze3D(DisentModule): - """From: https://github.com/amir-abdi/disentanglement-pytorch""" def forward(self, x): - x = x.unsqueeze(-1) - x = x.unsqueeze(-1) - return x + assert x.ndim == 2 + return x.view(*x.shape, 1, 1) # (B, N) -> (B, N, 1, 1) class Flatten3D(DisentModule): - """From: https://github.com/amir-abdi/disentanglement-pytorch""" def forward(self, x): - x = x.view(x.size()[0], -1) - return x + assert x.ndim == 4 + return x.view(*x.shape[0], -1) # (B, C, H, W) -> (B, C*H*W) + # ========================================================================= # # END # diff --git a/disent/schedule/lerp.py b/disent/schedule/lerp.py index 5df86e28..9fea0d14 100644 --- a/disent/schedule/lerp.py +++ b/disent/schedule/lerp.py @@ -52,27 +52,28 @@ def lerp_step(step, max_step, a, b): # ========================================================================= # -# Cyclical Annealing Schedules # -# - https://arxiv.org/abs/1903.10145 # -# - https://github.com/haofuml/cyclical_annealing # -# These functions are not exactly the same, but are more flexible. # +# linear interpolate # # ========================================================================= # -def activate_linear(v): return v -def activate_sigmoid(v): return 1 / (1 + np.exp(-12 * v + 6)) -def activate_cosine(v): return 0.5 * (1 - np.cos(v * math.pi)) +_SCALE_RATIO_FNS = { + 'linear': lambda r: r, + 'sigmoid': lambda r: 1 / (1 + np.exp(-12 * r + 6)), + 'cosine': lambda r: 0.5 * (1 - np.cos(r * math.pi)), +} -_FLERP_ACTIVATIONS = { - 'linear': activate_linear, - 'sigmoid': activate_sigmoid, - 'cosine': activate_cosine, -} +def scale_ratio(r, mode='linear'): + r = np.clip(r, 0., 1.) + return _SCALE_RATIO_FNS[mode](r) -def activate(v, mode='linear'): - return _FLERP_ACTIVATIONS[mode](v) +# ========================================================================= # +# Cyclical Annealing Schedules # +# - https://arxiv.org/abs/1903.10145 # +# - https://github.com/haofuml/cyclical_annealing # +# These functions are not exactly the same, but are more flexible. # +# ========================================================================= # _END_VALUES = { @@ -106,7 +107,7 @@ def cyclical_anneal( # compute increasing values if low_ratio + high_ratio < 1: r = (r - low_ratio) / (1-low_ratio-high_ratio) - r = activate(r, mode=mode) + r = scale_ratio(r, mode=mode) # truncate values r = np.where(low_mask, 0, r) r = np.where(high_mask, 1, r) diff --git a/disent/util/__init__.py b/disent/util/__init__.py index aca606b9..c1d0ae82 100644 --- a/disent/util/__init__.py +++ b/disent/util/__init__.py @@ -105,7 +105,7 @@ def __exit__(self, *args, **kwargs): self._state = None # ========================================================================= # -# IO # +# Conversion # # ========================================================================= # @@ -125,56 +125,6 @@ def to_numpy(array) -> np.ndarray: return np.array(array) -# ========================================================================= # -# IO # -# ========================================================================= # - - -def atomic_save(obj, path): - """ - Save a model to a file, making sure that the file will - never be partly written. - - This prevents the model from getting corrupted in the - event that the process dies or the machine crashes. - - FROM: my obstacle_tower project - """ - import os - import torch - - if os.path.dirname(path): - os.makedirs(os.path.dirname(path), exist_ok=True) - torch.save(obj, path + '.tmp') - os.rename(path + '.tmp', path) - - -def save_model(model, path): - atomic_save(model.state_dict(), path) - log.info(f'[MODEL]: saved {path}') - - -def load_model(model, path, cuda=True, fail_if_missing=True): - """ - FROM: my obstacle_tower project - """ - import os - import torch - - if path and os.path.exists(path): - model.load_state_dict(torch.load( - path, - map_location=torch.device('cuda' if cuda else 'cpu') - )) - log.info(f'[MODEL]: loaded {path} (cuda: {cuda})') - else: - if fail_if_missing: - raise Exception(f'Could not load model, path does not exist: {path}') - if cuda: - model = model.cuda() # this needs to stay despite the above. - log.info('[MODEL]: Moved to GPU') - return model - # ========================================================================= # # Iterators # diff --git a/disent/util/colors.py b/disent/util/colors.py index 0e39f285..725583d9 100644 --- a/disent/util/colors.py +++ b/disent/util/colors.py @@ -27,6 +27,7 @@ # Ansi Colors # # ========================================================================= # + RST = '\033[0m' # dark colors @@ -49,6 +50,7 @@ CYN = '\033[36m' lGRY = '\033[37m' + # ========================================================================= # # END # # ========================================================================= # diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..c831f34b --- /dev/null +++ b/pytest.ini @@ -0,0 +1,9 @@ + +[pytest] +minversion = 6.0 +testpaths = + tests + disent +python_files = + test_*.py + __test__*.py diff --git a/tests/test_data.py b/tests/test_data.py index 6031accb..b9a627da 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -22,7 +22,9 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import numpy as np +import pytest +from disent.data.groundtruth import Shapes3dData from disent.data.groundtruth import XYSquaresData from disent.data.groundtruth._xysquares import XYSquaresMinimalData @@ -30,6 +32,7 @@ # ========================================================================= # # TESTS # # ========================================================================= # +from disent.data.groundtruth.base import Hdf5GroundTruthData def test_xysquares_similarity(): @@ -46,6 +49,54 @@ def test_xysquares_similarity(): assert np.allclose(data_org[n-1], data_min[n-1]) + + + +@pytest.mark.parametrize("num_workers", [0, 1, 2]) +def test_hdf5_multiproc_dataset(num_workers): + from disent.dataset.random import RandomDataset + from torch.utils.data import DataLoader + + xysquares = XYSquaresData(square_size=2, image_size=4) + + + # class TestHdf5Dataset(Hdf5GroundTruthData): + # + # + # factor_names = ('floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 'orientation') + # factor_sizes = (10, 10, 10, 8, 4, 15) # TOTAL: 480000 + # observation_shape = (64, 64, 3) + # + # data_object = DlH5DataObject( + # # processed dataset file + # file_name='3dshapes.h5', + # file_hashes={'fast': 'e3a1a449b95293d4b2c25edbfcb8e804', 'full': 'b5187ee0d8b519bb33281c5ca549658c'}, + # # download file/link + # uri='https://storage.googleapis.com/3d-shapes/3dshapes.h5', + # uri_hashes={'fast': '85b20ed7cc8dc1f939f7031698d2d2ab', 'full': '099a2078d58cec4daad0702c55d06868'}, + # # hash settings + # hash_mode='fast', + # hash_type='md5', + # # h5 re-save settings + # hdf5_dataset_name='images', + # hdf5_chunk_size=(1, 64, 64, 3), + # hdf5_compression='gzip', + # hdf5_compression_lvl=4, + # ) + # + # + # + # Shapes3dData() + # dataset = RandomDataset(Shapes3dData(prepare=True)) + # + # dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=2, shuffle=True) + # + # with tqdm(total=len(dataset)) as progress: + # for batch in dataloader: + # progress.update(256) + + + # ========================================================================= # # END # # ========================================================================= # From d7c59f8618799ba02c9419383847e7d81c2487a9 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Tue, 1 Jun 2021 23:15:12 +0200 Subject: [PATCH 11/34] moved base modules --- disent/frameworks/framework.py | 4 +- disent/frameworks/helper/reconstructions.py | 2 +- disent/model/base.py | 2 +- disent/{nn => }/model/common.py | 3 +- disent/nn/__init__.py | 23 ++++++++ disent/nn/modules.py | 55 +++++++++++++++++++ disent/transform/_augment.py | 2 +- disent/util/__init__.py | 46 ---------------- .../run_03_train_disentangle_kernel.py | 4 +- 9 files changed, 87 insertions(+), 54 deletions(-) rename disent/{nn => }/model/common.py (98%) create mode 100644 disent/nn/__init__.py create mode 100644 disent/nn/modules.py diff --git a/disent/frameworks/framework.py b/disent/frameworks/framework.py index c88d9af1..20d2a3c1 100644 --- a/disent/frameworks/framework.py +++ b/disent/frameworks/framework.py @@ -38,8 +38,8 @@ import torch from disent.schedule import Schedule -from disent.util import DisentConfigurable -from disent.util import DisentLightningModule +from disent.nn.modules import DisentLightningModule + log = logging.getLogger(__name__) diff --git a/disent/frameworks/helper/reconstructions.py b/disent/frameworks/helper/reconstructions.py index 7f23ad5f..cb3e4bdc 100644 --- a/disent/frameworks/helper/reconstructions.py +++ b/disent/frameworks/helper/reconstructions.py @@ -35,7 +35,7 @@ from disent.frameworks.helper.reductions import loss_reduction from disent.frameworks.helper.util import compute_ave_loss from disent.transform import FftKernel -from disent.util import DisentModule +from disent.nn.modules import DisentModule from deprecated import deprecated diff --git a/disent/model/base.py b/disent/model/base.py index 30a6acec..a55065cb 100644 --- a/disent/model/base.py +++ b/disent/model/base.py @@ -28,7 +28,7 @@ import numpy as np from torch import Tensor -from disent.util import DisentModule +from disent.nn.modules import DisentModule log = logging.getLogger(__name__) diff --git a/disent/nn/model/common.py b/disent/model/common.py similarity index 98% rename from disent/nn/model/common.py rename to disent/model/common.py index e1cc0c28..55de8714 100644 --- a/disent/nn/model/common.py +++ b/disent/model/common.py @@ -23,7 +23,8 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging -from disent.util.base import DisentModule + +from disent.nn.modules import DisentModule log = logging.getLogger(__name__) diff --git a/disent/nn/__init__.py b/disent/nn/__init__.py new file mode 100644 index 00000000..9a05a479 --- /dev/null +++ b/disent/nn/__init__.py @@ -0,0 +1,23 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ diff --git a/disent/nn/modules.py b/disent/nn/modules.py new file mode 100644 index 00000000..9de5f574 --- /dev/null +++ b/disent/nn/modules.py @@ -0,0 +1,55 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import pytorch_lightning as pl +import torch + + +# ========================================================================= # +# Base Modules # +# ========================================================================= # + + +class DisentModule(torch.nn.Module): + + def _forward_unimplemented(self, *args): + # Annoying fix applied by torch for Module.forward: + # https://github.com/python/mypy/issues/8795 + raise RuntimeError('This should never run!') + + def forward(self, *args, **kwargs): + raise NotImplementedError + + +class DisentLightningModule(pl.LightningModule): + + def _forward_unimplemented(self, *args): + # Annoying fix applied by torch for Module.forward: + # https://github.com/python/mypy/issues/8795 + raise RuntimeError('This should never run!') + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/transform/_augment.py b/disent/transform/_augment.py index e82ec955..938af078 100644 --- a/disent/transform/_augment.py +++ b/disent/transform/_augment.py @@ -33,7 +33,7 @@ import torch import disent -from disent.util import DisentModule +from disent.nn.modules import DisentModule from disent.util.math import torch_box_kernel_2d from disent.util.math import torch_conv2d_channel_wise_fft from disent.util.math import torch_gaussian_kernel_2d diff --git a/disent/util/__init__.py b/disent/util/__init__.py index c1d0ae82..14265d90 100644 --- a/disent/util/__init__.py +++ b/disent/util/__init__.py @@ -425,52 +425,6 @@ def get_memory_usage(): return num_bytes -# ========================================================================= # -# Torch Helper # -# ========================================================================= # - - -class DisentModule(torch.nn.Module): - - def _forward_unimplemented(self, *args): - # Annoying fix applied by torch for Module.forward: - # https://github.com/python/mypy/issues/8795 - raise RuntimeError('This should never run!') - - def forward(self, *args, **kwargs): - raise NotImplementedError - - -class DisentLightningModule(pl.LightningModule): - - def _forward_unimplemented(self, *args): - # Annoying fix applied by torch for Module.forward: - # https://github.com/python/mypy/issues/8795 - raise RuntimeError('This should never run!') - - -class DisentConfigurable(object): - - @dataclass - class cfg(object): - def get_keys(self) -> list: - return list(self.to_dict().keys()) - - def to_dict(self) -> dict: - return asdict(self) - - def __str__(self): - return pformat(self.to_dict(), sort_dicts=False) - - def __init__(self, cfg: cfg = cfg()): - if cfg is None: - cfg = self.__class__.cfg() - log.info(f'Initialised default config {cfg=} for {self.__class__.__name__}') - super().__init__() - assert isinstance(cfg, self.__class__.cfg), f'{cfg=} ({type(cfg)}) is not an instance of {self.__class__.cfg}' - self.cfg = cfg - - # ========================================================================= # # Slot Tuple # # ========================================================================= # diff --git a/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py b/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py index 5a362ba3..3f4c6e7a 100644 --- a/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py +++ b/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py @@ -38,8 +38,8 @@ from torch.utils.data import DataLoader import experiment.exp.util as H -from disent.util import DisentLightningModule -from disent.util import DisentModule +from disent.nn.modules import DisentLightningModule +from disent.nn.modules import DisentModule from disent.util import make_box_str from disent.util import seed from disent.util.math import torch_conv2d_channel_wise_fft From c798bb84cfb527396adc43b4e3c3c076be5f6cf3 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Wed, 2 Jun 2021 00:24:28 +0200 Subject: [PATCH 12/34] large refactor --- README.md | 4 +- disent/frameworks/ae/_supervised__tae.py | 4 +- disent/frameworks/ae/_unsupervised__ae.py | 2 +- .../frameworks/helper/latent_distributions.py | 94 ++------------- disent/frameworks/helper/reconstructions.py | 11 +- disent/frameworks/vae/_supervised__tvae.py | 6 +- .../frameworks/vae/_unsupervised__dfcvae.py | 4 +- .../experimental/_supervised__adaneg_tvae.py | 4 +- .../vae/experimental/_supervised__adatvae.py | 4 +- .../vae/experimental/_supervised__tbadavae.py | 4 +- .../vae/experimental/_supervised__tgadavae.py | 4 +- disent/model/ae/base.py | 73 ------------ disent/{model => nn/loss}/__init__.py | 0 disent/nn/loss/kl.py | 112 ++++++++++++++++++ .../triplet_loss.py => nn/loss/triplet.py} | 0 disent/{model/ae => nn/model}/__init__.py | 16 +-- .../common.py => nn/model/ae/__init__.py} | 53 +++------ disent/{ => nn}/model/ae/_conv64.py | 9 +- disent/{ => nn}/model/ae/_conv64_alt.py | 9 +- disent/{ => nn}/model/ae/_fc.py | 9 +- disent/{ => nn}/model/ae/_simpleconv64.py | 9 +- disent/{ => nn}/model/ae/_simplefc.py | 9 +- disent/{model => nn/model/ae}/base.py | 49 +++++++- disent/nn/modules.py | 26 ++++ .../{frameworks/helper => nn}/reductions.py | 0 disent/{ => nn}/transform/__init__.py | 0 disent/{ => nn}/transform/_augment.py | 6 +- disent/{ => nn}/transform/_transforms.py | 2 +- disent/{ => nn}/transform/functional.py | 0 disent/{ => nn}/transform/groundtruth.py | 0 disent/{model/init.py => nn/weights.py} | 11 +- disent/util/__init__.py | 69 ----------- disent/util/math.py | 26 +++- docs/examples/mnist_example.py | 6 +- docs/examples/overview_dataset_loader.py | 7 +- docs/examples/overview_dataset_pair.py | 5 +- .../examples/overview_dataset_pair_augment.py | 5 +- docs/examples/overview_dataset_single.py | 3 +- docs/examples/overview_framework_adagvae.py | 9 +- docs/examples/overview_framework_ae.py | 9 +- docs/examples/overview_framework_betavae.py | 9 +- .../overview_framework_betavae_scheduled.py | 9 +- docs/examples/overview_metrics.py | 5 +- experiment/exp/01_visual_overlap/run.py | 2 +- experiment/exp/util/_dataset.py | 2 +- experiment/exp/util/_loss.py | 2 +- experiment/run.py | 5 +- experiment/util/hydra_data.py | 2 +- tests/test_frameworks.py | 8 +- tests/test_math.py | 2 +- tests/test_transform.py | 5 +- 51 files changed, 344 insertions(+), 380 deletions(-) delete mode 100644 disent/model/ae/base.py rename disent/{model => nn/loss}/__init__.py (100%) create mode 100644 disent/nn/loss/kl.py rename disent/{frameworks/helper/triplet_loss.py => nn/loss/triplet.py} (100%) rename disent/{model/ae => nn/model}/__init__.py (75%) rename disent/{model/common.py => nn/model/ae/__init__.py} (55%) rename disent/{ => nn}/model/ae/_conv64.py (95%) rename disent/{ => nn}/model/ae/_conv64_alt.py (96%) rename disent/{ => nn}/model/ae/_fc.py (94%) rename disent/{ => nn}/model/ae/_simpleconv64.py (95%) rename disent/{ => nn}/model/ae/_simplefc.py (93%) rename disent/{model => nn/model/ae}/base.py (71%) rename disent/{frameworks/helper => nn}/reductions.py (100%) rename disent/{ => nn}/transform/__init__.py (100%) rename disent/{ => nn}/transform/_augment.py (97%) rename disent/{ => nn}/transform/_transforms.py (98%) rename disent/{ => nn}/transform/functional.py (100%) rename disent/{ => nn}/transform/groundtruth.py (100%) rename disent/{model/init.py => nn/weights.py} (86%) diff --git a/README.md b/README.md index f60034b3..55b76407 100644 --- a/README.md +++ b/README.md @@ -241,9 +241,9 @@ from disent.data.groundtruth import XYObjectData from disent.dataset.groundtruth import GroundTruthDataset from disent.frameworks.vae import BetaVae from disent.metrics import metric_dci, metric_mig -from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder +from disent.nn.model.ae import AutoEncoder, EncoderConv64, DecoderConv64 +from disent.nn.transform import ToStandardisedTensor from disent.schedule import CyclicSchedule -from disent.transform import ToStandardisedTensor # We use this internally to test this script. # You can remove all references to this in your own code. diff --git a/disent/frameworks/ae/_supervised__tae.py b/disent/frameworks/ae/_supervised__tae.py index 47982834..304a97f5 100644 --- a/disent/frameworks/ae/_supervised__tae.py +++ b/disent/frameworks/ae/_supervised__tae.py @@ -33,8 +33,8 @@ import torch from disent.frameworks.ae._unsupervised__ae import Ae -from disent.frameworks.helper.triplet_loss import compute_triplet_loss -from disent.frameworks.helper.triplet_loss import TripletLossConfig +from disent.nn.loss.triplet import compute_triplet_loss +from disent.nn.loss.triplet import TripletLossConfig # ========================================================================= # diff --git a/disent/frameworks/ae/_unsupervised__ae.py b/disent/frameworks/ae/_unsupervised__ae.py index 5ce21bee..091db822 100644 --- a/disent/frameworks/ae/_unsupervised__ae.py +++ b/disent/frameworks/ae/_unsupervised__ae.py @@ -39,7 +39,7 @@ from disent.frameworks.helper.reconstructions import make_reconstruction_loss from disent.frameworks.helper.reconstructions import ReconLossHandler from disent.frameworks.helper.util import detach_all -from disent.model.ae.base import AutoEncoder +from disent.nn.model.ae import AutoEncoder from disent.util import map_all diff --git a/disent/frameworks/helper/latent_distributions.py b/disent/frameworks/helper/latent_distributions.py index fd4a5a77..a5f3b2a6 100644 --- a/disent/frameworks/helper/latent_distributions.py +++ b/disent/frameworks/helper/latent_distributions.py @@ -23,102 +23,24 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -from dataclasses import fields +from typing import final from typing import Sequence -from typing import Tuple, final +from typing import Tuple -import numpy as np import torch +from torch.distributions import Distribution from torch.distributions import Laplace -from torch.distributions import Normal, Distribution +from torch.distributions import Normal -from disent.frameworks.helper.reductions import loss_reduction from disent.frameworks.helper.util import compute_ave_loss - - -# ========================================================================= # -# Helper Functions # -# ========================================================================= # - - -def short_dataclass_repr(self): - vals = { - k: v.shape if isinstance(v, (torch.Tensor, np.ndarray)) else v - for k, v in ((f.name, getattr(self, f.name)) for f in fields(self)) - } - return f'{self.__class__.__name__}({", ".join(f"{k}={v}" for k, v in vals.items())})' - - -# ========================================================================= # -# Kl Loss # -# ========================================================================= # - - -def kl_loss_direct_reverse(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None): - # This is how the original VAE/BetaVAE papers do it. - # - we compute the reverse kl divergence directly instead of approximating it - # - kl(post|prior) - # FORWARD vs. REVERSE kl (https://www.tuananhle.co.uk/notes/reverse-forward-kl.html) - # - If we minimize the kl(post|prior) or the reverse/exclusive KL, the zero-forcing/mode-seeking behavior arises. - # - If we minimize the kl(prior|post) or the forward/inclusive KL, the mass-covering/mean-seeking behavior arises. - return torch.distributions.kl_divergence(posterior, prior) - - -def kl_loss_approx_reverse(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None): - # This is how pytorch-lightning-bolts does it: - # - kl(post|prior) - # See issue: https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues/565 - # - we approximate the reverse kl divergence instead of computing it analytically - assert z_sampled is not None, 'to compute the approximate kl loss, z_sampled needs to be defined (cfg.kl_mode="approx")' - return posterior.log_prob(z_sampled) - prior.log_prob(z_sampled) - - -def kl_loss_direct_forward(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None): - # compute the forward kl - # - kl(prior|post) - return torch.distributions.kl_divergence(prior, posterior) - - -def kl_loss_approx_forward(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None): - # compute the approximate forward kl - # - kl(prior|post) - assert z_sampled is not None, 'to compute the approximate kl loss, z_sampled needs to be defined (cfg.kl_mode="approx")' - return prior.log_prob(z_sampled) - posterior.log_prob(z_sampled) - - -def kl_loss_direct_symmetric(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None): - # compute the (scaled) symmetric kl - # - 0.5 * kl(prior|post) + 0.5 * kl(prior|post) - return 0.5 * kl_loss_direct_reverse(posterior, prior, z_sampled) + 0.5 * kl_loss_direct_forward(posterior, prior, z_sampled) - - -def kl_loss_approx_symmetric(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None): - # compute the approximate (scaled) symmetric kl - # - 0.5 * kl(prior|post) + 0.5 * kl(prior|post) - return 0.5 * kl_loss_approx_reverse(posterior, prior, z_sampled) + 0.5 * kl_loss_approx_forward(posterior, prior, z_sampled) - - -_KL_LOSS_MODES = { - # reverse kl -- how it should be done for VAEs - 'direct': kl_loss_direct_reverse, # alias for reverse modes - 'approx': kl_loss_approx_reverse, # alias for reverse modes - 'direct_reverse': kl_loss_direct_reverse, - 'approx_reverse': kl_loss_approx_reverse, - # forward kl - 'direct_forward': kl_loss_direct_forward, - 'approx_forward': kl_loss_approx_forward, - # symmetric kl - 'direct_symmetric': kl_loss_direct_symmetric, - 'approx_symmetric': kl_loss_approx_symmetric, -} - - -def kl_loss(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None, mode='direct'): - return _KL_LOSS_MODES[mode](posterior, prior, z_sampled) +from disent.nn.loss.kl import kl_loss +from disent.nn.reductions import loss_reduction # ========================================================================= # # Vae Distributions # +# TODO: this should be moved into NNs # +# TODO: encoder modules should directly output distributions! # # ========================================================================= # diff --git a/disent/frameworks/helper/reconstructions.py b/disent/frameworks/helper/reconstructions.py index cb3e4bdc..23cb9ac2 100644 --- a/disent/frameworks/helper/reconstructions.py +++ b/disent/frameworks/helper/reconstructions.py @@ -22,22 +22,21 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import re import warnings from typing import final from typing import Sequence from typing import Union -import re import torch import torch.nn.functional as F +from deprecated import deprecated -from disent.frameworks.helper.reductions import batch_loss_reduction -from disent.frameworks.helper.reductions import loss_reduction from disent.frameworks.helper.util import compute_ave_loss -from disent.transform import FftKernel from disent.nn.modules import DisentModule - -from deprecated import deprecated +from disent.nn.reductions import batch_loss_reduction +from disent.nn.reductions import loss_reduction +from disent.nn.transform import FftKernel # ========================================================================= # diff --git a/disent/frameworks/vae/_supervised__tvae.py b/disent/frameworks/vae/_supervised__tvae.py index cd1298eb..681e1471 100644 --- a/disent/frameworks/vae/_supervised__tvae.py +++ b/disent/frameworks/vae/_supervised__tvae.py @@ -23,7 +23,6 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ from dataclasses import dataclass -from distutils.dist import Distribution from numbers import Number from typing import Any from typing import Dict @@ -31,12 +30,11 @@ from typing import Tuple from typing import Union -import numpy as np import torch from torch.distributions import Normal -from disent.frameworks.helper.triplet_loss import compute_triplet_loss -from disent.frameworks.helper.triplet_loss import TripletLossConfig +from disent.nn.loss.triplet import compute_triplet_loss +from disent.nn.loss.triplet import TripletLossConfig from disent.frameworks.vae._unsupervised__betavae import BetaVae diff --git a/disent/frameworks/vae/_unsupervised__dfcvae.py b/disent/frameworks/vae/_unsupervised__dfcvae.py index 8593a6b6..f9f480f2 100644 --- a/disent/frameworks/vae/_unsupervised__dfcvae.py +++ b/disent/frameworks/vae/_unsupervised__dfcvae.py @@ -36,10 +36,10 @@ from torchvision.models import vgg19_bn from torch.nn import functional as F -from disent.frameworks.helper.reductions import get_mean_loss_scale +from disent.nn.reductions import get_mean_loss_scale from disent.frameworks.helper.util import compute_ave_loss from disent.frameworks.vae._unsupervised__betavae import BetaVae -from disent.transform.functional import check_tensor +from disent.nn.transform.functional import check_tensor # ========================================================================= # diff --git a/disent/frameworks/vae/experimental/_supervised__adaneg_tvae.py b/disent/frameworks/vae/experimental/_supervised__adaneg_tvae.py index d0596054..469b4099 100644 --- a/disent/frameworks/vae/experimental/_supervised__adaneg_tvae.py +++ b/disent/frameworks/vae/experimental/_supervised__adaneg_tvae.py @@ -29,8 +29,8 @@ import torch from torch.distributions import Normal -from disent.frameworks.helper.triplet_loss import configured_dist_triplet -from disent.frameworks.helper.triplet_loss import configured_triplet +from disent.nn.loss.triplet import configured_dist_triplet +from disent.nn.loss.triplet import configured_triplet from disent.frameworks.vae._supervised__tvae import TripletVae from disent.frameworks.vae.experimental._supervised__adatvae import compute_triplet_shared_masks from disent.frameworks.vae.experimental._supervised__adatvae import compute_triplet_shared_masks_from_zs diff --git a/disent/frameworks/vae/experimental/_supervised__adatvae.py b/disent/frameworks/vae/experimental/_supervised__adatvae.py index 42c8e449..70b7745d 100644 --- a/disent/frameworks/vae/experimental/_supervised__adatvae.py +++ b/disent/frameworks/vae/experimental/_supervised__adatvae.py @@ -32,8 +32,8 @@ from torch.distributions import Distribution from torch.distributions import Normal -from disent.frameworks.helper.triplet_loss import configured_dist_triplet -from disent.frameworks.helper.triplet_loss import configured_triplet +from disent.nn.loss.triplet import configured_dist_triplet +from disent.nn.loss.triplet import configured_triplet from disent.frameworks.vae._supervised__tvae import TripletVae from disent.frameworks.vae._weaklysupervised__adavae import AdaVae from disent.frameworks.vae._weaklysupervised__adavae import compute_average_distribution diff --git a/disent/frameworks/vae/experimental/_supervised__tbadavae.py b/disent/frameworks/vae/experimental/_supervised__tbadavae.py index aaa7579a..9e5caf8d 100644 --- a/disent/frameworks/vae/experimental/_supervised__tbadavae.py +++ b/disent/frameworks/vae/experimental/_supervised__tbadavae.py @@ -24,9 +24,9 @@ from dataclasses import dataclass -from disent.frameworks.helper.triplet_loss import compute_triplet_loss from disent.frameworks.vae.experimental._supervised__badavae import BoundedAdaVae -from disent.frameworks.helper.triplet_loss import TripletLossConfig +from disent.nn.loss.triplet import compute_triplet_loss +from disent.nn.loss.triplet import TripletLossConfig # ========================================================================= # diff --git a/disent/frameworks/vae/experimental/_supervised__tgadavae.py b/disent/frameworks/vae/experimental/_supervised__tgadavae.py index 2d2fc0c9..0739e751 100644 --- a/disent/frameworks/vae/experimental/_supervised__tgadavae.py +++ b/disent/frameworks/vae/experimental/_supervised__tgadavae.py @@ -24,9 +24,9 @@ from dataclasses import dataclass -from disent.frameworks.helper.triplet_loss import compute_triplet_loss from disent.frameworks.vae.experimental._supervised__gadavae import GuidedAdaVae -from disent.frameworks.helper.triplet_loss import TripletLossConfig +from disent.nn.loss.triplet import compute_triplet_loss +from disent.nn.loss.triplet import TripletLossConfig # ========================================================================= # diff --git a/disent/model/ae/base.py b/disent/model/ae/base.py deleted file mode 100644 index 557862d0..00000000 --- a/disent/model/ae/base.py +++ /dev/null @@ -1,73 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -from torch import Tensor - -from disent.model.base import BaseDecoderModule, BaseEncoderModule, BaseModule - - -# ========================================================================= # -# gaussian encoder model # -# ========================================================================= # - - -class AutoEncoder(BaseModule): - - def __init__(self, encoder: BaseEncoderModule, decoder: BaseDecoderModule): - assert isinstance(encoder, BaseEncoderModule) - assert isinstance(decoder, BaseDecoderModule) - # check sizes - assert encoder.x_shape == decoder.x_shape, 'x_shape mismatch' - assert encoder.x_size == decoder.x_size, 'x_size mismatch - this should never happen if x_shape matches' - assert encoder.z_size == decoder.z_size, 'z_size mismatch' - # initialise - super().__init__(x_shape=decoder.x_shape, z_size=decoder.z_size, z_multiplier=encoder.z_multiplier) - # assign - self._encoder = encoder - self._decoder = decoder - - def forward(self, x): - raise RuntimeError('This has been disabled') - - def encode(self, x): - z_raw = self._encoder(x) - # extract components if necessary - if self._z_multiplier == 1: - return z_raw - elif self.z_multiplier == 2: - return z_raw[..., :self.z_size], z_raw[..., self.z_size:] - else: - raise KeyError(f'z_multiplier={self.z_multiplier} is unsupported') - - def decode(self, z: Tensor) -> Tensor: - """ - decode the given representation. - the returned tensor does not have an activation applied to it! - """ - return self._decoder(z) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/model/__init__.py b/disent/nn/loss/__init__.py similarity index 100% rename from disent/model/__init__.py rename to disent/nn/loss/__init__.py diff --git a/disent/nn/loss/kl.py b/disent/nn/loss/kl.py new file mode 100644 index 00000000..669c55dd --- /dev/null +++ b/disent/nn/loss/kl.py @@ -0,0 +1,112 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import torch +from torch.distributions import Distribution +from torch.distributions import Distribution +from torch.distributions import Distribution +from torch.distributions import Distribution +from torch.distributions import Distribution +from torch.distributions import Distribution +from torch.distributions import Distribution +from torch.distributions import Distribution +from torch.distributions import Distribution +from torch.distributions import Distribution +from torch.distributions import Distribution +from torch.distributions import Distribution +from torch.distributions import Distribution +from torch.distributions import Distribution + + +# ========================================================================= # +# Kl Losses # +# ========================================================================= # + + +def kl_loss_direct_reverse(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None): + # This is how the original VAE/BetaVAE papers do it. + # - we compute the reverse kl divergence directly instead of approximating it + # - kl(post|prior) + # FORWARD vs. REVERSE kl (https://www.tuananhle.co.uk/notes/reverse-forward-kl.html) + # - If we minimize the kl(post|prior) or the reverse/exclusive KL, the zero-forcing/mode-seeking behavior arises. + # - If we minimize the kl(prior|post) or the forward/inclusive KL, the mass-covering/mean-seeking behavior arises. + return torch.distributions.kl_divergence(posterior, prior) + + +def kl_loss_approx_reverse(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None): + # This is how pytorch-lightning-bolts does it: + # - kl(post|prior) + # See issue: https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues/565 + # - we approximate the reverse kl divergence instead of computing it analytically + assert z_sampled is not None, 'to compute the approximate kl loss, z_sampled needs to be defined (cfg.kl_mode="approx")' + return posterior.log_prob(z_sampled) - prior.log_prob(z_sampled) + + +def kl_loss_direct_forward(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None): + # compute the forward kl + # - kl(prior|post) + return torch.distributions.kl_divergence(prior, posterior) + + +def kl_loss_approx_forward(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None): + # compute the approximate forward kl + # - kl(prior|post) + assert z_sampled is not None, 'to compute the approximate kl loss, z_sampled needs to be defined (cfg.kl_mode="approx")' + return prior.log_prob(z_sampled) - posterior.log_prob(z_sampled) + + +def kl_loss_direct_symmetric(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None): + # compute the (scaled) symmetric kl + # - 0.5 * kl(prior|post) + 0.5 * kl(prior|post) + return 0.5 * kl_loss_direct_reverse(posterior, prior, z_sampled) + 0.5 * kl_loss_direct_forward(posterior, prior, z_sampled) + + +def kl_loss_approx_symmetric(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None): + # compute the approximate (scaled) symmetric kl + # - 0.5 * kl(prior|post) + 0.5 * kl(prior|post) + return 0.5 * kl_loss_approx_reverse(posterior, prior, z_sampled) + 0.5 * kl_loss_approx_forward(posterior, prior, z_sampled) + + +_KL_LOSS_MODES = { + # reverse kl -- how it should be done for VAEs + 'direct': kl_loss_direct_reverse, # alias for reverse modes + 'approx': kl_loss_approx_reverse, # alias for reverse modes + 'direct_reverse': kl_loss_direct_reverse, + 'approx_reverse': kl_loss_approx_reverse, + # forward kl + 'direct_forward': kl_loss_direct_forward, + 'approx_forward': kl_loss_approx_forward, + # symmetric kl + 'direct_symmetric': kl_loss_direct_symmetric, + 'approx_symmetric': kl_loss_approx_symmetric, +} + + +def kl_loss(posterior: Distribution, prior: Distribution, z_sampled: torch.Tensor = None, mode='direct'): + return _KL_LOSS_MODES[mode](posterior, prior, z_sampled) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/frameworks/helper/triplet_loss.py b/disent/nn/loss/triplet.py similarity index 100% rename from disent/frameworks/helper/triplet_loss.py rename to disent/nn/loss/triplet.py diff --git a/disent/model/ae/__init__.py b/disent/nn/model/__init__.py similarity index 75% rename from disent/model/ae/__init__.py rename to disent/nn/model/__init__.py index 3e411516..fdc7c167 100644 --- a/disent/model/ae/__init__.py +++ b/disent/nn/model/__init__.py @@ -22,15 +22,7 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -from ._conv64 import DecoderConv64 -from ._conv64 import EncoderConv64 -from ._conv64_alt import DecoderConv64Alt -from ._conv64_alt import EncoderConv64Alt -# components -from ._fc import DecoderFC -from ._fc import EncoderFC -from ._simpleconv64 import DecoderSimpleConv64 -from ._simpleconv64 import EncoderSimpleConv64 -from ._simplefc import DecoderSimpleFC -from ._simplefc import EncoderSimpleFC -from .base import AutoEncoder +# disent base modules +from disent.nn.model.ae import AutoEncoder +from disent.nn.model.ae import DisentEncoder +from disent.nn.model.ae import DisentDecoder diff --git a/disent/model/common.py b/disent/nn/model/ae/__init__.py similarity index 55% rename from disent/model/common.py rename to disent/nn/model/ae/__init__.py index 55de8714..fd679e8b 100644 --- a/disent/model/common.py +++ b/disent/nn/model/ae/__init__.py @@ -22,40 +22,19 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import logging - -from disent.nn.modules import DisentModule - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Utility Layers # -# ========================================================================= # - - -class BatchView(DisentModule): - def __init__(self, size): - super().__init__() - self._size = (-1, *size) - - def forward(self, x): - return x.view(*self._size) - - -class Unsqueeze3D(DisentModule): - def forward(self, x): - assert x.ndim == 2 - return x.view(*x.shape, 1, 1) # (B, N) -> (B, N, 1, 1) - - -class Flatten3D(DisentModule): - def forward(self, x): - assert x.ndim == 4 - return x.view(*x.shape[0], -1) # (B, C, H, W) -> (B, C*H*W) - - -# ========================================================================= # -# END # -# ========================================================================= # +# encoders & decoders +from disent.nn.model.ae._conv64 import DecoderConv64 +from disent.nn.model.ae._conv64 import EncoderConv64 +from disent.nn.model.ae._conv64_alt import DecoderConv64Alt +from disent.nn.model.ae._conv64_alt import EncoderConv64Alt +from disent.nn.model.ae._fc import DecoderFC +from disent.nn.model.ae._fc import EncoderFC +from disent.nn.model.ae._simpleconv64 import DecoderSimpleConv64 +from disent.nn.model.ae._simpleconv64 import EncoderSimpleConv64 +from disent.nn.model.ae._simplefc import DecoderSimpleFC +from disent.nn.model.ae._simplefc import EncoderSimpleFC + +# auto-encoder wrapper +from disent.nn.model.ae.base import AutoEncoder +from disent.nn.model.ae.base import DisentEncoder +from disent.nn.model.ae.base import DisentDecoder diff --git a/disent/model/ae/_conv64.py b/disent/nn/model/ae/_conv64.py similarity index 95% rename from disent/model/ae/_conv64.py rename to disent/nn/model/ae/_conv64.py index b4567463..29696214 100644 --- a/disent/model/ae/_conv64.py +++ b/disent/nn/model/ae/_conv64.py @@ -24,8 +24,9 @@ from torch import nn as nn, Tensor -from disent.model.base import BaseEncoderModule, BaseDecoderModule -from disent.model.common import Flatten3D, BatchView +from disent.nn.model.ae.base import DisentEncoder, DisentDecoder +from disent.nn.modules import Flatten3D +from disent.nn.modules import BatchView # ========================================================================= # @@ -33,7 +34,7 @@ # ========================================================================= # -class EncoderConv64(BaseEncoderModule): +class EncoderConv64(DisentEncoder): """ Reference Implementation: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/methods/shared/architectures.py @@ -71,7 +72,7 @@ def encode(self, x) -> (Tensor, Tensor): return self.model(x) -class DecoderConv64(BaseDecoderModule): +class DecoderConv64(DisentDecoder): """ From: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/methods/shared/architectures.py diff --git a/disent/model/ae/_conv64_alt.py b/disent/nn/model/ae/_conv64_alt.py similarity index 96% rename from disent/model/ae/_conv64_alt.py rename to disent/nn/model/ae/_conv64_alt.py index d3709516..2e9e9c0b 100644 --- a/disent/model/ae/_conv64_alt.py +++ b/disent/nn/model/ae/_conv64_alt.py @@ -24,8 +24,9 @@ from torch import nn as nn, Tensor -from disent.model.base import BaseEncoderModule, BaseDecoderModule -from disent.model.common import Flatten3D, BatchView +from disent.nn.model.ae.base import DisentEncoder, DisentDecoder +from disent.nn.modules import Flatten3D +from disent.nn.modules import BatchView def _make_activations(activation='relu', inplace=True, norm='instance', num_features: int = None, norm_pre_act=True): @@ -61,7 +62,7 @@ def _make_activations(activation='relu', inplace=True, norm='instance', num_feat # ========================================================================= # -class EncoderConv64Alt(BaseEncoderModule): +class EncoderConv64Alt(DisentEncoder): """ Reference Implementation: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/methods/shared/architectures.py @@ -99,7 +100,7 @@ def encode(self, x) -> (Tensor, Tensor): return self.model(x) -class DecoderConv64Alt(BaseDecoderModule): +class DecoderConv64Alt(DisentDecoder): """ From: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/methods/shared/architectures.py diff --git a/disent/model/ae/_fc.py b/disent/nn/model/ae/_fc.py similarity index 94% rename from disent/model/ae/_fc.py rename to disent/nn/model/ae/_fc.py index bc73a463..c378a7e8 100644 --- a/disent/model/ae/_fc.py +++ b/disent/nn/model/ae/_fc.py @@ -25,8 +25,9 @@ import numpy as np from torch import nn as nn, Tensor -from disent.model.base import BaseEncoderModule, BaseDecoderModule -from disent.model.common import Flatten3D, BatchView +from disent.nn.model.ae.base import DisentEncoder, DisentDecoder +from disent.nn.modules import Flatten3D +from disent.nn.modules import BatchView # ========================================================================= # @@ -34,7 +35,7 @@ # ========================================================================= # -class EncoderFC(BaseEncoderModule): +class EncoderFC(DisentEncoder): """ Reference Implementation: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/methods/shared/architectures.py @@ -64,7 +65,7 @@ def encode(self, x) -> (Tensor, Tensor): return self.model(x) -class DecoderFC(BaseDecoderModule): +class DecoderFC(DisentDecoder): """ From: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/methods/shared/architectures.py diff --git a/disent/model/ae/_simpleconv64.py b/disent/nn/model/ae/_simpleconv64.py similarity index 95% rename from disent/model/ae/_simpleconv64.py rename to disent/nn/model/ae/_simpleconv64.py index 0e4a46ed..dff42f5c 100644 --- a/disent/model/ae/_simpleconv64.py +++ b/disent/nn/model/ae/_simpleconv64.py @@ -24,8 +24,9 @@ from torch import nn as nn, Tensor -from disent.model.base import BaseEncoderModule, BaseDecoderModule -from disent.model.common import Flatten3D, Unsqueeze3D +from disent.nn.model.ae.base import DisentEncoder, DisentDecoder +from disent.nn.modules import Flatten3D +from disent.nn.modules import Unsqueeze3D # ========================================================================= # @@ -33,7 +34,7 @@ # ========================================================================= # -class EncoderSimpleConv64(BaseEncoderModule): +class EncoderSimpleConv64(DisentEncoder): """ Reference Implementation: https://github.com/amir-abdi/disentanglement-pytorch # TODO: verify, things have changed... @@ -66,7 +67,7 @@ def encode(self, x) -> (Tensor, Tensor): return self.model(x) -class DecoderSimpleConv64(BaseDecoderModule): +class DecoderSimpleConv64(DisentDecoder): """ From: https://github.com/amir-abdi/disentanglement-pytorch # TODO: verify, things have changed... diff --git a/disent/model/ae/_simplefc.py b/disent/nn/model/ae/_simplefc.py similarity index 93% rename from disent/model/ae/_simplefc.py rename to disent/nn/model/ae/_simplefc.py index 852fa1e0..f3e6b2d0 100644 --- a/disent/model/ae/_simplefc.py +++ b/disent/nn/model/ae/_simplefc.py @@ -24,8 +24,9 @@ from torch import nn as nn, Tensor -from disent.model.base import BaseEncoderModule, BaseDecoderModule -from disent.model.common import Flatten3D, BatchView +from disent.nn.model.ae.base import DisentEncoder, DisentDecoder +from disent.nn.modules import Flatten3D +from disent.nn.modules import BatchView # ========================================================================= # @@ -33,7 +34,7 @@ # ========================================================================= # -class EncoderSimpleFC(BaseEncoderModule): +class EncoderSimpleFC(DisentEncoder): """ Custom Fully Connected Encoder. """ @@ -53,7 +54,7 @@ def encode(self, x) -> (Tensor, Tensor): return self.model(x) -class DecoderSimpleFC(BaseDecoderModule): +class DecoderSimpleFC(DisentDecoder): """ Custom Fully Connected Decoder. """ diff --git a/disent/model/base.py b/disent/nn/model/ae/base.py similarity index 71% rename from disent/model/base.py rename to disent/nn/model/ae/base.py index a55065cb..9f9ee32a 100644 --- a/disent/model/base.py +++ b/disent/nn/model/ae/base.py @@ -39,7 +39,7 @@ # ========================================================================= # -class BaseModule(DisentModule): +class DisentLatentsBase(DisentModule): def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1): super().__init__() @@ -78,11 +78,11 @@ def assert_lengths(self, x, z): # ========================================================================= # -# Custom Base nn.Module # +# Base Encoder & Base Decoder # # ========================================================================= # -class BaseEncoderModule(BaseModule): +class DisentEncoder(DisentLatentsBase): @final def forward(self, x) -> Tensor: @@ -102,7 +102,7 @@ def encode(self, x) -> Tensor: raise NotImplementedError -class BaseDecoderModule(BaseModule): +class DisentDecoder(DisentLatentsBase): def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1): assert z_multiplier == 1, 'decoder does not support z_multiplier != 1' @@ -124,6 +124,47 @@ def decode(self, z) -> Tensor: raise NotImplementedError +# ========================================================================= # +# Auto-Encoder Wrapper # +# ========================================================================= # + + +class AutoEncoder(DisentLatentsBase): + + def __init__(self, encoder: DisentEncoder, decoder: DisentDecoder): + assert isinstance(encoder, DisentEncoder) + assert isinstance(decoder, DisentDecoder) + # check sizes + assert encoder.x_shape == decoder.x_shape, 'x_shape mismatch' + assert encoder.x_size == decoder.x_size, 'x_size mismatch - this should never happen if x_shape matches' + assert encoder.z_size == decoder.z_size, 'z_size mismatch' + # initialise + super().__init__(x_shape=decoder.x_shape, z_size=decoder.z_size, z_multiplier=encoder.z_multiplier) + # assign + self._encoder = encoder + self._decoder = decoder + + def forward(self, x): + raise RuntimeError('This has been disabled') + + def encode(self, x): + z_raw = self._encoder(x) + # extract components if necessary + if self._z_multiplier == 1: + return z_raw + elif self.z_multiplier == 2: + return z_raw[..., :self.z_size], z_raw[..., self.z_size:] + else: + raise KeyError(f'z_multiplier={self.z_multiplier} is unsupported') + + def decode(self, z: Tensor) -> Tensor: + """ + decode the given representation. + the returned tensor does not have an activation applied to it! + """ + return self._decoder(z) + + # ========================================================================= # # END # # ========================================================================= # diff --git a/disent/nn/modules.py b/disent/nn/modules.py index 9de5f574..1f4169c6 100644 --- a/disent/nn/modules.py +++ b/disent/nn/modules.py @@ -50,6 +50,32 @@ def _forward_unimplemented(self, *args): raise RuntimeError('This should never run!') +# ========================================================================= # +# Utility Layers # +# ========================================================================= # + + +class BatchView(DisentModule): + def __init__(self, size): + super().__init__() + self._size = (-1, *size) + + def forward(self, x): + return x.view(*self._size) + + +class Unsqueeze3D(DisentModule): + def forward(self, x): + assert x.ndim == 2 + return x.view(*x.shape, 1, 1) # (B, N) -> (B, N, 1, 1) + + +class Flatten3D(DisentModule): + def forward(self, x): + assert x.ndim == 4 + return x.view(x.shape[0], -1) # (B, C, H, W) -> (B, C*H*W) + + # ========================================================================= # # END # # ========================================================================= # diff --git a/disent/frameworks/helper/reductions.py b/disent/nn/reductions.py similarity index 100% rename from disent/frameworks/helper/reductions.py rename to disent/nn/reductions.py diff --git a/disent/transform/__init__.py b/disent/nn/transform/__init__.py similarity index 100% rename from disent/transform/__init__.py rename to disent/nn/transform/__init__.py diff --git a/disent/transform/_augment.py b/disent/nn/transform/_augment.py similarity index 97% rename from disent/transform/_augment.py rename to disent/nn/transform/_augment.py index 938af078..0f4eb8ac 100644 --- a/disent/transform/_augment.py +++ b/disent/nn/transform/_augment.py @@ -235,8 +235,10 @@ def _check_kernel(kernel: torch.Tensor) -> torch.Tensor: # (REGEX, EXAMPLE, FACTORY_FUNC) # - factory function takes at min one arg: fn(reduction) with one arg after that per regex capture group # - regex expressions are tested in order, expressions should be mutually exclusive or ordered such that more specialized versions occur first. - (re.compile(r'^(xy8)_r(47)$'), 'xy8_r47', lambda kern, radius: torch.load(os.path.abspath(os.path.join(disent.__file__, '../../data/adversarial_kernel', 'r47-1_s28800_adam_lr0.003_wd0.0_xy8x8.pt')))), - (re.compile(r'^(xy1)_r(47)$'), 'xy1_r47', lambda kern, radius: torch.load(os.path.abspath(os.path.join(disent.__file__, '../../data/adversarial_kernel', 'r47-1_s28800_adam_lr0.003_wd0.0_xy1x1.pt')))), + (re.compile(r'^(xy8)_r(47)$'), 'xy8_r47', lambda kern, radius: torch.load(os.path.abspath(os.path.join(disent.__file__, + '../../../data/adversarial_kernel', 'r47-1_s28800_adam_lr0.003_wd0.0_xy8x8.pt')))), + (re.compile(r'^(xy1)_r(47)$'), 'xy1_r47', lambda kern, radius: torch.load(os.path.abspath(os.path.join(disent.__file__, + '../../../data/adversarial_kernel', 'r47-1_s28800_adam_lr0.003_wd0.0_xy1x1.pt')))), (re.compile(r'^(box)_r(\d+)$'), 'box_r31', lambda kern, radius: torch_box_kernel_2d(radius=int(radius))[None, ...]), (re.compile(r'^(gau)_r(\d+)$'), 'gau_r31', lambda kern, radius: torch_gaussian_kernel_2d(sigma=int(radius) / 4.0, truncate=4.0)[None, None, ...]), ] diff --git a/disent/transform/_transforms.py b/disent/nn/transform/_transforms.py similarity index 98% rename from disent/transform/_transforms.py rename to disent/nn/transform/_transforms.py index b4e67637..60fb70a8 100644 --- a/disent/transform/_transforms.py +++ b/disent/nn/transform/_transforms.py @@ -23,7 +23,7 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import torch -import disent.transform.functional as F_d +import disent.nn.transform.functional as F_d # ========================================================================= # diff --git a/disent/transform/functional.py b/disent/nn/transform/functional.py similarity index 100% rename from disent/transform/functional.py rename to disent/nn/transform/functional.py diff --git a/disent/transform/groundtruth.py b/disent/nn/transform/groundtruth.py similarity index 100% rename from disent/transform/groundtruth.py rename to disent/nn/transform/groundtruth.py diff --git a/disent/model/init.py b/disent/nn/weights.py similarity index 86% rename from disent/model/init.py rename to disent/nn/weights.py index 991f1f8c..7f181255 100644 --- a/disent/model/init.py +++ b/disent/nn/weights.py @@ -25,7 +25,6 @@ import logging from torch import nn - from disent.util import colors as c @@ -33,11 +32,11 @@ # ========================================================================= # -# Helper # +# Basic Weight Initialisation # # ========================================================================= # -def init_model_weights(model: nn.Module, mode='xavier_normal'): +def init_model_weights(model: nn.Module, mode='xavier_normal', log_level=logging.INFO): count = 0 # get default mode @@ -61,11 +60,11 @@ def init_normal(m): # print messages if init: - log.info(f'| {count:03d} {c.lGRN}INIT{c.RST}: {m.__class__.__name__}') + log.log(log_level, f'| {count:03d} {c.lGRN}INIT{c.RST}: {m.__class__.__name__}') else: - log.info(f'| {count:03d} {c.lRED}SKIP{c.RST}: {m.__class__.__name__}') + log.log(log_level, f'| {count:03d} {c.lRED}SKIP{c.RST}: {m.__class__.__name__}') - log.info(f'Initialising Model Layers: {mode}') + log.log(log_level, f'Initialising Model Layers: {mode}') model.apply(init_normal) return model diff --git a/disent/util/__init__.py b/disent/util/__init__.py index 14265d90..4fadebd8 100644 --- a/disent/util/__init__.py +++ b/disent/util/__init__.py @@ -27,16 +27,10 @@ import os import time from collections import Sequence -from dataclasses import asdict -from dataclasses import dataclass -from dataclasses import fields from itertools import islice -from pprint import pformat -from random import random from typing import List import numpy as np -import pytorch_lightning as pl import torch @@ -125,7 +119,6 @@ def to_numpy(array) -> np.ndarray: return np.array(array) - # ========================================================================= # # Iterators # # ========================================================================= # @@ -425,68 +418,6 @@ def get_memory_usage(): return num_bytes -# ========================================================================= # -# Slot Tuple # -# ========================================================================= # - - -@dataclass -class TupleDataClass: - """ - Like a named tuple + dataclass combination, that is mutable. - -- requires that you still decorate the inherited class with @dataclass - """ - - __field_names_cache = None - - @property - def __field_names(self): - # check for attribute and set on class only - if self.__class__.__field_names_cache is None: - self.__class__.__field_names_cache = tuple(f.name for f in fields(self)) - return self.__class__.__field_names_cache - - def __iter__(self): - for name in self.__field_names: - yield getattr(self, name) - - def __len__(self): - return self.__field_names.__len__() - - def __str__(self): - return str(tuple(self)) - - def __repr__(self): - return f'{self.__class__.__name__}({", ".join(f"{name}={repr(getattr(self, name))}" for name in self.__field_names)})' - - -# ========================================================================= # -# END # -# ========================================================================= # - - -def debug_transform_tensors(obj): - """ - recursively convert all tensors to their shapes for debugging - """ - if isinstance(obj, (torch.Tensor, np.ndarray)): - return obj.shape - elif isinstance(obj, dict): - return {debug_transform_tensors(k): debug_transform_tensors(v) for k, v in obj.items()} - elif isinstance(obj, list): - return list(debug_transform_tensors(v) for v in obj) - elif isinstance(obj, tuple): - return tuple(debug_transform_tensors(v) for v in obj) - elif isinstance(obj, set): - return {debug_transform_tensors(k) for k in obj} - else: - return obj - - -def pprint_tensors(*args, **kwargs): - print(*(debug_transform_tensors(arg) for arg in args), **kwargs) - - # ========================================================================= # # END # # ========================================================================= # diff --git a/disent/util/math.py b/disent/util/math.py index a2db7918..8647852a 100644 --- a/disent/util/math.py +++ b/disent/util/math.py @@ -561,5 +561,29 @@ def torch_conv2d_channel_wise_fft(signal, kernel): # ========================================================================= # -# end # +# DEBUG # # ========================================================================= # + + +def debug_transform_tensors(obj): + """ + recursively convert all tensors to their shapes for debugging + """ + if isinstance(obj, (torch.Tensor, np.ndarray)): + return obj.shape + elif isinstance(obj, dict): + return {debug_transform_tensors(k): debug_transform_tensors(v) for k, v in obj.items()} + elif isinstance(obj, list): + return list(debug_transform_tensors(v) for v in obj) + elif isinstance(obj, tuple): + return tuple(debug_transform_tensors(v) for v in obj) + elif isinstance(obj, set): + return {debug_transform_tensors(k) for k in obj} + else: + return obj + + +# ========================================================================= # +# END # +# ========================================================================= # + diff --git a/docs/examples/mnist_example.py b/docs/examples/mnist_example.py index cbccc2c2..e8e81919 100644 --- a/docs/examples/mnist_example.py +++ b/docs/examples/mnist_example.py @@ -9,10 +9,8 @@ from disent.dataset.random import RandomDataset from disent.frameworks.vae import AdaVae -from disent.model.ae import AutoEncoder -from disent.model.ae import DecoderConv64Alt -from disent.model.ae import EncoderConv64Alt -from disent.transform import ToStandardisedTensor +from disent.nn.model.ae import AutoEncoder, DecoderConv64Alt, EncoderConv64Alt +from disent.nn.transform import ToStandardisedTensor from disent.util import is_test_run diff --git a/docs/examples/overview_dataset_loader.py b/docs/examples/overview_dataset_loader.py index 793e2a54..1e3148b8 100644 --- a/docs/examples/overview_dataset_loader.py +++ b/docs/examples/overview_dataset_loader.py @@ -1,7 +1,8 @@ -from torch.utils.data import Dataset, DataLoader -from disent.data.groundtruth import XYSquaresData, GroundTruthData +from torch.utils.data import DataLoader, Dataset +from disent.data.groundtruth import GroundTruthData, XYSquaresData from disent.dataset.groundtruth import GroundTruthDatasetPairs -from disent.transform import ToStandardisedTensor +from disent.nn.transform import ToStandardisedTensor + data: GroundTruthData = XYSquaresData(square_size=1, image_size=2, num_squares=2) dataset: Dataset = GroundTruthDatasetPairs(data, transform=ToStandardisedTensor(), augment=None) diff --git a/docs/examples/overview_dataset_pair.py b/docs/examples/overview_dataset_pair.py index b18a57b1..b3ee375a 100644 --- a/docs/examples/overview_dataset_pair.py +++ b/docs/examples/overview_dataset_pair.py @@ -1,7 +1,8 @@ from torch.utils.data import Dataset -from disent.data.groundtruth import XYSquaresData, GroundTruthData +from disent.data.groundtruth import GroundTruthData, XYSquaresData from disent.dataset.groundtruth import GroundTruthDatasetPairs -from disent.transform import ToStandardisedTensor +from disent.nn.transform import ToStandardisedTensor + data: GroundTruthData = XYSquaresData(square_size=1, image_size=2, num_squares=2) dataset: Dataset = GroundTruthDatasetPairs(data, transform=ToStandardisedTensor(), augment=None) diff --git a/docs/examples/overview_dataset_pair_augment.py b/docs/examples/overview_dataset_pair_augment.py index 17080512..8ef099de 100644 --- a/docs/examples/overview_dataset_pair_augment.py +++ b/docs/examples/overview_dataset_pair_augment.py @@ -1,7 +1,8 @@ from torch.utils.data import Dataset -from disent.data.groundtruth import XYSquaresData, GroundTruthData +from disent.data.groundtruth import GroundTruthData, XYSquaresData from disent.dataset.groundtruth import GroundTruthDatasetPairs -from disent.transform import ToStandardisedTensor, FftBoxBlur +from disent.nn.transform import FftBoxBlur, ToStandardisedTensor + data: GroundTruthData = XYSquaresData(square_size=1, image_size=2, num_squares=2) dataset: Dataset = GroundTruthDatasetPairs(data, transform=ToStandardisedTensor(), augment=FftBoxBlur(radius=1, p=1.0)) diff --git a/docs/examples/overview_dataset_single.py b/docs/examples/overview_dataset_single.py index a400620a..38e8b4f0 100644 --- a/docs/examples/overview_dataset_single.py +++ b/docs/examples/overview_dataset_single.py @@ -1,7 +1,8 @@ from torch.utils.data import Dataset -from disent.data.groundtruth import XYSquaresData, GroundTruthData +from disent.data.groundtruth import GroundTruthData, XYSquaresData from disent.dataset.groundtruth import GroundTruthDataset + data: GroundTruthData = XYSquaresData(square_size=1, image_size=2, num_squares=2) dataset: Dataset = GroundTruthDataset(data, transform=None, augment=None) diff --git a/docs/examples/overview_framework_adagvae.py b/docs/examples/overview_framework_adagvae.py index 64c5c373..1da81ea5 100644 --- a/docs/examples/overview_framework_adagvae.py +++ b/docs/examples/overview_framework_adagvae.py @@ -1,13 +1,14 @@ import pytorch_lightning as pl from torch.optim import Adam -from torch.utils.data import Dataset, DataLoader -from disent.data.groundtruth import XYSquaresData, GroundTruthData +from torch.utils.data import DataLoader, Dataset +from disent.data.groundtruth import GroundTruthData, XYSquaresData from disent.dataset.groundtruth import GroundTruthDatasetOrigWeakPairs from disent.frameworks.vae import AdaVae -from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder -from disent.transform import ToStandardisedTensor +from disent.nn.model.ae import AutoEncoder, DecoderConv64, EncoderConv64 +from disent.nn.transform import ToStandardisedTensor from disent.util import is_test_run + data: GroundTruthData = XYSquaresData() dataset: Dataset = GroundTruthDatasetOrigWeakPairs(data, transform=ToStandardisedTensor()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) diff --git a/docs/examples/overview_framework_ae.py b/docs/examples/overview_framework_ae.py index f18a5ab2..04696479 100644 --- a/docs/examples/overview_framework_ae.py +++ b/docs/examples/overview_framework_ae.py @@ -1,13 +1,14 @@ import pytorch_lightning as pl from torch.optim import Adam -from torch.utils.data import Dataset, DataLoader -from disent.data.groundtruth import XYSquaresData, GroundTruthData +from torch.utils.data import DataLoader, Dataset +from disent.data.groundtruth import GroundTruthData, XYSquaresData from disent.dataset.groundtruth import GroundTruthDataset from disent.frameworks.ae import Ae -from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder -from disent.transform import ToStandardisedTensor +from disent.nn.model.ae import AutoEncoder, DecoderConv64, EncoderConv64 +from disent.nn.transform import ToStandardisedTensor from disent.util import is_test_run + data: GroundTruthData = XYSquaresData() dataset: Dataset = GroundTruthDataset(data, transform=ToStandardisedTensor()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) diff --git a/docs/examples/overview_framework_betavae.py b/docs/examples/overview_framework_betavae.py index e8664655..60b4ff1d 100644 --- a/docs/examples/overview_framework_betavae.py +++ b/docs/examples/overview_framework_betavae.py @@ -1,13 +1,14 @@ import pytorch_lightning as pl from torch.optim import Adam -from torch.utils.data import Dataset, DataLoader -from disent.data.groundtruth import XYSquaresData, GroundTruthData +from torch.utils.data import DataLoader, Dataset +from disent.data.groundtruth import GroundTruthData, XYSquaresData from disent.dataset.groundtruth import GroundTruthDataset from disent.frameworks.vae import BetaVae -from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder -from disent.transform import ToStandardisedTensor +from disent.nn.model.ae import AutoEncoder, DecoderConv64, EncoderConv64 +from disent.nn.transform import ToStandardisedTensor from disent.util import is_test_run + data: GroundTruthData = XYSquaresData() dataset: Dataset = GroundTruthDataset(data, transform=ToStandardisedTensor()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) diff --git a/docs/examples/overview_framework_betavae_scheduled.py b/docs/examples/overview_framework_betavae_scheduled.py index a3e06031..79e3df58 100644 --- a/docs/examples/overview_framework_betavae_scheduled.py +++ b/docs/examples/overview_framework_betavae_scheduled.py @@ -1,14 +1,15 @@ import pytorch_lightning as pl from torch.optim import Adam -from torch.utils.data import Dataset, DataLoader -from disent.data.groundtruth import XYSquaresData, GroundTruthData +from torch.utils.data import DataLoader, Dataset +from disent.data.groundtruth import GroundTruthData, XYSquaresData from disent.dataset.groundtruth import GroundTruthDataset from disent.frameworks.vae import BetaVae -from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder +from disent.nn.model.ae import AutoEncoder, DecoderConv64, EncoderConv64 +from disent.nn.transform import ToStandardisedTensor from disent.schedule import CyclicSchedule -from disent.transform import ToStandardisedTensor from disent.util import is_test_run + data: GroundTruthData = XYSquaresData() dataset: Dataset = GroundTruthDataset(data, transform=ToStandardisedTensor()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) diff --git a/docs/examples/overview_metrics.py b/docs/examples/overview_metrics.py index 1dfb4eab..2ee55a86 100644 --- a/docs/examples/overview_metrics.py +++ b/docs/examples/overview_metrics.py @@ -5,10 +5,11 @@ from disent.dataset.groundtruth import GroundTruthDataset from disent.frameworks.vae import BetaVae from disent.metrics import metric_dci, metric_mig -from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder -from disent.transform import ToStandardisedTensor +from disent.nn.model.ae import AutoEncoder, DecoderConv64, EncoderConv64 +from disent.nn.transform import ToStandardisedTensor from disent.util import is_test_run + data = XYObjectData() dataset = GroundTruthDataset(data, transform=ToStandardisedTensor()) dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True) diff --git a/experiment/exp/01_visual_overlap/run.py b/experiment/exp/01_visual_overlap/run.py index 8f73a684..aa3ed2ca 100644 --- a/experiment/exp/01_visual_overlap/run.py +++ b/experiment/exp/01_visual_overlap/run.py @@ -38,7 +38,7 @@ import experiment.exp.util as H from disent.data.groundtruth import * from disent.dataset.groundtruth import GroundTruthDataset -from disent.transform import ToStandardisedTensor +from disent.nn.transform import ToStandardisedTensor from disent.util import to_numpy diff --git a/experiment/exp/util/_dataset.py b/experiment/exp/util/_dataset.py index c15efaea..f30fc149 100644 --- a/experiment/exp/util/_dataset.py +++ b/experiment/exp/util/_dataset.py @@ -42,7 +42,7 @@ from disent.data.groundtruth import XYSquaresData from disent.dataset.groundtruth import GroundTruthDataset from disent.dataset.groundtruth import GroundTruthDatasetAndFactors -from disent.transform import ToStandardisedTensor +from disent.nn.transform import ToStandardisedTensor from disent.util import TempNumpySeed from disent.visualize.visualize_util import make_animated_image_grid from disent.visualize.visualize_util import make_image_grid diff --git a/experiment/exp/util/_loss.py b/experiment/exp/util/_loss.py index 0051bda1..0e54ba3e 100644 --- a/experiment/exp/util/_loss.py +++ b/experiment/exp/util/_loss.py @@ -26,7 +26,7 @@ import torch_optimizer from torch.nn import functional as F -from disent.frameworks.helper.reductions import batch_loss_reduction +from disent.nn.reductions import batch_loss_reduction # ========================================================================= # diff --git a/experiment/run.py b/experiment/run.py index f02a3960..4b5eff87 100644 --- a/experiment/run.py +++ b/experiment/run.py @@ -22,7 +22,6 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import dataclasses import logging import os @@ -38,8 +37,8 @@ from disent import metrics from disent.frameworks.framework import BaseFramework -from disent.model.ae.base import AutoEncoder -from disent.model.init import init_model_weights +from disent.nn.model.ae import AutoEncoder +from disent.nn.weights import init_model_weights from disent.util import DisentConfigurable from disent.util import make_box_str from experiment.util.callbacks import LoggerProgressCallback diff --git a/experiment/util/hydra_data.py b/experiment/util/hydra_data.py index 8c03bfd7..598b1545 100644 --- a/experiment/util/hydra_data.py +++ b/experiment/util/hydra_data.py @@ -28,7 +28,7 @@ from omegaconf import DictConfig from disent.dataset._augment_util import AugmentableDataset -from disent.transform.groundtruth import GroundTruthDatasetBatchAugment +from disent.nn.transform import GroundTruthDatasetBatchAugment from experiment.util.hydra_utils import instantiate_recursive diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index 8df74e59..c6eabeb6 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -38,10 +38,10 @@ from disent.frameworks.ae.experimental import * from disent.frameworks.vae import * from disent.frameworks.vae.experimental import * -from disent.model.ae import AutoEncoder -from disent.model.ae import DecoderConv64 -from disent.model.ae import EncoderConv64 -from disent.transform import ToStandardisedTensor +from disent.nn.model.ae import AutoEncoder +from disent.nn.model.ae import DecoderConv64 +from disent.nn.model.ae import EncoderConv64 +from disent.nn.transform import ToStandardisedTensor # ========================================================================= # diff --git a/tests/test_math.py b/tests/test_math.py index 756dd61f..12835511 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -30,7 +30,7 @@ from disent.data.groundtruth import XYSquaresData from disent.dataset.groundtruth import GroundTruthDataset -from disent.transform import ToStandardisedTensor +from disent.nn.transform import ToStandardisedTensor from disent.util.math import torch_conv2d_channel_wise from disent.util.math import torch_conv2d_channel_wise_fft from disent.util import to_numpy diff --git a/tests/test_transform.py b/tests/test_transform.py index 0e901acf..8d91657a 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -24,8 +24,8 @@ import pytest import torch -from disent.transform._augment import _expand_to_min_max_tuples -from disent.transform._augment import FftGaussianBlur +from disent.nn.transform import FftGaussianBlur +from disent.nn.transform._augment import _expand_to_min_max_tuples from disent.util.math import torch_gaussian_kernel from disent.util.math import torch_gaussian_kernel_2d @@ -58,6 +58,7 @@ def test_fft_guassian_blur_sigmas(): with pytest.raises(Exception): _expand_to_min_max_tuples([0.0, [1.0, 2.0]]) + def test_fft_guassian_blur(): fn = FftGaussianBlur(sigma=1.0, truncate=3.0) fn(torch.randn(256, 3, 64, 64)) From 5035b4edbba093f784d0cee47bd9e52e69548d32 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Wed, 2 Jun 2021 00:44:23 +0200 Subject: [PATCH 13/34] dataset fixes --- disent/dataset/__init__.py | 7 +++++-- disent/dataset/{_augment_util.py => _base.py} | 13 +++++++------ disent/dataset/groundtruth/_single.py | 12 +++++++++--- disent/dataset/random/_random_dataset.py | 7 +++---- disent/util/__init__.py | 4 ++-- 5 files changed, 26 insertions(+), 17 deletions(-) rename disent/dataset/{_augment_util.py => _base.py} (97%) diff --git a/disent/dataset/__init__.py b/disent/dataset/__init__.py index 78634246..9ee9532f 100644 --- a/disent/dataset/__init__.py +++ b/disent/dataset/__init__.py @@ -1,5 +1,3 @@ - - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # MIT License # @@ -23,3 +21,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + + +# expose base dataset +from disent.dataset._base import DisentDataset + diff --git a/disent/dataset/_augment_util.py b/disent/dataset/_base.py similarity index 97% rename from disent/dataset/_augment_util.py rename to disent/dataset/_base.py index 6740dfe9..869071aa 100644 --- a/disent/dataset/_augment_util.py +++ b/disent/dataset/_base.py @@ -22,19 +22,23 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import numpy as np from abc import abstractmethod -from typing import Optional, List +from typing import List +from typing import Optional +import numpy as np +from torch.utils.data import Dataset from torch.utils.data.dataloader import default_collate +from disent.util import LengthIter + # ========================================================================= # # util # # ========================================================================= # -class AugmentableDataset(object): +class DisentDataset(Dataset, LengthIter): @property @abstractmethod @@ -49,9 +53,6 @@ def augment(self) -> Optional[callable]: def _get_augmentable_observation(self, idx): raise NotImplementedError - def __len__(self): - raise NotImplementedError - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Single Datapoints # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # diff --git a/disent/dataset/groundtruth/_single.py b/disent/dataset/groundtruth/_single.py index 0e1d17db..7aa73c58 100644 --- a/disent/dataset/groundtruth/_single.py +++ b/disent/dataset/groundtruth/_single.py @@ -24,11 +24,12 @@ import logging from typing import Tuple + import numpy as np -from torch.utils.data import Dataset from torch.utils.data.dataloader import default_collate + from disent.data.groundtruth.base import GroundTruthData -from disent.dataset._augment_util import AugmentableDataset +from disent.dataset import DisentDataset log = logging.getLogger(__name__) @@ -39,7 +40,7 @@ # ========================================================================= # -class GroundTruthDataset(Dataset, GroundTruthData, AugmentableDataset): +class GroundTruthDataset(DisentDataset, GroundTruthData): # TODO: these transformations should be a wrapper around any dataset. # for example: dataset = AugmentedDataset(GroundTruthDataset(XYGridData())) @@ -111,6 +112,11 @@ def dataset_sample_batch(self, num_samples: int, mode: str): # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # +# ========================================================================= # +# EXTRA # +# ========================================================================= # + + class GroundTruthDatasetAndFactors(GroundTruthDataset): def dataset_get_observation(self, *idxs): return { diff --git a/disent/dataset/random/_random_dataset.py b/disent/dataset/random/_random_dataset.py index 475bf315..c0b45fba 100644 --- a/disent/dataset/random/_random_dataset.py +++ b/disent/dataset/random/_random_dataset.py @@ -21,12 +21,11 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + from typing import Sequence import numpy as np -from torch.utils.data import Dataset -from disent.dataset._augment_util import AugmentableDataset -from disent.util import LengthIter +from disent.dataset import DisentDataset # ========================================================================= # @@ -34,7 +33,7 @@ # ========================================================================= # -class RandomDataset(Dataset, LengthIter, AugmentableDataset): +class RandomDataset(DisentDataset): def __init__( self, diff --git a/disent/util/__init__.py b/disent/util/__init__.py index 4fadebd8..857b98d3 100644 --- a/disent/util/__init__.py +++ b/disent/util/__init__.py @@ -272,10 +272,10 @@ def __iter__(self): yield self[i] def __len__(self): - raise NotImplemented() + raise NotImplementedError() def __getitem__(self, item): - raise NotImplemented() + raise NotImplementedError() # ========================================================================= # From 006a225546f391522c0e5a3a769e62f7e71c0221 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Wed, 2 Jun 2021 00:50:32 +0200 Subject: [PATCH 14/34] cleaned up frameworks imports --- disent/frameworks/__init__.py | 5 +++++ disent/frameworks/{framework.py => _framework.py} | 2 +- disent/frameworks/ae/_unsupervised__ae.py | 6 +++--- disent/nn/model/ae/_conv64.py | 4 ++-- disent/nn/model/ae/_conv64_alt.py | 4 ++-- disent/nn/model/ae/_fc.py | 4 ++-- experiment/run.py | 6 +++--- 7 files changed, 18 insertions(+), 13 deletions(-) rename disent/frameworks/{framework.py => _framework.py} (99%) diff --git a/disent/frameworks/__init__.py b/disent/frameworks/__init__.py index 9a05a479..5c9bdaf1 100644 --- a/disent/frameworks/__init__.py +++ b/disent/frameworks/__init__.py @@ -21,3 +21,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + + +# export +from disent.frameworks._framework import DisentConfigurable +from disent.frameworks._framework import DisentFramework diff --git a/disent/frameworks/framework.py b/disent/frameworks/_framework.py similarity index 99% rename from disent/frameworks/framework.py rename to disent/frameworks/_framework.py index 20d2a3c1..081575d1 100644 --- a/disent/frameworks/framework.py +++ b/disent/frameworks/_framework.py @@ -76,7 +76,7 @@ def __init__(self, cfg: cfg = cfg()): # ========================================================================= # -class BaseFramework(DisentConfigurable, DisentLightningModule): +class DisentFramework(DisentConfigurable, DisentLightningModule): @dataclass class cfg(DisentConfigurable.cfg): diff --git a/disent/frameworks/ae/_unsupervised__ae.py b/disent/frameworks/ae/_unsupervised__ae.py index 091db822..efc275d4 100644 --- a/disent/frameworks/ae/_unsupervised__ae.py +++ b/disent/frameworks/ae/_unsupervised__ae.py @@ -35,7 +35,7 @@ import torch -from disent.frameworks.framework import BaseFramework +from disent.frameworks import DisentFramework from disent.frameworks.helper.reconstructions import make_reconstruction_loss from disent.frameworks.helper.reconstructions import ReconLossHandler from disent.frameworks.helper.util import detach_all @@ -51,7 +51,7 @@ # ========================================================================= # -class Ae(BaseFramework): +class Ae(DisentFramework): """ Basic Auto Encoder ------------------ @@ -77,7 +77,7 @@ class Ae(BaseFramework): REQUIRED_OBS = 1 @dataclass - class cfg(BaseFramework.cfg): + class cfg(DisentFramework.cfg): recon_loss: str = 'mse' # multiple reduction modes exist for the various loss components. # - 'sum': sum over the entire batch diff --git a/disent/nn/model/ae/_conv64.py b/disent/nn/model/ae/_conv64.py index 29696214..c46ecffa 100644 --- a/disent/nn/model/ae/_conv64.py +++ b/disent/nn/model/ae/_conv64.py @@ -45,7 +45,7 @@ def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1): """ Convolutional encoder used in beta-VAE paper for the chairs data. Based on row 3 of Table 1 on page 13 of "beta-VAE: Learning Basic Visual - Concepts with a Constrained Variational BaseFramework" + Concepts with a Constrained Variational Framework" (https://openreview.net/forum?id=Sy2fzU9gl) """ # checks @@ -83,7 +83,7 @@ def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1): """ Convolutional decoder used in beta-VAE paper for the chairs data. Based on row 3 of Table 1 on page 13 of "beta-VAE: Learning Basic Visual - Concepts with a Constrained Variational BaseFramework" + Concepts with a Constrained Variational Framework" (https://openreview.net/forum?id=Sy2fzU9gl) """ assert tuple(x_shape[1:]) == (64, 64), 'This model only works with image size 64x64.' diff --git a/disent/nn/model/ae/_conv64_alt.py b/disent/nn/model/ae/_conv64_alt.py index 2e9e9c0b..a735eaa2 100644 --- a/disent/nn/model/ae/_conv64_alt.py +++ b/disent/nn/model/ae/_conv64_alt.py @@ -73,7 +73,7 @@ def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1, activation='le """ Convolutional encoder used in beta-VAE paper for the chairs data. Based on row 3 of Table 1 on page 13 of "beta-VAE: Learning Basic Visual - Concepts with a Constrained Variational BaseFramework" + Concepts with a Constrained Variational Framework" (https://openreview.net/forum?id=Sy2fzU9gl) """ # checks @@ -111,7 +111,7 @@ def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1, activation='le """ Convolutional decoder used in beta-VAE paper for the chairs data. Based on row 3 of Table 1 on page 13 of "beta-VAE: Learning Basic Visual - Concepts with a Constrained Variational BaseFramework" + Concepts with a Constrained Variational Framework" (https://openreview.net/forum?id=Sy2fzU9gl) """ assert tuple(x_shape[1:]) == (64, 64), 'This model only works with image size 64x64.' diff --git a/disent/nn/model/ae/_fc.py b/disent/nn/model/ae/_fc.py index c378a7e8..0142d62d 100644 --- a/disent/nn/model/ae/_fc.py +++ b/disent/nn/model/ae/_fc.py @@ -46,7 +46,7 @@ def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1): """ Fully connected encoder used in beta-VAE paper for the dSprites data. Based on row 1 of Table 1 on page 13 of "beta-VAE: Learning Basic Visual - Concepts with a Constrained Variational BaseFramework" + Concepts with a Constrained Variational Framework" (https://openreview.net/forum?id=Sy2fzU9gl). """ # checks @@ -76,7 +76,7 @@ def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1): """ Fully connected encoder used in beta-VAE paper for the dSprites data. Based on row 1 of Table 1 on page 13 of "beta-VAE: Learning Basic Visual - Concepts with a Constrained Variational BaseFramework" + Concepts with a Constrained Variational Framework" (https://openreview.net/forum?id=Sy2fzU9gl) """ super().__init__(x_shape=x_shape, z_size=z_size, z_multiplier=z_multiplier) diff --git a/experiment/run.py b/experiment/run.py index 4b5eff87..056a5c83 100644 --- a/experiment/run.py +++ b/experiment/run.py @@ -36,10 +36,10 @@ from pytorch_lightning.loggers import WandbLogger from disent import metrics -from disent.frameworks.framework import BaseFramework +from disent.frameworks import DisentConfigurable +from disent.frameworks import DisentFramework from disent.nn.model.ae import AutoEncoder from disent.nn.weights import init_model_weights -from disent.util import DisentConfigurable from disent.util import make_box_str from experiment.util.callbacks import LoggerProgressCallback from experiment.util.callbacks import VaeDisentanglementLoggingCallback @@ -195,7 +195,7 @@ def hydra_append_correlation_callback(callbacks, cfg): )) -def hydra_register_schedules(module: BaseFramework, cfg): +def hydra_register_schedules(module: DisentFramework, cfg): if cfg.schedules is None: cfg.schedules = {} if cfg.schedules: From 7738e6252247cbc6e15228e2c008356ef72b16ef Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Wed, 2 Jun 2021 01:03:40 +0200 Subject: [PATCH 15/34] moved models back into their own folder --- README.md | 3 ++- disent/frameworks/ae/_unsupervised__ae.py | 2 +- disent/{nn => }/model/__init__.py | 8 +++--- .../{nn/model/ae/base.py => model/_base.py} | 6 +++-- disent/{nn => }/model/ae/__init__.py | 25 ++++++++----------- disent/{nn => }/model/ae/_conv64.py | 3 ++- disent/{nn => }/model/ae/_conv64_alt.py | 3 ++- disent/{nn => }/model/ae/_fc.py | 3 ++- disent/{nn => }/model/ae/_simpleconv64.py | 3 ++- disent/{nn => }/model/ae/_simplefc.py | 3 ++- docs/examples/mnist_example.py | 3 ++- docs/examples/overview_framework_adagvae.py | 3 ++- docs/examples/overview_framework_ae.py | 3 ++- docs/examples/overview_framework_betavae.py | 3 ++- .../overview_framework_betavae_scheduled.py | 3 ++- docs/examples/overview_metrics.py | 3 ++- experiment/run.py | 2 +- tests/test_frameworks.py | 6 ++--- 18 files changed, 47 insertions(+), 38 deletions(-) rename disent/{nn => }/model/__init__.py (89%) rename disent/{nn/model/ae/base.py => model/_base.py} (98%) rename disent/{nn => }/model/ae/__init__.py (64%) rename disent/{nn => }/model/ae/_conv64.py (98%) rename disent/{nn => }/model/ae/_conv64_alt.py (98%) rename disent/{nn => }/model/ae/_fc.py (98%) rename disent/{nn => }/model/ae/_simpleconv64.py (98%) rename disent/{nn => }/model/ae/_simplefc.py (97%) diff --git a/README.md b/README.md index 55b76407..bd8c52f7 100644 --- a/README.md +++ b/README.md @@ -241,7 +241,8 @@ from disent.data.groundtruth import XYObjectData from disent.dataset.groundtruth import GroundTruthDataset from disent.frameworks.vae import BetaVae from disent.metrics import metric_dci, metric_mig -from disent.nn.model.ae import AutoEncoder, EncoderConv64, DecoderConv64 +from disent.model.ae import EncoderConv64, DecoderConv64 +from disent.model import AutoEncoder from disent.nn.transform import ToStandardisedTensor from disent.schedule import CyclicSchedule diff --git a/disent/frameworks/ae/_unsupervised__ae.py b/disent/frameworks/ae/_unsupervised__ae.py index efc275d4..69a17615 100644 --- a/disent/frameworks/ae/_unsupervised__ae.py +++ b/disent/frameworks/ae/_unsupervised__ae.py @@ -39,7 +39,7 @@ from disent.frameworks.helper.reconstructions import make_reconstruction_loss from disent.frameworks.helper.reconstructions import ReconLossHandler from disent.frameworks.helper.util import detach_all -from disent.nn.model.ae import AutoEncoder +from disent.model import AutoEncoder from disent.util import map_all diff --git a/disent/nn/model/__init__.py b/disent/model/__init__.py similarity index 89% rename from disent/nn/model/__init__.py rename to disent/model/__init__.py index fdc7c167..133b7f1b 100644 --- a/disent/nn/model/__init__.py +++ b/disent/model/__init__.py @@ -22,7 +22,7 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# disent base modules -from disent.nn.model.ae import AutoEncoder -from disent.nn.model.ae import DisentEncoder -from disent.nn.model.ae import DisentDecoder +# encoders & decoders +from disent.model._base import AutoEncoder +from disent.model._base import DisentEncoder +from disent.model._base import DisentDecoder diff --git a/disent/nn/model/ae/base.py b/disent/model/_base.py similarity index 98% rename from disent/nn/model/ae/base.py rename to disent/model/_base.py index 9f9ee32a..c8056971 100644 --- a/disent/nn/model/ae/base.py +++ b/disent/model/_base.py @@ -22,6 +22,7 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + import logging from typing import final @@ -35,7 +36,7 @@ # ========================================================================= # -# Custom Base nn.Module # +# Custom Base Module Involving Inputs & Representations # # ========================================================================= # @@ -103,7 +104,7 @@ def encode(self, x) -> Tensor: class DisentDecoder(DisentLatentsBase): - + def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1): assert z_multiplier == 1, 'decoder does not support z_multiplier != 1' super().__init__(x_shape=x_shape, z_size=z_size, z_multiplier=z_multiplier) @@ -168,3 +169,4 @@ def decode(self, z: Tensor) -> Tensor: # ========================================================================= # # END # # ========================================================================= # + diff --git a/disent/nn/model/ae/__init__.py b/disent/model/ae/__init__.py similarity index 64% rename from disent/nn/model/ae/__init__.py rename to disent/model/ae/__init__.py index fd679e8b..59ed8667 100644 --- a/disent/nn/model/ae/__init__.py +++ b/disent/model/ae/__init__.py @@ -23,18 +23,13 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # encoders & decoders -from disent.nn.model.ae._conv64 import DecoderConv64 -from disent.nn.model.ae._conv64 import EncoderConv64 -from disent.nn.model.ae._conv64_alt import DecoderConv64Alt -from disent.nn.model.ae._conv64_alt import EncoderConv64Alt -from disent.nn.model.ae._fc import DecoderFC -from disent.nn.model.ae._fc import EncoderFC -from disent.nn.model.ae._simpleconv64 import DecoderSimpleConv64 -from disent.nn.model.ae._simpleconv64 import EncoderSimpleConv64 -from disent.nn.model.ae._simplefc import DecoderSimpleFC -from disent.nn.model.ae._simplefc import EncoderSimpleFC - -# auto-encoder wrapper -from disent.nn.model.ae.base import AutoEncoder -from disent.nn.model.ae.base import DisentEncoder -from disent.nn.model.ae.base import DisentDecoder +from disent.model.ae._conv64 import DecoderConv64 +from disent.model.ae._conv64 import EncoderConv64 +from disent.model.ae._conv64_alt import DecoderConv64Alt +from disent.model.ae._conv64_alt import EncoderConv64Alt +from disent.model.ae._fc import DecoderFC +from disent.model.ae._fc import EncoderFC +from disent.model.ae._simpleconv64 import DecoderSimpleConv64 +from disent.model.ae._simpleconv64 import EncoderSimpleConv64 +from disent.model.ae._simplefc import DecoderSimpleFC +from disent.model.ae._simplefc import EncoderSimpleFC diff --git a/disent/nn/model/ae/_conv64.py b/disent/model/ae/_conv64.py similarity index 98% rename from disent/nn/model/ae/_conv64.py rename to disent/model/ae/_conv64.py index c46ecffa..0917e8a3 100644 --- a/disent/nn/model/ae/_conv64.py +++ b/disent/model/ae/_conv64.py @@ -24,7 +24,8 @@ from torch import nn as nn, Tensor -from disent.nn.model.ae.base import DisentEncoder, DisentDecoder +from disent.model import DisentDecoder +from disent.model import DisentEncoder from disent.nn.modules import Flatten3D from disent.nn.modules import BatchView diff --git a/disent/nn/model/ae/_conv64_alt.py b/disent/model/ae/_conv64_alt.py similarity index 98% rename from disent/nn/model/ae/_conv64_alt.py rename to disent/model/ae/_conv64_alt.py index a735eaa2..921293e5 100644 --- a/disent/nn/model/ae/_conv64_alt.py +++ b/disent/model/ae/_conv64_alt.py @@ -24,7 +24,8 @@ from torch import nn as nn, Tensor -from disent.nn.model.ae.base import DisentEncoder, DisentDecoder +from disent.model import DisentDecoder +from disent.model import DisentEncoder from disent.nn.modules import Flatten3D from disent.nn.modules import BatchView diff --git a/disent/nn/model/ae/_fc.py b/disent/model/ae/_fc.py similarity index 98% rename from disent/nn/model/ae/_fc.py rename to disent/model/ae/_fc.py index 0142d62d..ed2ebb1f 100644 --- a/disent/nn/model/ae/_fc.py +++ b/disent/model/ae/_fc.py @@ -25,7 +25,8 @@ import numpy as np from torch import nn as nn, Tensor -from disent.nn.model.ae.base import DisentEncoder, DisentDecoder +from disent.model import DisentDecoder +from disent.model import DisentEncoder from disent.nn.modules import Flatten3D from disent.nn.modules import BatchView diff --git a/disent/nn/model/ae/_simpleconv64.py b/disent/model/ae/_simpleconv64.py similarity index 98% rename from disent/nn/model/ae/_simpleconv64.py rename to disent/model/ae/_simpleconv64.py index dff42f5c..9312935e 100644 --- a/disent/nn/model/ae/_simpleconv64.py +++ b/disent/model/ae/_simpleconv64.py @@ -24,7 +24,8 @@ from torch import nn as nn, Tensor -from disent.nn.model.ae.base import DisentEncoder, DisentDecoder +from disent.model import DisentDecoder +from disent.model import DisentEncoder from disent.nn.modules import Flatten3D from disent.nn.modules import Unsqueeze3D diff --git a/disent/nn/model/ae/_simplefc.py b/disent/model/ae/_simplefc.py similarity index 97% rename from disent/nn/model/ae/_simplefc.py rename to disent/model/ae/_simplefc.py index f3e6b2d0..86dbb0e1 100644 --- a/disent/nn/model/ae/_simplefc.py +++ b/disent/model/ae/_simplefc.py @@ -24,7 +24,8 @@ from torch import nn as nn, Tensor -from disent.nn.model.ae.base import DisentEncoder, DisentDecoder +from disent.model import DisentDecoder +from disent.model import DisentEncoder from disent.nn.modules import Flatten3D from disent.nn.modules import BatchView diff --git a/docs/examples/mnist_example.py b/docs/examples/mnist_example.py index e8e81919..36a52249 100644 --- a/docs/examples/mnist_example.py +++ b/docs/examples/mnist_example.py @@ -9,7 +9,8 @@ from disent.dataset.random import RandomDataset from disent.frameworks.vae import AdaVae -from disent.nn.model.ae import AutoEncoder, DecoderConv64Alt, EncoderConv64Alt +from disent.model.ae import DecoderConv64Alt, EncoderConv64Alt +from disent.model import AutoEncoder from disent.nn.transform import ToStandardisedTensor from disent.util import is_test_run diff --git a/docs/examples/overview_framework_adagvae.py b/docs/examples/overview_framework_adagvae.py index 1da81ea5..7f7412ff 100644 --- a/docs/examples/overview_framework_adagvae.py +++ b/docs/examples/overview_framework_adagvae.py @@ -4,7 +4,8 @@ from disent.data.groundtruth import GroundTruthData, XYSquaresData from disent.dataset.groundtruth import GroundTruthDatasetOrigWeakPairs from disent.frameworks.vae import AdaVae -from disent.nn.model.ae import AutoEncoder, DecoderConv64, EncoderConv64 +from disent.model.ae import DecoderConv64, EncoderConv64 +from disent.model import AutoEncoder from disent.nn.transform import ToStandardisedTensor from disent.util import is_test_run diff --git a/docs/examples/overview_framework_ae.py b/docs/examples/overview_framework_ae.py index 04696479..7d8404c6 100644 --- a/docs/examples/overview_framework_ae.py +++ b/docs/examples/overview_framework_ae.py @@ -4,7 +4,8 @@ from disent.data.groundtruth import GroundTruthData, XYSquaresData from disent.dataset.groundtruth import GroundTruthDataset from disent.frameworks.ae import Ae -from disent.nn.model.ae import AutoEncoder, DecoderConv64, EncoderConv64 +from disent.model.ae import DecoderConv64, EncoderConv64 +from disent.model import AutoEncoder from disent.nn.transform import ToStandardisedTensor from disent.util import is_test_run diff --git a/docs/examples/overview_framework_betavae.py b/docs/examples/overview_framework_betavae.py index 60b4ff1d..11289190 100644 --- a/docs/examples/overview_framework_betavae.py +++ b/docs/examples/overview_framework_betavae.py @@ -4,7 +4,8 @@ from disent.data.groundtruth import GroundTruthData, XYSquaresData from disent.dataset.groundtruth import GroundTruthDataset from disent.frameworks.vae import BetaVae -from disent.nn.model.ae import AutoEncoder, DecoderConv64, EncoderConv64 +from disent.model.ae import DecoderConv64, EncoderConv64 +from disent.model import AutoEncoder from disent.nn.transform import ToStandardisedTensor from disent.util import is_test_run diff --git a/docs/examples/overview_framework_betavae_scheduled.py b/docs/examples/overview_framework_betavae_scheduled.py index 79e3df58..00a37305 100644 --- a/docs/examples/overview_framework_betavae_scheduled.py +++ b/docs/examples/overview_framework_betavae_scheduled.py @@ -4,7 +4,8 @@ from disent.data.groundtruth import GroundTruthData, XYSquaresData from disent.dataset.groundtruth import GroundTruthDataset from disent.frameworks.vae import BetaVae -from disent.nn.model.ae import AutoEncoder, DecoderConv64, EncoderConv64 +from disent.model.ae import DecoderConv64, EncoderConv64 +from disent.model import AutoEncoder from disent.nn.transform import ToStandardisedTensor from disent.schedule import CyclicSchedule from disent.util import is_test_run diff --git a/docs/examples/overview_metrics.py b/docs/examples/overview_metrics.py index 2ee55a86..96312e35 100644 --- a/docs/examples/overview_metrics.py +++ b/docs/examples/overview_metrics.py @@ -5,7 +5,8 @@ from disent.dataset.groundtruth import GroundTruthDataset from disent.frameworks.vae import BetaVae from disent.metrics import metric_dci, metric_mig -from disent.nn.model.ae import AutoEncoder, DecoderConv64, EncoderConv64 +from disent.model.ae import DecoderConv64, EncoderConv64 +from disent.model import AutoEncoder from disent.nn.transform import ToStandardisedTensor from disent.util import is_test_run diff --git a/experiment/run.py b/experiment/run.py index 056a5c83..592b5d51 100644 --- a/experiment/run.py +++ b/experiment/run.py @@ -38,7 +38,7 @@ from disent import metrics from disent.frameworks import DisentConfigurable from disent.frameworks import DisentFramework -from disent.nn.model.ae import AutoEncoder +from disent.model import AutoEncoder from disent.nn.weights import init_model_weights from disent.util import make_box_str from experiment.util.callbacks import LoggerProgressCallback diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index c6eabeb6..d3ca0d19 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -38,9 +38,9 @@ from disent.frameworks.ae.experimental import * from disent.frameworks.vae import * from disent.frameworks.vae.experimental import * -from disent.nn.model.ae import AutoEncoder -from disent.nn.model.ae import DecoderConv64 -from disent.nn.model.ae import EncoderConv64 +from disent.model import AutoEncoder +from disent.model.ae import DecoderConv64 +from disent.model.ae import EncoderConv64 from disent.nn.transform import ToStandardisedTensor From 64535d91edec45370934e8546182f293e25edfc8 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Wed, 2 Jun 2021 02:45:58 +0200 Subject: [PATCH 16/34] working norb dataset --- disent/data/groundtruth/_dsprites.py | 14 +-- disent/data/groundtruth/_norb.py | 157 +++++++++++++-------------- disent/data/groundtruth/_shapes3d.py | 12 +- disent/data/groundtruth/base.py | 80 +++++++++----- disent/data/util/jobs.py | 6 +- 5 files changed, 137 insertions(+), 132 deletions(-) diff --git a/disent/data/groundtruth/_dsprites.py b/disent/data/groundtruth/_dsprites.py index 126fa050..04da491f 100644 --- a/disent/data/groundtruth/_dsprites.py +++ b/disent/data/groundtruth/_dsprites.py @@ -54,22 +54,16 @@ class DSpritesData(Hdf5GroundTruthData): observation_shape = (64, 64, 1) data_object = DlH5DataObject( - # processed dataset file - file_name='dsprites_ndarray_co1sh3sc6or40x32y32_64x64.hdf5', - file_hashes={'fast': '6d6d43d5f4d5c08c4b99a406289b8ecd', 'full': '1473ac1e1af7fdbc910766b3f9157f7b'}, # download file/link uri='https://raw.githubusercontent.com/deepmind/dsprites-dataset/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.hdf5', - uri_hashes={'fast': 'd6ee1e43db715c2f0de3c41e38863347', 'full': 'b331c4447a651c44bf5e8ae09022e230'}, - # hash settings - hash_mode='fast', - hash_type='md5', + uri_hash={'fast': 'd6ee1e43db715c2f0de3c41e38863347', 'full': 'b331c4447a651c44bf5e8ae09022e230'}, + # processed dataset file + file_hash={'fast': '6d6d43d5f4d5c08c4b99a406289b8ecd', 'full': '1473ac1e1af7fdbc910766b3f9157f7b'}, # h5 re-save settings hdf5_dataset_name='imgs', hdf5_chunk_size=(1, 64, 64), - hdf5_compression='gzip', - hdf5_compression_lvl=4, hdf5_dtype='uint8', - hdf5_mutator=lambda x: x # lambda batch: batch * 255 + hdf5_mutator=lambda x: x * 255 ) diff --git a/disent/data/groundtruth/_norb.py b/disent/data/groundtruth/_norb.py index 2e5fadac..a923841e 100644 --- a/disent/data/groundtruth/_norb.py +++ b/disent/data/groundtruth/_norb.py @@ -21,13 +21,15 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import dataclasses + import gzip -from typing import Dict +from typing import Optional +from typing import Sequence +from typing import Tuple import numpy as np -from disent.data.groundtruth import GroundTruthData +from disent.data.groundtruth.base import DiskGroundTruthData from disent.data.groundtruth.base import DlDataObject @@ -47,7 +49,7 @@ } -def read_binary_matrix_buffer(buffer): +def read_binary_matrix_bytes(bytes): """ Read the binary matrix data - modified from disentanglement_lib @@ -64,61 +66,59 @@ def read_binary_matrix_buffer(buffer): - Little endian matrix data comes after the header, the index of the last dimension changes the fastest. """ - dtype = int(np.frombuffer(buffer, "int32", 1, 0)) # bytes [0, 4) - ndim = int(np.frombuffer(buffer, "int32", 1, 4)) # bytes [4, 8) - eff_dim = max(3, ndim) # stores minimum of 3 dimensions even for 1D array - dims = np.frombuffer(buffer, "int32", eff_dim, 8)[0:ndim] # bytes [8, 8 + eff_dim * 4) - data = np.frombuffer(buffer, _BINARY_MATRIX_TYPES[dtype], offset=8 + eff_dim * 4) + # header: dtype, ndim, dim_sizes + dtype = int(np.frombuffer(bytes, dtype='int32', count=1, offset=0)) # bytes [0, 4) + ndim = int(np.frombuffer(bytes, dtype='int32', count=1, offset=4)) # bytes [4, 8) + stored_ndim = max(3, ndim) # stores minimum of 3 dimensions even for 1D array + dims = np.frombuffer(bytes, dtype='int32', count=stored_ndim, offset=8)[0:ndim] # bytes [8, 8 + eff_dim * 4) + # matrix: data + data = np.frombuffer(bytes, dtype=_BINARY_MATRIX_TYPES[dtype], count=-1, offset=8 + stored_ndim * 4) data = data.reshape(tuple(dims)) + # done return data def read_binary_matrix_file(file, gzipped: bool = True): + # this does not seem to copy the bytes, which saves memory with (gzip.open if gzipped else open)(file, "rb") as f: - return read_binary_matrix_buffer(buffer=f) - - -def resave_binary_matrix_file(inp_path, out_path, gzipped: bool = True): - with AtomicFileContext(out_path, open_mode=None) as temp_out_path: - data = read_binary_matrix_file(file=inp_path, gzipped=gzipped) - np.savez(temp_out_path, data=data) + return read_binary_matrix_bytes(bytes=f.read()) # ========================================================================= # -# Norb Data Tasks # +# Norb Functions # # ========================================================================= # -@dataclasses.dataclass -class BinaryMatrixDataObject(DlDataObject): - file_name: str - file_hashes: Dict[str, str] - # download file/link - uri: str - uri_hashes: Dict[str, str] - # hash settings - hash_mode: str - hash_type: str - - def _make_h5_job(self, load_path: str, save_path: str): - return CachedJobFile( - make_file_fn=lambda path: resave_binary_matrix_file( - inp_path=load_path, - out_path=path, - gzipped=True, - ), - path=save_path, - hash=self.file_hashes[self.hash_mode], - hash_type=self.hash_type, - hash_mode=self.hash_mode, - ) - - def prepare(self, data_dir: str): - dl_path = self.get_file_path(data_dir=data_dir, variant='ORIG') - h5_path = self.get_file_path(data_dir=data_dir) - dl_job = self._make_dl_job(save_path=dl_path) - h5_job = self._make_h5_job(load_path=dl_path, save_path=h5_path) - dl_job.set_child(h5_job).run() +def read_norb_dataset(dat_path: str, cat_path: str, info_path: str, gzipped=True, sort=True) -> Tuple[np.ndarray, np.ndarray]: + """ + Load The Normalised Dataset + * dat: + - images (5 categories, 5 instances, 6 lightings, 9 elevations, and 18 azimuths) + * cat: + - initial ground truth factor: + 0. category of images (0 for animal, 1 for human, 2 for plane, 3 for truck, 4 for car). + * info: + - additional ground truth factors: + 1. the instance in the category (0 to 9) + 2. the elevation (0 to 8, which mean cameras are 30, 35,40,45,50,55,60,65,70 degrees from the horizontal respectively) + 3. the azimuth (0,2,4,...,34, multiply by 10 to get the azimuth in degrees) + 4. the lighting condition (0 to 5) + """ + # read the dataset + dat = read_binary_matrix_file(dat_path, gzipped=gzipped) + cat = read_binary_matrix_file(cat_path, gzipped=gzipped) + info = read_binary_matrix_file(info_path, gzipped=gzipped) + # collect the ground truth factors + factors = np.column_stack([cat, info]) # append info to categories + factors[:, 3] = factors[:, 3] / 2 # azimuth values are even numbers, convert to indices + images = dat[:, 0] # images are in pairs, only use the first. TODO: what is the second of each? + # order the images and factors + if sort: + indices = np.lexsort(factors[:, [4, 3, 2, 1, 0]].T) + images = images[indices] + factors = factors[indices] + # done! + return images, factors # ========================================================================= # @@ -126,57 +126,48 @@ def prepare(self, data_dir: str): # ========================================================================= # -class SmallNorbData(GroundTruthData): +class SmallNorbData(DiskGroundTruthData): """ Small NORB Dataset - https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/ # reference implementation: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/data/ground_truth/norb.py + # TODO: add ability to randomly sample the instance so that this corresponds to disentanglement_lib """ - # ordered training data (dat, cat, info) - NORB_TRAIN_URLS = [ - 'https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz', - 'https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz', - 'https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz', - ] - - # ordered testing data (dat, cat, info) - NORB_TEST_URLS = [ - 'https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz', - 'https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz', - 'https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz', - ] - - dataset_urls = [*NORB_TRAIN_URLS, *NORB_TEST_URLS] - - # TODO: add ability to randomly sample the instance so - # that this corresponds to disentanglement_lib - factor_names = ('category', 'instance', 'elevation', 'azimuth', 'lighting_condition') + name = 'smallnorb' + + factor_names = ('category', 'instance', 'elevation', 'rotation', 'lighting') factor_sizes = (5, 5, 9, 18, 6) # TOTAL: 24300 observation_shape = (96, 96, 1) - def __init__(self, data_dir='data/dataset/smallnorb', force_download=False, is_test=False): - super().__init__(data_dir=data_dir, force_download=force_download) - assert not is_test, 'Test set not yet supported' + TRAIN_DATA_OBJECTS = { + 'dat': DlDataObject(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz', uri_hash={'fast': '92560cccc7bcbd6512805e435448b62d', 'full': '66054832f9accfe74a0f4c36a75bc0a2'}), + 'cat': DlDataObject(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz', uri_hash={'fast': '348fc3ccefd651d69f500611988b5dcd', 'full': '23c8b86101fbf0904a000b43d3ed2fd9'}), + 'info': DlDataObject(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz', uri_hash={'fast': 'f1b170c16925867c05f58608eb33ba7f', 'full': '51dee1210a742582ff607dfd94e332e3'}), + } + + TEST_DATA_OBJECTS = { + 'dat': DlDataObject(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz', uri_hash={'fast': '9aee0b474a4fc2a2ec392b463efb8858', 'full': 'e4ad715691ed5a3a5f138751a4ceb071'}), + 'cat': DlDataObject(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz', uri_hash={'fast': '8cfae0679f5fa2df7a0aedfce90e5673', 'full': '5aa791cd7e6016cf957ce9bdb93b8603'}), + 'info': DlDataObject(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz', uri_hash={'fast': 'd2703a3f95e7b9a970ad52e91f0aaf6a', 'full': 'a9454f3864d7fd4bb3ea7fc3eb84924e'}), + } + + def __init__(self, data_root: Optional[str] = 'data/TEMP/dataset', prepare: bool = False, is_test=True): + self._is_test = is_test + # initialize + super().__init__(data_root=data_root, prepare=prepare) # read dataset and sort by features - images, features = self._read_norb_set(is_test) - indices = np.lexsort(features[:, [4, 3, 2, 1, 0]].T) - self._data = images[indices] + dat_path, cat_path, info_path = (obj.get_file_path(data_dir=self.data_dir) for obj in self.data_objects) + self._data, _ = read_norb_dataset(dat_path=dat_path, cat_path=cat_path, info_path=info_path) def __getitem__(self, idx): return self._data[idx] - def _read_norb_set(self, is_test): - # get file data corresponding to urls - dat, cat, info = [ - self._read_norb_file(self.dataset_paths[self.dataset_urls.index(url)]) - for url in (self.NORB_TEST_URLS if is_test else self.NORB_TRAIN_URLS) - ] - features = np.column_stack([cat, info]) # append info to categories - features[:, 3] = features[:, 3] / 2 # azimuth values are even numbers, convert to indices - images = dat[:, 0] # images are in pairs, we only extract the first one of each - return images, features + @property + def data_objects(self) -> Sequence[DlDataObject]: + norb_objects = self.TEST_DATA_OBJECTS if self._is_test else self.TRAIN_DATA_OBJECTS + return norb_objects['dat'], norb_objects['cat'], norb_objects['info'] # ========================================================================= # diff --git a/disent/data/groundtruth/_shapes3d.py b/disent/data/groundtruth/_shapes3d.py index c198f724..3eaf8628 100644 --- a/disent/data/groundtruth/_shapes3d.py +++ b/disent/data/groundtruth/_shapes3d.py @@ -50,20 +50,14 @@ class Shapes3dData(Hdf5GroundTruthData): observation_shape = (64, 64, 3) data_object = DlH5DataObject( - # processed dataset file - file_name='3dshapes.h5', - file_hashes={'fast': 'e3a1a449b95293d4b2c25edbfcb8e804', 'full': 'b5187ee0d8b519bb33281c5ca549658c'}, # download file/link uri='https://storage.googleapis.com/3d-shapes/3dshapes.h5', - uri_hashes={'fast': '85b20ed7cc8dc1f939f7031698d2d2ab', 'full': '099a2078d58cec4daad0702c55d06868'}, - # hash settings - hash_mode='fast', - hash_type='md5', + uri_hash={'fast': '85b20ed7cc8dc1f939f7031698d2d2ab', 'full': '099a2078d58cec4daad0702c55d06868'}, + # processed dataset file + file_hash={'fast': 'e3a1a449b95293d4b2c25edbfcb8e804', 'full': 'b5187ee0d8b519bb33281c5ca549658c'}, # h5 re-save settings hdf5_dataset_name='images', hdf5_chunk_size=(1, 64, 64, 3), - hdf5_compression='gzip', - hdf5_compression_lvl=4, ) diff --git a/disent/data/groundtruth/base.py b/disent/data/groundtruth/base.py index 7e1bef26..a169522c 100644 --- a/disent/data/groundtruth/base.py +++ b/disent/data/groundtruth/base.py @@ -37,6 +37,7 @@ from disent.data.util.hdf5 import hdf5_resave_file from disent.data.util.hdf5 import PickleH5pyDataset +from disent.data.util.in_out import basename_from_url from disent.data.util.in_out import ensure_dir_exists from disent.data.util.in_out import retrieve_file from disent.data.util.jobs import CachedJobFile @@ -163,9 +164,10 @@ def data_object(self) -> 'DlH5DataObject': # ========================================================================= # -@dataclasses.dataclass class DataObject(object): - file_name: str + + def __init__(self, file_name: str): + self.file_name = file_name def prepare(self, data_dir: str): pass @@ -175,15 +177,24 @@ def get_file_path(self, data_dir: str, variant: Optional[str] = None): return os.path.join(data_dir, self.file_name + suffix) -@dataclasses.dataclass class DlDataObject(DataObject): - file_name: str - # download file/link - uri: str - uri_hashes: Dict[str, str] - # hash settings - hash_mode: str - hash_type: str + + def __init__( + self, + # download file/link + uri: str, + uri_hash: Union[str, Dict[str, str]], + # save path + file_name: Optional[str] = None, # automatically obtain file name from url if None + # hash settings + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + super().__init__(file_name=basename_from_url(uri) if (file_name is None) else file_name) + self.uri = uri + self.uri_hash = uri_hash + self.hash_mode = hash_mode + self.hash_type = hash_type def _make_dl_job(self, save_path: str): return CachedJobFile( @@ -193,7 +204,7 @@ def _make_dl_job(self, save_path: str): overwrite_existing=True, ), path=save_path, - hash=self.uri_hashes[self.hash_mode], + hash=self.uri_hash, hash_type=self.hash_type, hash_mode=self.hash_mode, ) @@ -203,23 +214,36 @@ def prepare(self, data_dir: str): dl_job.run() -@dataclasses.dataclass class DlH5DataObject(DlDataObject): - file_name: str - file_hashes: Dict[str, str] - # download file/link - uri: str - uri_hashes: Dict[str, str] - # hash settings - hash_mode: str - hash_type: str - # h5 re-save settings - hdf5_dataset_name: str - hdf5_chunk_size: Tuple[int, ...] - hdf5_compression: Optional[str] - hdf5_compression_lvl: Optional[int] - hdf5_dtype: Optional[Union[np.dtype, str]] = None - hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None + + def __init__( + self, + # download file/link + uri: str, + uri_hash: Union[str, Dict[str, str]], + # save hash + file_hash: Union[str, Dict[str, str]], + # h5 re-save settings + hdf5_dataset_name: str, + hdf5_chunk_size: Tuple[int, ...], + hdf5_compression: Optional[str] = 'gzip', + hdf5_compression_lvl: Optional[int] = 4, + hdf5_dtype: Optional[Union[np.dtype, str]] = None, + hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None, + # save path + file_name: Optional[str] = None, # automatically obtain file name from url if None + # hash settings + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + super().__init__(file_name=file_name, uri=uri, uri_hash=uri_hash, hash_mode=hash_mode, hash_type=hash_type) + self.file_hash = file_hash + self.hdf5_dataset_name = hdf5_dataset_name + self.hdf5_chunk_size = hdf5_chunk_size + self.hdf5_compression = hdf5_compression + self.hdf5_compression_lvl = hdf5_compression_lvl + self.hdf5_dtype = hdf5_dtype + self.hdf5_mutator = hdf5_mutator def _make_h5_job(self, load_path: str, save_path: str): return CachedJobFile( @@ -235,7 +259,7 @@ def _make_h5_job(self, load_path: str, save_path: str): out_mutator=self.hdf5_mutator, ), path=save_path, - hash=self.file_hashes[self.hash_mode], + hash=self.file_hash, hash_type=self.hash_type, hash_mode=self.hash_mode, ) diff --git a/disent/data/util/jobs.py b/disent/data/util/jobs.py index d8905ba9..c32c60a6 100644 --- a/disent/data/util/jobs.py +++ b/disent/data/util/jobs.py @@ -26,7 +26,9 @@ import os from abc import ABCMeta from typing import Callable +from typing import Dict from typing import NoReturn +from typing import Union from disent.data.util.in_out import hash_file @@ -111,13 +113,13 @@ def __init__( self, make_file_fn: Callable[[str], NoReturn], path: str, - hash: str, + hash: Union[str, Dict[str, str]], hash_type: str = 'md5', hash_mode: str = 'full', ): # set attributes self.path = path - self.hash = hash + self.hash: str = hash if isinstance(hash, str) else hash[hash_mode] self.hash_type = hash_type self.hash_mode = hash_mode # generate From a547cb103e76aea7b083d7fbb350b2d137473ba7 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Wed, 2 Jun 2021 03:10:18 +0200 Subject: [PATCH 17/34] optional hashes, and possibly fixed mpi3d --- disent/data/groundtruth/_cars3d.py | 4 +-- disent/data/groundtruth/_mpi3d.py | 56 ++++++++++++++++-------------- disent/data/groundtruth/_norb.py | 3 +- disent/data/groundtruth/base.py | 8 ++--- disent/data/util/jobs.py | 11 ++++-- 5 files changed, 44 insertions(+), 38 deletions(-) diff --git a/disent/data/groundtruth/_cars3d.py b/disent/data/groundtruth/_cars3d.py index 24c43fa8..95fcedab 100644 --- a/disent/data/groundtruth/_cars3d.py +++ b/disent/data/groundtruth/_cars3d.py @@ -27,7 +27,7 @@ import shutil import numpy as np from scipy.io import loadmat -from disent.data.groundtruth.base import DownloadableGroundTruthData +from disent.data.groundtruth.base import DiskGroundTruthData log = logging.getLogger(__name__) @@ -37,7 +37,7 @@ # ========================================================================= # -class Cars3dData(DownloadableGroundTruthData): +class Cars3dData(DiskGroundTruthData): """ Cars3D Dataset - Deep Visual Analogy-Making (https://papers.nips.cc/paper/5845-deep-visual-analogy-making) diff --git a/disent/data/groundtruth/_mpi3d.py b/disent/data/groundtruth/_mpi3d.py index 626576e9..4f7e882a 100644 --- a/disent/data/groundtruth/_mpi3d.py +++ b/disent/data/groundtruth/_mpi3d.py @@ -23,8 +23,13 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging +from typing import Optional +from typing import Sequence + import numpy as np -from disent.data.groundtruth.base import DownloadableGroundTruthData + +from disent.data.groundtruth.base import DiskGroundTruthData +from disent.data.groundtruth.base import DlDataObject log = logging.getLogger(__name__) @@ -33,52 +38,49 @@ # ========================================================================= # -class Mpi3dData(DownloadableGroundTruthData): +class Mpi3dData(DiskGroundTruthData): """ MPI3D Dataset - https://github.com/rr-learning/disentanglement_dataset - Files: - - toy: https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_toy.npz - - realistic: https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_realistic.npz - - real: https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_real.npz - reference implementation: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/data/ground_truth/mpi3d.py """ - URLS = { - 'toy': 'https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_toy.npz', - 'realistic': 'https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_realistic.npz', - 'real': 'https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_real.npz', + MPI3D_DATASETS = { + 'toy': DlDataObject(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_toy.npz', uri_hash=None), + 'realistic': DlDataObject(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_realistic.npz', uri_hash=None), + 'real': DlDataObject(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_real.npz', uri_hash=None), } - factor_names = ("object_color", "object_shape", "object_size", "camera_height", "background_color", "first_dof", "second_dof") + factor_names = ('object_color', 'object_shape', 'object_size', 'camera_height', 'background_color', 'first_dof', 'second_dof') factor_sizes = (4, 4, 2, 3, 3, 40, 40) # TOTAL: 460800 observation_shape = (64, 64, 3) - @property - def dataset_urls(self): - return [Mpi3dData.URLS[self.subset]] - - def __init__(self, data_dir='data/dataset/mpi3d', force_download=False, subset='realistic', in_memory=False): - # check subset - assert subset in Mpi3dData.URLS, f'Invalid subset: {subset=} must be one of: {set(Mpi3dData.URLS.values())}' - self.subset = subset - - # TODO: add support for converting to h5py for fast disk access - assert in_memory, f'{in_memory=} is not yet supported' + def __init__(self, data_root: Optional[str] = None, prepare: bool = False, subset='realistic', in_memory=False): + # check subset is correct + assert subset in self.MPI3D_DATASETS, f'Invalid MPI3D subset: {repr(subset)} must be one of: {set(self.MPI3D_DATASETS.keys())}' + self._subset = subset + # handle different cases if in_memory: log.warning('[WARNING]: mpi3d files are extremely large (over 11GB), you are trying to load these into memory.') - + else: + raise NotImplementedError('TODO: add support for converting to h5py for fast disk access') # TODO! # initialise - super().__init__(data_dir=data_dir, force_download=force_download) - + super().__init__(data_root=data_root, prepare=prepare) # load data - self._data = np.load(self.dataset_paths[0]) + self._data = np.load(self.data_object.get_file_path(data_dir=self.data_dir)) def __getitem__(self, idx): return self._data[idx] + @property + def data_object(self) -> DlDataObject: + return self.MPI3D_DATASETS[self._subset] + + @property + def data_objects(self) -> Sequence[DlDataObject]: + return [self.data_object] + # ========================================================================= # # END # diff --git a/disent/data/groundtruth/_norb.py b/disent/data/groundtruth/_norb.py index a923841e..b4b12e2c 100644 --- a/disent/data/groundtruth/_norb.py +++ b/disent/data/groundtruth/_norb.py @@ -153,7 +153,7 @@ class SmallNorbData(DiskGroundTruthData): 'info': DlDataObject(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz', uri_hash={'fast': 'd2703a3f95e7b9a970ad52e91f0aaf6a', 'full': 'a9454f3864d7fd4bb3ea7fc3eb84924e'}), } - def __init__(self, data_root: Optional[str] = 'data/TEMP/dataset', prepare: bool = False, is_test=True): + def __init__(self, data_root: Optional[str] = 'data/TEMP/dataset', prepare: bool = False, is_test=False): self._is_test = is_test # initialize super().__init__(data_root=data_root, prepare=prepare) @@ -173,4 +173,3 @@ def data_objects(self) -> Sequence[DlDataObject]: # ========================================================================= # # END # # ========================================================================= # - diff --git a/disent/data/groundtruth/base.py b/disent/data/groundtruth/base.py index a169522c..373bed9c 100644 --- a/disent/data/groundtruth/base.py +++ b/disent/data/groundtruth/base.py @@ -133,7 +133,7 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_me self._in_memory = in_memory # load the h5py dataset data = PickleH5pyDataset( - h5_path=self.data_object.get_file_path(self.data_dir), + h5_path=self.data_object.get_file_path(data_dir=self.data_dir), h5_dataset_name=self.data_object.hdf5_dataset_name, ) # handle different memroy modes @@ -183,7 +183,7 @@ def __init__( self, # download file/link uri: str, - uri_hash: Union[str, Dict[str, str]], + uri_hash: Optional[Union[str, Dict[str, str]]], # save path file_name: Optional[str] = None, # automatically obtain file name from url if None # hash settings @@ -220,9 +220,9 @@ def __init__( self, # download file/link uri: str, - uri_hash: Union[str, Dict[str, str]], + uri_hash: Optional[Union[str, Dict[str, str]]], # save hash - file_hash: Union[str, Dict[str, str]], + file_hash: Optional[Union[str, Dict[str, str]]], # h5 re-save settings hdf5_dataset_name: str, hdf5_chunk_size: Tuple[int, ...], diff --git a/disent/data/util/jobs.py b/disent/data/util/jobs.py index c32c60a6..aa41a6bf 100644 --- a/disent/data/util/jobs.py +++ b/disent/data/util/jobs.py @@ -28,6 +28,7 @@ from typing import Callable from typing import Dict from typing import NoReturn +from typing import Optional from typing import Union from disent.data.util.in_out import hash_file @@ -119,7 +120,7 @@ def __init__( ): # set attributes self.path = path - self.hash: str = hash if isinstance(hash, str) else hash[hash_mode] + self.hash: Optional[str] = hash if ((hash is None) or isinstance(hash, str)) else hash[hash_mode] self.hash_type = hash_type self.hash_mode = hash_mode # generate @@ -137,7 +138,9 @@ def __is_cached_fn(self) -> bool: return True # stale if the hash does not match fhash = self.__compute_hash() - if self.hash != fhash: + if self.hash is None: + log.warning(f'{self}: not stale because it exists and no target hash was given. current {self.hash_mode} {self.hash_type} hash is: {fhash} for: {repr(self.path)}') + elif self.hash != fhash: log.warning(f'{self}: stale because computed {self.hash_mode} {self.hash_type} hash: {repr(fhash)} does not match expected hash: {repr(self.hash)} for: {repr(self.path)}') return True # not stale, we don't need to do anything! @@ -147,7 +150,9 @@ def __job_fn(self): self._make_file_fn(self.path) # check the hash fhash = self.__compute_hash() - if self.hash != fhash: + if self.hash is None: + log.warning(f'{self}: could not verify generated file because no target hash was given. current {self.hash_mode} {self.hash_type} hash is: {fhash} for: {repr(self.path)}') + elif self.hash != fhash: raise RuntimeError(f'{self}: error because computed {self.hash_mode} {self.hash_type} hash: {repr(fhash)} does not match expected hash: {repr(self.hash)} for: {repr(self.path)}') else: log.debug(f'{self}: successfully generated file: {repr(self.path)} with correct {self.hash_mode} {self.hash_type} hash: {fhash}') From d825e3767a392ca0a73e6189ef022ed6a1601e4d Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Wed, 2 Jun 2021 14:20:10 +0200 Subject: [PATCH 18/34] fix cars3d --- disent/data/groundtruth/_cars3d.py | 134 +++++++++++++++++------------ disent/data/groundtruth/_mpi3d.py | 19 ++-- disent/data/groundtruth/base.py | 116 ++++++++++++++++++++----- disent/data/util/in_out.py | 9 +- 4 files changed, 183 insertions(+), 95 deletions(-) diff --git a/disent/data/groundtruth/_cars3d.py b/disent/data/groundtruth/_cars3d.py index 95fcedab..6cdc0bf1 100644 --- a/disent/data/groundtruth/_cars3d.py +++ b/disent/data/groundtruth/_cars3d.py @@ -25,87 +25,111 @@ import logging import os import shutil +from tempfile import TemporaryDirectory + import numpy as np from scipy.io import loadmat -from disent.data.groundtruth.base import DiskGroundTruthData + +from disent.data.groundtruth.base import NumpyGroundTruthData +from disent.data.groundtruth.base import ProcessedDataObject +from disent.data.util.in_out import AtomicFileContext +from disent.data.util.jobs import CachedJobFile + log = logging.getLogger(__name__) +# ========================================================================= # +# cars 3d data processing # +# ========================================================================= # + + +def load_cars3d_folder(raw_data_dir): + """ + nips2015-analogy-data.tar.gz contains: + 1. /data/cars + - list.txt: [ordered list of mat files "car_***_mesh" without the extension] + - car_***_mesh.mat: [MATLAB file with keys: "im" (128, 128, 3, 24, 4), "mask" (128, 128, 24, 4)] + 2. /data/sprites + 3. /data/shapes48.mat + """ + # load image paths + with open(os.path.join(raw_data_dir, 'cars/list.txt'), 'r') as img_names: + img_paths = [os.path.join(raw_data_dir, f'cars/{name.strip()}.mat') for name in img_names.readlines()] + # load images + images = np.stack([loadmat(img_path)['im'] for img_path in img_paths], axis=0) + # check size + assert images.shape == (183, 128, 128, 3, 24, 4) + # reshape & transpose: (183, 128, 128, 3, 24, 4) -> (4, 24, 183, 128, 128, 3) -> (17568, 128, 128, 3) + return images.transpose([5, 4, 0, 1, 2, 3]).reshape([-1, 128, 128, 3]) + + +def resave_cars3d_archive(orig_zipped_file, new_save_file, overwrite=False): + """ + Convert a cars3d archive 'nips2015-analogy-data.tar.gz' to a numpy file, + uncompressing the contents of the archive into a temporary directory in the same folder. + """ + with TemporaryDirectory(prefix='raw_cars3d_', dir=os.path.dirname(orig_zipped_file)) as temp_dir: + # extract zipfile and get path + log.info(f"Extracting into temporary directory: {temp_dir}") + shutil.unpack_archive(filename=orig_zipped_file, extract_dir=temp_dir) + # load image paths & resave + with AtomicFileContext(new_save_file, overwrite=overwrite) as temp_file: + images = load_cars3d_folder(raw_data_dir=os.path.join(temp_dir, 'data')) + # TODO: resize images? + np.savez(temp_file, images=images) + + +# ========================================================================= # +# cars3d data object # +# ========================================================================= # + + +class Cars3dDataObject(ProcessedDataObject): + def _make_proc_job(self, load_path: str, save_path: str): + return CachedJobFile( + make_file_fn=lambda save_path: resave_cars3d_archive(orig_zipped_file=load_path, new_save_file=save_path, overwrite=True), + path=save_path, + hash=self.file_hash, + hash_type=self.hash_type, + hash_mode=self.hash_mode + ) + + # ========================================================================= # # dataset_cars3d # # ========================================================================= # -class Cars3dData(DiskGroundTruthData): +class Cars3dData(NumpyGroundTruthData): """ Cars3D Dataset - Deep Visual Analogy-Making (https://papers.nips.cc/paper/5845-deep-visual-analogy-making) http://www.scottreed.info - Files: - - http://www.scottreed.info/files/nips2015-analogy-data.tar.gz - # reference implementation: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/data/ground_truth/cars3d.py """ + + name = 'cars3d' + factor_names = ('elevation', 'azimuth', 'object_type') factor_sizes = (4, 24, 183) # TOTAL: 17568 observation_shape = (128, 128, 3) - dataset_urls = ['http://www.scottreed.info/files/nips2015-analogy-data.tar.gz'] - - def __init__(self, data_dir='data/dataset/cars3d', force_download=False): - super().__init__(data_dir=data_dir, force_download=force_download) - converted_file = self._make_converted_file(data_dir, force_download) - self._data = np.load(converted_file)['images'] - - def __getitem__(self, idx): - return self._data[idx] - - def _make_converted_file(self, data_dir, force_download): - # get files & folders - zip_path = self.dataset_paths[0] - dataset_dir = os.path.splitext(os.path.splitext(zip_path)[0])[0] # remove .tar & .gz, name of directory after renaming - images_dir = os.path.join(dataset_dir, 'cars') # mesh folder inside renamed folder - converted_file = os.path.join(dataset_dir, 'cars.npz') - - if not os.path.exists(converted_file) or force_download: - # extract data if required - if (not os.path.exists(images_dir)) or force_download: - extract_dir = os.path.join(data_dir, 'data') # directory after extracting, before renaming - # make sure the extract directory doesnt exist - if os.path.exists(extract_dir): - shutil.rmtree(extract_dir) - if os.path.exists(dataset_dir): - shutil.rmtree(dataset_dir) - # extract the files - log.info(f'[UNZIPPING]: {zip_path} to {dataset_dir}') - shutil.unpack_archive(zip_path, data_dir) - # rename dir - shutil.move(extract_dir, dataset_dir) - - images = self._load_cars3d_images(images_dir) - log.info(f'[CONVERTING]: {converted_file}') - np.savez(os.path.splitext(converted_file)[0], images=images) - - return converted_file - - @staticmethod - def _load_cars3d_images(images_dir): - images = [] - log.info(f'[LOADING]: {images_dir}') - with open(os.path.join(images_dir, 'list.txt'), 'r') as img_names: - for i, img_name in enumerate(img_names): - img_path = os.path.join(images_dir, f'{img_name.strip()}.mat') - img = loadmat(img_path)['im'] - img = img[..., None].transpose([4, 3, 5, 0, 1, 2]) # (128, 128, 3, 24, 4, 1) -> (4, 24, 1, 128, 128, 3) - images.append(img) - return np.concatenate(images, axis=2).reshape([-1, 128, 128, 3]) # (4, 24, 183, 128, 128, 3) -> (17568, 1, 128, 128, 3) + data_key = 'images' + data_object = Cars3dDataObject( + uri='http://www.scottreed.info/files/nips2015-analogy-data.tar.gz', + uri_hash={'fast': 'fe77d39e3fa9d77c31df2262660c2a67', 'full': '4e866a7919c1beedf53964e6f7a23686'}, + file_name='cars3d.npz', + file_hash={'fast': 'ef5d86d1572ddb122b466ec700b3abf2', 'full': 'dc03319a0b9118fbe0e23d13220a745b'}, + hash_mode='fast' + ) # ========================================================================= # # END # # ========================================================================= # + if __name__ == '__main__': - Cars3dData() + Cars3dData(prepare=True) diff --git a/disent/data/groundtruth/_mpi3d.py b/disent/data/groundtruth/_mpi3d.py index 4f7e882a..7220cdeb 100644 --- a/disent/data/groundtruth/_mpi3d.py +++ b/disent/data/groundtruth/_mpi3d.py @@ -24,12 +24,10 @@ import logging from typing import Optional -from typing import Sequence -import numpy as np - -from disent.data.groundtruth.base import DiskGroundTruthData from disent.data.groundtruth.base import DlDataObject +from disent.data.groundtruth.base import NumpyGroundTruthData + log = logging.getLogger(__name__) @@ -38,7 +36,7 @@ # ========================================================================= # -class Mpi3dData(DiskGroundTruthData): +class Mpi3dData(NumpyGroundTruthData): """ MPI3D Dataset - https://github.com/rr-learning/disentanglement_dataset @@ -46,6 +44,8 @@ class Mpi3dData(DiskGroundTruthData): reference implementation: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/data/ground_truth/mpi3d.py """ + name = 'mpi3d' + MPI3D_DATASETS = { 'toy': DlDataObject(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_toy.npz', uri_hash=None), 'realistic': DlDataObject(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_realistic.npz', uri_hash=None), @@ -67,20 +67,11 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False, subse raise NotImplementedError('TODO: add support for converting to h5py for fast disk access') # TODO! # initialise super().__init__(data_root=data_root, prepare=prepare) - # load data - self._data = np.load(self.data_object.get_file_path(data_dir=self.data_dir)) - - def __getitem__(self, idx): - return self._data[idx] @property def data_object(self) -> DlDataObject: return self.MPI3D_DATASETS[self._subset] - @property - def data_objects(self) -> Sequence[DlDataObject]: - return [self.data_object] - # ========================================================================= # # END # diff --git a/disent/data/groundtruth/base.py b/disent/data/groundtruth/base.py index 373bed9c..e6970f72 100644 --- a/disent/data/groundtruth/base.py +++ b/disent/data/groundtruth/base.py @@ -117,7 +117,7 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False): data_object.prepare(self.data_dir) @property - def data_dir(self): + def data_dir(self) -> str: return self._data_dir @property @@ -125,6 +125,33 @@ def data_objects(self) -> Sequence['DataObject']: raise NotImplementedError +class NumpyGroundTruthData(DiskGroundTruthData, metaclass=ABCMeta): + + def __init__(self, data_root: Optional[str] = None, prepare: bool = False): + super().__init__(data_root=data_root, prepare=prepare) + # load dataset + self._data = np.load(self.data_object.get_file_path(self.data_dir)) + if self.data_key is not None: + self._data = self._data[self.data_key] + + def __getitem__(self, idx): + return self._data[idx] + + @property + def data_objects(self) -> Sequence['DataObject']: + return [self.data_object] + + @property + def data_object(self) -> 'DataObject': + raise NotImplementedError + + @property + def data_key(self) -> Optional[str]: + # can override this! + return None + + + class Hdf5GroundTruthData(DiskGroundTruthData, metaclass=ABCMeta): def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_memory=False): @@ -166,15 +193,18 @@ def data_object(self) -> 'DlH5DataObject': class DataObject(object): - def __init__(self, file_name: str): - self.file_name = file_name + @property + def file_name(self) -> str: + raise NotImplementedError def prepare(self, data_dir: str): pass - def get_file_path(self, data_dir: str, variant: Optional[str] = None): - suffix = '' if (variant is None) else f'.{variant}' - return os.path.join(data_dir, self.file_name + suffix) + def get_file_path(self, data_dir: str, file_name: Optional[str] = None, prefix: str = '', postfix: str = ''): + file_name = self.file_name if (file_name is None) else file_name + if prefix or postfix: + file_name = os.path.join(os.path.dirname(file_name), f'{prefix}{os.path.basename(file_name)}{postfix}') + return os.path.join(data_dir, file_name) class DlDataObject(DataObject): @@ -184,18 +214,22 @@ def __init__( # download file/link uri: str, uri_hash: Optional[Union[str, Dict[str, str]]], - # save path - file_name: Optional[str] = None, # automatically obtain file name from url if None + uri_name: Optional[str] = None, # automatically obtain uri name from url if None # hash settings hash_type: str = 'md5', hash_mode: str = 'fast', ): - super().__init__(file_name=basename_from_url(uri) if (file_name is None) else file_name) + super().__init__() self.uri = uri self.uri_hash = uri_hash + self.uri_name = basename_from_url(uri) if (uri_name is None) else uri_name self.hash_mode = hash_mode self.hash_type = hash_type + @property + def file_name(self) -> str: + return self.uri_name + def _make_dl_job(self, save_path: str): return CachedJobFile( make_file_fn=lambda path: retrieve_file( @@ -210,11 +244,46 @@ def _make_dl_job(self, save_path: str): ) def prepare(self, data_dir: str): - dl_job = self._make_dl_job(save_path=self.get_file_path(data_dir=data_dir)) + dl_job = self._make_dl_job(save_path=os.path.join(data_dir, self.uri_name)) dl_job.run() -class DlH5DataObject(DlDataObject): +class ProcessedDataObject(DlDataObject): + + def __init__( + self, + # save path + file_name: str, + file_hash: Optional[Union[str, Dict[str, str]]], + # download file/link + uri: str, + uri_hash: Optional[Union[str, Dict[str, str]]], + uri_name: Optional[str] = None, # automatically obtain file name from url if None + # hash settings + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + super().__init__(uri=uri, uri_hash=uri_hash, uri_name=uri_name, hash_mode=hash_mode, hash_type=hash_type) + assert file_name is not None + self._file_name = file_name + self.file_hash = file_hash + + @property + def file_name(self) -> str: + return self._file_name + + def _make_proc_job(self, load_path: str, save_path: str): + raise NotImplementedError + + def prepare(self, data_dir: str): + dl_path = os.path.join(data_dir, self.uri_name) + proc_path = os.path.join(data_dir, self.file_name) + dl_job = self._make_dl_job(save_path=dl_path) + proc_job = self._make_proc_job(load_path=dl_path, save_path=proc_path) + dl_job.set_child(proc_job).run() + + +class DlH5DataObject(ProcessedDataObject): def __init__( self, @@ -236,8 +305,16 @@ def __init__( hash_type: str = 'md5', hash_mode: str = 'fast', ): - super().__init__(file_name=file_name, uri=uri, uri_hash=uri_hash, hash_mode=hash_mode, hash_type=hash_type) - self.file_hash = file_hash + file_name = basename_from_url(uri) if (file_name is None) else file_name + super().__init__( + file_name=file_name, + file_hash=file_hash, + uri=uri, + uri_hash=uri_hash, + uri_name=file_name + '.ORIG', + hash_type=hash_type, + hash_mode=hash_mode, + ) self.hdf5_dataset_name = hdf5_dataset_name self.hdf5_chunk_size = hdf5_chunk_size self.hdf5_compression = hdf5_compression @@ -245,11 +322,11 @@ def __init__( self.hdf5_dtype = hdf5_dtype self.hdf5_mutator = hdf5_mutator - def _make_h5_job(self, load_path: str, save_path: str): + def _make_proc_job(self, load_path: str, save_path: str): return CachedJobFile( - make_file_fn=lambda path: hdf5_resave_file( + make_file_fn=lambda save_path: hdf5_resave_file( inp_path=load_path, - out_path=path, + out_path=save_path, dataset_name=self.hdf5_dataset_name, chunk_size=self.hdf5_chunk_size, compression=self.hdf5_compression, @@ -264,13 +341,6 @@ def _make_h5_job(self, load_path: str, save_path: str): hash_mode=self.hash_mode, ) - def prepare(self, data_dir: str): - dl_path = self.get_file_path(data_dir=data_dir, variant='ORIG') - h5_path = self.get_file_path(data_dir=data_dir) - dl_job = self._make_dl_job(save_path=dl_path) - h5_job = self._make_h5_job(load_path=dl_path, save_path=h5_path) - dl_job.set_child(h5_job).run() - # ========================================================================= # # END # diff --git a/disent/data/util/in_out.py b/disent/data/util/in_out.py index eb480bef..0b52ad8a 100644 --- a/disent/data/util/in_out.py +++ b/disent/data/util/in_out.py @@ -139,9 +139,11 @@ class AtomicFileContext(object): with open(tmp_file, 'w') as f: f.write("hello world!\n") ``` + + # TODO: can this be cleaned up with the TemporaryDirectory and TemporaryFile classes? """ - def __init__(self, file: str, open_mode: Optional[str] = None, overwrite: bool = False, makedirs: bool = True, tmp_file: Optional[str] = None): + def __init__(self, file: str, open_mode: Optional[str] = None, overwrite: bool = False, makedirs: bool = True, tmp_file: Optional[str] = None, tmp_prefix='_TEMP_.', tmp_postfix=''): from pathlib import Path # check files if not file: @@ -150,7 +152,8 @@ def __init__(self, file: str, open_mode: Optional[str] = None, overwrite: bool = raise ValueError(f'tmp_file must not be empty: {repr(tmp_file)}') # get files self.trg_file = Path(file).absolute() - self.tmp_file = Path(f'{self.trg_file}.TEMP' if (tmp_file is None) else tmp_file) + tmp_file = Path(self.trg_file if (tmp_file is None) else tmp_file) + self.tmp_file = tmp_file.parent.joinpath(f'{tmp_prefix}{tmp_file.name}{tmp_postfix}') # check that the files are different if self.trg_file == self.tmp_file: raise ValueError(f'temporary and target files are the same: {self.tmp_file} == {self.trg_file}') @@ -167,7 +170,7 @@ def __enter__(self): raise FileExistsError(f'the temporary file exists but is not a file: {self.tmp_file}') if self.trg_file.exists(): if not self._overwrite: - raise FileExistsError('the target file already exists, set overwrite=True to ignore this error.') + raise FileExistsError(f'the target file already exists: {self.trg_file}, set overwrite=True to ignore this error.') if not self.trg_file.is_file(): raise FileExistsError(f'the target file exists but is not a file: {self.trg_file}') # create the missing directories if needed From 5e8ce56403d39537ab4374aa403344f4d7053f5a Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Wed, 2 Jun 2021 15:57:17 +0200 Subject: [PATCH 19/34] minor fixes --- disent/data/groundtruth/_cars3d.py | 4 +++- disent/data/groundtruth/_dsprites.py | 7 ++----- disent/data/groundtruth/_mpi3d.py | 4 ++++ disent/data/groundtruth/_norb.py | 4 ++++ disent/data/groundtruth/_shapes3d.py | 3 +++ disent/data/groundtruth/_xyblocks.py | 8 -------- disent/data/groundtruth/_xyobject.py | 8 -------- disent/data/groundtruth/base.py | 7 ++++++- disent/data/util/in_out.py | 21 ++++++++++++++++++--- 9 files changed, 40 insertions(+), 26 deletions(-) diff --git a/disent/data/groundtruth/_cars3d.py b/disent/data/groundtruth/_cars3d.py index 6cdc0bf1..c7bbf12f 100644 --- a/disent/data/groundtruth/_cars3d.py +++ b/disent/data/groundtruth/_cars3d.py @@ -76,7 +76,6 @@ def resave_cars3d_archive(orig_zipped_file, new_save_file, overwrite=False): # load image paths & resave with AtomicFileContext(new_save_file, overwrite=overwrite) as temp_file: images = load_cars3d_folder(raw_data_dir=os.path.join(temp_dir, 'data')) - # TODO: resize images? np.savez(temp_file, images=images) @@ -86,6 +85,9 @@ def resave_cars3d_archive(orig_zipped_file, new_save_file, overwrite=False): class Cars3dDataObject(ProcessedDataObject): + """ + download the cars3d dataset and convert it to a numpy file. + """ def _make_proc_job(self, load_path: str, save_path: str): return CachedJobFile( make_file_fn=lambda save_path: resave_cars3d_archive(orig_zipped_file=load_path, new_save_file=save_path, overwrite=True), diff --git a/disent/data/groundtruth/_dsprites.py b/disent/data/groundtruth/_dsprites.py index 04da491f..06629a61 100644 --- a/disent/data/groundtruth/_dsprites.py +++ b/disent/data/groundtruth/_dsprites.py @@ -58,7 +58,7 @@ class DSpritesData(Hdf5GroundTruthData): uri='https://raw.githubusercontent.com/deepmind/dsprites-dataset/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.hdf5', uri_hash={'fast': 'd6ee1e43db715c2f0de3c41e38863347', 'full': 'b331c4447a651c44bf5e8ae09022e230'}, # processed dataset file - file_hash={'fast': '6d6d43d5f4d5c08c4b99a406289b8ecd', 'full': '1473ac1e1af7fdbc910766b3f9157f7b'}, + file_hash={'fast': '7a6e83ebf35f93a1cd9ae0210112b421', 'full': '27c674fb5170dcd6a1f9853b66c5785d'}, # h5 re-save settings hdf5_dataset_name='imgs', hdf5_chunk_size=(1, 64, 64), @@ -73,7 +73,4 @@ class DSpritesData(Hdf5GroundTruthData): if __name__ == '__main__': - - data = DSpritesData(in_memory=False, prepare=True) - for dat in data: - print(dat) + DSpritesData(prepare=True) diff --git a/disent/data/groundtruth/_mpi3d.py b/disent/data/groundtruth/_mpi3d.py index 7220cdeb..b14a2d1f 100644 --- a/disent/data/groundtruth/_mpi3d.py +++ b/disent/data/groundtruth/_mpi3d.py @@ -76,3 +76,7 @@ def data_object(self) -> DlDataObject: # ========================================================================= # # END # # ========================================================================= # + + +if __name__ == '__main__': + Mpi3dData(prepare=True, in_memory=False) diff --git a/disent/data/groundtruth/_norb.py b/disent/data/groundtruth/_norb.py index b4b12e2c..8680e968 100644 --- a/disent/data/groundtruth/_norb.py +++ b/disent/data/groundtruth/_norb.py @@ -173,3 +173,7 @@ def data_objects(self) -> Sequence[DlDataObject]: # ========================================================================= # # END # # ========================================================================= # + + +if __name__ == '__main__': + SmallNorbData(prepare=True) diff --git a/disent/data/groundtruth/_shapes3d.py b/disent/data/groundtruth/_shapes3d.py index 3eaf8628..5d0f85b7 100644 --- a/disent/data/groundtruth/_shapes3d.py +++ b/disent/data/groundtruth/_shapes3d.py @@ -64,3 +64,6 @@ class Shapes3dData(Hdf5GroundTruthData): # ========================================================================= # # END # # ========================================================================= # + +if __name__ == '__main__': + Shapes3dData(prepare=True) diff --git a/disent/data/groundtruth/_xyblocks.py b/disent/data/groundtruth/_xyblocks.py index 07e04e77..169bb25d 100644 --- a/disent/data/groundtruth/_xyblocks.py +++ b/disent/data/groundtruth/_xyblocks.py @@ -143,11 +143,3 @@ def __getitem__(self, idx): # ========================================================================= # # END # # ========================================================================= # - - -# if __name__ == '__main__': - # data = XYBlocksData(64, [1, 2, 3], rgb=True, palette='rgb', invert_bg=False) # 110592 // 256 = 432 - # print(len(data)) - # for obs in tqdm(data): - # pass - # # print(obs[:, :, 0]) diff --git a/disent/data/groundtruth/_xyobject.py b/disent/data/groundtruth/_xyobject.py index 72e83a7f..fe3e4ae8 100644 --- a/disent/data/groundtruth/_xyobject.py +++ b/disent/data/groundtruth/_xyobject.py @@ -140,11 +140,3 @@ def __getitem__(self, idx): # ========================================================================= # # END # # ========================================================================= # - -# if __name__ == '__main__': -# print(len(XYScaleColorData())) -# for i in XYScaleColorData(6, 2, 2, 4, 2): -# print(i[:, :, 0]) -# print(i[:, :, 1]) -# print(i[:, :, 2]) -# print() diff --git a/disent/data/groundtruth/base.py b/disent/data/groundtruth/base.py index e6970f72..d617cc06 100644 --- a/disent/data/groundtruth/base.py +++ b/disent/data/groundtruth/base.py @@ -105,7 +105,7 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False): super().__init__() # get root data folder if data_root is None: - data_root = os.path.abspath(os.environ.get('DISENT_DATA_ROOT', 'data/dataset')) + data_root = self.default_data_root else: data_root = os.path.abspath(data_root) # get class data folder @@ -120,6 +120,10 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False): def data_dir(self) -> str: return self._data_dir + @property + def default_data_root(self): + return os.path.abspath(os.environ.get('DISENT_DATA_ROOT', 'data/dataset')) + @property def data_objects(self) -> Sequence['DataObject']: raise NotImplementedError @@ -188,6 +192,7 @@ def data_object(self) -> 'DlH5DataObject': # ========================================================================= # # data objects # +# TODO: clean this up, this could be simplified! # # ========================================================================= # diff --git a/disent/data/util/in_out.py b/disent/data/util/in_out.py index 0b52ad8a..ceb3d01b 100644 --- a/disent/data/util/in_out.py +++ b/disent/data/util/in_out.py @@ -25,8 +25,10 @@ import logging import math import os +from typing import Dict from typing import Optional from typing import Tuple +from typing import Union from disent.util import colors as c @@ -117,10 +119,14 @@ class HashError(Exception): """ -def validate_file_hash(file: str, hash: str, hash_type='md5', hash_mode='full'): +def validate_file_hash(file: str, hash: Union[str, Dict[str, str]], hash_type='md5', hash_mode='full'): + if isinstance(hash, dict): + hash = hash[hash_mode] fhash = hash_file(file=file, hash_type=hash_type, hash_mode=hash_mode) if fhash != hash: - raise HashError(f'computed {hash_mode} {hash_type} hash: {repr(fhash)} does not match expected hash: {repr(hash)} for file: {repr(file)}') + msg = f'computed {hash_mode} {hash_type} hash: {repr(fhash)} does not match expected hash: {repr(hash)} for file: {repr(file)}' + log.error(msg) + raise HashError(msg) # ========================================================================= # @@ -143,7 +149,16 @@ class AtomicFileContext(object): # TODO: can this be cleaned up with the TemporaryDirectory and TemporaryFile classes? """ - def __init__(self, file: str, open_mode: Optional[str] = None, overwrite: bool = False, makedirs: bool = True, tmp_file: Optional[str] = None, tmp_prefix='_TEMP_.', tmp_postfix=''): + def __init__( + self, + file: str, + open_mode: Optional[str] = None, + overwrite: bool = False, + makedirs: bool = True, + tmp_file: Optional[str] = None, + tmp_prefix: str = '_TEMP_.', + tmp_postfix: str = '', + ): from pathlib import Path # check files if not file: From 154389e767f584eabac0899e3d851e38338f4729 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Thu, 3 Jun 2021 13:00:32 +0200 Subject: [PATCH 20/34] simplified dataset objects -- removed jobs --- disent/data/groundtruth/_cars3d.py | 20 +- disent/data/groundtruth/_dsprites.py | 3 + disent/data/groundtruth/_mpi3d.py | 1 + disent/data/groundtruth/_norb.py | 5 +- disent/data/groundtruth/_shapes3d.py | 2 + disent/data/groundtruth/base.py | 401 ++++++++++++++++++++------- disent/data/util/hdf5.py | 4 +- disent/data/util/in_out.py | 121 ++++++-- disent/data/util/jobs.py | 163 ----------- 9 files changed, 425 insertions(+), 295 deletions(-) delete mode 100644 disent/data/util/jobs.py diff --git a/disent/data/groundtruth/_cars3d.py b/disent/data/groundtruth/_cars3d.py index c7bbf12f..66b3c9b8 100644 --- a/disent/data/groundtruth/_cars3d.py +++ b/disent/data/groundtruth/_cars3d.py @@ -30,10 +30,9 @@ import numpy as np from scipy.io import loadmat +from disent.data.groundtruth.base import DlGenDataObject from disent.data.groundtruth.base import NumpyGroundTruthData -from disent.data.groundtruth.base import ProcessedDataObject -from disent.data.util.in_out import AtomicFileContext -from disent.data.util.jobs import CachedJobFile +from disent.data.util.in_out import AtomicSaveFile log = logging.getLogger(__name__) @@ -74,7 +73,7 @@ def resave_cars3d_archive(orig_zipped_file, new_save_file, overwrite=False): log.info(f"Extracting into temporary directory: {temp_dir}") shutil.unpack_archive(filename=orig_zipped_file, extract_dir=temp_dir) # load image paths & resave - with AtomicFileContext(new_save_file, overwrite=overwrite) as temp_file: + with AtomicSaveFile(new_save_file, overwrite=overwrite) as temp_file: images = load_cars3d_folder(raw_data_dir=os.path.join(temp_dir, 'data')) np.savez(temp_file, images=images) @@ -84,18 +83,12 @@ def resave_cars3d_archive(orig_zipped_file, new_save_file, overwrite=False): # ========================================================================= # -class Cars3dDataObject(ProcessedDataObject): +class Cars3dDataObject(DlGenDataObject): """ download the cars3d dataset and convert it to a numpy file. """ - def _make_proc_job(self, load_path: str, save_path: str): - return CachedJobFile( - make_file_fn=lambda save_path: resave_cars3d_archive(orig_zipped_file=load_path, new_save_file=save_path, overwrite=True), - path=save_path, - hash=self.file_hash, - hash_type=self.hash_type, - hash_mode=self.hash_mode - ) + def _generate(self, inp_file: str, out_file: str): + resave_cars3d_archive(orig_zipped_file=inp_file, new_save_file=out_file, overwrite=True) # ========================================================================= # @@ -134,4 +127,5 @@ class Cars3dData(NumpyGroundTruthData): if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) Cars3dData(prepare=True) diff --git a/disent/data/groundtruth/_dsprites.py b/disent/data/groundtruth/_dsprites.py index 06629a61..82ac6858 100644 --- a/disent/data/groundtruth/_dsprites.py +++ b/disent/data/groundtruth/_dsprites.py @@ -22,6 +22,8 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import logging + from disent.data.groundtruth.base import DlH5DataObject from disent.data.groundtruth.base import Hdf5GroundTruthData @@ -73,4 +75,5 @@ class DSpritesData(Hdf5GroundTruthData): if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) DSpritesData(prepare=True) diff --git a/disent/data/groundtruth/_mpi3d.py b/disent/data/groundtruth/_mpi3d.py index b14a2d1f..b7619a49 100644 --- a/disent/data/groundtruth/_mpi3d.py +++ b/disent/data/groundtruth/_mpi3d.py @@ -79,4 +79,5 @@ def data_object(self) -> DlDataObject: if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) Mpi3dData(prepare=True, in_memory=False) diff --git a/disent/data/groundtruth/_norb.py b/disent/data/groundtruth/_norb.py index 8680e968..f0d2a997 100644 --- a/disent/data/groundtruth/_norb.py +++ b/disent/data/groundtruth/_norb.py @@ -23,6 +23,8 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import gzip +import logging +import os from typing import Optional from typing import Sequence from typing import Tuple @@ -158,7 +160,7 @@ def __init__(self, data_root: Optional[str] = 'data/TEMP/dataset', prepare: bool # initialize super().__init__(data_root=data_root, prepare=prepare) # read dataset and sort by features - dat_path, cat_path, info_path = (obj.get_file_path(data_dir=self.data_dir) for obj in self.data_objects) + dat_path, cat_path, info_path = (os.path.join(self.data_dir, obj.out_name) for obj in self.data_objects) self._data, _ = read_norb_dataset(dat_path=dat_path, cat_path=cat_path, info_path=info_path) def __getitem__(self, idx): @@ -176,4 +178,5 @@ def data_objects(self) -> Sequence[DlDataObject]: if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) SmallNorbData(prepare=True) diff --git a/disent/data/groundtruth/_shapes3d.py b/disent/data/groundtruth/_shapes3d.py index 5d0f85b7..19b03e60 100644 --- a/disent/data/groundtruth/_shapes3d.py +++ b/disent/data/groundtruth/_shapes3d.py @@ -22,6 +22,7 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import logging from disent.data.groundtruth.base import DlH5DataObject from disent.data.groundtruth.base import Hdf5GroundTruthData @@ -66,4 +67,5 @@ class Shapes3dData(Hdf5GroundTruthData): # ========================================================================= # if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) Shapes3dData(prepare=True) diff --git a/disent/data/groundtruth/base.py b/disent/data/groundtruth/base.py index d617cc06..b878ec3a 100644 --- a/disent/data/groundtruth/base.py +++ b/disent/data/groundtruth/base.py @@ -22,12 +22,12 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import dataclasses import logging import os from abc import ABCMeta from typing import Callable from typing import Dict +from typing import NoReturn from typing import Optional from typing import Sequence from typing import Tuple @@ -38,10 +38,13 @@ from disent.data.util.hdf5 import hdf5_resave_file from disent.data.util.hdf5 import PickleH5pyDataset from disent.data.util.in_out import basename_from_url +from disent.data.util.in_out import modify_file_name +from disent.data.util.in_out import stalefile +from disent.data.util.in_out import download_file from disent.data.util.in_out import ensure_dir_exists from disent.data.util.in_out import retrieve_file -from disent.data.util.jobs import CachedJobFile from disent.data.util.state_space import StateSpace +from disent.util import wrapped_partial log = logging.getLogger(__name__) @@ -134,7 +137,7 @@ class NumpyGroundTruthData(DiskGroundTruthData, metaclass=ABCMeta): def __init__(self, data_root: Optional[str] = None, prepare: bool = False): super().__init__(data_root=data_root, prepare=prepare) # load dataset - self._data = np.load(self.data_object.get_file_path(self.data_dir)) + self._data = np.load(os.path.join(self.data_dir, self.data_object.out_name)) if self.data_key is not None: self._data = self._data[self.data_key] @@ -155,7 +158,6 @@ def data_key(self) -> Optional[str]: return None - class Hdf5GroundTruthData(DiskGroundTruthData, metaclass=ABCMeta): def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_memory=False): @@ -164,10 +166,10 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_me self._in_memory = in_memory # load the h5py dataset data = PickleH5pyDataset( - h5_path=self.data_object.get_file_path(data_dir=self.data_dir), - h5_dataset_name=self.data_object.hdf5_dataset_name, + h5_path=os.path.join(self.data_dir, self.data_object.out_name), + h5_dataset_name=self.data_object.out_dataset_name, ) - # handle different memroy modes + # handle different memory modes if self._in_memory: # Load the entire dataset into memory if required # indexing dataset objects returns numpy array @@ -198,104 +200,98 @@ def data_object(self) -> 'DlH5DataObject': class DataObject(object): + def __init__( + self, + file_name: str, + file_hash: Optional[Union[str, Dict[str, str]]], + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + self._file_name = file_name + self._file_hash = file_hash + self._hash_type = hash_type + self._hash_mode = hash_mode + @property - def file_name(self) -> str: - raise NotImplementedError + def out_name(self) -> str: + return self._file_name - def prepare(self, data_dir: str): - pass + def prepare(self, out_dir: str) -> str: + @stalefile(file=os.path.join(out_dir, self._file_name), hash=self._file_hash, hash_type=self._hash_type, hash_mode=self._hash_mode) + def wrapped(out_file): + self._prepare(out_dir=out_dir, out_file=out_file) + return wrapped() - def get_file_path(self, data_dir: str, file_name: Optional[str] = None, prefix: str = '', postfix: str = ''): - file_name = self.file_name if (file_name is None) else file_name - if prefix or postfix: - file_name = os.path.join(os.path.dirname(file_name), f'{prefix}{os.path.basename(file_name)}{postfix}') - return os.path.join(data_dir, file_name) + def _prepare(self, out_dir: str, out_file: str) -> str: + raise NotImplementedError class DlDataObject(DataObject): def __init__( self, - # download file/link uri: str, uri_hash: Optional[Union[str, Dict[str, str]]], - uri_name: Optional[str] = None, # automatically obtain uri name from url if None - # hash settings + uri_name: Optional[str] = None, hash_type: str = 'md5', hash_mode: str = 'fast', ): - super().__init__() - self.uri = uri - self.uri_hash = uri_hash - self.uri_name = basename_from_url(uri) if (uri_name is None) else uri_name - self.hash_mode = hash_mode - self.hash_type = hash_type - - @property - def file_name(self) -> str: - return self.uri_name - - def _make_dl_job(self, save_path: str): - return CachedJobFile( - make_file_fn=lambda path: retrieve_file( - src_uri=self.uri, - dst_path=path, - overwrite_existing=True, - ), - path=save_path, - hash=self.uri_hash, - hash_type=self.hash_type, - hash_mode=self.hash_mode, + super().__init__( + file_name=basename_from_url(uri) if (uri_name is None) else uri_name, + file_hash=uri_hash, + hash_type=hash_type, + hash_mode=hash_mode ) + self._uri = uri - def prepare(self, data_dir: str): - dl_job = self._make_dl_job(save_path=os.path.join(data_dir, self.uri_name)) - dl_job.run() + def _prepare(self, out_dir: str, out_file: str): + retrieve_file(src_uri=self._uri, dst_path=out_file, overwrite_existing=True) -class ProcessedDataObject(DlDataObject): +class DlGenDataObject(DataObject): def __init__( self, - # save path - file_name: str, - file_hash: Optional[Union[str, Dict[str, str]]], - # download file/link + # download & save files uri: str, uri_hash: Optional[Union[str, Dict[str, str]]], - uri_name: Optional[str] = None, # automatically obtain file name from url if None + file_hash: Optional[Union[str, Dict[str, str]]], + # save paths + uri_name: Optional[str] = None, + file_name: Optional[str] = None, # hash settings hash_type: str = 'md5', hash_mode: str = 'fast', ): - super().__init__(uri=uri, uri_hash=uri_hash, uri_name=uri_name, hash_mode=hash_mode, hash_type=hash_type) - assert file_name is not None - self._file_name = file_name - self.file_hash = file_hash + self._dl_obj = DlDataObject( + uri=uri, + uri_hash=uri_hash, + uri_name=uri_name, + hash_type=hash_type, + hash_mode=hash_mode, + ) + super().__init__( + file_name=modify_file_name(self._dl_obj.out_name, prefix='gen') if (file_name is None) else file_name, + file_hash=file_hash, + hash_type=hash_type, + hash_mode=hash_mode, + ) - @property - def file_name(self) -> str: - return self._file_name + def _prepare(self, out_dir: str, out_file: str): + inp_file = self._dl_obj.prepare(out_dir=out_dir) + self._generate(inp_file=inp_file, out_file=out_file) - def _make_proc_job(self, load_path: str, save_path: str): + def _generate(self, inp_file: str, out_file: str): raise NotImplementedError - def prepare(self, data_dir: str): - dl_path = os.path.join(data_dir, self.uri_name) - proc_path = os.path.join(data_dir, self.file_name) - dl_job = self._make_dl_job(save_path=dl_path) - proc_job = self._make_proc_job(load_path=dl_path, save_path=proc_path) - dl_job.set_child(proc_job).run() - -class DlH5DataObject(ProcessedDataObject): +class DlH5DataObject(DlGenDataObject): def __init__( self, - # download file/link + # download & save files uri: str, uri_hash: Optional[Union[str, Dict[str, str]]], - # save hash file_hash: Optional[Union[str, Dict[str, str]]], # h5 re-save settings hdf5_dataset_name: str, @@ -304,47 +300,260 @@ def __init__( hdf5_compression_lvl: Optional[int] = 4, hdf5_dtype: Optional[Union[np.dtype, str]] = None, hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None, - # save path - file_name: Optional[str] = None, # automatically obtain file name from url if None + # save paths + uri_name: Optional[str] = None, + file_name: Optional[str] = None, # hash settings hash_type: str = 'md5', hash_mode: str = 'fast', ): - file_name = basename_from_url(uri) if (file_name is None) else file_name super().__init__( file_name=file_name, file_hash=file_hash, uri=uri, uri_hash=uri_hash, - uri_name=file_name + '.ORIG', + uri_name=uri_name, hash_type=hash_type, hash_mode=hash_mode, ) - self.hdf5_dataset_name = hdf5_dataset_name - self.hdf5_chunk_size = hdf5_chunk_size - self.hdf5_compression = hdf5_compression - self.hdf5_compression_lvl = hdf5_compression_lvl - self.hdf5_dtype = hdf5_dtype - self.hdf5_mutator = hdf5_mutator - - def _make_proc_job(self, load_path: str, save_path: str): - return CachedJobFile( - make_file_fn=lambda save_path: hdf5_resave_file( - inp_path=load_path, - out_path=save_path, - dataset_name=self.hdf5_dataset_name, - chunk_size=self.hdf5_chunk_size, - compression=self.hdf5_compression, - compression_lvl=self.hdf5_compression_lvl, - batch_size=None, - out_dtype=self.hdf5_dtype, - out_mutator=self.hdf5_mutator, - ), - path=save_path, - hash=self.file_hash, - hash_type=self.hash_type, - hash_mode=self.hash_mode, + self._hdf5_resave_file = wrapped_partial( + hdf5_resave_file, + dataset_name=hdf5_dataset_name, + chunk_size=hdf5_chunk_size, + compression=hdf5_compression, + compression_lvl=hdf5_compression_lvl, + out_dtype=hdf5_dtype, + out_mutator=hdf5_mutator, ) + # save the dataset name + self._out_dataset_name = hdf5_dataset_name + + @property + def out_dataset_name(self) -> str: + return self._out_dataset_name + + def _generate(self, inp_file: str, out_file: str): + self._hdf5_resave_file(inp_path=inp_file, out_path=out_file) + + +# class DataObject(object): +# +# @property +# def file_name(self) -> str: +# raise NotImplementedError +# +# def prepare(self, data_dir: str): +# pass +# +# def get_path(self, data_dir, *attrs): +# paths = [os.path.join(data_dir, getattr(self, attr)) for attr in attrs] +# if len(paths) == 1: +# return paths[0] +# return paths +# +# +# class DlDataObject(DataObject): +# +# def __init__( +# self, +# # download file/link +# uri: str, +# uri_hash: Optional[Union[str, Dict[str, str]]], +# uri_name: Optional[str] = None, # automatically obtain uri name from url if None +# # hash settings +# hash_type: str = 'md5', +# hash_mode: str = 'fast', +# ): +# def _prepare(data_dir: str): +# dl_path = os.path.join(data_dir, self._uri_name) +# # cached download task +# @stalefile(file=dl_path, hash=uri_hash, hash_type=hash_type, hash_mode=hash_mode) +# def download(): +# retrieve_file(src_uri=uri, dst_path=dl_path, overwrite_existing=True) +# # run task +# return download() +# +# # instance variables +# self._uri_name = basename_from_url(uri) if (uri_name is None) else uri_name +# self._out_name = self._uri_name +# self._prepare = _prepare +# +# @property +# def file_name(self) -> str: +# return self._out_name +# +# def prepare(self, data_dir: str): +# return self._prepare(data_dir) +# +# +# class DlGenDataObject(DlDataObject): +# +# def __init__( +# self, +# # save path +# file_name: str, +# file_hash: Optional[Union[str, Dict[str, str]]], +# # download file/link +# uri: str, +# uri_hash: Optional[Union[str, Dict[str, str]]], +# uri_name: Optional[str] = None, # automatically obtain file name from url if None +# # hash settings +# hash_type: str = 'md5', +# hash_mode: str = 'fast', +# ): +# super().__init__(uri=uri, uri_hash=uri_hash, uri_name=uri_name, hash_mode=hash_mode, hash_type=hash_type) +# +# def _prepare(data_dir: str): +# proc_path = os.path.join(data_dir, file_name) +# # cached process task +# @stalefile(file=proc_path, hash=file_hash, hash_type=hash_type, hash_mode=hash_mode) +# def process(): +# self._process_file(inp_file=self._prepare_dl(data_dir), out_file=proc_path) +# # run task +# return process() +# +# self._prepare, self._prepare_dl = _prepare, self._prepare +# self._out_name = file_name +# +# def _process_file(self, inp_file: str, out_file: str): +# raise NotImplementedError +# +# +# class DlH5DataObject(DlGenDataObject): +# +# def __init__( +# self, +# # download file/link +# uri: str, +# uri_hash: Optional[Union[str, Dict[str, str]]], +# # save hash +# file_hash: Optional[Union[str, Dict[str, str]]], +# # h5 re-save settings +# hdf5_dataset_name: str, +# hdf5_chunk_size: Tuple[int, ...], +# hdf5_compression: Optional[str] = 'gzip', +# hdf5_compression_lvl: Optional[int] = 4, +# hdf5_dtype: Optional[Union[np.dtype, str]] = None, +# hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None, +# # save path +# file_name: Optional[str] = None, # automatically obtain file name from url if None +# # hash settings +# hash_type: str = 'md5', +# hash_mode: str = 'fast', +# ): +# file_name = basename_from_url(uri) if (file_name is None) else file_name +# uri_name = f'dl.{file_name}' +# self.hdf5_dataset_name = hdf5_dataset_name +# +# super().__init__( +# file_name=file_name, +# file_hash=file_hash, +# uri=uri, +# uri_hash=uri_hash, +# uri_name=uri_name, +# hash_type=hash_type, +# hash_mode=hash_mode +# ) +# +# def prepare_file(inp_file: str, out_file: str): +# hdf5_resave_file( +# inp_path=inp_file, +# out_path=out_file, +# dataset_name=hdf5_dataset_name, +# chunk_size=hdf5_chunk_size, +# compression=hdf5_compression, +# compression_lvl=hdf5_compression_lvl, +# batch_size=None, +# out_dtype=hdf5_dtype, +# out_mutator=hdf5_mutator, +# ) +# +# self._process_file = prepare_file + + +DataFilePrepare = Callable[[str, str], NoReturn] +DataPrepare = Callable[[str], str] + + +def stalefile_prepare( + out_name, + out_hash: Union[str, Dict[str, str]], + hash_type: str = 'md5', + hash_mode: str = 'fast', +) -> Callable[[DataFilePrepare], DataPrepare]: + def wrapper(func: DataFilePrepare) -> DataPrepare: + def prepare(data_dir: str) -> str: + out_file = os.path.join(data_dir, out_name) + @stalefile(file=out_file, hash=out_hash, hash_type=hash_type, hash_mode=hash_mode) + def run(): + func(data_dir, out_file) + return run() + return prepare + return wrapper + + +# def data_object_downloader( +# uri: str, +# uri_hash: Optional[Union[str, Dict[str, str]]], +# uri_name: Optional[str] = None, +# hash_type: str = 'md5', +# hash_mode: str = 'fast', +# ) -> DataObject: +# +# uri_name = basename_from_url(uri) if (uri_name is None) else uri_name +# +# @prepare_stalefile(out_name=uri_name, out_hash=uri_hash, hash_type=hash_type, hash_mode=hash_mode) +# def download(out_dir: str, out_path: str): +# retrieve_file(src_uri=uri, dst_path=out_path, overwrite_existing=True) +# return out_path +# +# return download +# +# +# def data_object_download_h5( +# # save hash +# file_hash: Optional[Union[str, Dict[str, str]]], +# # download file/link +# uri: str, +# uri_hash: Optional[Union[str, Dict[str, str]]], +# # h5 re-save settings +# hdf5_dataset_name: str, +# hdf5_chunk_size: Tuple[int, ...], +# hdf5_compression: Optional[str] = 'gzip', +# hdf5_compression_lvl: Optional[int] = 4, +# hdf5_dtype: Optional[Union[np.dtype, str]] = None, +# hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None, +# # file names +# file_name: Optional[str] = None, +# uri_name: Optional[str] = None, +# # hash settings +# hash_type: str = 'md5', +# hash_mode: str = 'fast', +# ) -> DataObject: +# +# file_name = basename_from_url(uri) if (file_name is None) else file_name +# uri_name = f'dl.{basename_from_url(uri)}' if (uri_name is None) else uri_name +# +# @prepare_stalefile(out_name=uri_name, out_hash=uri_hash, hash_type=hash_type, hash_mode=hash_mode) +# def download_to(out_dir: str, out_path: str): +# retrieve_file(src_uri=uri, dst_path=out_path, overwrite_existing=True) +# +# @prepare_stalefile(out_name=file_name, out_hash=file_hash, hash_type=hash_type, hash_mode=hash_mode) +# def process_to(out_dir: str, out_path: str): +# inp_path = download_to(out_dir) +# hdf5_resave_file( +# inp_path=inp_path, +# out_path=out_path, +# dataset_name=hdf5_dataset_name, +# chunk_size=hdf5_chunk_size, +# compression=hdf5_compression, +# compression_lvl=hdf5_compression_lvl, +# out_dtype=hdf5_dtype, +# out_mutator=hdf5_mutator, +# ) +# +# return process_to + # ========================================================================= # diff --git a/disent/data/util/hdf5.py b/disent/data/util/hdf5.py index e761f5d4..6fbb4390 100644 --- a/disent/data/util/hdf5.py +++ b/disent/data/util/hdf5.py @@ -32,7 +32,7 @@ import numpy as np from tqdm import tqdm -from disent.data.util.in_out import AtomicFileContext +from disent.data.util.in_out import AtomicSaveFile from disent.data.util.in_out import bytes_to_human from disent.util import colors as c from disent.util import iter_chunks @@ -173,7 +173,7 @@ def hdf5_resave_dataset(inp_h5: h5py.File, out_h5: h5py.File, dataset_name, chun def hdf5_resave_file(inp_path: str, out_path: str, dataset_name, chunk_size=None, compression=None, compression_lvl=None, batch_size=None, out_dtype=None, out_mutator=None): # re-save datasets with h5py.File(inp_path, 'r') as inp_h5: - with AtomicFileContext(out_path, open_mode=None, overwrite=True) as tmp_h5_path: + with AtomicSaveFile(out_path, open_mode=None, overwrite=True) as tmp_h5_path: with h5py.File(tmp_h5_path, 'w', libver='earliest') as out_h5: # TODO: libver='latest' is not deterministic, even with track_times=False hdf5_resave_dataset( inp_h5=inp_h5, diff --git a/disent/data/util/in_out.py b/disent/data/util/in_out.py index ceb3d01b..1810fb05 100644 --- a/disent/data/util/in_out.py +++ b/disent/data/util/in_out.py @@ -25,10 +25,15 @@ import logging import math import os +from functools import wraps +from pathlib import Path +from typing import Callable from typing import Dict +from typing import NoReturn from typing import Optional from typing import Tuple from typing import Union +from uuid import uuid4 from disent.util import colors as c @@ -57,6 +62,18 @@ def bytes_to_human(size_bytes, decimals=3, color=True): return f"{s:{4+decimals}.{decimals}f} {name}" +def modify_file_name(file: Union[str, Path], prefix: str = None, suffix: str = None, sep='.') -> Union[str, Path]: + # get path components + path = Path(file) + assert path.name, f'file name cannot be empty: {repr(path)}, for name: {repr(path.name)}' + # create new path + prefix = '' if (prefix is None) else f'{prefix}{sep}' + suffix = '' if (suffix is None) else f'{sep}{suffix}' + new_path = path.parent.joinpath(f'{prefix}{path.name}{suffix}') + # return path + return str(new_path) if isinstance(file, str) else new_path + + # ========================================================================= # # file hashing # # ========================================================================= # @@ -88,15 +105,21 @@ def yield_fast_hash_bytes(file: str, chunk_size=16384, num_chunks=3): yield f.read(chunk_size) -def hash_file(file: str, hash_type='md5', hash_mode='full') -> str: +def hash_file(file: str, hash_type='md5', hash_mode='full', missing_ok=True) -> str: """ :param file: the path to the file :param hash_type: the kind of hash to compute, default is "md5" :param hash_mode: "full" uses all the bytes in the file to compute the hash, "fast" uses the start, middle, end bytes as well as the size of the file in the hash. :param chunk_size: number of bytes to read at a time :return: the hexdigest of the hash + :raises FileNotFoundError """ import hashlib + # check the file exists + if not os.path.isfile(file): + if missing_ok: + return '' + raise FileNotFoundError(f'could not compute hash for missing file: {repr(file)}') # get file bytes iterator if hash_mode == 'full': byte_iter = yield_file_bytes(file=file) @@ -119,14 +142,74 @@ class HashError(Exception): """ -def validate_file_hash(file: str, hash: Union[str, Dict[str, str]], hash_type='md5', hash_mode='full'): - if isinstance(hash, dict): - hash = hash[hash_mode] - fhash = hash_file(file=file, hash_type=hash_type, hash_mode=hash_mode) +def get_hash(hash: Union[str, Dict[str, str]], hash_mode: str) -> str: + return hash[hash_mode] if isinstance(hash, dict) else hash + + +def validate_file_hash(file: str, hash: Union[str, Dict[str, str]], hash_type: str = 'md5', hash_mode: str = 'full', missing_ok=True): + """ + :raises FileNotFoundError, HashError + """ + hash = get_hash(hash=hash, hash_mode=hash_mode) + # compute the hash + fhash = hash_file(file=file, hash_type=hash_type, hash_mode=hash_mode, missing_ok=missing_ok) + # check the hash if fhash != hash: - msg = f'computed {hash_mode} {hash_type} hash: {repr(fhash)} does not match expected hash: {repr(hash)} for file: {repr(file)}' - log.error(msg) - raise HashError(msg) + raise HashError(f'computed {hash_mode} {hash_type} hash: {repr(fhash)} does not match expected hash: {repr(hash)} for file: {repr(file)}') + + +def is_valid_file_hash(file: str, hash: Union[str, Dict[str, str]], hash_type: str = 'md5', hash_mode: str = 'full', missing_ok=True): + try: + validate_file_hash(file=file, hash=hash, hash_type=hash_type, hash_mode=hash_mode, missing_ok=missing_ok) + except HashError: + return False + return True + + +# ========================================================================= # +# Function Caching # +# ========================================================================= # + + +class stalefile(object): + + def __init__( + self, + file: str, + hash: Optional[Union[str, Dict[str, str]]], + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + self.file = file + self.hash = get_hash(hash=hash, hash_mode=hash_mode) + self.hash_type = hash_type + self.hash_mode = hash_mode + + def __call__(self, func: Callable[[str], NoReturn]) -> Callable[[], str]: + @wraps(func) + def wrapper() -> str: + if self.is_stale(): + log.debug(f'calling wrapped function: {func} because the file is stale: {repr(self.file)}') + func(self.file) + validate_file_hash(self.file, hash=self.hash, hash_type=self.hash_type, hash_mode=self.hash_mode) + else: + log.debug(f'skipped wrapped function: {func} because the file is fresh: {repr(self.file)}') + return self.file + return wrapper + + def is_stale(self): + fhash = hash_file(file=self.file, hash_type=self.hash_type, hash_mode=self.hash_mode, missing_ok=True) + if not fhash: + log.info(f'file is stale because it does not exist: {repr(self.file)}') + return True + if fhash != self.hash: + log.info(f'file is stale because the computed {self.hash_mode} {self.hash_type} hash: {fhash} does not match the target hash: {self.hash} for file: {repr(self.file)}') + return True + log.info(f'file is fresh: {repr(self.file)}') + return False + + def __bool__(self): + return self.is_stale() # ========================================================================= # @@ -134,11 +217,13 @@ def validate_file_hash(file: str, hash: Union[str, Dict[str, str]], hash_type='m # ========================================================================= # -class AtomicFileContext(object): +class AtomicSaveFile(object): """ Within the context, data must be written to a temporary file. Once data has been successfully written, the temporary file - is moved to the location of the given file. + is moved to the location of the target file. + + The temporary file is created in the same directory as the target file. ``` with AtomicFileHandler('file.txt') as tmp_file: @@ -155,20 +240,16 @@ def __init__( open_mode: Optional[str] = None, overwrite: bool = False, makedirs: bool = True, - tmp_file: Optional[str] = None, - tmp_prefix: str = '_TEMP_.', - tmp_postfix: str = '', + tmp_prefix: Optional[str] = '.temp.', + tmp_suffix: Optional[str] = None, ): from pathlib import Path # check files if not file: raise ValueError(f'file must not be empty: {repr(file)}') - if not tmp_file and (tmp_file is not None): - raise ValueError(f'tmp_file must not be empty: {repr(tmp_file)}') # get files self.trg_file = Path(file).absolute() - tmp_file = Path(self.trg_file if (tmp_file is None) else tmp_file) - self.tmp_file = tmp_file.parent.joinpath(f'{tmp_prefix}{tmp_file.name}{tmp_postfix}') + self.tmp_file = modify_file_name(self.trg_file, prefix=f'{tmp_prefix}{uuid4()}', suffix=tmp_suffix) # check that the files are different if self.trg_file == self.tmp_file: raise ValueError(f'temporary and target files are the same: {self.tmp_file} == {self.trg_file}') @@ -269,7 +350,7 @@ def download_file(url: str, save_path: str, overwrite_existing: bool = False, ch import requests from tqdm import tqdm # write the file - with AtomicFileContext(file=save_path, open_mode='wb', overwrite=overwrite_existing) as (_, file): + with AtomicSaveFile(file=save_path, open_mode='wb', overwrite=overwrite_existing) as (_, file): response = requests.get(url, stream=True) total_length = response.headers.get('content-length') # cast to integer if content-length exists on response @@ -288,12 +369,12 @@ def copy_file(src: str, dst: str, overwrite_existing: bool = False): if os.path.abspath(src) == os.path.abspath(dst): raise FileExistsError(f'input and output paths for copy are the same, skipping: {repr(dst)}') else: - with AtomicFileContext(file=dst, overwrite=overwrite_existing) as path: + with AtomicSaveFile(file=dst, overwrite=overwrite_existing) as path: import shutil shutil.copyfile(src, path) -def retrieve_file(src_uri: str, dst_path: str, overwrite_existing: bool = True): +def retrieve_file(src_uri: str, dst_path: str, overwrite_existing: bool = False): uri, is_url = _uri_parse_file_or_url(src_uri) if is_url: download_file(url=uri, save_path=dst_path, overwrite_existing=overwrite_existing) diff --git a/disent/data/util/jobs.py b/disent/data/util/jobs.py deleted file mode 100644 index aa41a6bf..00000000 --- a/disent/data/util/jobs.py +++ /dev/null @@ -1,163 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -import os -from abc import ABCMeta -from typing import Callable -from typing import Dict -from typing import NoReturn -from typing import Optional -from typing import Union - -from disent.data.util.in_out import hash_file - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Base Job # -# ========================================================================= # - - -class CachedJob(object): - """ - Base class for cached jobs. A job is some arbitrary directed chains - of computations where child jobs depend on parent jobs, and jobs - can be skipped if it has already been run and the cache is valid. - - Jobs are always deterministic, and if run and cached should never go out of date. - - NOTE: if needed it would be easy to add support directed acyclic graphs, and sub-graphs - NOTE: this is probably overkill, but it makes the code to write a new dataset nice and clean... - """ - - def __init__(self, job_fn: Callable[[], NoReturn], is_cached_fn: Callable[[], bool]): - self._parent = None - self._child = None - self._job_fn = job_fn - self._is_cached_fn = is_cached_fn - - def __repr__(self): - return f'{self.__class__.__name__}' - - def set_parent(self, parent: 'CachedJob'): - if not isinstance(parent, CachedJob): - raise TypeError(f'{self}: parent job was not an instance of: {CachedJob.__class__}') - if self._parent is not None: - raise RuntimeError(f'{self}: parent has already been set') - if parent._child is not None: - raise RuntimeError(f'{parent}: child has already been set') - self._parent = parent - parent._child = self - return parent - - def set_child(self, child: 'CachedJob'): - child.set_parent(self) - return child - - def run(self, force=False, recursive=False) -> 'CachedJob': - # visit parents always - if recursive: - if self._parent is not None: - self._parent.run(force=force, recursive=recursive) - # skip if fresh - if not force: - if not self._is_cached_fn(): - log.debug(f'{self}: skipped non-stale job') - return self - # don't visit parents if fresh - if not recursive: - if self._parent is not None: - self._parent.run(force=force, recursive=recursive) - # run nodes - log.debug(f'{self}: run stale job') - self._job_fn() - return self - - -# ========================================================================= # -# Base File Job # -# ========================================================================= # - - -class CachedJobFile(CachedJob): - - """ - An abstract cached job that only runs if a file does not exist, - or the files hash sum does not match the given value. - """ - - def __init__( - self, - make_file_fn: Callable[[str], NoReturn], - path: str, - hash: Union[str, Dict[str, str]], - hash_type: str = 'md5', - hash_mode: str = 'full', - ): - # set attributes - self.path = path - self.hash: Optional[str] = hash if ((hash is None) or isinstance(hash, str)) else hash[hash_mode] - self.hash_type = hash_type - self.hash_mode = hash_mode - # generate - self._make_file_fn = make_file_fn - # check hash - super().__init__(job_fn=self.__job_fn, is_cached_fn=self.__is_cached_fn) - - def __compute_hash(self) -> str: - return hash_file(self.path, hash_type=self.hash_type, hash_mode=self.hash_mode) - - def __is_cached_fn(self) -> bool: - # stale if the file does not exist - if not os.path.exists(self.path): - log.warning(f'{self}: stale because file does not exist: {repr(self.path)}') - return True - # stale if the hash does not match - fhash = self.__compute_hash() - if self.hash is None: - log.warning(f'{self}: not stale because it exists and no target hash was given. current {self.hash_mode} {self.hash_type} hash is: {fhash} for: {repr(self.path)}') - elif self.hash != fhash: - log.warning(f'{self}: stale because computed {self.hash_mode} {self.hash_type} hash: {repr(fhash)} does not match expected hash: {repr(self.hash)} for: {repr(self.path)}') - return True - # not stale, we don't need to do anything! - return False - - def __job_fn(self): - self._make_file_fn(self.path) - # check the hash - fhash = self.__compute_hash() - if self.hash is None: - log.warning(f'{self}: could not verify generated file because no target hash was given. current {self.hash_mode} {self.hash_type} hash is: {fhash} for: {repr(self.path)}') - elif self.hash != fhash: - raise RuntimeError(f'{self}: error because computed {self.hash_mode} {self.hash_type} hash: {repr(fhash)} does not match expected hash: {repr(self.hash)} for: {repr(self.path)}') - else: - log.debug(f'{self}: successfully generated file: {repr(self.path)} with correct {self.hash_mode} {self.hash_type} hash: {fhash}') - - -# ========================================================================= # -# END # -# ========================================================================= # From cade4f3aae4331d0a75927639b5ccb0108cc2d7e Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Thu, 3 Jun 2021 13:32:24 +0200 Subject: [PATCH 21/34] cleaned up and commented base GroundTruthData --- disent/data/groundtruth/base.py | 293 +++++++------------------------- disent/util/__init__.py | 3 +- 2 files changed, 64 insertions(+), 232 deletions(-) diff --git a/disent/data/groundtruth/base.py b/disent/data/groundtruth/base.py index b878ec3a..e46e2eb7 100644 --- a/disent/data/groundtruth/base.py +++ b/disent/data/groundtruth/base.py @@ -27,6 +27,7 @@ from abc import ABCMeta from typing import Callable from typing import Dict +from typing import final from typing import NoReturn from typing import Optional from typing import Sequence @@ -56,6 +57,9 @@ class GroundTruthData(StateSpace): + """ + Dataset that corresponds to some state space or ground truth factors + """ def __init__(self): assert len(self.factor_names) == len(self.factor_sizes), 'Dimensionality mismatch of FACTOR_NAMES and FACTOR_DIMS' @@ -99,11 +103,18 @@ def __getitem__(self, idx): # ========================================================================= # # disk ground truth data # +# TODO: data & data_object preparation should be split out from # +# GroundTruthData, instead GroundTruthData should be a wrapper # # ========================================================================= # class DiskGroundTruthData(GroundTruthData, metaclass=ABCMeta): + """ + Dataset that prepares a list DataObjects into some local directory. + - This directory can be + """ + def __init__(self, data_root: Optional[str] = None, prepare: bool = False): super().__init__() # get root data folder @@ -133,6 +144,10 @@ def data_objects(self) -> Sequence['DataObject']: class NumpyGroundTruthData(DiskGroundTruthData, metaclass=ABCMeta): + """ + Dataset that loads a numpy file from a DataObject + - if the dataset is contained in a key, set the `data_key` property + """ def __init__(self, data_root: Optional[str] = None, prepare: bool = False): super().__init__(data_root=data_root, prepare=prepare) @@ -159,6 +174,11 @@ def data_key(self) -> Optional[str]: class Hdf5GroundTruthData(DiskGroundTruthData, metaclass=ABCMeta): + """ + Dataset that loads an Hdf5 file from a DataObject + - requires that the data object has the `out_dataset_name` attribute + that points to the hdf5 dataset in the file to load. + """ def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_memory=False): super().__init__(data_root=data_root, prepare=prepare) @@ -194,11 +214,35 @@ def data_object(self) -> 'DlH5DataObject': # ========================================================================= # # data objects # -# TODO: clean this up, this could be simplified! # # ========================================================================= # -class DataObject(object): +class DataObject(object, metaclass=ABCMeta): + """ + base DataObject that does nothing, if the file does + not exist or it has the incorrect hash, then that's your problem! + """ + + def __init__(self, file_name: str): + self._file_name = file_name + + @final + @property + def out_name(self) -> str: + return self._file_name + + def prepare(self, out_dir: str) -> str: + # TODO: maybe check that the file exists or not and raise a FileNotFoundError? + pass + + +class HashedDataObject(DataObject, metaclass=ABCMeta): + """ + Abstract Class + - Base DataObject class that guarantees a file to exist, + if the file does not exist, or the hash of the file is + incorrect, then the file is re-generated. + """ def __init__( self, @@ -207,15 +251,11 @@ def __init__( hash_type: str = 'md5', hash_mode: str = 'fast', ): - self._file_name = file_name + super().__init__(file_name=file_name) self._file_hash = file_hash self._hash_type = hash_type self._hash_mode = hash_mode - @property - def out_name(self) -> str: - return self._file_name - def prepare(self, out_dir: str) -> str: @stalefile(file=os.path.join(out_dir, self._file_name), hash=self._file_hash, hash_type=self._hash_type, hash_mode=self._hash_mode) def wrapped(out_file): @@ -223,10 +263,16 @@ def wrapped(out_file): return wrapped() def _prepare(self, out_dir: str, out_file: str) -> str: + # TODO: maybe raise a FileNotFoundError or a HashError instead? raise NotImplementedError -class DlDataObject(DataObject): +class DlDataObject(HashedDataObject): + """ + Download a file + - uri can also be a file to perform a copy instead of download, + useful for example if you want to retrieve a file from a network drive. + """ def __init__( self, @@ -248,7 +294,11 @@ def _prepare(self, out_dir: str, out_file: str): retrieve_file(src_uri=self._uri, dst_path=out_file, overwrite_existing=True) -class DlGenDataObject(DataObject): +class DlGenDataObject(HashedDataObject, metaclass=ABCMeta): + """ + Abstract class + - download a file and perform some processing on that file. + """ def __init__( self, @@ -286,6 +336,9 @@ def _generate(self, inp_file: str, out_file: str): class DlH5DataObject(DlGenDataObject): + """ + Downloads an hdf5 file and pre-processes it into the specified chunk_size. + """ def __init__( self, @@ -336,229 +389,7 @@ def _generate(self, inp_file: str, out_file: str): self._hdf5_resave_file(inp_path=inp_file, out_path=out_file) -# class DataObject(object): -# -# @property -# def file_name(self) -> str: -# raise NotImplementedError -# -# def prepare(self, data_dir: str): -# pass -# -# def get_path(self, data_dir, *attrs): -# paths = [os.path.join(data_dir, getattr(self, attr)) for attr in attrs] -# if len(paths) == 1: -# return paths[0] -# return paths -# -# -# class DlDataObject(DataObject): -# -# def __init__( -# self, -# # download file/link -# uri: str, -# uri_hash: Optional[Union[str, Dict[str, str]]], -# uri_name: Optional[str] = None, # automatically obtain uri name from url if None -# # hash settings -# hash_type: str = 'md5', -# hash_mode: str = 'fast', -# ): -# def _prepare(data_dir: str): -# dl_path = os.path.join(data_dir, self._uri_name) -# # cached download task -# @stalefile(file=dl_path, hash=uri_hash, hash_type=hash_type, hash_mode=hash_mode) -# def download(): -# retrieve_file(src_uri=uri, dst_path=dl_path, overwrite_existing=True) -# # run task -# return download() -# -# # instance variables -# self._uri_name = basename_from_url(uri) if (uri_name is None) else uri_name -# self._out_name = self._uri_name -# self._prepare = _prepare -# -# @property -# def file_name(self) -> str: -# return self._out_name -# -# def prepare(self, data_dir: str): -# return self._prepare(data_dir) -# -# -# class DlGenDataObject(DlDataObject): -# -# def __init__( -# self, -# # save path -# file_name: str, -# file_hash: Optional[Union[str, Dict[str, str]]], -# # download file/link -# uri: str, -# uri_hash: Optional[Union[str, Dict[str, str]]], -# uri_name: Optional[str] = None, # automatically obtain file name from url if None -# # hash settings -# hash_type: str = 'md5', -# hash_mode: str = 'fast', -# ): -# super().__init__(uri=uri, uri_hash=uri_hash, uri_name=uri_name, hash_mode=hash_mode, hash_type=hash_type) -# -# def _prepare(data_dir: str): -# proc_path = os.path.join(data_dir, file_name) -# # cached process task -# @stalefile(file=proc_path, hash=file_hash, hash_type=hash_type, hash_mode=hash_mode) -# def process(): -# self._process_file(inp_file=self._prepare_dl(data_dir), out_file=proc_path) -# # run task -# return process() -# -# self._prepare, self._prepare_dl = _prepare, self._prepare -# self._out_name = file_name -# -# def _process_file(self, inp_file: str, out_file: str): -# raise NotImplementedError -# -# -# class DlH5DataObject(DlGenDataObject): -# -# def __init__( -# self, -# # download file/link -# uri: str, -# uri_hash: Optional[Union[str, Dict[str, str]]], -# # save hash -# file_hash: Optional[Union[str, Dict[str, str]]], -# # h5 re-save settings -# hdf5_dataset_name: str, -# hdf5_chunk_size: Tuple[int, ...], -# hdf5_compression: Optional[str] = 'gzip', -# hdf5_compression_lvl: Optional[int] = 4, -# hdf5_dtype: Optional[Union[np.dtype, str]] = None, -# hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None, -# # save path -# file_name: Optional[str] = None, # automatically obtain file name from url if None -# # hash settings -# hash_type: str = 'md5', -# hash_mode: str = 'fast', -# ): -# file_name = basename_from_url(uri) if (file_name is None) else file_name -# uri_name = f'dl.{file_name}' -# self.hdf5_dataset_name = hdf5_dataset_name -# -# super().__init__( -# file_name=file_name, -# file_hash=file_hash, -# uri=uri, -# uri_hash=uri_hash, -# uri_name=uri_name, -# hash_type=hash_type, -# hash_mode=hash_mode -# ) -# -# def prepare_file(inp_file: str, out_file: str): -# hdf5_resave_file( -# inp_path=inp_file, -# out_path=out_file, -# dataset_name=hdf5_dataset_name, -# chunk_size=hdf5_chunk_size, -# compression=hdf5_compression, -# compression_lvl=hdf5_compression_lvl, -# batch_size=None, -# out_dtype=hdf5_dtype, -# out_mutator=hdf5_mutator, -# ) -# -# self._process_file = prepare_file - - -DataFilePrepare = Callable[[str, str], NoReturn] -DataPrepare = Callable[[str], str] - - -def stalefile_prepare( - out_name, - out_hash: Union[str, Dict[str, str]], - hash_type: str = 'md5', - hash_mode: str = 'fast', -) -> Callable[[DataFilePrepare], DataPrepare]: - def wrapper(func: DataFilePrepare) -> DataPrepare: - def prepare(data_dir: str) -> str: - out_file = os.path.join(data_dir, out_name) - @stalefile(file=out_file, hash=out_hash, hash_type=hash_type, hash_mode=hash_mode) - def run(): - func(data_dir, out_file) - return run() - return prepare - return wrapper - - -# def data_object_downloader( -# uri: str, -# uri_hash: Optional[Union[str, Dict[str, str]]], -# uri_name: Optional[str] = None, -# hash_type: str = 'md5', -# hash_mode: str = 'fast', -# ) -> DataObject: -# -# uri_name = basename_from_url(uri) if (uri_name is None) else uri_name -# -# @prepare_stalefile(out_name=uri_name, out_hash=uri_hash, hash_type=hash_type, hash_mode=hash_mode) -# def download(out_dir: str, out_path: str): -# retrieve_file(src_uri=uri, dst_path=out_path, overwrite_existing=True) -# return out_path -# -# return download -# -# -# def data_object_download_h5( -# # save hash -# file_hash: Optional[Union[str, Dict[str, str]]], -# # download file/link -# uri: str, -# uri_hash: Optional[Union[str, Dict[str, str]]], -# # h5 re-save settings -# hdf5_dataset_name: str, -# hdf5_chunk_size: Tuple[int, ...], -# hdf5_compression: Optional[str] = 'gzip', -# hdf5_compression_lvl: Optional[int] = 4, -# hdf5_dtype: Optional[Union[np.dtype, str]] = None, -# hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None, -# # file names -# file_name: Optional[str] = None, -# uri_name: Optional[str] = None, -# # hash settings -# hash_type: str = 'md5', -# hash_mode: str = 'fast', -# ) -> DataObject: -# -# file_name = basename_from_url(uri) if (file_name is None) else file_name -# uri_name = f'dl.{basename_from_url(uri)}' if (uri_name is None) else uri_name -# -# @prepare_stalefile(out_name=uri_name, out_hash=uri_hash, hash_type=hash_type, hash_mode=hash_mode) -# def download_to(out_dir: str, out_path: str): -# retrieve_file(src_uri=uri, dst_path=out_path, overwrite_existing=True) -# -# @prepare_stalefile(out_name=file_name, out_hash=file_hash, hash_type=hash_type, hash_mode=hash_mode) -# def process_to(out_dir: str, out_path: str): -# inp_path = download_to(out_dir) -# hdf5_resave_file( -# inp_path=inp_path, -# out_path=out_path, -# dataset_name=hdf5_dataset_name, -# chunk_size=hdf5_chunk_size, -# compression=hdf5_compression, -# compression_lvl=hdf5_compression_lvl, -# out_dtype=hdf5_dtype, -# out_mutator=hdf5_mutator, -# ) -# -# return process_to - - - # ========================================================================= # # END # # ========================================================================= # - - diff --git a/disent/util/__init__.py b/disent/util/__init__.py index 857b98d3..c9c659d8 100644 --- a/disent/util/__init__.py +++ b/disent/util/__init__.py @@ -27,6 +27,7 @@ import os import time from collections import Sequence +from contextlib import ContextDecorator from itertools import islice from typing import List @@ -283,7 +284,7 @@ def __getitem__(self, item): # ========================================================================= # -class Timer: +class Timer(ContextDecorator): """ Timer class, can be used with a with statement to From 85e0d62b078e54b3a44a81916c79750d280d2d72 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Thu, 3 Jun 2021 14:15:22 +0200 Subject: [PATCH 22/34] moved reductions --- disent/frameworks/helper/latent_distributions.py | 2 +- disent/frameworks/helper/reconstructions.py | 4 ++-- disent/frameworks/vae/_unsupervised__dfcvae.py | 2 +- disent/nn/{reductions.py => loss/reduction.py} | 0 experiment/exp/util/_loss.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) rename disent/nn/{reductions.py => loss/reduction.py} (100%) diff --git a/disent/frameworks/helper/latent_distributions.py b/disent/frameworks/helper/latent_distributions.py index a5f3b2a6..ad7ac0a9 100644 --- a/disent/frameworks/helper/latent_distributions.py +++ b/disent/frameworks/helper/latent_distributions.py @@ -34,7 +34,7 @@ from disent.frameworks.helper.util import compute_ave_loss from disent.nn.loss.kl import kl_loss -from disent.nn.reductions import loss_reduction +from disent.nn.loss.reduction import loss_reduction # ========================================================================= # diff --git a/disent/frameworks/helper/reconstructions.py b/disent/frameworks/helper/reconstructions.py index 23cb9ac2..4aca1138 100644 --- a/disent/frameworks/helper/reconstructions.py +++ b/disent/frameworks/helper/reconstructions.py @@ -34,8 +34,8 @@ from disent.frameworks.helper.util import compute_ave_loss from disent.nn.modules import DisentModule -from disent.nn.reductions import batch_loss_reduction -from disent.nn.reductions import loss_reduction +from disent.nn.loss.reduction import batch_loss_reduction +from disent.nn.loss.reduction import loss_reduction from disent.nn.transform import FftKernel diff --git a/disent/frameworks/vae/_unsupervised__dfcvae.py b/disent/frameworks/vae/_unsupervised__dfcvae.py index f9f480f2..6da2cfa0 100644 --- a/disent/frameworks/vae/_unsupervised__dfcvae.py +++ b/disent/frameworks/vae/_unsupervised__dfcvae.py @@ -36,7 +36,7 @@ from torchvision.models import vgg19_bn from torch.nn import functional as F -from disent.nn.reductions import get_mean_loss_scale +from disent.nn.loss.reduction import get_mean_loss_scale from disent.frameworks.helper.util import compute_ave_loss from disent.frameworks.vae._unsupervised__betavae import BetaVae from disent.nn.transform.functional import check_tensor diff --git a/disent/nn/reductions.py b/disent/nn/loss/reduction.py similarity index 100% rename from disent/nn/reductions.py rename to disent/nn/loss/reduction.py diff --git a/experiment/exp/util/_loss.py b/experiment/exp/util/_loss.py index 0e54ba3e..73ddfff8 100644 --- a/experiment/exp/util/_loss.py +++ b/experiment/exp/util/_loss.py @@ -26,7 +26,7 @@ import torch_optimizer from torch.nn import functional as F -from disent.nn.reductions import batch_loss_reduction +from disent.nn.loss.reduction import batch_loss_reduction # ========================================================================= # From 6a4a935517b882719414bba298a87686a8935c19 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Thu, 3 Jun 2021 14:16:52 +0200 Subject: [PATCH 23/34] moved softsort loss functions --- disent/frameworks/vae/experimental/_unsupervised__dorvae.py | 5 ++--- disent/{util/math_loss.py => nn/loss/softsort.py} | 0 experiment/exp/05_adversarial_data/run_01_sort_loss.py | 4 ++-- .../05_adversarial_data/run_03_train_disentangle_kernel.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) rename disent/{util/math_loss.py => nn/loss/softsort.py} (100%) diff --git a/disent/frameworks/vae/experimental/_unsupervised__dorvae.py b/disent/frameworks/vae/experimental/_unsupervised__dorvae.py index dca64f30..0be6b6e2 100644 --- a/disent/frameworks/vae/experimental/_unsupervised__dorvae.py +++ b/disent/frameworks/vae/experimental/_unsupervised__dorvae.py @@ -33,10 +33,9 @@ from disent.frameworks.helper.reconstructions import make_reconstruction_loss from disent.frameworks.helper.reconstructions import ReconLossHandler from disent.frameworks.vae._supervised__tvae import TripletVae -from disent.frameworks.vae._unsupervised__betavae import BetaVae from disent.frameworks.vae._weaklysupervised__adavae import AdaVae -from disent.util.math_loss import torch_mse_rank_loss -from disent.util.math_loss import spearman_rank_loss +from disent.nn.loss.softsort import torch_mse_rank_loss +from disent.nn.loss.softsort import spearman_rank_loss from experiment.util.hydra_utils import instantiate_recursive diff --git a/disent/util/math_loss.py b/disent/nn/loss/softsort.py similarity index 100% rename from disent/util/math_loss.py rename to disent/nn/loss/softsort.py diff --git a/experiment/exp/05_adversarial_data/run_01_sort_loss.py b/experiment/exp/05_adversarial_data/run_01_sort_loss.py index efe668bc..ff36af1a 100644 --- a/experiment/exp/05_adversarial_data/run_01_sort_loss.py +++ b/experiment/exp/05_adversarial_data/run_01_sort_loss.py @@ -27,8 +27,8 @@ from torch.utils.data import DataLoader import experiment.exp.util as H -from disent.util.math_loss import multi_spearman_rank_loss -from disent.util.math_loss import torch_soft_rank +from disent.nn.loss.softsort import multi_spearman_rank_loss +from disent.nn.loss.softsort import torch_soft_rank # ========================================================================= # diff --git a/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py b/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py index 3f4c6e7a..3144caa4 100644 --- a/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py +++ b/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py @@ -43,7 +43,7 @@ from disent.util import make_box_str from disent.util import seed from disent.util.math import torch_conv2d_channel_wise_fft -from disent.util.math_loss import spearman_rank_loss +from disent.nn.loss.softsort import spearman_rank_loss from experiment.run import hydra_append_progress_callback from experiment.run import hydra_check_cuda from experiment.run import hydra_make_logger From f375da558aef964812fe446a890151a5507a3aba Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Thu, 3 Jun 2021 14:17:26 +0200 Subject: [PATCH 24/34] move differentiable torch math functions --- .../frameworks/vae/_unsupervised__dipvae.py | 2 +- disent/metrics/_flatness_components.py | 4 ++-- .../math.py => nn/functional/__init__.py} | 11 +++++----- .../functional/_generic_tensors.py} | 0 disent/nn/transform/_augment.py | 6 +++--- .../run_02_check_aug_gt_dists.py | 6 +++--- .../run_03_train_disentangle_kernel.py | 2 +- tests/test_math.py | 20 +++++++++---------- tests/test_math_generic.py | 10 +++++----- tests/test_transform.py | 4 ++-- 10 files changed, 33 insertions(+), 32 deletions(-) rename disent/{util/math.py => nn/functional/__init__.py} (98%) rename disent/{util/math_generic.py => nn/functional/_generic_tensors.py} (100%) diff --git a/disent/frameworks/vae/_unsupervised__dipvae.py b/disent/frameworks/vae/_unsupervised__dipvae.py index 6da93fc4..f58c9a86 100644 --- a/disent/frameworks/vae/_unsupervised__dipvae.py +++ b/disent/frameworks/vae/_unsupervised__dipvae.py @@ -30,7 +30,7 @@ from disent.frameworks.helper.util import compute_ave_loss_and_logs from disent.frameworks.vae._unsupervised__betavae import BetaVae -from disent.util.math import torch_cov_matrix +from disent.nn.functional import torch_cov_matrix # ========================================================================= # diff --git a/disent/metrics/_flatness_components.py b/disent/metrics/_flatness_components.py index eddf327c..bbbfb573 100644 --- a/disent/metrics/_flatness_components.py +++ b/disent/metrics/_flatness_components.py @@ -35,8 +35,8 @@ from disent.metrics._flatness import filter_inactive_factors from disent.util import iter_chunks from disent.util import to_numpy -from disent.util.math import torch_mean_generalized -from disent.util.math import torch_pca +from disent.nn.functional import torch_mean_generalized +from disent.nn.functional import torch_pca log = logging.getLogger(__name__) diff --git a/disent/util/math.py b/disent/nn/functional/__init__.py similarity index 98% rename from disent/util/math.py rename to disent/nn/functional/__init__.py index 8647852a..be20449a 100644 --- a/disent/util/math.py +++ b/disent/nn/functional/__init__.py @@ -21,19 +21,20 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging import warnings from typing import List from typing import Optional from typing import Union -import logging import numpy as np import torch -from disent.util.math_generic import generic_as_int32 -from disent.util.math_generic import generic_max -from disent.util.math_generic import TypeGenericTensor -from disent.util.math_generic import TypeGenericTorch +from disent.nn.functional._generic_tensors import generic_as_int32 +from disent.nn.functional._generic_tensors import generic_max +from disent.nn.functional._generic_tensors import TypeGenericTensor +from disent.nn.functional._generic_tensors import TypeGenericTorch log = logging.getLogger(__name__) diff --git a/disent/util/math_generic.py b/disent/nn/functional/_generic_tensors.py similarity index 100% rename from disent/util/math_generic.py rename to disent/nn/functional/_generic_tensors.py diff --git a/disent/nn/transform/_augment.py b/disent/nn/transform/_augment.py index 0f4eb8ac..bc0cf7cf 100644 --- a/disent/nn/transform/_augment.py +++ b/disent/nn/transform/_augment.py @@ -34,9 +34,9 @@ import disent from disent.nn.modules import DisentModule -from disent.util.math import torch_box_kernel_2d -from disent.util.math import torch_conv2d_channel_wise_fft -from disent.util.math import torch_gaussian_kernel_2d +from disent.nn.functional import torch_box_kernel_2d +from disent.nn.functional import torch_conv2d_channel_wise_fft +from disent.nn.functional import torch_gaussian_kernel_2d # ========================================================================= # diff --git a/experiment/exp/05_adversarial_data/run_02_check_aug_gt_dists.py b/experiment/exp/05_adversarial_data/run_02_check_aug_gt_dists.py index 2b588bc5..a1f85efd 100644 --- a/experiment/exp/05_adversarial_data/run_02_check_aug_gt_dists.py +++ b/experiment/exp/05_adversarial_data/run_02_check_aug_gt_dists.py @@ -30,9 +30,9 @@ import torch.nn.functional as F import experiment.exp.util as H -from disent.util.math import torch_conv2d_channel_wise_fft -from disent.util.math import torch_box_kernel_2d -from disent.util.math import torch_gaussian_kernel_2d +from disent.nn.functional import torch_conv2d_channel_wise_fft +from disent.nn.functional import torch_box_kernel_2d +from disent.nn.functional import torch_gaussian_kernel_2d # ========================================================================= # diff --git a/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py b/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py index 3144caa4..3bc387b6 100644 --- a/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py +++ b/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py @@ -42,7 +42,7 @@ from disent.nn.modules import DisentModule from disent.util import make_box_str from disent.util import seed -from disent.util.math import torch_conv2d_channel_wise_fft +from disent.nn.functional import torch_conv2d_channel_wise_fft from disent.nn.loss.softsort import spearman_rank_loss from experiment.run import hydra_append_progress_callback from experiment.run import hydra_check_cuda diff --git a/tests/test_math.py b/tests/test_math.py index 12835511..3dc4959d 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -30,18 +30,18 @@ from disent.data.groundtruth import XYSquaresData from disent.dataset.groundtruth import GroundTruthDataset +from disent.nn.functional import torch_conv2d_channel_wise +from disent.nn.functional import torch_conv2d_channel_wise_fft +from disent.nn.functional import torch_corr_matrix +from disent.nn.functional import torch_cov_matrix +from disent.nn.functional import torch_dct +from disent.nn.functional import torch_dct2 +from disent.nn.functional import torch_gaussian_kernel_2d +from disent.nn.functional import torch_idct +from disent.nn.functional import torch_idct2 +from disent.nn.functional import torch_mean_generalized from disent.nn.transform import ToStandardisedTensor -from disent.util.math import torch_conv2d_channel_wise -from disent.util.math import torch_conv2d_channel_wise_fft from disent.util import to_numpy -from disent.util.math import torch_dct -from disent.util.math import torch_dct2 -from disent.util.math import torch_gaussian_kernel_2d -from disent.util.math import torch_idct -from disent.util.math import torch_idct2 -from disent.util.math import torch_corr_matrix -from disent.util.math import torch_cov_matrix -from disent.util.math import torch_mean_generalized # ========================================================================= # diff --git a/tests/test_math_generic.py b/tests/test_math_generic.py index 3c064655..9f692b05 100644 --- a/tests/test_math_generic.py +++ b/tests/test_math_generic.py @@ -26,11 +26,11 @@ import pytest import torch -from disent.util.math_generic import generic_as_int32 -from disent.util.math_generic import generic_max -from disent.util.math_generic import generic_min -from disent.util.math_generic import generic_ndim -from disent.util.math_generic import generic_shape +from disent.nn.functional._generic_tensors import generic_as_int32 +from disent.nn.functional._generic_tensors import generic_max +from disent.nn.functional._generic_tensors import generic_min +from disent.nn.functional._generic_tensors import generic_ndim +from disent.nn.functional._generic_tensors import generic_shape # ========================================================================= # diff --git a/tests/test_transform.py b/tests/test_transform.py index 8d91657a..b6a4ced2 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -26,8 +26,8 @@ from disent.nn.transform import FftGaussianBlur from disent.nn.transform._augment import _expand_to_min_max_tuples -from disent.util.math import torch_gaussian_kernel -from disent.util.math import torch_gaussian_kernel_2d +from disent.nn.functional import torch_gaussian_kernel +from disent.nn.functional import torch_gaussian_kernel_2d # ========================================================================= # From 71f509afb6aa1e39d9cecc11cb56ca305f431ca3 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Thu, 3 Jun 2021 14:51:20 +0200 Subject: [PATCH 25/34] split out utilities into submodules --- disent/data/episodes/_base.py | 2 +- disent/data/groundtruth/_xysquares.py | 2 +- disent/data/groundtruth/base.py | 2 +- disent/data/util/hdf5.py | 6 +- disent/data/util/state_space.py | 2 +- disent/dataset/_base.py | 2 +- disent/frameworks/ae/_unsupervised__ae.py | 2 +- disent/frameworks/helper/util.py | 6 +- disent/frameworks/vae/_unsupervised__vae.py | 2 +- disent/metrics/__init__.py | 2 +- disent/metrics/_flatness.py | 2 +- disent/metrics/_flatness_components.py | 2 +- disent/util/__init__.py | 346 ------------------ disent/util/function.py | 44 +++ disent/util/iters.py | 128 +++++++ disent/util/profiling.py | 162 ++++++++ disent/util/seeds.py | 83 +++++ disent/util/strings.py | 93 +++++ experiment/exp/00_data_traversal/run.py | 2 +- .../run_03_train_disentangle_kernel.py | 7 +- .../run_04_gen_adversarial.py | 6 +- experiment/exp/06_metric/make_graphs.py | 2 +- experiment/exp/util/_dataset.py | 2 +- experiment/run.py | 2 +- experiment/run_dataset_visualiser.py | 2 +- experiment/util/callbacks/callbacks_vae.py | 6 +- 26 files changed, 541 insertions(+), 376 deletions(-) create mode 100644 disent/util/function.py create mode 100644 disent/util/iters.py create mode 100644 disent/util/profiling.py create mode 100644 disent/util/seeds.py create mode 100644 disent/util/strings.py diff --git a/disent/data/episodes/_base.py b/disent/data/episodes/_base.py index dbc4ba6a..9e9e7958 100644 --- a/disent/data/episodes/_base.py +++ b/disent/data/episodes/_base.py @@ -26,7 +26,7 @@ import numpy as np from disent.dataset.groundtruth._triplet import sample_radius -from disent.util import LengthIter +from disent.util.iters import LengthIter class BaseOptionEpisodesData(LengthIter): diff --git a/disent/data/groundtruth/_xysquares.py b/disent/data/groundtruth/_xysquares.py index f59391bf..eec48852 100644 --- a/disent/data/groundtruth/_xysquares.py +++ b/disent/data/groundtruth/_xysquares.py @@ -30,7 +30,7 @@ import numpy as np from disent.data.groundtruth.base import GroundTruthData -from disent.util import iter_chunks +from disent.util.iters import iter_chunks log = logging.getLogger(__name__) diff --git a/disent/data/groundtruth/base.py b/disent/data/groundtruth/base.py index e46e2eb7..296a833c 100644 --- a/disent/data/groundtruth/base.py +++ b/disent/data/groundtruth/base.py @@ -45,7 +45,7 @@ from disent.data.util.in_out import ensure_dir_exists from disent.data.util.in_out import retrieve_file from disent.data.util.state_space import StateSpace -from disent.util import wrapped_partial +from disent.util.function import wrapped_partial log = logging.getLogger(__name__) diff --git a/disent/data/util/hdf5.py b/disent/data/util/hdf5.py index 6fbb4390..e68ccbb9 100644 --- a/disent/data/util/hdf5.py +++ b/disent/data/util/hdf5.py @@ -35,9 +35,9 @@ from disent.data.util.in_out import AtomicSaveFile from disent.data.util.in_out import bytes_to_human from disent.util import colors as c -from disent.util import iter_chunks -from disent.util import LengthIter -from disent.util import Timer +from disent.util.iters import iter_chunks +from disent.util.iters import LengthIter +from disent.util.profiling import Timer log = logging.getLogger(__name__) diff --git a/disent/data/util/state_space.py b/disent/data/util/state_space.py index 813c9f51..d4cb6931 100644 --- a/disent/data/util/state_space.py +++ b/disent/data/util/state_space.py @@ -23,7 +23,7 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import numpy as np -from disent.util import LengthIter +from disent.util.iters import LengthIter from disent.visualize.visualize_util import get_factor_traversal diff --git a/disent/dataset/_base.py b/disent/dataset/_base.py index 869071aa..c5ac1703 100644 --- a/disent/dataset/_base.py +++ b/disent/dataset/_base.py @@ -30,7 +30,7 @@ from torch.utils.data import Dataset from torch.utils.data.dataloader import default_collate -from disent.util import LengthIter +from disent.util.iters import LengthIter # ========================================================================= # diff --git a/disent/frameworks/ae/_unsupervised__ae.py b/disent/frameworks/ae/_unsupervised__ae.py index 69a17615..7f74baa8 100644 --- a/disent/frameworks/ae/_unsupervised__ae.py +++ b/disent/frameworks/ae/_unsupervised__ae.py @@ -40,7 +40,7 @@ from disent.frameworks.helper.reconstructions import ReconLossHandler from disent.frameworks.helper.util import detach_all from disent.model import AutoEncoder -from disent.util import map_all +from disent.util.iters import map_all log = logging.getLogger(__name__) diff --git a/disent/frameworks/helper/util.py b/disent/frameworks/helper/util.py index e842b478..0a1f6058 100644 --- a/disent/frameworks/helper/util.py +++ b/disent/frameworks/helper/util.py @@ -29,9 +29,9 @@ import torch -from disent.util import aggregate_dict -from disent.util import collect_dicts -from disent.util import map_all +from disent.util.iters import aggregate_dict +from disent.util.iters import collect_dicts +from disent.util.iters import map_all # ========================================================================= # diff --git a/disent/frameworks/vae/_unsupervised__vae.py b/disent/frameworks/vae/_unsupervised__vae.py index e366c9e1..964c376d 100644 --- a/disent/frameworks/vae/_unsupervised__vae.py +++ b/disent/frameworks/vae/_unsupervised__vae.py @@ -42,7 +42,7 @@ from disent.frameworks.helper.latent_distributions import make_latent_distribution from disent.frameworks.helper.util import detach_all -from disent.util import map_all +from disent.util.iters import map_all # ========================================================================= # diff --git a/disent/metrics/__init__.py b/disent/metrics/__init__.py index f0aa9449..99321195 100644 --- a/disent/metrics/__init__.py +++ b/disent/metrics/__init__.py @@ -40,7 +40,7 @@ # helper imports -from disent.util import wrapped_partial as _wrapped_partial +from disent.util.function import wrapped_partial as _wrapped_partial FAST_METRICS = { diff --git a/disent/metrics/_flatness.py b/disent/metrics/_flatness.py index cbb3daa6..6b7aaae4 100644 --- a/disent/metrics/_flatness.py +++ b/disent/metrics/_flatness.py @@ -36,7 +36,7 @@ from torch.utils.data.dataloader import default_collate from disent.dataset.groundtruth import GroundTruthDataset -from disent.util import iter_chunks +from disent.util.iters import iter_chunks log = logging.getLogger(__name__) diff --git a/disent/metrics/_flatness_components.py b/disent/metrics/_flatness_components.py index bbbfb573..5906ad6a 100644 --- a/disent/metrics/_flatness_components.py +++ b/disent/metrics/_flatness_components.py @@ -33,7 +33,7 @@ from disent.metrics._flatness import encode_all_along_factor from disent.metrics._flatness import encode_all_factors from disent.metrics._flatness import filter_inactive_factors -from disent.util import iter_chunks +from disent.util.iters import iter_chunks from disent.util import to_numpy from disent.nn.functional import torch_mean_generalized from disent.nn.functional import torch_pca diff --git a/disent/util/__init__.py b/disent/util/__init__.py index c9c659d8..51aaf150 100644 --- a/disent/util/__init__.py +++ b/disent/util/__init__.py @@ -22,14 +22,8 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import functools import logging import os -import time -from collections import Sequence -from contextlib import ContextDecorator -from itertools import islice -from typing import List import numpy as np import torch @@ -58,47 +52,6 @@ def _set_test_run(): os.environ['DISENT_TEST_RUN'] = 'True' -# ========================================================================= # -# seeds # -# ========================================================================= # - - -def seed(long=777): - """ - https://pytorch.org/docs/stable/notes/randomness.html - """ - if long is None: - log.warning(f'[SEEDING]: no seed was specified. Seeding skipped!') - return - torch.manual_seed(long) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - np.random.seed(long) - log.info(f'[SEEDED]: {long}') - - -class TempNumpySeed(object): - def __init__(self, seed=None, offset=0): - if seed is not None: - try: - seed = int(seed) - except: - raise ValueError(f'{seed=} is not int-like!') - self._seed = seed - if seed is not None: - self._seed += offset - self._state = None - - def __enter__(self): - if self._seed is not None: - self._state = np.random.get_state() - np.random.seed(self._seed) - - def __exit__(self, *args, **kwargs): - if self._seed is not None: - np.random.set_state(self._state) - self._state = None - # ========================================================================= # # Conversion # # ========================================================================= # @@ -120,305 +73,6 @@ def to_numpy(array) -> np.ndarray: return np.array(array) -# ========================================================================= # -# Iterators # -# ========================================================================= # - - -def chunked(arr, chunk_size: int, include_remainder=True): - """ - return an array of array chucks of size chunk_size. - This is NOT an iterable, and returns all the data. - """ - size = (len(arr) + chunk_size - 1) if include_remainder else len(arr) - return [arr[chunk_size*i:chunk_size*(i+1)] for i in range(size // chunk_size)] - - -def iter_chunks(items, chunk_size: int, include_remainder=True): - """ - iterable version of chunked. - that does not evaluate unneeded elements - """ - items = iter(items) - for first in items: - chunk = [first, *islice(items, chunk_size-1)] - if len(chunk) >= chunk_size or include_remainder: - yield chunk - - -def iter_rechunk(chunks, chunk_size: int, include_remainder=True): - """ - takes in chunks and returns chunks of a new size. - - Does not evaluate unneeded chunks - """ - return iter_chunks( - (item for chunk in chunks for item in chunk), # flatten chunks - chunk_size=chunk_size, - include_remainder=include_remainder - ) - - -# TODO: not actually an iterator -def map_all(fn, *arg_lists, starmap: bool = True, collect_returned: bool = False, common_kwargs: dict = None): - assert arg_lists, 'an empty list of args was passed' - # check all lengths are the same - num = len(arg_lists[0]) - assert num > 0 - assert all(len(items) == num for items in arg_lists) - # update kwargs - if common_kwargs is None: - common_kwargs = {} - # map everything - if starmap: - results = (fn(*args, **common_kwargs) for args in zip(*arg_lists)) - else: - results = (fn(args, **common_kwargs) for args in zip(*arg_lists)) - # zip everything - if collect_returned: - return tuple(zip(*results)) - else: - return tuple(results) - - -def collect_dicts(results: List[dict]): - # collect everything - keys = results[0].keys() - values = zip(*([result[k] for k in keys] for result in results)) - return {k: list(v) for k, v in zip(keys, values)} - - -# TODO: this shouldn't be here -def aggregate_dict(results: dict, reduction='mean'): - assert reduction == 'mean', 'mean is the only mode supported' - return { - k: sum(v) / len(v) for k, v in results.items() - } - - -# ========================================================================= # -# STRINGS # -# ========================================================================= # - - -def make_separator_str(text, header=None, width=100, char_v='#', char_h='=', char_corners=None): - """ - function wraps text between two lines or inside a box with lines on either side. - FROM: my obstacle_tower project - """ - if char_corners is None: - char_corners = char_v - assert len(char_v) == len(char_corners) - assert len(char_h) == 1 - import textwrap - import pprint - - def append_wrapped(text): - for line in text.splitlines(): - for wrapped in (textwrap.wrap(line, w, tabsize=4) if line.strip() else ['']): - lines.append(f'{char_v} {wrapped:{w}s} {char_v}') - - w = width-4 - lines = [] - sep = f'{char_corners} {char_h*w} {char_corners}' - lines.append(f'\n{sep}') - if header: - append_wrapped(header) - lines.append(sep) - if type(text) != str: - text = pprint.pformat(text, width=w) - append_wrapped(text) - lines.append(f'{sep}\n') - return '\n'.join(lines) - - -def make_box_str(text, header=None, width=100, char_v='|', char_h='-', char_corners='#'): - """ - like print_separator but is isntead a box - FROM: my obstacle_tower project - """ - return make_separator_str(text, header=header, width=width, char_v=char_v, char_h=char_h, char_corners=char_corners) - - -def concat_lines(*strings, sep=' | '): - """ - Join multi-line strings together horizontally, with the - specified separator between them. - """ - - def pad_width(lines): - max_len = max(len(line) for line in lines) - return [f'{s:{max_len}}' for s in lines] - - def pad_height(list_of_lines): - max_lines = max(len(lines) for lines in list_of_lines) - return [(lines + ([''] * (max_lines - len(lines)))) for lines in list_of_lines] - - list_of_lines = [str(string).splitlines() for string in strings] - list_of_lines = pad_height(list_of_lines) - list_of_lines = [pad_width(lines) for lines in list_of_lines] - return '\n'.join(sep.join(rows) for rows in zip(*list_of_lines)) - - -# ========================================================================= # -# Iterable # -# ========================================================================= # - - -class LengthIter(Sequence): - - def __iter__(self): - # this takes priority over __getitem__, otherwise __getitem__ would need to - # raise an IndexError if out of bounds to signal the end of iteration - for i in range(len(self)): - yield self[i] - - def __len__(self): - raise NotImplementedError() - - def __getitem__(self, item): - raise NotImplementedError() - - -# ========================================================================= # -# Context Manager Timer # -# ========================================================================= # - - -class Timer(ContextDecorator): - - """ - Timer class, can be used with a with statement to - measure the execution time of a block of code! - - Examples: - - 1. get the runtime - ``` - with Timer() as t: - time.sleep(1) - print(t.pretty) - ``` - - 2. automatic print - ``` - with Timer(name="example") as t: - time.sleep(1) - ``` - - 3. reuse timer to measure multiple instances - ``` - t = Timer() - for i in range(100): - with t: - time.sleep(0.95) - if t.elapsed > 3: - break - print(t) - ``` - """ - - def __init__(self, name: str = None, log_level: int = logging.INFO): - self._start_time: int = None - self._end_time: int = None - self._total_time = 0 - self.name = name - self._log_level = log_level - - def __enter__(self): - self._start_time = time.time_ns() - return self - - def __exit__(self, *args, **kwargs): - self._end_time = time.time_ns() - # add elapsed time to total time, and reset the timer! - self._total_time += (self._end_time - self._start_time) - self._start_time = None - self._end_time = None - # print results - if self.name: - if self._log_level is None: - print(f'{self.name}: {self.pretty}') - else: - log.log(self._log_level, f'{self.name}: {self.pretty}') - - @property - def elapsed_ns(self) -> int: - if self._start_time is not None: - # running - return self._total_time + (time.time_ns() - self._start_time) - # finished running - return self._total_time - - @property - def elapsed_ms(self) -> float: - return self.elapsed_ns / 1_000_000 - - @property - def elapsed(self) -> float: - return self.elapsed_ns / 1_000_000_000 - - @property - def pretty(self) -> str: - return Timer.prettify_time(self.elapsed_ns) - - def __int__(self): return self.elapsed_ns - def __float__(self): return self.elapsed - def __str__(self): return self.pretty - def __repr__(self): return self.pretty - - @staticmethod - def prettify_time(ns: int) -> str: - if ns == 0: - return 'N/A' - elif ns < 0: - return 'NaN' - # get power of 1000 - pow = min(3, int(np.log10(ns) // 3)) - time = ns / 1000**pow - # get pretty string! - if pow < 3 or time < 60: - # less than 1 minute - name = ['ns', 'µs', 'ms', 's'][pow] - return f'{time:.3f}{name}' - else: - # 1 or more minutes - s = int(time) - d, s = divmod(s, 86400) - h, s = divmod(s, 3600) - m, s = divmod(s, 60) - if d > 0: return f'{d}d:{h}h:{m}m' - elif h > 0: return f'{h}h:{m}m:{s}s' - else: return f'{m}m:{s}s' - - -# ========================================================================= # -# Function Helper # -# ========================================================================= # - - -def wrapped_partial(func, *args, **kwargs): - """ - Like functools.partial but keeps the same __name__ and __doc__ - on the returned function. - """ - partial_func = functools.partial(func, *args, **kwargs) - functools.update_wrapper(partial_func, func) - return partial_func - - -# ========================================================================= # -# Memory Usage # -# ========================================================================= # - - -def get_memory_usage(): - import os - import psutil - process = psutil.Process(os.getpid()) - num_bytes = process.memory_info().rss # in bytes - return num_bytes - - # ========================================================================= # # END # # ========================================================================= # diff --git a/disent/util/function.py b/disent/util/function.py new file mode 100644 index 00000000..ac8266ea --- /dev/null +++ b/disent/util/function.py @@ -0,0 +1,44 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + + +# ========================================================================= # +# Function Helper # +# ========================================================================= # + + +def wrapped_partial(func, *args, **kwargs): + """ + Like functools.partial but keeps the same __name__ and __doc__ + on the returned function. + """ + import functools + partial_func = functools.partial(func, *args, **kwargs) + functools.update_wrapper(partial_func, func) + return partial_func + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/iters.py b/disent/util/iters.py new file mode 100644 index 00000000..6888e7fa --- /dev/null +++ b/disent/util/iters.py @@ -0,0 +1,128 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +from itertools import islice +from typing import List +from typing import Sequence + + +# ========================================================================= # +# Iterators # +# ========================================================================= # + + +def chunked(arr, chunk_size: int, include_remainder=True): + """ + return an array of array chucks of size chunk_size. + This is NOT an iterable, and returns all the data. + """ + size = (len(arr) + chunk_size - 1) if include_remainder else len(arr) + return [arr[chunk_size*i:chunk_size*(i+1)] for i in range(size // chunk_size)] + + +def iter_chunks(items, chunk_size: int, include_remainder=True): + """ + iterable version of chunked. + that does not evaluate unneeded elements + """ + items = iter(items) + for first in items: + chunk = [first, *islice(items, chunk_size-1)] + if len(chunk) >= chunk_size or include_remainder: + yield chunk + + +def iter_rechunk(chunks, chunk_size: int, include_remainder=True): + """ + takes in chunks and returns chunks of a new size. + - Does not evaluate unneeded chunks + """ + return iter_chunks( + (item for chunk in chunks for item in chunk), # flatten chunks + chunk_size=chunk_size, + include_remainder=include_remainder + ) + + +def map_all(fn, *arg_lists, starmap: bool = True, collect_returned: bool = False, common_kwargs: dict = None): + # TODO: not actually an iterator + assert arg_lists, 'an empty list of args was passed' + # check all lengths are the same + num = len(arg_lists[0]) + assert num > 0 + assert all(len(items) == num for items in arg_lists) + # update kwargs + if common_kwargs is None: + common_kwargs = {} + # map everything + if starmap: + results = (fn(*args, **common_kwargs) for args in zip(*arg_lists)) + else: + results = (fn(args, **common_kwargs) for args in zip(*arg_lists)) + # zip everything + if collect_returned: + return tuple(zip(*results)) + else: + return tuple(results) + + +def collect_dicts(results: List[dict]): + # collect everything + keys = results[0].keys() + values = zip(*([result[k] for k in keys] for result in results)) + return {k: list(v) for k, v in zip(keys, values)} + + +def aggregate_dict(results: dict, reduction='mean'): + # TODO: this shouldn't be here + assert reduction == 'mean', 'mean is the only mode supported' + return { + k: sum(v) / len(v) for k, v in results.items() + } + + +# ========================================================================= # +# Base Class # +# ========================================================================= # + + +class LengthIter(Sequence): + + def __iter__(self): + # this takes priority over __getitem__, otherwise __getitem__ would need to + # raise an IndexError if out of bounds to signal the end of iteration + for i in range(len(self)): + yield self[i] + + def __len__(self): + raise NotImplementedError() + + def __getitem__(self, item): + raise NotImplementedError() + + +# ========================================================================= # +# END # +# ========================================================================= # + diff --git a/disent/util/profiling.py b/disent/util/profiling.py new file mode 100644 index 00000000..2721387b --- /dev/null +++ b/disent/util/profiling.py @@ -0,0 +1,162 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +import time +from math import log10 +from contextlib import ContextDecorator + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# Memory Usage # +# ========================================================================= # + + +def get_memory_usage(): + import os + import psutil + process = psutil.Process(os.getpid()) + num_bytes = process.memory_info().rss # in bytes + return num_bytes + + +# ========================================================================= # +# Context Manager Timer # +# ========================================================================= # + + +class Timer(ContextDecorator): + + """ + Timer class, can be used with a with statement to + measure the execution time of a block of code! + + Examples: + + 1. get the runtime + ``` + with Timer() as t: + time.sleep(1) + print(t.pretty) + ``` + + 2. automatic print + ``` + with Timer(name="example") as t: + time.sleep(1) + ``` + + 3. reuse timer to measure multiple instances + ``` + t = Timer() + for i in range(100): + with t: + time.sleep(0.95) + if t.elapsed > 3: + break + print(t) + ``` + """ + + def __init__(self, name: str = None, log_level: int = logging.INFO): + self._start_time: int = None + self._end_time: int = None + self._total_time = 0 + self.name = name + self._log_level = log_level + + def __enter__(self): + self._start_time = time.time_ns() + return self + + def __exit__(self, *args, **kwargs): + self._end_time = time.time_ns() + # add elapsed time to total time, and reset the timer! + self._total_time += (self._end_time - self._start_time) + self._start_time = None + self._end_time = None + # print results + if self.name: + if self._log_level is None: + print(f'{self.name}: {self.pretty}') + else: + log.log(self._log_level, f'{self.name}: {self.pretty}') + + @property + def elapsed_ns(self) -> int: + if self._start_time is not None: + # running + return self._total_time + (time.time_ns() - self._start_time) + # finished running + return self._total_time + + @property + def elapsed_ms(self) -> float: + return self.elapsed_ns / 1_000_000 + + @property + def elapsed(self) -> float: + return self.elapsed_ns / 1_000_000_000 + + @property + def pretty(self) -> str: + return Timer.prettify_time(self.elapsed_ns) + + def __int__(self): return self.elapsed_ns + def __float__(self): return self.elapsed + def __str__(self): return self.pretty + def __repr__(self): return self.pretty + + @staticmethod + def prettify_time(ns: int) -> str: + if ns == 0: + return 'N/A' + elif ns < 0: + return 'NaN' + # get power of 1000 + pow = min(3, int(log10(ns) // 3)) + time = ns / 1000**pow + # get pretty string! + if pow < 3 or time < 60: + # less than 1 minute + name = ['ns', 'µs', 'ms', 's'][pow] + return f'{time:.3f}{name}' + else: + # 1 or more minutes + s = int(time) + d, s = divmod(s, 86400) + h, s = divmod(s, 3600) + m, s = divmod(s, 60) + if d > 0: return f'{d}d:{h}h:{m}m' + elif h > 0: return f'{h}h:{m}m:{s}s' + else: return f'{m}m:{s}s' + + +# ========================================================================= # +# END # +# ========================================================================= # + diff --git a/disent/util/seeds.py b/disent/util/seeds.py new file mode 100644 index 00000000..45338cbc --- /dev/null +++ b/disent/util/seeds.py @@ -0,0 +1,83 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +import numpy as np + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# seeds # +# ========================================================================= # + + +def seed(long=777): + """ + https://pytorch.org/docs/stable/notes/randomness.html + """ + if long is None: + log.warning(f'[SEEDING]: no seed was specified. Seeding skipped!') + return + # seed torch - it can be slow to import + try: + import torch + torch.manual_seed(long) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + except ImportError: + log.warning(f'[SEEDING]: torch is not installed. Skipped seeding torch methods!') + # seed numpy + np.random.seed(long) + # done! + log.info(f'[SEEDED]: {long}') + + +class TempNumpySeed(object): + def __init__(self, seed=None, offset=0): + if seed is not None: + try: + seed = int(seed) + except: + raise ValueError(f'{seed=} is not int-like!') + self._seed = seed + if seed is not None: + self._seed += offset + self._state = None + + def __enter__(self): + if self._seed is not None: + self._state = np.random.get_state() + np.random.seed(self._seed) + + def __exit__(self, *args, **kwargs): + if self._seed is not None: + np.random.set_state(self._state) + self._state = None + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/strings.py b/disent/util/strings.py new file mode 100644 index 00000000..118f6f10 --- /dev/null +++ b/disent/util/strings.py @@ -0,0 +1,93 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + + +# ========================================================================= # +# STRINGS # +# ========================================================================= # + + +def make_separator_str(text, header=None, width=100, char_v='#', char_h='=', char_corners=None): + """ + function wraps text between two lines or inside a box with lines on either side. + FROM: my obstacle_tower project + """ + if char_corners is None: + char_corners = char_v + assert len(char_v) == len(char_corners) + assert len(char_h) == 1 + import textwrap + import pprint + + def append_wrapped(text): + for line in text.splitlines(): + for wrapped in (textwrap.wrap(line, w, tabsize=4) if line.strip() else ['']): + lines.append(f'{char_v} {wrapped:{w}s} {char_v}') + + w = width-4 + lines = [] + sep = f'{char_corners} {char_h*w} {char_corners}' + lines.append(f'\n{sep}') + if header: + append_wrapped(header) + lines.append(sep) + if type(text) != str: + text = pprint.pformat(text, width=w) + append_wrapped(text) + lines.append(f'{sep}\n') + return '\n'.join(lines) + + +def make_box_str(text, header=None, width=100, char_v='|', char_h='-', char_corners='#'): + """ + like print_separator but is isntead a box + FROM: my obstacle_tower project + """ + return make_separator_str(text, header=header, width=width, char_v=char_v, char_h=char_h, char_corners=char_corners) + + +def concat_lines(*strings, sep=' | '): + """ + Join multi-line strings together horizontally, with the + specified separator between them. + """ + + def pad_width(lines): + max_len = max(len(line) for line in lines) + return [f'{s:{max_len}}' for s in lines] + + def pad_height(list_of_lines): + max_lines = max(len(lines) for lines in list_of_lines) + return [(lines + ([''] * (max_lines - len(lines)))) for lines in list_of_lines] + + list_of_lines = [str(string).splitlines() for string in strings] + list_of_lines = pad_height(list_of_lines) + list_of_lines = [pad_width(lines) for lines in list_of_lines] + return '\n'.join(sep.join(rows) for rows in zip(*list_of_lines)) + + +# ========================================================================= # +# END # +# ========================================================================= # + diff --git a/experiment/exp/00_data_traversal/run.py b/experiment/exp/00_data_traversal/run.py index e3467c4b..9ee9373c 100644 --- a/experiment/exp/00_data_traversal/run.py +++ b/experiment/exp/00_data_traversal/run.py @@ -36,7 +36,7 @@ from disent.data.groundtruth import SmallNorbData from disent.data.groundtruth import XYSquaresData from disent.dataset.groundtruth import GroundTruthDataset -from disent.util import TempNumpySeed +from disent.util.seeds import TempNumpySeed # ========================================================================= # diff --git a/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py b/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py index 3bc387b6..ab1adda5 100644 --- a/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py +++ b/experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py @@ -37,11 +37,12 @@ from torch.nn import Parameter from torch.utils.data import DataLoader +import disent.util.seeds import experiment.exp.util as H from disent.nn.modules import DisentLightningModule from disent.nn.modules import DisentModule -from disent.util import make_box_str -from disent.util import seed +from disent.util.strings import make_box_str +from disent.util.seeds import seed from disent.nn.functional import torch_conv2d_channel_wise_fft from disent.nn.loss.softsort import spearman_rank_loss from experiment.run import hydra_append_progress_callback @@ -234,7 +235,7 @@ def run_disentangle_dataset_kernel(cfg): callbacks = [] hydra_append_progress_callback(callbacks, cfg) # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - seed(cfg.exp.seed) + seed(disent.util.seeds.seed) # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # # initialise dataset and get factor names to disentangle dataset = H.make_dataset(cfg.data.name, factors=True, data_dir=cfg.dataset.data_dir) diff --git a/experiment/exp/05_adversarial_data/run_04_gen_adversarial.py b/experiment/exp/05_adversarial_data/run_04_gen_adversarial.py index 86d20453..ed7f3832 100644 --- a/experiment/exp/05_adversarial_data/run_04_gen_adversarial.py +++ b/experiment/exp/05_adversarial_data/run_04_gen_adversarial.py @@ -40,9 +40,9 @@ import experiment.exp.util as H from disent.data.util.in_out import ensure_parent_dir_exists -from disent.util import seed -from disent.util import TempNumpySeed -from disent.util import Timer +from disent.util.seeds import seed +from disent.util.seeds import TempNumpySeed +from disent.util.profiling import Timer log = logging.getLogger(__name__) diff --git a/experiment/exp/06_metric/make_graphs.py b/experiment/exp/06_metric/make_graphs.py index aebb1104..f28fc72e 100644 --- a/experiment/exp/06_metric/make_graphs.py +++ b/experiment/exp/06_metric/make_graphs.py @@ -37,7 +37,7 @@ import experiment.exp.util as H from disent.metrics._flatness_components import compute_axis_score from disent.metrics._flatness_components import compute_linear_score -from disent.util import seed +from disent.util.seeds import seed # ========================================================================= # diff --git a/experiment/exp/util/_dataset.py b/experiment/exp/util/_dataset.py index f30fc149..edba0dda 100644 --- a/experiment/exp/util/_dataset.py +++ b/experiment/exp/util/_dataset.py @@ -43,7 +43,7 @@ from disent.dataset.groundtruth import GroundTruthDataset from disent.dataset.groundtruth import GroundTruthDatasetAndFactors from disent.nn.transform import ToStandardisedTensor -from disent.util import TempNumpySeed +from disent.util.seeds import TempNumpySeed from disent.visualize.visualize_util import make_animated_image_grid from disent.visualize.visualize_util import make_image_grid from experiment.exp.util._tasks import IN diff --git a/experiment/run.py b/experiment/run.py index 592b5d51..2ec066b0 100644 --- a/experiment/run.py +++ b/experiment/run.py @@ -40,7 +40,7 @@ from disent.frameworks import DisentFramework from disent.model import AutoEncoder from disent.nn.weights import init_model_weights -from disent.util import make_box_str +from disent.util.strings import make_box_str from experiment.util.callbacks import LoggerProgressCallback from experiment.util.callbacks import VaeDisentanglementLoggingCallback from experiment.util.callbacks import VaeLatentCycleLoggingCallback diff --git a/experiment/run_dataset_visualiser.py b/experiment/run_dataset_visualiser.py index aaa713e6..9eec699d 100644 --- a/experiment/run_dataset_visualiser.py +++ b/experiment/run_dataset_visualiser.py @@ -35,7 +35,7 @@ from omegaconf import DictConfig from omegaconf import OmegaConf -from disent.util import make_box_str +from disent.util.strings import make_box_str from disent.visualize.visualize_util import make_image_grid from experiment.run import hydra_check_datadir from experiment.run import HydraDataModule diff --git a/experiment/util/callbacks/callbacks_vae.py b/experiment/util/callbacks/callbacks_vae.py index b8b66d14..1b7ebc53 100644 --- a/experiment/util/callbacks/callbacks_vae.py +++ b/experiment/util/callbacks/callbacks_vae.py @@ -38,9 +38,9 @@ from disent.dataset.groundtruth import GroundTruthDataset from disent.frameworks.ae import Ae from disent.frameworks.vae import Vae -from disent.util import iter_chunks -from disent.util import TempNumpySeed -from disent.util import Timer +from disent.util.iters import iter_chunks +from disent.util.seeds import TempNumpySeed +from disent.util.profiling import Timer from disent.util import to_numpy from disent.visualize.visualize_model import latent_cycle_grid_animation from disent.visualize.visualize_util import make_image_grid From 377a3cb10e47a4fdc3ee3ce4380d2a651823b1c4 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Thu, 3 Jun 2021 15:42:00 +0200 Subject: [PATCH 26/34] split and move disent.data.utils into disent.utils + move data classes --- disent/data/dataobj.py | 223 +++++++++ disent/data/episodes/_option_episodes.py | 5 +- disent/data/groundtruth/__init__.py | 19 +- disent/data/groundtruth/_cars3d.py | 4 +- disent/data/groundtruth/_dsprites.py | 2 +- disent/data/groundtruth/_mpi3d.py | 2 +- disent/data/groundtruth/_norb.py | 2 +- disent/data/groundtruth/_shapes3d.py | 2 +- disent/data/groundtruth/_xyblocks.py | 3 + disent/data/groundtruth/_xyobject.py | 4 +- disent/data/groundtruth/base.py | 209 +-------- .../state_space.py => groundtruth/states.py} | 0 disent/data/{util => }/hdf5.py | 77 ++-- disent/data/util/__init__.py | 23 - disent/data/util/in_out.py | 422 ------------------ disent/util/cache.py | 94 ++++ disent/util/hashing.py | 140 ++++++ disent/util/in_out.py | 180 ++++++++ disent/util/paths.py | 118 +++++ disent/util/strings.py | 30 +- .../run_04_gen_adversarial.py | 2 +- experiment/exp/util/_io_util.py | 2 +- tests/test_state_space.py | 2 +- 23 files changed, 861 insertions(+), 704 deletions(-) create mode 100644 disent/data/dataobj.py rename disent/data/{util/state_space.py => groundtruth/states.py} (100%) rename disent/data/{util => }/hdf5.py (93%) delete mode 100644 disent/data/util/__init__.py delete mode 100644 disent/data/util/in_out.py create mode 100644 disent/util/cache.py create mode 100644 disent/util/hashing.py create mode 100644 disent/util/in_out.py create mode 100644 disent/util/paths.py diff --git a/disent/data/dataobj.py b/disent/data/dataobj.py new file mode 100644 index 00000000..57376f55 --- /dev/null +++ b/disent/data/dataobj.py @@ -0,0 +1,223 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import os +from abc import ABCMeta +from typing import Callable +from typing import Dict +from typing import final +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np + +from disent.data.hdf5 import hdf5_resave_file +from disent.util.cache import stalefile +from disent.util.function import wrapped_partial +from disent.util.in_out import retrieve_file +from disent.util.paths import filename_from_url +from disent.util.paths import modify_file_name + + +# ========================================================================= # +# data objects # +# ========================================================================= # + + +class DataObject(object, metaclass=ABCMeta): + """ + base DataObject that does nothing, if the file does + not exist or it has the incorrect hash, then that's your problem! + """ + + def __init__(self, file_name: str): + self._file_name = file_name + + @final + @property + def out_name(self) -> str: + return self._file_name + + def prepare(self, out_dir: str) -> str: + # TODO: maybe check that the file exists or not and raise a FileNotFoundError? + pass + + +class HashedDataObject(DataObject, metaclass=ABCMeta): + """ + Abstract Class + - Base DataObject class that guarantees a file to exist, + if the file does not exist, or the hash of the file is + incorrect, then the file is re-generated. + """ + + def __init__( + self, + file_name: str, + file_hash: Optional[Union[str, Dict[str, str]]], + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + super().__init__(file_name=file_name) + self._file_hash = file_hash + self._hash_type = hash_type + self._hash_mode = hash_mode + + def prepare(self, out_dir: str) -> str: + @stalefile(file=os.path.join(out_dir, self._file_name), hash=self._file_hash, hash_type=self._hash_type, hash_mode=self._hash_mode) + def wrapped(out_file): + self._prepare(out_dir=out_dir, out_file=out_file) + return wrapped() + + def _prepare(self, out_dir: str, out_file: str) -> str: + # TODO: maybe raise a FileNotFoundError or a HashError instead? + raise NotImplementedError + + +class DlDataObject(HashedDataObject): + """ + Download a file + - uri can also be a file to perform a copy instead of download, + useful for example if you want to retrieve a file from a network drive. + """ + + def __init__( + self, + uri: str, + uri_hash: Optional[Union[str, Dict[str, str]]], + uri_name: Optional[str] = None, + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + super().__init__( + file_name=filename_from_url(uri) if (uri_name is None) else uri_name, + file_hash=uri_hash, + hash_type=hash_type, + hash_mode=hash_mode + ) + self._uri = uri + + def _prepare(self, out_dir: str, out_file: str): + retrieve_file(src_uri=self._uri, dst_path=out_file, overwrite_existing=True) + + +class DlGenDataObject(HashedDataObject, metaclass=ABCMeta): + """ + Abstract class + - download a file and perform some processing on that file. + """ + + def __init__( + self, + # download & save files + uri: str, + uri_hash: Optional[Union[str, Dict[str, str]]], + file_hash: Optional[Union[str, Dict[str, str]]], + # save paths + uri_name: Optional[str] = None, + file_name: Optional[str] = None, + # hash settings + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + self._dl_obj = DlDataObject( + uri=uri, + uri_hash=uri_hash, + uri_name=uri_name, + hash_type=hash_type, + hash_mode=hash_mode, + ) + super().__init__( + file_name=modify_file_name(self._dl_obj.out_name, prefix='gen') if (file_name is None) else file_name, + file_hash=file_hash, + hash_type=hash_type, + hash_mode=hash_mode, + ) + + def _prepare(self, out_dir: str, out_file: str): + inp_file = self._dl_obj.prepare(out_dir=out_dir) + self._generate(inp_file=inp_file, out_file=out_file) + + def _generate(self, inp_file: str, out_file: str): + raise NotImplementedError + + +class DlH5DataObject(DlGenDataObject): + """ + Downloads an hdf5 file and pre-processes it into the specified chunk_size. + """ + + def __init__( + self, + # download & save files + uri: str, + uri_hash: Optional[Union[str, Dict[str, str]]], + file_hash: Optional[Union[str, Dict[str, str]]], + # h5 re-save settings + hdf5_dataset_name: str, + hdf5_chunk_size: Tuple[int, ...], + hdf5_compression: Optional[str] = 'gzip', + hdf5_compression_lvl: Optional[int] = 4, + hdf5_dtype: Optional[Union[np.dtype, str]] = None, + hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None, + # save paths + uri_name: Optional[str] = None, + file_name: Optional[str] = None, + # hash settings + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + super().__init__( + file_name=file_name, + file_hash=file_hash, + uri=uri, + uri_hash=uri_hash, + uri_name=uri_name, + hash_type=hash_type, + hash_mode=hash_mode, + ) + self._hdf5_resave_file = wrapped_partial( + hdf5_resave_file, + dataset_name=hdf5_dataset_name, + chunk_size=hdf5_chunk_size, + compression=hdf5_compression, + compression_lvl=hdf5_compression_lvl, + out_dtype=hdf5_dtype, + out_mutator=hdf5_mutator, + ) + # save the dataset name + self._out_dataset_name = hdf5_dataset_name + + @property + def out_dataset_name(self) -> str: + return self._out_dataset_name + + def _generate(self, inp_file: str, out_file: str): + self._hdf5_resave_file(inp_path=inp_file, out_path=out_file) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/data/episodes/_option_episodes.py b/disent/data/episodes/_option_episodes.py index b8a0f2d4..37fc9ded 100644 --- a/disent/data/episodes/_option_episodes.py +++ b/disent/data/episodes/_option_episodes.py @@ -26,7 +26,8 @@ from typing import List, Tuple import numpy as np from disent.data.episodes._base import BaseOptionEpisodesData -from disent.data.util.in_out import download_file, basename_from_url +from disent.util.in_out import download_file +from disent.util.paths import filename_from_url import logging log = logging.getLogger(__name__) @@ -118,7 +119,7 @@ def _download_and_extract_if_needed(self, download_url: str, required_file: str, if not isinstance(download_url, str): return # download file, but skip if file already exists - save_path = os.path.join(os.path.dirname(required_file), basename_from_url(download_url)) + save_path = os.path.join(os.path.dirname(required_file), filename_from_url(download_url)) if force_download or not os.path.exists(save_path): log.info(f'Downloading: {download_url=} to {save_path=}') download_file(download_url, save_path=save_path) diff --git a/disent/data/groundtruth/__init__.py b/disent/data/groundtruth/__init__.py index 5bb81f32..4941ce81 100644 --- a/disent/data/groundtruth/__init__.py +++ b/disent/data/groundtruth/__init__.py @@ -22,13 +22,14 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -from .base import GroundTruthData +from disent.data.groundtruth.base import GroundTruthData + # others -# from ._cars3d import Cars3dData -from ._dsprites import DSpritesData -# from ._mpi3d import Mpi3dData -# from ._norb import SmallNorbData -from ._shapes3d import Shapes3dData -from ._xyobject import XYObjectData -from ._xysquares import XYSquaresData, XYSquaresMinimalData -from ._xyblocks import XYBlocksData +from disent.data.groundtruth._cars3d import Cars3dData +from disent.data.groundtruth._dsprites import DSpritesData +from disent.data.groundtruth._mpi3d import Mpi3dData +from disent.data.groundtruth._norb import SmallNorbData +from disent.data.groundtruth._shapes3d import Shapes3dData +from disent.data.groundtruth._xyobject import XYObjectData +from disent.data.groundtruth._xysquares import XYSquaresData, XYSquaresMinimalData +from disent.data.groundtruth._xyblocks import XYBlocksData diff --git a/disent/data/groundtruth/_cars3d.py b/disent/data/groundtruth/_cars3d.py index 66b3c9b8..ef2d30bf 100644 --- a/disent/data/groundtruth/_cars3d.py +++ b/disent/data/groundtruth/_cars3d.py @@ -30,9 +30,9 @@ import numpy as np from scipy.io import loadmat -from disent.data.groundtruth.base import DlGenDataObject +from disent.data.dataobj import DlGenDataObject from disent.data.groundtruth.base import NumpyGroundTruthData -from disent.data.util.in_out import AtomicSaveFile +from disent.util.in_out import AtomicSaveFile log = logging.getLogger(__name__) diff --git a/disent/data/groundtruth/_dsprites.py b/disent/data/groundtruth/_dsprites.py index 82ac6858..90d2e431 100644 --- a/disent/data/groundtruth/_dsprites.py +++ b/disent/data/groundtruth/_dsprites.py @@ -24,7 +24,7 @@ import logging -from disent.data.groundtruth.base import DlH5DataObject +from disent.data.dataobj import DlH5DataObject from disent.data.groundtruth.base import Hdf5GroundTruthData diff --git a/disent/data/groundtruth/_mpi3d.py b/disent/data/groundtruth/_mpi3d.py index b7619a49..dbb1efbd 100644 --- a/disent/data/groundtruth/_mpi3d.py +++ b/disent/data/groundtruth/_mpi3d.py @@ -25,7 +25,7 @@ import logging from typing import Optional -from disent.data.groundtruth.base import DlDataObject +from disent.data.dataobj import DlDataObject from disent.data.groundtruth.base import NumpyGroundTruthData diff --git a/disent/data/groundtruth/_norb.py b/disent/data/groundtruth/_norb.py index f0d2a997..0f1a7487 100644 --- a/disent/data/groundtruth/_norb.py +++ b/disent/data/groundtruth/_norb.py @@ -31,8 +31,8 @@ import numpy as np +from disent.data.dataobj import DlDataObject from disent.data.groundtruth.base import DiskGroundTruthData -from disent.data.groundtruth.base import DlDataObject # ========================================================================= # diff --git a/disent/data/groundtruth/_shapes3d.py b/disent/data/groundtruth/_shapes3d.py index 19b03e60..2647fe7f 100644 --- a/disent/data/groundtruth/_shapes3d.py +++ b/disent/data/groundtruth/_shapes3d.py @@ -24,7 +24,7 @@ import logging -from disent.data.groundtruth.base import DlH5DataObject +from disent.data.dataobj import DlH5DataObject from disent.data.groundtruth.base import Hdf5GroundTruthData diff --git a/disent/data/groundtruth/_xyblocks.py b/disent/data/groundtruth/_xyblocks.py index 169bb25d..f00f8c36 100644 --- a/disent/data/groundtruth/_xyblocks.py +++ b/disent/data/groundtruth/_xyblocks.py @@ -24,9 +24,12 @@ import logging from typing import Tuple + import numpy as np + from disent.data.groundtruth.base import GroundTruthData + log = logging.getLogger(__name__) diff --git a/disent/data/groundtruth/_xyobject.py b/disent/data/groundtruth/_xyobject.py index fe3e4ae8..7cb77b22 100644 --- a/disent/data/groundtruth/_xyobject.py +++ b/disent/data/groundtruth/_xyobject.py @@ -23,9 +23,11 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ from typing import Tuple -from disent.data.groundtruth.base import GroundTruthData + import numpy as np +from disent.data.groundtruth.base import GroundTruthData + # ========================================================================= # # xy grid data # diff --git a/disent/data/groundtruth/base.py b/disent/data/groundtruth/base.py index 296a833c..c237fb69 100644 --- a/disent/data/groundtruth/base.py +++ b/disent/data/groundtruth/base.py @@ -25,27 +25,17 @@ import logging import os from abc import ABCMeta -from typing import Callable -from typing import Dict -from typing import final -from typing import NoReturn from typing import Optional from typing import Sequence from typing import Tuple -from typing import Union import numpy as np -from disent.data.util.hdf5 import hdf5_resave_file -from disent.data.util.hdf5 import PickleH5pyDataset -from disent.data.util.in_out import basename_from_url -from disent.data.util.in_out import modify_file_name -from disent.data.util.in_out import stalefile -from disent.data.util.in_out import download_file -from disent.data.util.in_out import ensure_dir_exists -from disent.data.util.in_out import retrieve_file -from disent.data.util.state_space import StateSpace -from disent.util.function import wrapped_partial +from disent.data.dataobj import DataObject +from disent.data.dataobj import DlH5DataObject +from disent.data.groundtruth.states import StateSpace +from disent.data.hdf5 import PickleH5pyData +from disent.util.paths import ensure_dir_exists log = logging.getLogger(__name__) @@ -139,7 +129,7 @@ def default_data_root(self): return os.path.abspath(os.environ.get('DISENT_DATA_ROOT', 'data/dataset')) @property - def data_objects(self) -> Sequence['DataObject']: + def data_objects(self) -> Sequence[DataObject]: raise NotImplementedError @@ -160,11 +150,11 @@ def __getitem__(self, idx): return self._data[idx] @property - def data_objects(self) -> Sequence['DataObject']: + def data_objects(self) -> Sequence[DataObject]: return [self.data_object] @property - def data_object(self) -> 'DataObject': + def data_object(self) -> DataObject: raise NotImplementedError @property @@ -185,7 +175,7 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_me # variables self._in_memory = in_memory # load the h5py dataset - data = PickleH5pyDataset( + data = PickleH5pyData( h5_path=os.path.join(self.data_dir, self.data_object.out_name), h5_dataset_name=self.data_object.out_dataset_name, ) @@ -204,191 +194,14 @@ def __getitem__(self, idx): return self._data[idx] @property - def data_objects(self) -> Sequence['DlH5DataObject']: + def data_objects(self) -> Sequence[DlH5DataObject]: return [self.data_object] @property - def data_object(self) -> 'DlH5DataObject': + def data_object(self) -> DlH5DataObject: raise NotImplementedError -# ========================================================================= # -# data objects # -# ========================================================================= # - - -class DataObject(object, metaclass=ABCMeta): - """ - base DataObject that does nothing, if the file does - not exist or it has the incorrect hash, then that's your problem! - """ - - def __init__(self, file_name: str): - self._file_name = file_name - - @final - @property - def out_name(self) -> str: - return self._file_name - - def prepare(self, out_dir: str) -> str: - # TODO: maybe check that the file exists or not and raise a FileNotFoundError? - pass - - -class HashedDataObject(DataObject, metaclass=ABCMeta): - """ - Abstract Class - - Base DataObject class that guarantees a file to exist, - if the file does not exist, or the hash of the file is - incorrect, then the file is re-generated. - """ - - def __init__( - self, - file_name: str, - file_hash: Optional[Union[str, Dict[str, str]]], - hash_type: str = 'md5', - hash_mode: str = 'fast', - ): - super().__init__(file_name=file_name) - self._file_hash = file_hash - self._hash_type = hash_type - self._hash_mode = hash_mode - - def prepare(self, out_dir: str) -> str: - @stalefile(file=os.path.join(out_dir, self._file_name), hash=self._file_hash, hash_type=self._hash_type, hash_mode=self._hash_mode) - def wrapped(out_file): - self._prepare(out_dir=out_dir, out_file=out_file) - return wrapped() - - def _prepare(self, out_dir: str, out_file: str) -> str: - # TODO: maybe raise a FileNotFoundError or a HashError instead? - raise NotImplementedError - - -class DlDataObject(HashedDataObject): - """ - Download a file - - uri can also be a file to perform a copy instead of download, - useful for example if you want to retrieve a file from a network drive. - """ - - def __init__( - self, - uri: str, - uri_hash: Optional[Union[str, Dict[str, str]]], - uri_name: Optional[str] = None, - hash_type: str = 'md5', - hash_mode: str = 'fast', - ): - super().__init__( - file_name=basename_from_url(uri) if (uri_name is None) else uri_name, - file_hash=uri_hash, - hash_type=hash_type, - hash_mode=hash_mode - ) - self._uri = uri - - def _prepare(self, out_dir: str, out_file: str): - retrieve_file(src_uri=self._uri, dst_path=out_file, overwrite_existing=True) - - -class DlGenDataObject(HashedDataObject, metaclass=ABCMeta): - """ - Abstract class - - download a file and perform some processing on that file. - """ - - def __init__( - self, - # download & save files - uri: str, - uri_hash: Optional[Union[str, Dict[str, str]]], - file_hash: Optional[Union[str, Dict[str, str]]], - # save paths - uri_name: Optional[str] = None, - file_name: Optional[str] = None, - # hash settings - hash_type: str = 'md5', - hash_mode: str = 'fast', - ): - self._dl_obj = DlDataObject( - uri=uri, - uri_hash=uri_hash, - uri_name=uri_name, - hash_type=hash_type, - hash_mode=hash_mode, - ) - super().__init__( - file_name=modify_file_name(self._dl_obj.out_name, prefix='gen') if (file_name is None) else file_name, - file_hash=file_hash, - hash_type=hash_type, - hash_mode=hash_mode, - ) - - def _prepare(self, out_dir: str, out_file: str): - inp_file = self._dl_obj.prepare(out_dir=out_dir) - self._generate(inp_file=inp_file, out_file=out_file) - - def _generate(self, inp_file: str, out_file: str): - raise NotImplementedError - - -class DlH5DataObject(DlGenDataObject): - """ - Downloads an hdf5 file and pre-processes it into the specified chunk_size. - """ - - def __init__( - self, - # download & save files - uri: str, - uri_hash: Optional[Union[str, Dict[str, str]]], - file_hash: Optional[Union[str, Dict[str, str]]], - # h5 re-save settings - hdf5_dataset_name: str, - hdf5_chunk_size: Tuple[int, ...], - hdf5_compression: Optional[str] = 'gzip', - hdf5_compression_lvl: Optional[int] = 4, - hdf5_dtype: Optional[Union[np.dtype, str]] = None, - hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None, - # save paths - uri_name: Optional[str] = None, - file_name: Optional[str] = None, - # hash settings - hash_type: str = 'md5', - hash_mode: str = 'fast', - ): - super().__init__( - file_name=file_name, - file_hash=file_hash, - uri=uri, - uri_hash=uri_hash, - uri_name=uri_name, - hash_type=hash_type, - hash_mode=hash_mode, - ) - self._hdf5_resave_file = wrapped_partial( - hdf5_resave_file, - dataset_name=hdf5_dataset_name, - chunk_size=hdf5_chunk_size, - compression=hdf5_compression, - compression_lvl=hdf5_compression_lvl, - out_dtype=hdf5_dtype, - out_mutator=hdf5_mutator, - ) - # save the dataset name - self._out_dataset_name = hdf5_dataset_name - - @property - def out_dataset_name(self) -> str: - return self._out_dataset_name - - def _generate(self, inp_file: str, out_file: str): - self._hdf5_resave_file(inp_path=inp_file, out_path=out_file) - - # ========================================================================= # # END # # ========================================================================= # diff --git a/disent/data/util/state_space.py b/disent/data/groundtruth/states.py similarity index 100% rename from disent/data/util/state_space.py rename to disent/data/groundtruth/states.py diff --git a/disent/data/util/hdf5.py b/disent/data/hdf5.py similarity index 93% rename from disent/data/util/hdf5.py rename to disent/data/hdf5.py index e68ccbb9..7edfcd5e 100644 --- a/disent/data/util/hdf5.py +++ b/disent/data/hdf5.py @@ -32,8 +32,8 @@ import numpy as np from tqdm import tqdm -from disent.data.util.in_out import AtomicSaveFile -from disent.data.util.in_out import bytes_to_human +from disent.util.in_out import AtomicSaveFile +from disent.util.strings import bytes_to_human from disent.util import colors as c from disent.util.iters import iter_chunks from disent.util.iters import LengthIter @@ -48,7 +48,7 @@ # ========================================================================= # -class PickleH5pyDataset(LengthIter): +class PickleH5pyData(LengthIter): """ This class supports pickling and unpickling of a read-only SWMR h5py file and corresponding dataset. @@ -99,35 +99,6 @@ def close(self): del self._hdf5_data -# ========================================================================= # -# hdf5 # -# ========================================================================= # - - -# TODO: cleanup -def hdf5_print_entry_data_stats(h5_dataset: h5py.Dataset, label='STATISTICS'): - dtype = h5_dataset.dtype - itemsize = h5_dataset.dtype.itemsize - # chunk - chunks = np.array(h5_dataset.chunks) - data_per_chunk = np.prod(chunks) * itemsize - # entry - shape = np.array([1, *h5_dataset.shape[1:]]) - data_per_entry = np.prod(shape) * itemsize - # chunks per entry - chunks_per_dim = np.ceil(shape / chunks).astype('int') - chunks_per_entry = np.prod(chunks_per_dim) - read_data_per_entry = data_per_chunk * chunks_per_entry - # print info - tqdm.write( - f'[{label:3s}] ' - f'entry: {str(list(shape)):18s} ({str(dtype):8s}) {c.lYLW}{bytes_to_human(data_per_entry)}{c.RST} ' - f'chunk: {str(list(chunks)):18s} {c.YLW}{bytes_to_human(data_per_chunk)}{c.RST} ' - f'chunks per entry: {str(list(chunks_per_dim)):18s} {c.YLW}{bytes_to_human(read_data_per_entry)}{c.RST} ({c.RED}{chunks_per_entry:5d}{c.RST}) | ' - f'compression: {repr(h5_dataset.compression)} compression lvl: {repr(h5_dataset.compression_opts)}' - ) - - # ========================================================================= # # hdf5 - resave # # ========================================================================= # @@ -195,25 +166,24 @@ def hdf5_resave_file(inp_path: str, out_path: str, dataset_name, chunk_size=None # ========================================================================= # -def hdf5_test_entries_per_second(h5_data: h5py.File, dataset_name, access_method='random', max_entries=48000, timeout=10, batch_size: int = 256): - data = h5_data[dataset_name] +def hdf5_test_entries_per_second(h5_dataset: h5py.Dataset, access_method='random', max_entries=48000, timeout=10, batch_size: int = 256): # get access method if access_method == 'sequential': - indices = np.arange(len(data)) + indices = np.arange(len(h5_dataset)) elif access_method == 'random': - indices = np.arange(len(data)) + indices = np.arange(len(h5_dataset)) np.random.shuffle(indices) else: raise KeyError('Invalid access method') # num entries to test - n = min(len(data), max_entries) + n = min(len(h5_dataset), max_entries) indices = indices[:n] # iterate through dataset, exit on timeout or max_entries t = Timer() for chunk in iter_chunks(enumerate(indices), chunk_size=batch_size): with t: for i, idx in chunk: - entry = data[idx] + entry = h5_dataset[idx] if t.elapsed > timeout: break # calculate score @@ -224,7 +194,36 @@ def hdf5_test_entries_per_second(h5_data: h5py.File, dataset_name, access_method def hdf5_test_speed(h5_path: str, dataset_name: str, access_method: str = 'random'): with h5py.File(h5_path, 'r') as out_h5: log.info('[TESTING] Access Speed...') - log.info(f'Random Accesses Per Second: {hdf5_test_entries_per_second(out_h5, dataset_name, access_method=access_method, max_entries=5_000):.3f}') + log.info(f'Random Accesses Per Second: {hdf5_test_entries_per_second(out_h5[dataset_name], access_method=access_method, max_entries=5_000):.3f}') + + +# ========================================================================= # +# hdf5 - stats # +# ========================================================================= # + + +# TODO: cleanup +def hdf5_print_entry_data_stats(h5_dataset: h5py.Dataset, label='STATISTICS'): + dtype = h5_dataset.dtype + itemsize = h5_dataset.dtype.itemsize + # chunk + chunks = np.array(h5_dataset.chunks) + data_per_chunk = np.prod(chunks) * itemsize + # entry + shape = np.array([1, *h5_dataset.shape[1:]]) + data_per_entry = np.prod(shape) * itemsize + # chunks per entry + chunks_per_dim = np.ceil(shape / chunks).astype('int') + chunks_per_entry = np.prod(chunks_per_dim) + read_data_per_entry = data_per_chunk * chunks_per_entry + # print info + tqdm.write( + f'[{label:3s}] ' + f'entry: {str(list(shape)):18s} ({str(dtype):8s}) {c.lYLW}{bytes_to_human(data_per_entry)}{c.RST} ' + f'chunk: {str(list(chunks)):18s} {c.YLW}{bytes_to_human(data_per_chunk)}{c.RST} ' + f'chunks per entry: {str(list(chunks_per_dim)):18s} {c.YLW}{bytes_to_human(read_data_per_entry)}{c.RST} ({c.RED}{chunks_per_entry:5d}{c.RST}) | ' + f'compression: {repr(h5_dataset.compression)} compression lvl: {repr(h5_dataset.compression_opts)}' + ) # ========================================================================= # diff --git a/disent/data/util/__init__.py b/disent/data/util/__init__.py deleted file mode 100644 index 9a05a479..00000000 --- a/disent/data/util/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ diff --git a/disent/data/util/in_out.py b/disent/data/util/in_out.py deleted file mode 100644 index 1810fb05..00000000 --- a/disent/data/util/in_out.py +++ /dev/null @@ -1,422 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -import math -import os -from functools import wraps -from pathlib import Path -from typing import Callable -from typing import Dict -from typing import NoReturn -from typing import Optional -from typing import Tuple -from typing import Union -from uuid import uuid4 - -from disent.util import colors as c - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Formatting # -# ========================================================================= # - - -_BYTES_POW_NAME = ("B ", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB") -_BYTES_POW_COLR = (c.WHT, c.lGRN, c.lYLW, c.lRED, c.lRED, c.lRED, c.lRED, c.lRED, c.lRED) - - -def bytes_to_human(size_bytes, decimals=3, color=True): - if size_bytes == 0: - return "0B" - # round correctly - i = int(math.floor(math.log(size_bytes, 1024))) - s = round(size_bytes / math.pow(1024, i), decimals) - # generate string - name = f'{_BYTES_POW_COLR[i]}{_BYTES_POW_NAME[i]}{c.RST}' if color else f'{_BYTES_POW_NAME[i]}' - # format string - return f"{s:{4+decimals}.{decimals}f} {name}" - - -def modify_file_name(file: Union[str, Path], prefix: str = None, suffix: str = None, sep='.') -> Union[str, Path]: - # get path components - path = Path(file) - assert path.name, f'file name cannot be empty: {repr(path)}, for name: {repr(path.name)}' - # create new path - prefix = '' if (prefix is None) else f'{prefix}{sep}' - suffix = '' if (suffix is None) else f'{sep}{suffix}' - new_path = path.parent.joinpath(f'{prefix}{path.name}{suffix}') - # return path - return str(new_path) if isinstance(file, str) else new_path - - -# ========================================================================= # -# file hashing # -# ========================================================================= # - - -def yield_file_bytes(file: str, chunk_size=16384): - with open(file, 'rb') as f: - bytes = True - while bytes: - bytes = f.read(chunk_size) - yield bytes - - -def yield_fast_hash_bytes(file: str, chunk_size=16384, num_chunks=3): - assert num_chunks >= 2 - # return the size in bytes - size = os.path.getsize(file) - yield size.to_bytes(length=64//8, byteorder='big', signed=False) - # return file bytes chunks - if size < chunk_size * num_chunks: - # we cant return chunks because the file is too small, return everything! - yield from yield_file_bytes(file, chunk_size=chunk_size) - else: - # includes evenly spaced start, middle and end chunks - with open(file, 'rb') as f: - for i in range(num_chunks): - pos = (i * (size - chunk_size)) // (num_chunks - 1) - f.seek(pos) - yield f.read(chunk_size) - - -def hash_file(file: str, hash_type='md5', hash_mode='full', missing_ok=True) -> str: - """ - :param file: the path to the file - :param hash_type: the kind of hash to compute, default is "md5" - :param hash_mode: "full" uses all the bytes in the file to compute the hash, "fast" uses the start, middle, end bytes as well as the size of the file in the hash. - :param chunk_size: number of bytes to read at a time - :return: the hexdigest of the hash - :raises FileNotFoundError - """ - import hashlib - # check the file exists - if not os.path.isfile(file): - if missing_ok: - return '' - raise FileNotFoundError(f'could not compute hash for missing file: {repr(file)}') - # get file bytes iterator - if hash_mode == 'full': - byte_iter = yield_file_bytes(file=file) - elif hash_mode == 'fast': - byte_iter = yield_fast_hash_bytes(file=file) - else: - raise KeyError(f'invalid hash_mode: {repr(hash_mode)}') - # generate hash - hash = hashlib.new(hash_type) - for bytes in byte_iter: - hash.update(bytes) - hash = hash.hexdigest() - # done - return hash - - -class HashError(Exception): - """ - Raised if the hash of a file was invalid. - """ - - -def get_hash(hash: Union[str, Dict[str, str]], hash_mode: str) -> str: - return hash[hash_mode] if isinstance(hash, dict) else hash - - -def validate_file_hash(file: str, hash: Union[str, Dict[str, str]], hash_type: str = 'md5', hash_mode: str = 'full', missing_ok=True): - """ - :raises FileNotFoundError, HashError - """ - hash = get_hash(hash=hash, hash_mode=hash_mode) - # compute the hash - fhash = hash_file(file=file, hash_type=hash_type, hash_mode=hash_mode, missing_ok=missing_ok) - # check the hash - if fhash != hash: - raise HashError(f'computed {hash_mode} {hash_type} hash: {repr(fhash)} does not match expected hash: {repr(hash)} for file: {repr(file)}') - - -def is_valid_file_hash(file: str, hash: Union[str, Dict[str, str]], hash_type: str = 'md5', hash_mode: str = 'full', missing_ok=True): - try: - validate_file_hash(file=file, hash=hash, hash_type=hash_type, hash_mode=hash_mode, missing_ok=missing_ok) - except HashError: - return False - return True - - -# ========================================================================= # -# Function Caching # -# ========================================================================= # - - -class stalefile(object): - - def __init__( - self, - file: str, - hash: Optional[Union[str, Dict[str, str]]], - hash_type: str = 'md5', - hash_mode: str = 'fast', - ): - self.file = file - self.hash = get_hash(hash=hash, hash_mode=hash_mode) - self.hash_type = hash_type - self.hash_mode = hash_mode - - def __call__(self, func: Callable[[str], NoReturn]) -> Callable[[], str]: - @wraps(func) - def wrapper() -> str: - if self.is_stale(): - log.debug(f'calling wrapped function: {func} because the file is stale: {repr(self.file)}') - func(self.file) - validate_file_hash(self.file, hash=self.hash, hash_type=self.hash_type, hash_mode=self.hash_mode) - else: - log.debug(f'skipped wrapped function: {func} because the file is fresh: {repr(self.file)}') - return self.file - return wrapper - - def is_stale(self): - fhash = hash_file(file=self.file, hash_type=self.hash_type, hash_mode=self.hash_mode, missing_ok=True) - if not fhash: - log.info(f'file is stale because it does not exist: {repr(self.file)}') - return True - if fhash != self.hash: - log.info(f'file is stale because the computed {self.hash_mode} {self.hash_type} hash: {fhash} does not match the target hash: {self.hash} for file: {repr(self.file)}') - return True - log.info(f'file is fresh: {repr(self.file)}') - return False - - def __bool__(self): - return self.is_stale() - - -# ========================================================================= # -# Atomic file saving # -# ========================================================================= # - - -class AtomicSaveFile(object): - """ - Within the context, data must be written to a temporary file. - Once data has been successfully written, the temporary file - is moved to the location of the target file. - - The temporary file is created in the same directory as the target file. - - ``` - with AtomicFileHandler('file.txt') as tmp_file: - with open(tmp_file, 'w') as f: - f.write("hello world!\n") - ``` - - # TODO: can this be cleaned up with the TemporaryDirectory and TemporaryFile classes? - """ - - def __init__( - self, - file: str, - open_mode: Optional[str] = None, - overwrite: bool = False, - makedirs: bool = True, - tmp_prefix: Optional[str] = '.temp.', - tmp_suffix: Optional[str] = None, - ): - from pathlib import Path - # check files - if not file: - raise ValueError(f'file must not be empty: {repr(file)}') - # get files - self.trg_file = Path(file).absolute() - self.tmp_file = modify_file_name(self.trg_file, prefix=f'{tmp_prefix}{uuid4()}', suffix=tmp_suffix) - # check that the files are different - if self.trg_file == self.tmp_file: - raise ValueError(f'temporary and target files are the same: {self.tmp_file} == {self.trg_file}') - # other settings - self._makedirs = makedirs - self._overwrite = overwrite - self._open_mode = open_mode - self._resource = None - - def __enter__(self): - # check files exist or not - if self.tmp_file.exists(): - if not self.tmp_file.is_file(): - raise FileExistsError(f'the temporary file exists but is not a file: {self.tmp_file}') - if self.trg_file.exists(): - if not self._overwrite: - raise FileExistsError(f'the target file already exists: {self.trg_file}, set overwrite=True to ignore this error.') - if not self.trg_file.is_file(): - raise FileExistsError(f'the target file exists but is not a file: {self.trg_file}') - # create the missing directories if needed - if self._makedirs: - self.tmp_file.parent.mkdir(parents=True, exist_ok=True) - # delete any existing temporary files - if self.tmp_file.exists(): - log.debug(f'deleting existing temporary file: {self.tmp_file}') - self.tmp_file.unlink() - # handle the different modes, deleting any existing tmp files - if self._open_mode is not None: - log.debug(f'created new temporary file: {self.tmp_file}') - self._resource = open(self.tmp_file, self._open_mode) - return str(self.tmp_file), self._resource - else: - return str(self.tmp_file) - - def __exit__(self, error_type, error, traceback): - # close the temp file - if self._resource is not None: - self._resource.close() - self._resource = None - # cleanup if there was an error, and exit early - if error_type is not None: - if self.tmp_file.exists(): - self.tmp_file.unlink(missing_ok=True) - log.error(f'An error occurred in {self.__class__.__name__}, cleaned up temporary file: {self.tmp_file}') - else: - log.error(f'An error occurred in {self.__class__.__name__}') - return - # the temp file must have been created! - if not self.tmp_file.exists(): - raise FileNotFoundError(f'the temporary file was not created: {self.tmp_file}') - # delete the target file if it exists and overwrite is enabled: - if self._overwrite: - log.warning(f'overwriting file: {self.trg_file}') - self.trg_file.unlink(missing_ok=True) - # create the missing directories if needed - if self._makedirs: - self.trg_file.parent.mkdir(parents=True, exist_ok=True) - # move the temp file to the target file - log.info(f'moved temporary file to final location: {self.tmp_file} -> {self.trg_file}') - os.rename(self.tmp_file, self.trg_file) - - -# ========================================================================= # -# files/dirs exist # -# ========================================================================= # - - -def ensure_dir_exists(*path, is_file=False, absolute=False): - import os - # join path - path = os.path.join(*path) - # to abs path - if absolute: - path = os.path.abspath(path) - # remove file - dirs = os.path.dirname(path) if is_file else path - # create missing directory - if os.path.exists(dirs): - if not os.path.isdir(dirs): - raise IOError(f'path is not a directory: {dirs}') - else: - os.makedirs(dirs, exist_ok=True) - log.info(f'created missing directories: {dirs}') - # return directory - return path - - -def ensure_parent_dir_exists(*path): - return ensure_dir_exists(*path, is_file=True, absolute=True) - - -# ========================================================================= # -# files/dirs exist # -# ========================================================================= # - - -def download_file(url: str, save_path: str, overwrite_existing: bool = False, chunk_size: int = 16384): - import requests - from tqdm import tqdm - # write the file - with AtomicSaveFile(file=save_path, open_mode='wb', overwrite=overwrite_existing) as (_, file): - response = requests.get(url, stream=True) - total_length = response.headers.get('content-length') - # cast to integer if content-length exists on response - if total_length is not None: - total_length = int(total_length) - # download with progress bar - log.info(f'Downloading: {url} to: {save_path}') - with tqdm(total=total_length, desc=f'Downloading', unit='B', unit_scale=True, unit_divisor=1024) as progress: - for data in response.iter_content(chunk_size=chunk_size): - file.write(data) - progress.update(chunk_size) - - -def copy_file(src: str, dst: str, overwrite_existing: bool = False): - # copy the file - if os.path.abspath(src) == os.path.abspath(dst): - raise FileExistsError(f'input and output paths for copy are the same, skipping: {repr(dst)}') - else: - with AtomicSaveFile(file=dst, overwrite=overwrite_existing) as path: - import shutil - shutil.copyfile(src, path) - - -def retrieve_file(src_uri: str, dst_path: str, overwrite_existing: bool = False): - uri, is_url = _uri_parse_file_or_url(src_uri) - if is_url: - download_file(url=uri, save_path=dst_path, overwrite_existing=overwrite_existing) - else: - copy_file(src=uri, dst=dst_path, overwrite_existing=overwrite_existing) - - -# ========================================================================= # -# path utils # -# ========================================================================= # - - -def basename_from_url(url): - import os - from urllib.parse import urlparse - return os.path.basename(urlparse(url).path) - - -def _uri_parse_file_or_url(inp_uri) -> Tuple[str, bool]: - from urllib.parse import urlparse - result = urlparse(inp_uri) - # parse different cases - if result.scheme in ('http', 'https'): - is_url = True - uri = result.geturl() - elif result.scheme in ('file', ''): - is_url = False - if result.scheme == 'file': - if result.netloc: - raise KeyError(f'file uri format is invalid: "{result.geturl()}" two slashes specifies host as: "{result.netloc}" eg. instead of "file://hostname/root_folder/file.txt", please use: "file:/root_folder/file.txt" (no hostname) or "file:///root_folder/file.txt" (empty hostname).') - if not os.path.isabs(result.path): - raise RuntimeError(f'path: {repr(result.path)} obtained from file URI: {repr(inp_uri)} should always be absolute') - uri = result.path - else: - uri = result.geturl() - uri = os.path.abspath(uri) - else: - raise ValueError(f'invalid file or url: {repr(inp_uri)}') - # done - return uri, is_url - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/util/cache.py b/disent/util/cache.py new file mode 100644 index 00000000..93b4b5fb --- /dev/null +++ b/disent/util/cache.py @@ -0,0 +1,94 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +from functools import wraps +from typing import Callable +from typing import Dict +from typing import NoReturn +from typing import Optional +from typing import Union + +from disent.util.hashing import normalise_hash +from disent.util.hashing import hash_file +from disent.util.hashing import validate_file_hash + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# Function Caching # +# ========================================================================= # + + +class stalefile(object): + """ + decorator that only runs the wrapped function if a + file does not exist, or its hash does not match. + """ + + def __init__( + self, + file: str, + hash: Optional[Union[str, Dict[str, str]]], + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + self.file = file + self.hash = normalise_hash(hash=hash, hash_mode=hash_mode) + self.hash_type = hash_type + self.hash_mode = hash_mode + + def __call__(self, func: Callable[[str], NoReturn]) -> Callable[[], str]: + @wraps(func) + def wrapper() -> str: + if self.is_stale(): + log.debug(f'calling wrapped function: {func} because the file is stale: {repr(self.file)}') + func(self.file) + validate_file_hash(self.file, hash=self.hash, hash_type=self.hash_type, hash_mode=self.hash_mode) + else: + log.debug(f'skipped wrapped function: {func} because the file is fresh: {repr(self.file)}') + return self.file + return wrapper + + def is_stale(self): + fhash = hash_file(file=self.file, hash_type=self.hash_type, hash_mode=self.hash_mode, missing_ok=True) + if not fhash: + log.info(f'file is stale because it does not exist: {repr(self.file)}') + return True + if fhash != self.hash: + log.info(f'file is stale because the computed {self.hash_mode} {self.hash_type} hash: {fhash} does not match the target hash: {self.hash} for file: {repr(self.file)}') + return True + log.info(f'file is fresh: {repr(self.file)}') + return False + + def __bool__(self): + return self.is_stale() + + +# ========================================================================= # +# END # +# ========================================================================= # + diff --git a/disent/util/hashing.py b/disent/util/hashing.py new file mode 100644 index 00000000..3ad3ae7e --- /dev/null +++ b/disent/util/hashing.py @@ -0,0 +1,140 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import os +from typing import Dict +from typing import Union + + +# ========================================================================= # +# file hashing # +# ========================================================================= # + + +def _yield_file_bytes(file: str, chunk_size=16384): + with open(file, 'rb') as f: + bytes = True + while bytes: + bytes = f.read(chunk_size) + yield bytes + + +def _yield_fast_hash_bytes(file: str, chunk_size=16384, num_chunks=3): + assert num_chunks >= 2 + # return the size in bytes + size = os.path.getsize(file) + yield size.to_bytes(length=64//8, byteorder='big', signed=False) + # return file bytes chunks + if size < chunk_size * num_chunks: + # we cant return chunks because the file is too small, return everything! + yield from _yield_file_bytes(file, chunk_size=chunk_size) + else: + # includes evenly spaced start, middle and end chunks + with open(file, 'rb') as f: + for i in range(num_chunks): + pos = (i * (size - chunk_size)) // (num_chunks - 1) + f.seek(pos) + yield f.read(chunk_size) + + +# ========================================================================= # +# file hashing # +# ========================================================================= # + + +def hash_file(file: str, hash_type='md5', hash_mode='full', missing_ok=True) -> str: + """ + :param file: the path to the file + :param hash_type: the kind of hash to compute, default is "md5" + :param hash_mode: "full" uses all the bytes in the file to compute the hash, "fast" uses the start, middle, end bytes as well as the size of the file in the hash. + :param chunk_size: number of bytes to read at a time + :return: the hexdigest of the hash + :raises FileNotFoundError + """ + import hashlib + # check the file exists + if not os.path.isfile(file): + if missing_ok: + return '' + raise FileNotFoundError(f'could not compute hash for missing file: {repr(file)}') + # get file bytes iterator + if hash_mode == 'full': + byte_iter = _yield_file_bytes(file=file) + elif hash_mode == 'fast': + byte_iter = _yield_fast_hash_bytes(file=file) + else: + raise KeyError(f'invalid hash_mode: {repr(hash_mode)}') + # generate hash + hash = hashlib.new(hash_type) + for bytes in byte_iter: + hash.update(bytes) + hash = hash.hexdigest() + # done + return hash + + +# ========================================================================= # +# file hashing utils # +# ========================================================================= # + + +class HashError(Exception): + """ + Raised if the hash of a file was invalid. + """ + + +def normalise_hash(hash: Union[str, Dict[str, str]], hash_mode: str) -> str: + """ + file hashes depend on the mode. + - Allow hashes to be dictionaries that map the hash_mode to the hash. + This function returns the correct hash if it is a dictionary. + """ + return hash[hash_mode] if isinstance(hash, dict) else hash + + +def validate_file_hash(file: str, hash: Union[str, Dict[str, str]], hash_type: str = 'md5', hash_mode: str = 'full', missing_ok=True): + """ + :raises FileNotFoundError, HashError + """ + hash = normalise_hash(hash=hash, hash_mode=hash_mode) + # compute the hash + fhash = hash_file(file=file, hash_type=hash_type, hash_mode=hash_mode, missing_ok=missing_ok) + # check the hash + if fhash != hash: + raise HashError(f'computed {hash_mode} {hash_type} hash: {repr(fhash)} does not match expected hash: {repr(hash)} for file: {repr(file)}') + + +def is_valid_file_hash(file: str, hash: Union[str, Dict[str, str]], hash_type: str = 'md5', hash_mode: str = 'full', missing_ok=True): + try: + validate_file_hash(file=file, hash=hash, hash_type=hash_type, hash_mode=hash_mode, missing_ok=missing_ok) + except HashError: + return False + return True + + +# ========================================================================= # +# file hashing # +# ========================================================================= # + diff --git a/disent/util/in_out.py b/disent/util/in_out.py new file mode 100644 index 00000000..4a329613 --- /dev/null +++ b/disent/util/in_out.py @@ -0,0 +1,180 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +import os +from typing import Optional +from uuid import uuid4 + +from disent.util.paths import uri_parse_file_or_url +from disent.util.paths import modify_file_name + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# Atomic file saving # +# ========================================================================= # + + +class AtomicSaveFile(object): + """ + Within the context, data must be written to a temporary file. + Once data has been successfully written, the temporary file + is moved to the location of the target file. + + The temporary file is created in the same directory as the target file. + + ``` + with AtomicFileHandler('file.txt') as tmp_file: + with open(tmp_file, 'w') as f: + f.write("hello world!\n") + ``` + + # TODO: can this be cleaned up with the TemporaryDirectory and TemporaryFile classes? + """ + + def __init__( + self, + file: str, + open_mode: Optional[str] = None, + overwrite: bool = False, + makedirs: bool = True, + tmp_prefix: Optional[str] = '.temp.', + tmp_suffix: Optional[str] = None, + ): + from pathlib import Path + # check files + if not file: + raise ValueError(f'file must not be empty: {repr(file)}') + # get files + self.trg_file = Path(file).absolute() + self.tmp_file = modify_file_name(self.trg_file, prefix=f'{tmp_prefix}{uuid4()}', suffix=tmp_suffix) + # check that the files are different + if self.trg_file == self.tmp_file: + raise ValueError(f'temporary and target files are the same: {self.tmp_file} == {self.trg_file}') + # other settings + self._makedirs = makedirs + self._overwrite = overwrite + self._open_mode = open_mode + self._resource = None + + def __enter__(self): + # check files exist or not + if self.tmp_file.exists(): + if not self.tmp_file.is_file(): + raise FileExistsError(f'the temporary file exists but is not a file: {self.tmp_file}') + if self.trg_file.exists(): + if not self._overwrite: + raise FileExistsError(f'the target file already exists: {self.trg_file}, set overwrite=True to ignore this error.') + if not self.trg_file.is_file(): + raise FileExistsError(f'the target file exists but is not a file: {self.trg_file}') + # create the missing directories if needed + if self._makedirs: + self.tmp_file.parent.mkdir(parents=True, exist_ok=True) + # delete any existing temporary files + if self.tmp_file.exists(): + log.debug(f'deleting existing temporary file: {self.tmp_file}') + self.tmp_file.unlink() + # handle the different modes, deleting any existing tmp files + if self._open_mode is not None: + log.debug(f'created new temporary file: {self.tmp_file}') + self._resource = open(self.tmp_file, self._open_mode) + return str(self.tmp_file), self._resource + else: + return str(self.tmp_file) + + def __exit__(self, error_type, error, traceback): + # close the temp file + if self._resource is not None: + self._resource.close() + self._resource = None + # cleanup if there was an error, and exit early + if error_type is not None: + if self.tmp_file.exists(): + self.tmp_file.unlink(missing_ok=True) + log.error(f'An error occurred in {self.__class__.__name__}, cleaned up temporary file: {self.tmp_file}') + else: + log.error(f'An error occurred in {self.__class__.__name__}') + return + # the temp file must have been created! + if not self.tmp_file.exists(): + raise FileNotFoundError(f'the temporary file was not created: {self.tmp_file}') + # delete the target file if it exists and overwrite is enabled: + if self._overwrite: + log.warning(f'overwriting file: {self.trg_file}') + self.trg_file.unlink(missing_ok=True) + # create the missing directories if needed + if self._makedirs: + self.trg_file.parent.mkdir(parents=True, exist_ok=True) + # move the temp file to the target file + log.info(f'moved temporary file to final location: {self.tmp_file} -> {self.trg_file}') + os.rename(self.tmp_file, self.trg_file) + + +# ========================================================================= # +# files/dirs exist # +# ========================================================================= # + + +def download_file(url: str, save_path: str, overwrite_existing: bool = False, chunk_size: int = 16384): + import requests + from tqdm import tqdm + # write the file + with AtomicSaveFile(file=save_path, open_mode='wb', overwrite=overwrite_existing) as (_, file): + response = requests.get(url, stream=True) + total_length = response.headers.get('content-length') + # cast to integer if content-length exists on response + if total_length is not None: + total_length = int(total_length) + # download with progress bar + log.info(f'Downloading: {url} to: {save_path}') + with tqdm(total=total_length, desc=f'Downloading', unit='B', unit_scale=True, unit_divisor=1024) as progress: + for data in response.iter_content(chunk_size=chunk_size): + file.write(data) + progress.update(chunk_size) + + +def copy_file(src: str, dst: str, overwrite_existing: bool = False): + # copy the file + if os.path.abspath(src) == os.path.abspath(dst): + raise FileExistsError(f'input and output paths for copy are the same, skipping: {repr(dst)}') + else: + with AtomicSaveFile(file=dst, overwrite=overwrite_existing) as path: + import shutil + shutil.copyfile(src, path) + + +def retrieve_file(src_uri: str, dst_path: str, overwrite_existing: bool = False): + uri, is_url = uri_parse_file_or_url(src_uri) + if is_url: + download_file(url=uri, save_path=dst_path, overwrite_existing=overwrite_existing) + else: + copy_file(src=uri, dst=dst_path, overwrite_existing=overwrite_existing) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/paths.py b/disent/util/paths.py new file mode 100644 index 00000000..1ad76ded --- /dev/null +++ b/disent/util/paths.py @@ -0,0 +1,118 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +import os +from pathlib import Path +from typing import Tuple +from typing import Union + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# Formatting # +# ========================================================================= # + + +def modify_file_name(file: Union[str, Path], prefix: str = None, suffix: str = None, sep='.') -> Union[str, Path]: + # get path components + path = Path(file) + assert path.name, f'file name cannot be empty: {repr(path)}, for name: {repr(path.name)}' + # create new path + prefix = '' if (prefix is None) else f'{prefix}{sep}' + suffix = '' if (suffix is None) else f'{sep}{suffix}' + new_path = path.parent.joinpath(f'{prefix}{path.name}{suffix}') + # return path + return str(new_path) if isinstance(file, str) else new_path + + +# ========================================================================= # +# files/dirs exist # +# ========================================================================= # + + +def ensure_dir_exists(*join_paths: str, is_file=False, absolute=False): + import os + # join path + path = os.path.join(*join_paths) + # to abs path + if absolute: + path = os.path.abspath(path) + # remove file + dirs = os.path.dirname(path) if is_file else path + # create missing directory + if os.path.exists(dirs): + if not os.path.isdir(dirs): + raise IOError(f'path is not a directory: {dirs}') + else: + os.makedirs(dirs, exist_ok=True) + log.info(f'created missing directories: {dirs}') + # return directory + return path + + +def ensure_parent_dir_exists(*join_paths: str): + return ensure_dir_exists(*join_paths, is_file=True, absolute=True) + + +# ========================================================================= # +# URI utils # +# ========================================================================= # + + +def filename_from_url(url: str): + import os + from urllib.parse import urlparse + return os.path.basename(urlparse(url).path) + + +def uri_parse_file_or_url(inp_uri: str) -> Tuple[str, bool]: + from urllib.parse import urlparse + result = urlparse(inp_uri) + # parse different cases + if result.scheme in ('http', 'https'): + is_url = True + uri = result.geturl() + elif result.scheme in ('file', ''): + is_url = False + if result.scheme == 'file': + if result.netloc: + raise KeyError(f'file uri format is invalid: "{result.geturl()}" two slashes specifies host as: "{result.netloc}" eg. instead of "file://hostname/root_folder/file.txt", please use: "file:/root_folder/file.txt" (no hostname) or "file:///root_folder/file.txt" (empty hostname).') + if not os.path.isabs(result.path): + raise RuntimeError(f'path: {repr(result.path)} obtained from file URI: {repr(inp_uri)} should always be absolute') + uri = result.path + else: + uri = result.geturl() + uri = os.path.abspath(uri) + else: + raise ValueError(f'invalid file or url: {repr(inp_uri)}') + # done + return uri, is_url + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/strings.py b/disent/util/strings.py index 118f6f10..bc9be4f3 100644 --- a/disent/util/strings.py +++ b/disent/util/strings.py @@ -22,6 +22,35 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import math +from disent.util import colors as c + + +# ========================================================================= # +# Byte Formatting # +# ========================================================================= # + + +_BYTES_COLR = (c.WHT, c.lGRN, c.lYLW, c.lRED, c.lRED, c.lRED, c.lRED, c.lRED, c.lRED) +_BYTES_NAME = { + 1024: ("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"), + 1000: ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"), +} + + +def bytes_to_human(size_bytes: int, decimals: int = 3, color: bool = True, mul: int = 1024) -> str: + if size_bytes == 0: + return "0B" + if mul not in _BYTES_NAME: + raise ValueError(f'invalid bytes multiplier: {repr(mul)} must be one of: {list(_BYTES_NAME.keys())}') + # round correctly + i = int(math.floor(math.log(size_bytes, mul))) + s = round(size_bytes / math.pow(mul, i), decimals) + # generate string + name = f'{_BYTES_COLR[i]}{_BYTES_NAME[mul][i]}{c.RST}' if color else f'{_BYTES_NAME[mul][i]}' + # format string + return f"{s:{4+decimals}.{decimals}f} {name}" + # ========================================================================= # # STRINGS # @@ -90,4 +119,3 @@ def pad_height(list_of_lines): # ========================================================================= # # END # # ========================================================================= # - diff --git a/experiment/exp/05_adversarial_data/run_04_gen_adversarial.py b/experiment/exp/05_adversarial_data/run_04_gen_adversarial.py index ed7f3832..4b8a6204 100644 --- a/experiment/exp/05_adversarial_data/run_04_gen_adversarial.py +++ b/experiment/exp/05_adversarial_data/run_04_gen_adversarial.py @@ -39,7 +39,7 @@ from tqdm import tqdm import experiment.exp.util as H -from disent.data.util.in_out import ensure_parent_dir_exists +from disent.util.paths import ensure_parent_dir_exists from disent.util.seeds import seed from disent.util.seeds import TempNumpySeed from disent.util.profiling import Timer diff --git a/experiment/exp/util/_io_util.py b/experiment/exp/util/_io_util.py index befc5a96..5fce91b4 100644 --- a/experiment/exp/util/_io_util.py +++ b/experiment/exp/util/_io_util.py @@ -31,7 +31,7 @@ import torch -from disent.data.util.in_out import ensure_parent_dir_exists +from disent.util.paths import ensure_parent_dir_exists # ========================================================================= # diff --git a/tests/test_state_space.py b/tests/test_state_space.py index c3dc3d89..f3316ae3 100644 --- a/tests/test_state_space.py +++ b/tests/test_state_space.py @@ -24,7 +24,7 @@ import numpy as np -from disent.data.util.state_space import StateSpace +from disent.data.groundtruth.states import StateSpace # ========================================================================= # From c4503b9bb2ea4d630bbfc55f42cdf47d87f78b8c Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Thu, 3 Jun 2021 19:24:09 +0200 Subject: [PATCH 27/34] update h5py tests --- disent/data/hdf5.py | 5 ++- tests/test_data.py | 89 +++++++++++++++++++++------------------------ 2 files changed, 45 insertions(+), 49 deletions(-) diff --git a/disent/data/hdf5.py b/disent/data/hdf5.py index 7edfcd5e..9c30ace8 100644 --- a/disent/data/hdf5.py +++ b/disent/data/hdf5.py @@ -21,6 +21,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + """ Utilities for converting and testing different chunk sizes of hdf5 files """ @@ -48,10 +49,12 @@ # ========================================================================= # -class PickleH5pyData(LengthIter): +class PickleH5pyFile(LengthIter): """ This class supports pickling and unpickling of a read-only SWMR h5py file and corresponding dataset. + + WARNING: this should probably not be used across multiple hosts? """ def __init__(self, h5_path: str, h5_dataset_name: str): diff --git a/tests/test_data.py b/tests/test_data.py index b9a627da..443acf08 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -21,18 +21,22 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +from concurrent.futures import ProcessPoolExecutor +from tempfile import NamedTemporaryFile + +import h5py import numpy as np import pytest -from disent.data.groundtruth import Shapes3dData from disent.data.groundtruth import XYSquaresData from disent.data.groundtruth._xysquares import XYSquaresMinimalData +from disent.data.hdf5 import PickleH5pyFile # ========================================================================= # # TESTS # # ========================================================================= # -from disent.data.groundtruth.base import Hdf5GroundTruthData def test_xysquares_similarity(): @@ -49,52 +53,41 @@ def test_xysquares_similarity(): assert np.allclose(data_org[n-1], data_min[n-1]) - - - -@pytest.mark.parametrize("num_workers", [0, 1, 2]) -def test_hdf5_multiproc_dataset(num_workers): - from disent.dataset.random import RandomDataset - from torch.utils.data import DataLoader - - xysquares = XYSquaresData(square_size=2, image_size=4) - - - # class TestHdf5Dataset(Hdf5GroundTruthData): - # - # - # factor_names = ('floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 'orientation') - # factor_sizes = (10, 10, 10, 8, 4, 15) # TOTAL: 480000 - # observation_shape = (64, 64, 3) - # - # data_object = DlH5DataObject( - # # processed dataset file - # file_name='3dshapes.h5', - # file_hashes={'fast': 'e3a1a449b95293d4b2c25edbfcb8e804', 'full': 'b5187ee0d8b519bb33281c5ca549658c'}, - # # download file/link - # uri='https://storage.googleapis.com/3d-shapes/3dshapes.h5', - # uri_hashes={'fast': '85b20ed7cc8dc1f939f7031698d2d2ab', 'full': '099a2078d58cec4daad0702c55d06868'}, - # # hash settings - # hash_mode='fast', - # hash_type='md5', - # # h5 re-save settings - # hdf5_dataset_name='images', - # hdf5_chunk_size=(1, 64, 64, 3), - # hdf5_compression='gzip', - # hdf5_compression_lvl=4, - # ) - # - # - # - # Shapes3dData() - # dataset = RandomDataset(Shapes3dData(prepare=True)) - # - # dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=2, shuffle=True) - # - # with tqdm(total=len(dataset)) as progress: - # for batch in dataloader: - # progress.update(256) - +def _iterate_over_data(data, indices): + i = -1 + for i, idx in enumerate(indices): + img = data[i] + return i + 1 + + +def test_hdf5_pickle_dataset(): + with NamedTemporaryFile('r') as temp_file: + # create temporary dataset + with h5py.File(temp_file.name, 'w') as file: + file.create_dataset( + name='data', + shape=(64, 4, 4, 3), + dtype='uint8', + data=np.stack([img for img in XYSquaresData(square_size=2, image_size=4)], axis=0) + ) + # load the data + # - ideally we want to test this with a pytorch + # DataLoader, but that is quite slow to initialise + with PickleH5pyFile(temp_file.name, 'data') as data: + indices = list(range(len(data))) + # test locally + assert _iterate_over_data(data=data, indices=indices) == 64 + # test multiprocessing + executor = ProcessPoolExecutor(2) + future_0 = executor.submit(_iterate_over_data, data=data, indices=indices[0::2]) + future_1 = executor.submit(_iterate_over_data, data=data, indices=indices[1::2]) + assert future_0.result() == 32 + assert future_1.result() == 32 + # test multiprocessing on invalid data + with h5py.File(temp_file.name, 'r', swmr=True) as file: + with pytest.raises(TypeError, match='h5py objects cannot be pickled'): + future_2 = executor.submit(_iterate_over_data, data=file['data'], indices=indices) + future_2.result() # ========================================================================= # From c8c5e1134c8381b6686ddb972ec68d17e7990bed Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 4 Jun 2021 12:11:35 +0200 Subject: [PATCH 28/34] renamed dataobj to datafile --- disent/data/{dataobj.py => datafile.py} | 25 ++++++++++++++----------- disent/data/groundtruth/_cars3d.py | 6 +++--- disent/data/groundtruth/_dsprites.py | 6 ++++-- disent/data/groundtruth/_mpi3d.py | 10 +++++----- disent/data/groundtruth/_norb.py | 16 ++++++++-------- disent/data/groundtruth/_shapes3d.py | 4 ++-- disent/data/groundtruth/states.py | 19 +++++++++---------- disent/data/hdf5.py | 13 +++++++++---- 8 files changed, 54 insertions(+), 45 deletions(-) rename disent/data/{dataobj.py => datafile.py} (91%) diff --git a/disent/data/dataobj.py b/disent/data/datafile.py similarity index 91% rename from disent/data/dataobj.py rename to disent/data/datafile.py index 57376f55..9cc237e1 100644 --- a/disent/data/dataobj.py +++ b/disent/data/datafile.py @@ -28,6 +28,7 @@ from typing import Dict from typing import final from typing import Optional +from typing import Sequence from typing import Tuple from typing import Union @@ -46,9 +47,9 @@ # ========================================================================= # -class DataObject(object, metaclass=ABCMeta): +class DataFile(object, metaclass=ABCMeta): """ - base DataObject that does nothing, if the file does + base DataFile that does nothing, if the file does not exist or it has the incorrect hash, then that's your problem! """ @@ -65,10 +66,10 @@ def prepare(self, out_dir: str) -> str: pass -class HashedDataObject(DataObject, metaclass=ABCMeta): +class DataFileHashed(DataFile, metaclass=ABCMeta): """ Abstract Class - - Base DataObject class that guarantees a file to exist, + - Base DataFile class that guarantees a file to exist, if the file does not exist, or the hash of the file is incorrect, then the file is re-generated. """ @@ -96,7 +97,7 @@ def _prepare(self, out_dir: str, out_file: str) -> str: raise NotImplementedError -class DlDataObject(HashedDataObject): +class DataFileHashedDl(DataFileHashed): """ Download a file - uri can also be a file to perform a copy instead of download, @@ -123,7 +124,7 @@ def _prepare(self, out_dir: str, out_file: str): retrieve_file(src_uri=self._uri, dst_path=out_file, overwrite_existing=True) -class DlGenDataObject(HashedDataObject, metaclass=ABCMeta): +class DataFileHashedDlGen(DataFileHashed, metaclass=ABCMeta): """ Abstract class - download a file and perform some processing on that file. @@ -142,7 +143,7 @@ def __init__( hash_type: str = 'md5', hash_mode: str = 'fast', ): - self._dl_obj = DlDataObject( + self._dl_obj = DataFileHashedDl( uri=uri, uri_hash=uri_hash, uri_name=uri_name, @@ -164,7 +165,7 @@ def _generate(self, inp_file: str, out_file: str): raise NotImplementedError -class DlH5DataObject(DlGenDataObject): +class DataFileHashedDlH5(DataFileHashedDlGen): """ Downloads an hdf5 file and pre-processes it into the specified chunk_size. """ @@ -182,6 +183,7 @@ def __init__( hdf5_compression_lvl: Optional[int] = 4, hdf5_dtype: Optional[Union[np.dtype, str]] = None, hdf5_mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None, + hdf5_obs_shape: Optional[Sequence[int]] = None, # save paths uri_name: Optional[str] = None, file_name: Optional[str] = None, @@ -206,13 +208,14 @@ def __init__( compression_lvl=hdf5_compression_lvl, out_dtype=hdf5_dtype, out_mutator=hdf5_mutator, + obs_shape=hdf5_obs_shape, ) # save the dataset name - self._out_dataset_name = hdf5_dataset_name + self._dataset_name = hdf5_dataset_name @property - def out_dataset_name(self) -> str: - return self._out_dataset_name + def dataset_name(self) -> str: + return self._dataset_name def _generate(self, inp_file: str, out_file: str): self._hdf5_resave_file(inp_path=inp_file, out_path=out_file) diff --git a/disent/data/groundtruth/_cars3d.py b/disent/data/groundtruth/_cars3d.py index ef2d30bf..1729031a 100644 --- a/disent/data/groundtruth/_cars3d.py +++ b/disent/data/groundtruth/_cars3d.py @@ -30,7 +30,7 @@ import numpy as np from scipy.io import loadmat -from disent.data.dataobj import DlGenDataObject +from disent.data.datafile import DataFileHashedDlGen from disent.data.groundtruth.base import NumpyGroundTruthData from disent.util.in_out import AtomicSaveFile @@ -83,7 +83,7 @@ def resave_cars3d_archive(orig_zipped_file, new_save_file, overwrite=False): # ========================================================================= # -class Cars3dDataObject(DlGenDataObject): +class DataFileCars3d(DataFileHashedDlGen): """ download the cars3d dataset and convert it to a numpy file. """ @@ -112,7 +112,7 @@ class Cars3dData(NumpyGroundTruthData): observation_shape = (128, 128, 3) data_key = 'images' - data_object = Cars3dDataObject( + data_object = DataFileCars3d( uri='http://www.scottreed.info/files/nips2015-analogy-data.tar.gz', uri_hash={'fast': 'fe77d39e3fa9d77c31df2262660c2a67', 'full': '4e866a7919c1beedf53964e6f7a23686'}, file_name='cars3d.npz', diff --git a/disent/data/groundtruth/_dsprites.py b/disent/data/groundtruth/_dsprites.py index 90d2e431..e0fffcd4 100644 --- a/disent/data/groundtruth/_dsprites.py +++ b/disent/data/groundtruth/_dsprites.py @@ -24,7 +24,7 @@ import logging -from disent.data.dataobj import DlH5DataObject +from disent.data.datafile import DataFileHashedDlH5 from disent.data.groundtruth.base import Hdf5GroundTruthData @@ -33,6 +33,8 @@ # ========================================================================= # + + class DSpritesData(Hdf5GroundTruthData): """ DSprites Dataset @@ -55,7 +57,7 @@ class DSpritesData(Hdf5GroundTruthData): factor_sizes = (3, 6, 40, 32, 32) # TOTAL: 737280 observation_shape = (64, 64, 1) - data_object = DlH5DataObject( + data_object = DataFileHashedDlH5( # download file/link uri='https://raw.githubusercontent.com/deepmind/dsprites-dataset/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.hdf5', uri_hash={'fast': 'd6ee1e43db715c2f0de3c41e38863347', 'full': 'b331c4447a651c44bf5e8ae09022e230'}, diff --git a/disent/data/groundtruth/_mpi3d.py b/disent/data/groundtruth/_mpi3d.py index dbb1efbd..632293c2 100644 --- a/disent/data/groundtruth/_mpi3d.py +++ b/disent/data/groundtruth/_mpi3d.py @@ -25,7 +25,7 @@ import logging from typing import Optional -from disent.data.dataobj import DlDataObject +from disent.data.datafile import DataFileHashedDl from disent.data.groundtruth.base import NumpyGroundTruthData @@ -47,9 +47,9 @@ class Mpi3dData(NumpyGroundTruthData): name = 'mpi3d' MPI3D_DATASETS = { - 'toy': DlDataObject(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_toy.npz', uri_hash=None), - 'realistic': DlDataObject(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_realistic.npz', uri_hash=None), - 'real': DlDataObject(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_real.npz', uri_hash=None), + 'toy': DataFileHashedDl(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_toy.npz', uri_hash=None), + 'realistic': DataFileHashedDl(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_realistic.npz', uri_hash=None), + 'real': DataFileHashedDl(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_real.npz', uri_hash=None), } factor_names = ('object_color', 'object_shape', 'object_size', 'camera_height', 'background_color', 'first_dof', 'second_dof') @@ -69,7 +69,7 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False, subse super().__init__(data_root=data_root, prepare=prepare) @property - def data_object(self) -> DlDataObject: + def data_object(self) -> DataFileHashedDl: return self.MPI3D_DATASETS[self._subset] diff --git a/disent/data/groundtruth/_norb.py b/disent/data/groundtruth/_norb.py index 0f1a7487..ce8d4852 100644 --- a/disent/data/groundtruth/_norb.py +++ b/disent/data/groundtruth/_norb.py @@ -31,7 +31,7 @@ import numpy as np -from disent.data.dataobj import DlDataObject +from disent.data.datafile import DataFileHashedDl from disent.data.groundtruth.base import DiskGroundTruthData @@ -144,15 +144,15 @@ class SmallNorbData(DiskGroundTruthData): observation_shape = (96, 96, 1) TRAIN_DATA_OBJECTS = { - 'dat': DlDataObject(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz', uri_hash={'fast': '92560cccc7bcbd6512805e435448b62d', 'full': '66054832f9accfe74a0f4c36a75bc0a2'}), - 'cat': DlDataObject(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz', uri_hash={'fast': '348fc3ccefd651d69f500611988b5dcd', 'full': '23c8b86101fbf0904a000b43d3ed2fd9'}), - 'info': DlDataObject(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz', uri_hash={'fast': 'f1b170c16925867c05f58608eb33ba7f', 'full': '51dee1210a742582ff607dfd94e332e3'}), + 'dat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz', uri_hash={'fast': '92560cccc7bcbd6512805e435448b62d', 'full': '66054832f9accfe74a0f4c36a75bc0a2'}), + 'cat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz', uri_hash={'fast': '348fc3ccefd651d69f500611988b5dcd', 'full': '23c8b86101fbf0904a000b43d3ed2fd9'}), + 'info': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz', uri_hash={'fast': 'f1b170c16925867c05f58608eb33ba7f', 'full': '51dee1210a742582ff607dfd94e332e3'}), } TEST_DATA_OBJECTS = { - 'dat': DlDataObject(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz', uri_hash={'fast': '9aee0b474a4fc2a2ec392b463efb8858', 'full': 'e4ad715691ed5a3a5f138751a4ceb071'}), - 'cat': DlDataObject(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz', uri_hash={'fast': '8cfae0679f5fa2df7a0aedfce90e5673', 'full': '5aa791cd7e6016cf957ce9bdb93b8603'}), - 'info': DlDataObject(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz', uri_hash={'fast': 'd2703a3f95e7b9a970ad52e91f0aaf6a', 'full': 'a9454f3864d7fd4bb3ea7fc3eb84924e'}), + 'dat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz', uri_hash={'fast': '9aee0b474a4fc2a2ec392b463efb8858', 'full': 'e4ad715691ed5a3a5f138751a4ceb071'}), + 'cat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz', uri_hash={'fast': '8cfae0679f5fa2df7a0aedfce90e5673', 'full': '5aa791cd7e6016cf957ce9bdb93b8603'}), + 'info': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz', uri_hash={'fast': 'd2703a3f95e7b9a970ad52e91f0aaf6a', 'full': 'a9454f3864d7fd4bb3ea7fc3eb84924e'}), } def __init__(self, data_root: Optional[str] = 'data/TEMP/dataset', prepare: bool = False, is_test=False): @@ -167,7 +167,7 @@ def __getitem__(self, idx): return self._data[idx] @property - def data_objects(self) -> Sequence[DlDataObject]: + def data_objects(self) -> Sequence[DataFileHashedDl]: norb_objects = self.TEST_DATA_OBJECTS if self._is_test else self.TRAIN_DATA_OBJECTS return norb_objects['dat'], norb_objects['cat'], norb_objects['info'] diff --git a/disent/data/groundtruth/_shapes3d.py b/disent/data/groundtruth/_shapes3d.py index 2647fe7f..b13daab4 100644 --- a/disent/data/groundtruth/_shapes3d.py +++ b/disent/data/groundtruth/_shapes3d.py @@ -24,7 +24,7 @@ import logging -from disent.data.dataobj import DlH5DataObject +from disent.data.datafile import DataFileHashedDlH5 from disent.data.groundtruth.base import Hdf5GroundTruthData @@ -50,7 +50,7 @@ class Shapes3dData(Hdf5GroundTruthData): factor_sizes = (10, 10, 10, 8, 4, 15) # TOTAL: 480000 observation_shape = (64, 64, 3) - data_object = DlH5DataObject( + data_object = DataFileHashedDlH5( # download file/link uri='https://storage.googleapis.com/3d-shapes/3dshapes.h5', uri_hash={'fast': '85b20ed7cc8dc1f939f7031698d2d2ab', 'full': '099a2078d58cec4daad0702c55d06868'}, diff --git a/disent/data/groundtruth/states.py b/disent/data/groundtruth/states.py index d4cb6931..a21a8b64 100644 --- a/disent/data/groundtruth/states.py +++ b/disent/data/groundtruth/states.py @@ -41,10 +41,10 @@ class StateSpace(LengthIter): def __init__(self, factor_sizes): super().__init__() # dimension - self._factor_sizes = np.array(factor_sizes) - self._factor_sizes.flags.writeable = False + self.__factor_sizes = np.array(factor_sizes) + self.__factor_sizes.flags.writeable = False # total permutations - self._size = int(np.prod(factor_sizes)) + self.__size = int(np.prod(factor_sizes)) def __len__(self): """Same as self.size""" @@ -61,17 +61,17 @@ def __getitem__(self, idx): @property def size(self) -> int: """The number of permutations of factors handled by this state space""" - return self._size + return self.__size @property def num_factors(self) -> int: """The number of factors handled by this state space""" - return len(self._factor_sizes) + return len(self.__factor_sizes) @property def factor_sizes(self) -> np.ndarray: """A list of sizes or dimensionality of factors handled by this state space""" - return self._factor_sizes + return self.__factor_sizes # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Coordinate Transform - any dim array, only last axis counts! # @@ -84,7 +84,7 @@ def pos_to_idx(self, positions) -> np.ndarray: - indices are integers < size """ positions = np.moveaxis(positions, source=-1, destination=0) - return np.ravel_multi_index(positions, self._factor_sizes) + return np.ravel_multi_index(positions, self.__factor_sizes) def idx_to_pos(self, indices) -> np.ndarray: """ @@ -92,7 +92,7 @@ def idx_to_pos(self, indices) -> np.ndarray: - indices are integers < size - positions are lists of integers, with each element < their corresponding factor size """ - positions = np.unravel_index(indices, self._factor_sizes) + positions = np.array(np.unravel_index(indices, self.__factor_sizes)) return np.moveaxis(positions, source=0, destination=-1) # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @@ -110,7 +110,7 @@ def sample_factors(self, size=None, factor_indices=None) -> np.ndarray: the same size as factor_indices, ie (*size, len(factor_indices)) """ # get factor sizes - sizes = self._factor_sizes if (factor_indices is None) else self._factor_sizes[factor_indices] + sizes = self.__factor_sizes if (factor_indices is None) else self.__factor_sizes[factor_indices] # get resample size if size is not None: # empty np.array(()) gets dtype float which is incompatible with len @@ -269,7 +269,6 @@ def sample_random_factor_traversal(self, f_idx: int = None, base_factors=None, n # Hidden State Space # # ========================================================================= # - # class StateSpaceRemapIndex(object): # """Mapping from incorrectly ordered factors to state space indices""" # diff --git a/disent/data/hdf5.py b/disent/data/hdf5.py index 9c30ace8..360d3e15 100644 --- a/disent/data/hdf5.py +++ b/disent/data/hdf5.py @@ -74,6 +74,10 @@ def __len__(self): def __getitem__(self, item): return self._hdf5_data[item] + @property + def shape(self): + return self._hdf5_data.shape + def __enter__(self): return self @@ -107,7 +111,7 @@ def close(self): # ========================================================================= # -def hdf5_resave_dataset(inp_h5: h5py.File, out_h5: h5py.File, dataset_name, chunk_size=None, compression=None, compression_lvl=None, batch_size=None, out_dtype=None, out_mutator=None): +def hdf5_resave_dataset(inp_h5: h5py.File, out_h5: h5py.File, dataset_name, chunk_size=None, compression=None, compression_lvl=None, batch_size=None, out_dtype=None, out_mutator=None, obs_shape=None): # check out_h5 version compatibility if (isinstance(out_h5.libver, str) and out_h5.libver != 'earliest') or (out_h5.libver[0] != 'earliest'): raise RuntimeError(f'hdf5 out file has an incompatible libver: {repr(out_h5.libver)} libver should be set to: "earliest"') @@ -115,7 +119,7 @@ def hdf5_resave_dataset(inp_h5: h5py.File, out_h5: h5py.File, dataset_name, chun inp_data = inp_h5[dataset_name] out_data = out_h5.create_dataset( name=dataset_name, - shape=inp_data.shape, + shape=inp_data.shape if (obs_shape is None) else (inp_data.shape[0], *obs_shape), dtype=out_dtype if (out_dtype is not None) else inp_data.dtype, chunks=chunk_size, compression=compression, @@ -140,11 +144,11 @@ def hdf5_resave_dataset(inp_h5: h5py.File, out_h5: h5py.File, dataset_name, chun # save data with tqdm(total=len(inp_data)) as progress: for i in range(0, len(inp_data), batch_size): - out_data[i:i + batch_size] = out_mutator(inp_data[i:i + batch_size]) + out_data[i:i + batch_size] = out_mutator(inp_data[i:i + batch_size]).reshape([-1, *obs_shape]) progress.update(batch_size) -def hdf5_resave_file(inp_path: str, out_path: str, dataset_name, chunk_size=None, compression=None, compression_lvl=None, batch_size=None, out_dtype=None, out_mutator=None): +def hdf5_resave_file(inp_path: str, out_path: str, dataset_name, chunk_size=None, compression=None, compression_lvl=None, batch_size=None, out_dtype=None, out_mutator=None, obs_shape=None): # re-save datasets with h5py.File(inp_path, 'r') as inp_h5: with AtomicSaveFile(out_path, open_mode=None, overwrite=True) as tmp_h5_path: @@ -159,6 +163,7 @@ def hdf5_resave_file(inp_path: str, out_path: str, dataset_name, chunk_size=None batch_size=batch_size, out_dtype=out_dtype, out_mutator=out_mutator, + obs_shape=obs_shape, ) # file size: log.info(f'[FILE SIZES] IN: {bytes_to_human(os.path.getsize(inp_path))} OUT: {bytes_to_human(os.path.getsize(out_path))}') From 4b39b35fdeb9667ebf7f0fa786c06d96a4ec2ba6 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 4 Jun 2021 15:11:13 +0200 Subject: [PATCH 29/34] update names fix --- disent/data/groundtruth/_cars3d.py | 3 +- disent/data/groundtruth/_dsprites.py | 9 +++--- disent/data/groundtruth/_mpi3d.py | 2 +- disent/data/groundtruth/_norb.py | 12 ++++---- disent/data/groundtruth/_shapes3d.py | 2 +- disent/data/groundtruth/base.py | 41 ++++++++++++++++------------ 6 files changed, 38 insertions(+), 31 deletions(-) diff --git a/disent/data/groundtruth/_cars3d.py b/disent/data/groundtruth/_cars3d.py index 1729031a..5efdd106 100644 --- a/disent/data/groundtruth/_cars3d.py +++ b/disent/data/groundtruth/_cars3d.py @@ -111,8 +111,7 @@ class Cars3dData(NumpyGroundTruthData): factor_sizes = (4, 24, 183) # TOTAL: 17568 observation_shape = (128, 128, 3) - data_key = 'images' - data_object = DataFileCars3d( + datafile = DataFileCars3d( uri='http://www.scottreed.info/files/nips2015-analogy-data.tar.gz', uri_hash={'fast': 'fe77d39e3fa9d77c31df2262660c2a67', 'full': '4e866a7919c1beedf53964e6f7a23686'}, file_name='cars3d.npz', diff --git a/disent/data/groundtruth/_dsprites.py b/disent/data/groundtruth/_dsprites.py index e0fffcd4..ee343add 100644 --- a/disent/data/groundtruth/_dsprites.py +++ b/disent/data/groundtruth/_dsprites.py @@ -57,17 +57,18 @@ class DSpritesData(Hdf5GroundTruthData): factor_sizes = (3, 6, 40, 32, 32) # TOTAL: 737280 observation_shape = (64, 64, 1) - data_object = DataFileHashedDlH5( + datafile = DataFileHashedDlH5( # download file/link uri='https://raw.githubusercontent.com/deepmind/dsprites-dataset/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.hdf5', uri_hash={'fast': 'd6ee1e43db715c2f0de3c41e38863347', 'full': 'b331c4447a651c44bf5e8ae09022e230'}, # processed dataset file - file_hash={'fast': '7a6e83ebf35f93a1cd9ae0210112b421', 'full': '27c674fb5170dcd6a1f9853b66c5785d'}, + file_hash={'fast': '25013c85aebbf4b1023d72564f9413f0', 'full': '4611d1a03e709cd5d0f6fdcdc221ca0e'}, # h5 re-save settings hdf5_dataset_name='imgs', - hdf5_chunk_size=(1, 64, 64), + hdf5_chunk_size=(1, 64, 64, 1), hdf5_dtype='uint8', - hdf5_mutator=lambda x: x * 255 + hdf5_mutator=lambda x: x * 255, + hdf5_obs_shape=(64, 64, 1), ) diff --git a/disent/data/groundtruth/_mpi3d.py b/disent/data/groundtruth/_mpi3d.py index 632293c2..6df08100 100644 --- a/disent/data/groundtruth/_mpi3d.py +++ b/disent/data/groundtruth/_mpi3d.py @@ -69,7 +69,7 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False, subse super().__init__(data_root=data_root, prepare=prepare) @property - def data_object(self) -> DataFileHashedDl: + def datafile(self) -> DataFileHashedDl: return self.MPI3D_DATASETS[self._subset] diff --git a/disent/data/groundtruth/_norb.py b/disent/data/groundtruth/_norb.py index ce8d4852..b0dfda80 100644 --- a/disent/data/groundtruth/_norb.py +++ b/disent/data/groundtruth/_norb.py @@ -143,32 +143,32 @@ class SmallNorbData(DiskGroundTruthData): factor_sizes = (5, 5, 9, 18, 6) # TOTAL: 24300 observation_shape = (96, 96, 1) - TRAIN_DATA_OBJECTS = { + TRAIN_DATA_FILES = { 'dat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz', uri_hash={'fast': '92560cccc7bcbd6512805e435448b62d', 'full': '66054832f9accfe74a0f4c36a75bc0a2'}), 'cat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz', uri_hash={'fast': '348fc3ccefd651d69f500611988b5dcd', 'full': '23c8b86101fbf0904a000b43d3ed2fd9'}), 'info': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz', uri_hash={'fast': 'f1b170c16925867c05f58608eb33ba7f', 'full': '51dee1210a742582ff607dfd94e332e3'}), } - TEST_DATA_OBJECTS = { + TEST_DATA_FILES = { 'dat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz', uri_hash={'fast': '9aee0b474a4fc2a2ec392b463efb8858', 'full': 'e4ad715691ed5a3a5f138751a4ceb071'}), 'cat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz', uri_hash={'fast': '8cfae0679f5fa2df7a0aedfce90e5673', 'full': '5aa791cd7e6016cf957ce9bdb93b8603'}), 'info': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz', uri_hash={'fast': 'd2703a3f95e7b9a970ad52e91f0aaf6a', 'full': 'a9454f3864d7fd4bb3ea7fc3eb84924e'}), } - def __init__(self, data_root: Optional[str] = 'data/TEMP/dataset', prepare: bool = False, is_test=False): + def __init__(self, data_root: Optional[str] = None, prepare: bool = False, is_test=False): self._is_test = is_test # initialize super().__init__(data_root=data_root, prepare=prepare) # read dataset and sort by features - dat_path, cat_path, info_path = (os.path.join(self.data_dir, obj.out_name) for obj in self.data_objects) + dat_path, cat_path, info_path = (os.path.join(self.data_dir, obj.out_name) for obj in self.datafiles) self._data, _ = read_norb_dataset(dat_path=dat_path, cat_path=cat_path, info_path=info_path) def __getitem__(self, idx): return self._data[idx] @property - def data_objects(self) -> Sequence[DataFileHashedDl]: - norb_objects = self.TEST_DATA_OBJECTS if self._is_test else self.TRAIN_DATA_OBJECTS + def datafiles(self) -> Sequence[DataFileHashedDl]: + norb_objects = self.TEST_DATA_FILES if self._is_test else self.TRAIN_DATA_FILES return norb_objects['dat'], norb_objects['cat'], norb_objects['info'] diff --git a/disent/data/groundtruth/_shapes3d.py b/disent/data/groundtruth/_shapes3d.py index b13daab4..493ff4fa 100644 --- a/disent/data/groundtruth/_shapes3d.py +++ b/disent/data/groundtruth/_shapes3d.py @@ -50,7 +50,7 @@ class Shapes3dData(Hdf5GroundTruthData): factor_sizes = (10, 10, 10, 8, 4, 15) # TOTAL: 480000 observation_shape = (64, 64, 3) - data_object = DataFileHashedDlH5( + datafile = DataFileHashedDlH5( # download file/link uri='https://storage.googleapis.com/3d-shapes/3dshapes.h5', uri_hash={'fast': '85b20ed7cc8dc1f939f7031698d2d2ab', 'full': '099a2078d58cec4daad0702c55d06868'}, diff --git a/disent/data/groundtruth/base.py b/disent/data/groundtruth/base.py index c237fb69..f1f8a36c 100644 --- a/disent/data/groundtruth/base.py +++ b/disent/data/groundtruth/base.py @@ -31,10 +31,10 @@ import numpy as np -from disent.data.dataobj import DataObject -from disent.data.dataobj import DlH5DataObject +from disent.data.datafile import DataFile +from disent.data.datafile import DataFileHashedDlH5 from disent.data.groundtruth.states import StateSpace -from disent.data.hdf5 import PickleH5pyData +from disent.data.hdf5 import PickleH5pyFile from disent.util.paths import ensure_dir_exists @@ -93,7 +93,7 @@ def __getitem__(self, idx): # ========================================================================= # # disk ground truth data # -# TODO: data & data_object preparation should be split out from # +# TODO: data & datafile preparation should be split out from # # GroundTruthData, instead GroundTruthData should be a wrapper # # ========================================================================= # @@ -117,8 +117,8 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False): log.info(f'{self.name}: data_dir_share={repr(self._data_dir)}') # prepare everything if prepare: - for data_object in self.data_objects: - data_object.prepare(self.data_dir) + for datafile in self.datafiles: + datafile.prepare(self.data_dir) @property def data_dir(self) -> str: @@ -129,7 +129,7 @@ def default_data_root(self): return os.path.abspath(os.environ.get('DISENT_DATA_ROOT', 'data/dataset')) @property - def data_objects(self) -> Sequence[DataObject]: + def datafiles(self) -> Sequence[DataFile]: raise NotImplementedError @@ -142,7 +142,14 @@ class NumpyGroundTruthData(DiskGroundTruthData, metaclass=ABCMeta): def __init__(self, data_root: Optional[str] = None, prepare: bool = False): super().__init__(data_root=data_root, prepare=prepare) # load dataset - self._data = np.load(os.path.join(self.data_dir, self.data_object.out_name)) + load_path = os.path.join(self.data_dir, self.datafile.out_name) + if load_path.endswith('.gz'): + import gzip + with gzip.GzipFile(load_path, 'r') as load_file: + self._data = np.load(load_file) + else: + self._data = np.load(load_path) + # load from the key if specified if self.data_key is not None: self._data = self._data[self.data_key] @@ -150,11 +157,11 @@ def __getitem__(self, idx): return self._data[idx] @property - def data_objects(self) -> Sequence[DataObject]: - return [self.data_object] + def datafiles(self) -> Sequence[DataFile]: + return [self.datafile] @property - def data_object(self) -> DataObject: + def datafile(self) -> DataFile: raise NotImplementedError @property @@ -175,9 +182,9 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_me # variables self._in_memory = in_memory # load the h5py dataset - data = PickleH5pyData( - h5_path=os.path.join(self.data_dir, self.data_object.out_name), - h5_dataset_name=self.data_object.out_dataset_name, + data = PickleH5pyFile( + h5_path=os.path.join(self.data_dir, self.datafile.out_name), + h5_dataset_name=self.datafile.dataset_name, ) # handle different memory modes if self._in_memory: @@ -194,11 +201,11 @@ def __getitem__(self, idx): return self._data[idx] @property - def data_objects(self) -> Sequence[DlH5DataObject]: - return [self.data_object] + def datafiles(self) -> Sequence[DataFileHashedDlH5]: + return [self.datafile] @property - def data_object(self) -> DlH5DataObject: + def datafile(self) -> DataFileHashedDlH5: raise NotImplementedError From cb8be2af2367a26231e519d5c21d54d846a2d979 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 4 Jun 2021 15:23:48 +0200 Subject: [PATCH 30/34] comments --- disent/data/episodes/_base.py | 11 +++++++++++ disent/data/episodes/_option_episodes.py | 24 ++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/disent/data/episodes/_base.py b/disent/data/episodes/_base.py index 9e9e7958..6a70f174 100644 --- a/disent/data/episodes/_base.py +++ b/disent/data/episodes/_base.py @@ -25,10 +25,16 @@ from typing import List, Tuple import numpy as np + from disent.dataset.groundtruth._triplet import sample_radius from disent.util.iters import LengthIter +# ========================================================================= # +# option episodes # +# ========================================================================= # + + class BaseOptionEpisodesData(LengthIter): def __init__(self): @@ -86,3 +92,8 @@ def sample_episode_indices(episode, idx, n=1, radius=None): def _load_episode_observations(self) -> List[np.ndarray]: raise NotImplementedError + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/data/episodes/_option_episodes.py b/disent/data/episodes/_option_episodes.py index 37fc9ded..71ac926b 100644 --- a/disent/data/episodes/_option_episodes.py +++ b/disent/data/episodes/_option_episodes.py @@ -22,17 +22,26 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import logging import os -from typing import List, Tuple +from typing import List +from typing import Tuple + import numpy as np + from disent.data.episodes._base import BaseOptionEpisodesData from disent.util.in_out import download_file from disent.util.paths import filename_from_url -import logging + log = logging.getLogger(__name__) +# ========================================================================= # +# option episodes # +# ========================================================================= # + + class OptionEpisodesPickledData(BaseOptionEpisodesData): def __init__(self, required_file: str): @@ -41,6 +50,10 @@ def __init__(self, required_file: str): # load data super().__init__() + # TODO: convert this to data files? + # TODO: convert this to data files? + # TODO: convert this to data files? + def _load_episode_observations(self) -> List[np.ndarray]: import pickle # load the raw data! @@ -108,6 +121,10 @@ def _load_episode_observations(self) -> List[np.ndarray]: class OptionEpisodesDownloadZippedPickledData(OptionEpisodesPickledData): + # TODO: convert this to data files? + # TODO: convert this to data files? + # TODO: convert this to data files? + def __init__(self, required_file: str, download_url=None, force_download=False): self._download_and_extract_if_needed(download_url=download_url, required_file=required_file, force_download=force_download) super().__init__(required_file=required_file) @@ -142,3 +159,6 @@ def _download_and_extract_if_needed(self, download_url: str, required_file: str, # ~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~ +# ========================================================================= # +# END # +# ========================================================================= # From fdd2e5ffa78aa2659528f086690a92f73fe07080 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 4 Jun 2021 22:46:21 +0200 Subject: [PATCH 31/34] moved factor names into state space --- disent/data/groundtruth/base.py | 6 ++++-- disent/data/groundtruth/states.py | 14 +++++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/disent/data/groundtruth/base.py b/disent/data/groundtruth/base.py index f1f8a36c..0fd4e77f 100644 --- a/disent/data/groundtruth/base.py +++ b/disent/data/groundtruth/base.py @@ -52,8 +52,10 @@ class GroundTruthData(StateSpace): """ def __init__(self): - assert len(self.factor_names) == len(self.factor_sizes), 'Dimensionality mismatch of FACTOR_NAMES and FACTOR_DIMS' - super().__init__(factor_sizes=self.factor_sizes) + super().__init__( + factor_sizes=self.factor_sizes, + factor_names=self.factor_names, + ) @property def name(self): diff --git a/disent/data/groundtruth/states.py b/disent/data/groundtruth/states.py index a21a8b64..37f0f5dc 100644 --- a/disent/data/groundtruth/states.py +++ b/disent/data/groundtruth/states.py @@ -21,6 +21,9 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +from typing import Optional +from typing import Sequence +from typing import Tuple import numpy as np from disent.util.iters import LengthIter @@ -38,13 +41,17 @@ class StateSpace(LengthIter): ie. State space with multiple factors of variation, where each factor can be a different size. """ - def __init__(self, factor_sizes): + def __init__(self, factor_sizes: Sequence[int], factor_names: Optional[Sequence[str]] = None): super().__init__() # dimension self.__factor_sizes = np.array(factor_sizes) self.__factor_sizes.flags.writeable = False # total permutations self.__size = int(np.prod(factor_sizes)) + # factor names + self.__factor_names = tuple(f'f{i}' for i in range(self.num_factors)) if (factor_names is None) else tuple(factor_names) + if len(self.__factor_names) != len(self.__factor_sizes): + raise ValueError(f'Dimensionality mismatch of factor_names and factor_sizes: len({self.__factor_names}) != len({tuple(self.__factor_sizes)})') def __len__(self): """Same as self.size""" @@ -73,6 +80,11 @@ def factor_sizes(self) -> np.ndarray: """A list of sizes or dimensionality of factors handled by this state space""" return self.__factor_sizes + @property + def factor_names(self) -> Tuple[str, ...]: + """A list of names of factors handled by this state space""" + return self.__factor_names + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Coordinate Transform - any dim array, only last axis counts! # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # From 217602c3b33edf93b621fc9a43d3ec1865dcbb01 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 4 Jun 2021 23:27:16 +0200 Subject: [PATCH 32/34] renamed GroundTruthDatasetBatchAugment to DisentDatasetTransform + fixes --- disent/nn/transform/__init__.py | 3 +++ ...groundtruth.py => _augment_groundtruth.py} | 27 +++++++++---------- disent/util/seeds.py | 10 ++++--- experiment/config/config.yaml | 1 + experiment/run.py | 5 ++++ experiment/util/callbacks/callbacks_vae.py | 6 ++--- experiment/util/hydra_data.py | 14 +++++----- 7 files changed, 39 insertions(+), 27 deletions(-) rename disent/nn/transform/{groundtruth.py => _augment_groundtruth.py} (77%) diff --git a/disent/nn/transform/__init__.py b/disent/nn/transform/__init__.py index 47e40029..a67aea70 100644 --- a/disent/nn/transform/__init__.py +++ b/disent/nn/transform/__init__.py @@ -31,3 +31,6 @@ from ._augment import FftGaussianBlur from ._augment import FftBoxBlur from ._augment import FftKernel + +# disent dataset augment +from ._augment_groundtruth import DisentDatasetTransform diff --git a/disent/nn/transform/groundtruth.py b/disent/nn/transform/_augment_groundtruth.py similarity index 77% rename from disent/nn/transform/groundtruth.py rename to disent/nn/transform/_augment_groundtruth.py index 3d426d7e..5cc305f0 100644 --- a/disent/nn/transform/groundtruth.py +++ b/disent/nn/transform/_augment_groundtruth.py @@ -28,7 +28,7 @@ # ========================================================================= # -class GroundTruthDatasetBatchAugment(object): +class DisentDatasetTransform(object): """ Applies transforms to batches generated from dataloaders of datasets from: disent.dataset.groundtruth @@ -40,11 +40,13 @@ def __init__(self, transform=None, transform_targ=None): def __call__(self, batch): # transform inputs - if self.transform: - batch = _apply_transform_to_batch_dict(batch, 'x', self.transform) + if self.transform is not None: + if 'x' not in batch: + batch['x'] = batch['x_targ'] + batch['x'] = _apply_transform_to_batch_dict(batch['x'], self.transform) # transform targets - if self.transform_targ: - batch = _apply_transform_to_batch_dict(batch, 'x_targ', self.transform_targ) + if self.transform_targ is not None: + batch['x_targ'] = _apply_transform_to_batch_dict(batch['x_targ'], self.transform_targ) # done! return batch @@ -52,16 +54,13 @@ def __repr__(self): return f'{self.__class__.__name__}(transform={repr(self.transform)}, transform_targ={repr(self.transform_targ)})' -def _apply_transform_to_batch_dict(batch, key, transform): - observations = batch[key] - if isinstance(observations, tuple): - observations = tuple([transform(obs) for obs in observations]) - if isinstance(observations, list): - observations = [transform(obs) for obs in observations] +def _apply_transform_to_batch_dict(batch, transform): + if isinstance(batch, tuple): + return tuple(transform(obs) for obs in batch) + if isinstance(batch, list): + return list(transform(obs) for obs in batch) else: - observations = transform(observations) - batch[key] = observations - return batch + return transform(batch) # ========================================================================= # diff --git a/disent/util/seeds.py b/disent/util/seeds.py index 45338cbc..3a0b3ac8 100644 --- a/disent/util/seeds.py +++ b/disent/util/seeds.py @@ -23,6 +23,8 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging +import random + import numpy as np @@ -41,16 +43,18 @@ def seed(long=777): if long is None: log.warning(f'[SEEDING]: no seed was specified. Seeding skipped!') return + # seed python + random.seed(long) + # seed numpy + np.random.seed(long) # seed torch - it can be slow to import try: import torch - torch.manual_seed(long) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False + torch.manual_seed(long) # also calls: torch.cuda.manual_seed_all except ImportError: log.warning(f'[SEEDING]: torch is not installed. Skipped seeding torch methods!') - # seed numpy - np.random.seed(long) # done! log.info(f'[SEEDED]: {long}') diff --git a/experiment/config/config.yaml b/experiment/config/config.yaml index 155f327b..63704614 100644 --- a/experiment/config/config.yaml +++ b/experiment/config/config.yaml @@ -23,6 +23,7 @@ job: project: 'DELETE' name: '${framework.name}:${framework.module.recon_loss}|${dataset.name}:${sampling.name}|${trainer.steps}' partition: stampede + seed: NULL framework: beta: 0.003 diff --git a/experiment/run.py b/experiment/run.py index 2ec066b0..ce4300f3 100644 --- a/experiment/run.py +++ b/experiment/run.py @@ -40,6 +40,7 @@ from disent.frameworks import DisentFramework from disent.model import AutoEncoder from disent.nn.weights import init_model_weights +from disent.util.seeds import seed from disent.util.strings import make_box_str from experiment.util.callbacks import LoggerProgressCallback from experiment.util.callbacks import VaeDisentanglementLoggingCallback @@ -244,8 +245,12 @@ def hydra_create_framework(framework_cfg, datamodule, cfg): def run(cfg: DictConfig): + # allow the cfg to be edited cfg = make_non_strict(cfg) + # deterministic seed + seed(cfg.job.setdefault('seed', None)) + # -~-~-~-~-~-~-~-~-~-~-~-~- # # INITIALISE & SETDEFAULT IN CONFIG # -~-~-~-~-~-~-~-~-~-~-~-~- # diff --git a/experiment/util/callbacks/callbacks_vae.py b/experiment/util/callbacks/callbacks_vae.py index 1b7ebc53..ce544a6a 100644 --- a/experiment/util/callbacks/callbacks_vae.py +++ b/experiment/util/callbacks/callbacks_vae.py @@ -34,7 +34,7 @@ import disent.metrics import disent.util.colors as c -from disent.dataset._augment_util import AugmentableDataset +from disent.dataset import DisentDataset from disent.dataset.groundtruth import GroundTruthDataset from disent.frameworks.ae import Ae from disent.frameworks.vae import Vae @@ -59,7 +59,7 @@ # ========================================================================= # -def _get_dataset_and_vae(trainer: pl.Trainer, pl_module: pl.LightningModule) -> (AugmentableDataset, Ae): +def _get_dataset_and_vae(trainer: pl.Trainer, pl_module: pl.LightningModule) -> (DisentDataset, Ae): assert isinstance(pl_module, Ae), f'{pl_module.__class__} is not an instance of {Ae}' # get dataset if hasattr(trainer, 'datamodule') and (trainer.datamodule is not None): @@ -73,7 +73,7 @@ def _get_dataset_and_vae(trainer: pl.Trainer, pl_module: pl.LightningModule) -> else: raise RuntimeError('could not retrieve dataset! please report this...') # check dataset - assert isinstance(dataset, AugmentableDataset), f'retrieved dataset is not an {AugmentableDataset.__name__}' + assert isinstance(dataset, DisentDataset), f'retrieved dataset is not an {DisentDataset.__name__}' # done checks return dataset, pl_module diff --git a/experiment/util/hydra_data.py b/experiment/util/hydra_data.py index 598b1545..383cf105 100644 --- a/experiment/util/hydra_data.py +++ b/experiment/util/hydra_data.py @@ -27,8 +27,8 @@ import pytorch_lightning as pl from omegaconf import DictConfig -from disent.dataset._augment_util import AugmentableDataset -from disent.nn.transform import GroundTruthDatasetBatchAugment +from disent.dataset import DisentDataset +from disent.nn.transform import DisentDatasetTransform from experiment.util.hydra_utils import instantiate_recursive @@ -53,12 +53,12 @@ def __init__(self, hparams: DictConfig): # - corresponds to below in train_dataloader() if self.hparams.dataset.gpu_augment: # TODO: this is outdated! - self.batch_augment = GroundTruthDatasetBatchAugment(transform=self.input_transform) + self.batch_augment = DisentDatasetTransform(transform=self.input_transform) else: self.batch_augment = None # datasets initialised in setup() - self.dataset_train_noaug: AugmentableDataset = None - self.dataset_train_aug: AugmentableDataset = None + self.dataset_train_noaug: DisentDataset = None + self.dataset_train_aug: DisentDataset = None def prepare_data(self) -> None: # *NB* Do not set model parameters here. @@ -77,8 +77,8 @@ def setup(self, stage=None) -> None: self.dataset_train_noaug = hydra.utils.instantiate(self.hparams.data_wrapper.wrapper, data, transform=self.data_transform, augment=None) self.dataset_train_aug = hydra.utils.instantiate(self.hparams.data_wrapper.wrapper, data, transform=self.data_transform, augment=self.input_transform) # TODO: make these assertions more general with some base-class - assert isinstance(self.dataset_train_noaug, AugmentableDataset) - assert isinstance(self.dataset_train_aug, AugmentableDataset) + assert isinstance(self.dataset_train_noaug, DisentDataset) + assert isinstance(self.dataset_train_aug, DisentDataset) # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Training Dataset: From eac7c05f5b3f4cd736c42547012be2cb20cfacba Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Fri, 4 Jun 2021 23:27:49 +0200 Subject: [PATCH 33/34] fix configs + experiment configs readme additions --- README.md | 42 +++++++++++++++++++ experiment/config/dataset/cars3d.yaml | 2 +- experiment/config/dataset/dsprites.yaml | 2 +- experiment/config/dataset/monte_rollouts.yaml | 2 +- experiment/config/dataset/mpi3d_real.yaml | 2 +- .../config/dataset/mpi3d_realistic.yaml | 2 +- experiment/config/dataset/mpi3d_toy.yaml | 2 +- experiment/config/dataset/shapes3d.yaml | 2 +- experiment/config/dataset/smallnorb.yaml | 2 +- experiment/config/dataset/xyblocks.yaml | 2 +- experiment/config/dataset/xyblocks_grey.yaml | 2 +- experiment/config/dataset/xyobject.yaml | 2 +- experiment/config/dataset/xyobject_grey.yaml | 2 +- experiment/config/dataset/xysquares.yaml | 2 +- experiment/config/dataset/xysquares_grey.yaml | 2 +- experiment/config/dataset/xysquares_rgb.yaml | 2 +- experiment/config/metrics/test.yaml | 21 ++++++++++ 17 files changed, 78 insertions(+), 15 deletions(-) create mode 100644 experiment/config/metrics/test.yaml diff --git a/README.md b/README.md index bd8c52f7..5f0120e6 100644 --- a/README.md +++ b/README.md @@ -299,4 +299,46 @@ print('metrics:', metrics) Visit the [docs](https://disent.dontpanic.sh) for more examples! + +---------------------- + +### Hydra Experiment Example + +The entrypoint for basic experiments is `experiments/run.py`. + +Some configuration will be required, but basic experiments can +be adjusted by modifying the [Hydra Config 1.0](https://github.com/facebookresearch/hydra) +files in `experiment/config`. + +Modifying the main `experiment/config/config.yaml` is all you +need for most basic experiments. The main config file contains +a defaults list with entries corresponding to yaml configuration +files (config options) in the subfolders (config groups) in +`experiment/config//