Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
nocotan committed Dec 28, 2020
1 parent 6c3b151 commit bf674a2
Showing 1 changed file with 29 additions and 22 deletions.
51 changes: 29 additions & 22 deletions models/wgan_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
tensorboard --logdir default
"""
import os

from argparse import ArgumentParser, Namespace
from collections import OrderedDict

Expand All @@ -14,8 +15,9 @@
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer import Trainer
Expand Down Expand Up @@ -124,11 +126,10 @@ def training_step(self, batch, batch_idx, optimizer_idx):
z = torch.randn(imgs.shape[0], self.latent_dim)
z = z.type_as(imgs)

n_critic = 5
lambda_gp = 10

# train generator
if optimizer_idx % n_critic == 0:
if optimizer_idx == 0:

# generate images
self.generated_imgs = self(z)
Expand All @@ -155,33 +156,39 @@ def training_step(self, batch, batch_idx, optimizer_idx):

# train discriminator
# Measure discriminator's ability to classify real from generated samples
fake_imgs = self(z)

# Real images
real_validity = self.discriminator(imgs)
# Fake images
fake_validity = self.discriminator(fake_imgs)
# Gradient penalty
gradient_penalty = self.compute_gradient_penalty(imgs.data, fake_imgs.data)
# Adversarial loss
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

tqdm_dict = {'d_loss': d_loss}
output = OrderedDict({
'loss': d_loss,
'progress_bar': tqdm_dict,
'log': tqdm_dict
})
return output
elif optimizer_idx == 1:
fake_imgs = self(z)

# Real images
real_validity = self.discriminator(imgs)
# Fake images
fake_validity = self.discriminator(fake_imgs)
# Gradient penalty
gradient_penalty = self.compute_gradient_penalty(imgs.data, fake_imgs.data)
# Adversarial loss
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

tqdm_dict = {'d_loss': d_loss}
output = OrderedDict({
'loss': d_loss,
'progress_bar': tqdm_dict,
'log': tqdm_dict
})
return output

def configure_optimizers(self):
n_critic = 5

lr = self.lr
b1 = self.b1
b2 = self.b2

opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
return [opt_g, opt_d], []
return (
{'optimizer': opt_g, 'frequency': n_critic},
{'optimizer': opt_d, 'frequency': 1}
)

def train_dataloader(self):
transform = transforms.Compose([
Expand Down

0 comments on commit bf674a2

Please sign in to comment.