-
Notifications
You must be signed in to change notification settings - Fork 5
/
model.py
68 lines (54 loc) · 2.97 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn
from stylegan import Generator, GeneratorNOutputs, StyledGenerator, StyledGenerators
def get_masks(fg_masks):
return tuple([fm[1] for fm in fg_masks])
class SimpleBgFgMask(nn.Module):
def __init__(self, code_dim=512, n_mlp=8):
super(SimpleBgFgMask, self).__init__()
output_dims = [3, 3, 1]
self.generator = GeneratorNOutputs(code_dim, output_dims=output_dims)
self.generator = StyledGenerator(self.generator, code_dim, n_mlp, code_dim)
def forward(self, x, noise=None, step=0, alpha=-1, mean_style=None, style_weight=0, mixing_range=(-1, -1)):
bg, fg, mask = self.generator(x, noise, step, alpha, mean_style, style_weight, mixing_range)
mask = torch.sigmoid(mask)
return bg, (fg, mask)
def parameter_groups(self):
groups = {}
groups['style'] = self.generator.style.parameters()
groups['generator'] = self.generator.generator.parameters()
return groups
class BgFgMask(nn.Module):
def __init__(self, code_dim=512, n_mlp=8):
super(BgFgMask, self).__init__()
output_dims = [3, 1]
self.generator_bg = StyledGenerator(Generator(code_dim), code_dim, n_mlp, code_dim)
self.generator_objects = StyledGenerator(GeneratorNOutputs(code_dim, output_dims=output_dims), code_dim, n_mlp, code_dim)
def forward(self, x, noise=None, step=0, alpha=-1, mean_style=None, style_weight=0, mixing_range=(-1, -1)):
bg = self.generator_bg(x, noise, step, alpha, mean_style, style_weight, mixing_range)
fg, mask = self.generator_objects(x, noise, step, alpha, mean_style, style_weight, mixing_range)
mask = torch.sigmoid(mask)
return bg, (fg, mask)
def parameter_groups(self):
groups = {}
groups['style'] = list(self.generator_bg.style.parameters()) + list(self.generator_objects.style.parameters())
groups['generator'] = list(self.generator_bg.generator.parameters()) + list(self.generator_objects.generator.parameters())
return groups
class BgFgMaskSharedStyle(nn.Module):
def __init__(self, code_dim=512, n_mlp=8):
super(BgFgMaskSharedStyle, self).__init__()
output_dims = [3, 1]
generator_bg = Generator(code_dim)
generator_objects = GeneratorNOutputs(code_dim, output_dims=output_dims)
self.generator = StyledGenerators((generator_bg, generator_objects), code_dim, n_mlp, code_dim)
def forward(self, x, noise=None, step=0, alpha=-1, mean_style=None, style_weight=0, mixing_range=(-1, -1)):
bg, out = self.generator(x, noise, step, alpha, mean_style, style_weight, mixing_range=mixing_range)
fg, mask = out[0], torch.sigmoid(out[1])
return bg, (fg, mask)
def parameter_groups(self):
groups = {}
groups['style'] = self.generator.style.parameters()
groups['generator'] = []
for gen in self.generator.generators:
groups['generator'] += list(gen.parameters())
return groups