Skip to content

Commit

Permalink
Fixed training script
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniojkim committed Aug 23, 2020
1 parent 98d8745 commit 17e3fb3
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 16 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
*.tar
*ipynb_checkpoint*
*pycache*
data
data/
model/Pro-GAN_scraps.txt
model/Pro-GAN_scraps.txt
1 change: 0 additions & 1 deletion data

This file was deleted.

4 changes: 2 additions & 2 deletions model/Pro-GAN/params256.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
resolution: 256
batch_size: 32
learning_rate: 0.0005
learning_rate: 0.00025
save_interval: 200
save_model_path: "../../data/checkpoints"
log_path: "../../data/images"
log_path: "../../data/images"
6 changes: 3 additions & 3 deletions model/Pro-GAN/params512.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
resolution: 512
batch_size: 32
learning_rate: 0.0005
save_interval: 200
learning_rate: 0.0001
save_interval: 250
save_model_path: "../../data/checkpoints"
log_path: "../../data/images"
log_path: "../../data/images"
12 changes: 6 additions & 6 deletions model/Pro-GAN/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def FromTensor(tensor):
return np.moveaxis(tensor.detach().cpu().numpy(), 1, -1)


def generate_fake_images(generator, params, suffix=""):
def generate_fake_images(generator, alpha, params, suffix=""):
res = params.resolution

noise = generate_noise(generator, 4)
fake_output = generator.forward(noise, alpha=0)
fake_output = generator.forward(noise, alpha=alpha)
fake_output = FromTensor(fake_output)

fig, ax = plt.subplots(2, 2, figsize=(10, 11))
Expand Down Expand Up @@ -117,7 +117,7 @@ def train(args):

print("\n"+"-"*80+"\n")

generate_fake_images(generator, params, "Untrained")
generate_fake_images(generator, 0, params, "Untrained")

dataloader = torch.utils.data.DataLoader(
Dataset(resolution=resolution, size=800000),
Expand Down Expand Up @@ -180,7 +180,7 @@ def train(fade_in: bool):
)
progress.update()

if abs(discriminator_err.item()) > 50:
if abs(discriminator_err.item()) > 100:
raise Exception("Training has diverged.")

generator_losses.append(generator_err.item())
Expand All @@ -204,7 +204,7 @@ def train(fade_in: bool):
)
)

generate_fake_images(generator, params, f"Pro-GAN Iteration {iteration}")
generate_fake_images(generator, alpha, params, f"Pro-GAN Iteration {iteration}")

pd.DataFrame(
{
Expand Down Expand Up @@ -246,7 +246,7 @@ def train(fade_in: bool):

train(fade_in=False)

generate_fake_images(generator, params, "Trained")
generate_fake_images(generator, 1, params, "Trained")

generator.save(
os.path.join(
Expand Down
6 changes: 3 additions & 3 deletions model/Pro-GAN/train.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash


python -u train.py --param-file params128.yml
python -u train.py --param-file params256.yml
# python -u train.py --param-file params512.yml
# python -u train.py --param-file params128.yml
# python -u train.py --param-file params256.yml
python -u train.py --param-file params512.yml
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
git+https://github.com/antoniojkim/TorchX.git
matplotlib
numpy
pandas
pillow
torch
torchvision
Expand Down

0 comments on commit 17e3fb3

Please sign in to comment.