Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
nocotan committed Dec 28, 2020
1 parent ca38215 commit 6c3b151
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions models/wgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
tensorboard --logdir default
"""
import os

from argparse import ArgumentParser, Namespace
from collections import OrderedDict

Expand All @@ -14,8 +15,9 @@
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer import Trainer
Expand Down Expand Up @@ -102,11 +104,10 @@ def training_step(self, batch, batch_idx, optimizer_idx):
z = torch.randn(imgs.shape[0], self.latent_dim)
z = z.type_as(imgs)

n_critic = 5
clip_value = 0.01

# train generator
if optimizer_idx % n_critic == 0:
if optimizer_idx == 0:

# generate images
self.generated_imgs = self(z)
Expand Down Expand Up @@ -135,27 +136,33 @@ def training_step(self, batch, batch_idx, optimizer_idx):
# Measure discriminator's ability to classify real from generated samples

# discriminator loss is the average of these
d_loss = -torch.mean(self.discriminator(imgs)) + torch.mean(self.discriminator(self(z)))
elif optimizer_idx == 1:
d_loss = -torch.mean(self.discriminator(imgs)) + torch.mean(self.discriminator(self(z)))

for p in self.discriminator.parameters():
p.data.clamp_(-clip_value, clip_value)
for p in self.discriminator.parameters():
p.data.clamp_(-clip_value, clip_value)

tqdm_dict = {'d_loss': d_loss}
output = OrderedDict({
'loss': d_loss,
'progress_bar': tqdm_dict,
'log': tqdm_dict
})
return output
tqdm_dict = {'d_loss': d_loss}
output = OrderedDict({
'loss': d_loss,
'progress_bar': tqdm_dict,
'log': tqdm_dict
})
return output

def configure_optimizers(self):
n_critic = 5

lr = self.lr
b1 = self.b1
b2 = self.b2

opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
return [opt_g, opt_d], []
return (
{'optimizer': opt_g, 'frequency': n_critic},
{'optimizer': opt_d, 'frequency': 1}
)

def train_dataloader(self):
transform = transforms.Compose([
Expand Down

0 comments on commit 6c3b151

Please sign in to comment.