forked from znxlwm/pytorch-CartoonGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
43 lines (37 loc) · 1.36 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import itertools, imageio, torch, random
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from torchvision import datasets
from scipy.misc import imresize
from torch.autograd import Variable
def data_load(path, subfolder, transform, batch_size, shuffle=False, drop_last=True):
dset = datasets.ImageFolder(path, transform)
ind = dset.class_to_idx[subfolder]
n = 0
for i in range(dset.__len__()):
if ind != dset.imgs[n][1]:
del dset.imgs[n]
n -= 1
n += 1
return torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
def initialize_weights(net):
for m in net.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()