-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
157 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
import pytorch_lightning as pl | ||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
from disent.dataset import DisentDataset | ||
from disent.dataset.data import XYObjectData | ||
from disent.dataset.sampling import SingleSampler | ||
from disent.dataset.transform import ToImgTensorF32 | ||
from disent.frameworks.vae import BetaVae | ||
from disent.metrics import metric_dci | ||
from disent.metrics import metric_mig | ||
from disent.model import AutoEncoder | ||
from disent.model.ae import DecoderConv64 | ||
from disent.model.ae import EncoderConv64 | ||
from disent.schedule import CyclicSchedule | ||
|
||
# create the dataset & dataloaders | ||
# - ToImgTensorF32 transforms images from numpy arrays to tensors and performs checks | ||
data = XYObjectData() | ||
dataset = DisentDataset(dataset=data, sampler=SingleSampler(), transform=ToImgTensorF32()) | ||
dataloader = DataLoader(dataset=dataset, batch_size=128, shuffle=True, num_workers=os.cpu_count()) | ||
|
||
# create the BetaVAE model | ||
# - adjusting the beta, learning rate, and representation size. | ||
module = BetaVae( | ||
model=AutoEncoder( | ||
# z_multiplier is needed to output mu & logvar when parameterising normal distribution | ||
encoder=EncoderConv64(x_shape=data.x_shape, z_size=10, z_multiplier=2), | ||
decoder=DecoderConv64(x_shape=data.x_shape, z_size=10), | ||
), | ||
cfg=BetaVae.cfg( | ||
optimizer='adam', | ||
optimizer_kwargs=dict(lr=1e-3), | ||
loss_reduction='mean_sum', | ||
beta=4, | ||
) | ||
) | ||
|
||
# cyclic schedule for target 'beta' in the config/cfg. The initial value from the | ||
# config is saved and multiplied by the ratio from the schedule on each step. | ||
# - based on: https://arxiv.org/abs/1903.10145 | ||
module.register_schedule( | ||
'beta', CyclicSchedule( | ||
period=1024, # repeat every: trainer.global_step % period | ||
) | ||
) | ||
|
||
# train model | ||
# - for 2048 batches/steps | ||
trainer = pl.Trainer( | ||
max_steps=2048, gpus=1 if torch.cuda.is_available() else None, logger=False, checkpoint_callback=False | ||
) | ||
trainer.fit(module, dataloader) | ||
|
||
# compute disentanglement metrics | ||
# - we cannot guarantee which device the representation is on | ||
# - this will take a while to run | ||
get_repr = lambda x: module.encode(x.to(module.device)) | ||
|
||
metrics = { | ||
**metric_dci(dataset, get_repr, num_train=1000, num_test=500, show_progress=True), | ||
**metric_mig(dataset, get_repr, num_train=2000), | ||
} | ||
|
||
# evaluate | ||
print('metrics:', metrics) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,7 +48,7 @@ | |
author="Nathan Juraj Michlo", | ||
author_email="[email protected]", | ||
|
||
version="0.3.0", | ||
version="0.3.1", | ||
python_requires=">=3.8", # we make use of standard library features only in 3.8 | ||
packages=setuptools.find_packages(), | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters