Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Aug 5, 2022
2 parents fe8b66d + e57f4e2 commit c2a4126
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 6 deletions.
3 changes: 3 additions & 0 deletions disent/dataset/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ def dataset_sample_elems(self, num_samples: int, mode: str, return_indices: bool
# Batches -- Ground Truth Only #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

# TODO: batches should be obtained from indices
# - the wrapped gt datasets should handle generating these indices, eg. factor traversals etc.

@groundtruth_only
def dataset_batch_from_factors(self, factors: np.ndarray, mode: str, collate: bool = True):
"""Get a batch of observations X from a batch of factors Y."""
Expand Down
File renamed without changes.
3 changes: 1 addition & 2 deletions disent/frameworks/ae/_unsupervised__ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

import logging
from dataclasses import dataclass
from numbers import Number
from typing import Any
Expand All @@ -34,7 +33,7 @@

import torch

from disent.frameworks.ae._ae_mixin import _AeAndVaeMixin
from disent.frameworks._ae_mixin import _AeAndVaeMixin
from disent.frameworks.helper.util import detach_all
from disent.model import AutoEncoder
from disent.util.iters import map_all
Expand Down
19 changes: 19 additions & 0 deletions disent/frameworks/vae/_unsupervised__dfcvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from disent.frameworks.helper.util import compute_ave_loss
from disent.frameworks.vae._unsupervised__betavae import BetaVae
from disent.nn.loss.reduction import batch_loss_reduction
from disent.nn.loss.reduction import get_mean_loss_scale
from disent.dataset.transform.functional import check_tensor

Expand Down Expand Up @@ -132,6 +133,24 @@ def __init__(self, feature_layers: Optional[List[Union[str, int]]] = None, input
assert input_mode in {'none', 'clamp', 'assert'}
self.input_mode = input_mode

def compute_pairwise_loss(self, x_recon, x_targ, reduction='mean'):
"""
THIS DOES NOT HAVE LOSS SCALING, LIKE `compute_loss`
x_recon and x_targ data should be an unnormalized RGB batch of
data [B x C x H x W] in the range [0, 1].
"""
features_recon = self._extract_features(x_recon)
features_targ = self._extract_features(x_targ)
# compute losses
feature_loss = 0.0
for (f_recon, f_targ) in zip(features_recon, features_targ):
loss = F.mse_loss(f_recon, f_targ, reduction='none')
feature_loss += batch_loss_reduction(loss, reduction=reduction)
# checks
assert (feature_loss.ndim == 1) and (len(feature_loss) == len(x_recon))
return feature_loss

def compute_loss(self, x_recon, x_targ, reduction='mean'):
"""
x_recon and x_targ data should be an unnormalized RGB batch of
Expand Down
2 changes: 1 addition & 1 deletion disent/frameworks/vae/_unsupervised__dotvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from disent.frameworks.helper.reconstructions import make_reconstruction_loss
from disent.frameworks.helper.reconstructions import ReconLossHandler
from disent.frameworks.vae import AdaNegTripletVae
from disent.frameworks.vae._supervised__adaneg_tvae import AdaNegTripletVae
from disent.nn.loss.triplet_mining import configured_idx_mine


Expand Down
2 changes: 1 addition & 1 deletion disent/frameworks/vae/_unsupervised__vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import torch
from torch.distributions import Distribution

from disent.frameworks.ae._ae_mixin import _AeAndVaeMixin
from disent.frameworks._ae_mixin import _AeAndVaeMixin
from disent.frameworks.helper.latent_distributions import LatentDistsHandler
from disent.frameworks.helper.latent_distributions import make_latent_distribution
from disent.frameworks.helper.util import detach_all
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pip>=21.0
numpy>=1.19.0
torch>=1.9.0
torchvision>=0.10.0
pytorch-lightning>=1.4.0
pytorch-lightning>=1.4.0,<1.7
torch_optimizer>=0.1.0
scipy>=1.7.0
scikit-learn>=0.24.2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
author="Nathan Juraj Michlo",
author_email="[email protected]",

version="0.6.0",
version="0.6.1",
python_requires=">=3.8", # we make use of standard library features only in 3.8
packages=setuptools.find_packages(),

Expand Down
31 changes: 31 additions & 0 deletions tests/test_000_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~
# MIT License
#
# Copyright (c) 2022 Nathan Juraj Michlo
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~


# THIS TEST FILE SHOULD ALWAYS BE LOADED AND RUN FIRST
from disent.frameworks.vae import BetaVae


def test_000_import():
assert BetaVae

0 comments on commit c2a4126

Please sign in to comment.