diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index da25eb80..9146ba21 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -16,7 +16,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] # [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.8", "3.9"] + python-version: ["3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v2 diff --git a/disent/frameworks/vae/_weaklysupervised__adavae.py b/disent/frameworks/vae/_weaklysupervised__adavae.py index 4434711d..97586c77 100644 --- a/disent/frameworks/vae/_weaklysupervised__adavae.py +++ b/disent/frameworks/vae/_weaklysupervised__adavae.py @@ -218,7 +218,7 @@ def compute_average_gvae_std(d0_posterior: Normal, d1_posterior: Normal) -> Norm assert isinstance(d1_posterior, Normal), f'posterior distributions must be {Normal.__name__} distributions, got: {type(d1_posterior)}' # averages ave_std = 0.5 * (d0_posterior.stddev + d1_posterior.stddev) - ave_mean = 0.5 * (d1_posterior.mean + d1_posterior.mean) + ave_mean = 0.5 * (d0_posterior.mean + d1_posterior.mean) # done! return Normal(loc=ave_mean, scale=ave_std) @@ -235,7 +235,7 @@ def compute_average_gvae(d0_posterior: Normal, d1_posterior: Normal) -> Normal: assert isinstance(d1_posterior, Normal), f'posterior distributions must be {Normal.__name__} distributions, got: {type(d1_posterior)}' # averages ave_var = 0.5 * (d0_posterior.variance + d1_posterior.variance) - ave_mean = 0.5 * (d1_posterior.mean + d1_posterior.mean) + ave_mean = 0.5 * (d0_posterior.mean + d1_posterior.mean) # done! return Normal(loc=ave_mean, scale=torch.sqrt(ave_var)) @@ -323,10 +323,10 @@ def hook_intercept_ds(self, ds_posterior: Sequence[Distribution], ds_prior: Sequ ave_std = (0.5 * d0_posterior.variance + 0.5 * d1_posterior.variance) ** 0.5 # [4.b] select shared or original values based on mask - z0_mean = torch.where(share_mask, d0_posterior.loc, ave_mean) - z1_mean = torch.where(share_mask, d1_posterior.loc, ave_mean) - z0_std = torch.where(share_mask, d0_posterior.scale, ave_std) - z1_std = torch.where(share_mask, d1_posterior.scale, ave_std) + z0_mean = torch.where(share_mask, ave_mean, d0_posterior.loc) + z1_mean = torch.where(share_mask, ave_mean, d1_posterior.loc) + z0_std = torch.where(share_mask, ave_std, d0_posterior.scale) + z1_std = torch.where(share_mask, ave_std, d1_posterior.scale) # construct distributions ave_d0_posterior = Normal(loc=z0_mean, scale=z0_std) diff --git a/disent/util/seeds.py b/disent/util/seeds.py index 9bb29ae2..912af68d 100644 --- a/disent/util/seeds.py +++ b/disent/util/seeds.py @@ -24,8 +24,6 @@ import contextlib import logging -import random -import numpy as np log = logging.getLogger(__name__) @@ -44,8 +42,10 @@ def seed(long=777): log.warning(f'[SEEDING]: no seed was specified. Seeding skipped!') return # seed python + import random random.seed(long) # seed numpy + import numpy as np np.random.seed(long) # seed torch - it can be slow to import try: @@ -60,27 +60,26 @@ def seed(long=777): class TempNumpySeed(contextlib.ContextDecorator): - def __init__(self, seed=None, offset=0): + def __init__(self, seed: int = None): # check and normalize seed if seed is not None: try: seed = int(seed) except: - raise ValueError(f'{seed=} is not int-like!') - # offset seed - if seed is not None: - seed += offset + raise ValueError(f'seed={seed} is not int-like!') # save values self._seed = seed self._state = None def __enter__(self): if self._seed is not None: + import numpy as np self._state = np.random.get_state() np.random.seed(self._seed) def __exit__(self, *args, **kwargs): if self._seed is not None: + import numpy as np np.random.set_state(self._state) self._state = None @@ -88,6 +87,7 @@ def _recreate_cm(self): # TODO: do we need to override this? return self + # ========================================================================= # # END # # ========================================================================= # diff --git a/setup.py b/setup.py index 29aad750..93d87a4e 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ author="Nathan Juraj Michlo", author_email="NathanJMichlo@gmail.com", - version="0.5.0", + version="0.5.1", python_requires=">=3.8", # we make use of standard library features only in 3.8 packages=setuptools.find_packages(), @@ -64,6 +64,7 @@ "Operating System :: OS Independent", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "Intended Audience :: Science/Research", ], ) diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index f1488b9a..7e3c62d3 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -28,6 +28,7 @@ import pytest import pytorch_lightning as pl +import torch from torch.utils.data import DataLoader from disent.dataset import DisentDataset @@ -35,6 +36,7 @@ from disent.dataset.sampling import GroundTruthSingleSampler from disent.dataset.sampling import GroundTruthPairSampler from disent.dataset.sampling import GroundTruthTripleSampler +from disent.dataset.sampling import RandomSampler from disent.frameworks.ae import * from disent.frameworks.vae import * from disent.model import AutoEncoder @@ -46,6 +48,8 @@ # ========================================================================= # # TEST FRAMEWORKS # # ========================================================================= # +from disent.util.seeds import seed +from disent.util.seeds import TempNumpySeed from docs.examples.extend_experiment.code.weaklysupervised__si_adavae import SwappedInputAdaVae from docs.examples.extend_experiment.code.weaklysupervised__si_betavae import SwappedInputBetaVae @@ -166,6 +170,53 @@ def test_framework_config_defaults(): ) +def test_ada_vae_similarity(): + + seed(42) + + data = XYObjectData() + dataset = DisentDataset(data, sampler=RandomSampler(num_samples=2), transform=ToImgTensorF32()) + dataloader = DataLoader(dataset, num_workers=0, batch_size=3) + + model = AutoEncoder( + encoder=EncoderLinear(x_shape=data.x_shape, z_size=25, z_multiplier=2), + decoder=DecoderLinear(x_shape=data.x_shape, z_size=25, z_multiplier=1), + ) + + adavae0 = AdaGVaeMinimal(model=model, cfg=AdaGVaeMinimal.cfg()) + adavae1 = AdaVae(model=model, cfg=AdaVae.cfg()) + adavae2 = AdaVae(model=model, cfg=AdaVae.cfg( + ada_average_mode='gvae', + ada_thresh_mode='symmetric_kl', + ada_thresh_ratio=0.5, + )) + + batch = next(iter(dataloader)) + + # TODO: add a TempNumpySeed equivalent for torch + seed(777) + result0a = adavae0.do_training_step(batch, 0) + seed(777) + result0b = adavae0.do_training_step(batch, 0) + assert torch.allclose(result0a, result0b), f'{result0a} does not match {result0b}' + + seed(777) + result1a = adavae1.do_training_step(batch, 0) + seed(777) + result1b = adavae1.do_training_step(batch, 0) + assert torch.allclose(result1a, result1b), f'{result1a} does not match {result1b}' + + seed(777) + result2a = adavae2.do_training_step(batch, 0) + seed(777) + result2b = adavae2.do_training_step(batch, 0) + assert torch.allclose(result2a, result2b), f'{result2a} does not match {result2b}' + + # check similar + assert torch.allclose(result0a, result1a), f'{result0a} does not match {result1a}' + assert torch.allclose(result1a, result2a), f'{result1a} does not match {result2a}' + + # ========================================================================= # # END # # ========================================================================= #