diff --git a/models/wgan_gp.py b/models/wgan_gp.py index 4018a90..35a4303 100644 --- a/models/wgan_gp.py +++ b/models/wgan_gp.py @@ -5,6 +5,7 @@ tensorboard --logdir default """ import os + from argparse import ArgumentParser, Namespace from collections import OrderedDict @@ -14,8 +15,9 @@ import torch.nn.functional as F import torchvision import torchvision.transforms as transforms -from torch.utils.data import DataLoader + from torchvision.datasets import MNIST +from torch.utils.data import DataLoader from pytorch_lightning.core import LightningModule from pytorch_lightning.trainer import Trainer @@ -124,11 +126,10 @@ def training_step(self, batch, batch_idx, optimizer_idx): z = torch.randn(imgs.shape[0], self.latent_dim) z = z.type_as(imgs) - n_critic = 5 lambda_gp = 10 # train generator - if optimizer_idx % n_critic == 0: + if optimizer_idx == 0: # generate images self.generated_imgs = self(z) @@ -155,33 +156,39 @@ def training_step(self, batch, batch_idx, optimizer_idx): # train discriminator # Measure discriminator's ability to classify real from generated samples - fake_imgs = self(z) - - # Real images - real_validity = self.discriminator(imgs) - # Fake images - fake_validity = self.discriminator(fake_imgs) - # Gradient penalty - gradient_penalty = self.compute_gradient_penalty(imgs.data, fake_imgs.data) - # Adversarial loss - d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty - - tqdm_dict = {'d_loss': d_loss} - output = OrderedDict({ - 'loss': d_loss, - 'progress_bar': tqdm_dict, - 'log': tqdm_dict - }) - return output + elif optimizer_idx == 1: + fake_imgs = self(z) + + # Real images + real_validity = self.discriminator(imgs) + # Fake images + fake_validity = self.discriminator(fake_imgs) + # Gradient penalty + gradient_penalty = self.compute_gradient_penalty(imgs.data, fake_imgs.data) + # Adversarial loss + d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty + + tqdm_dict = {'d_loss': d_loss} + output = OrderedDict({ + 'loss': d_loss, + 'progress_bar': tqdm_dict, + 'log': tqdm_dict + }) + return output def configure_optimizers(self): + n_critic = 5 + lr = self.lr b1 = self.b1 b2 = self.b2 opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) - return [opt_g, opt_d], [] + return ( + {'optimizer': opt_g, 'frequency': n_critic}, + {'optimizer': opt_d, 'frequency': 1} + ) def train_dataloader(self): transform = transforms.Compose([