-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconfig.py
97 lines (73 loc) · 2.35 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
'''
data: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
'''
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
# from IPython.display import HTML
# 设置随机数种子
manualSeed = 999
#manualSeed = random.randint(1, 10000) # 如果你想要新的结果就是要这段代码
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
# 数据集根目录
# dataroot = "data/celeba"
dataroot = "data"
# 加载数据的工作线程数
workers = 2
# batch size
batch_size = 128
# 训练图像空间大下
image_size = 64
# 训练图像的通道数
nc = 3
# 潜变量 Z的大下(生成器输入的大小)
nz = 100
# 生成器中特征图的大小
ngf = 64
# 判别器中特征映射的大小
ndf = 64
# epoch
num_epoch = 100
# lr
lr = 0.0002
# optimizor Adam
beta1 = 0.5
# nym gpus
ngpu = 1
# 我们可以按照设置的方式使用图像文件夹数据集。
# 用ImageFolder数据集类,它要求在数据集的根文件夹中有子目录
# 创建数据集
dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# 创建加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)
# 选择我们运行在上面的设备
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
if __name__ == "__main__":
# 绘制部分我们的输入图像
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))