-
Notifications
You must be signed in to change notification settings - Fork 171
/
mixstyle.py
124 lines (93 loc) · 3.1 KB
/
mixstyle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import random
from contextlib import contextmanager
import torch
import torch.nn as nn
def deactivate_mixstyle(m):
if type(m) == MixStyle:
m.set_activation_status(False)
def activate_mixstyle(m):
if type(m) == MixStyle:
m.set_activation_status(True)
def random_mixstyle(m):
if type(m) == MixStyle:
m.update_mix_method("random")
def crossdomain_mixstyle(m):
if type(m) == MixStyle:
m.update_mix_method("crossdomain")
@contextmanager
def run_without_mixstyle(model):
# Assume MixStyle was initially activated
try:
model.apply(deactivate_mixstyle)
yield
finally:
model.apply(activate_mixstyle)
@contextmanager
def run_with_mixstyle(model, mix=None):
# Assume MixStyle was initially deactivated
if mix == "random":
model.apply(random_mixstyle)
elif mix == "crossdomain":
model.apply(crossdomain_mixstyle)
try:
model.apply(activate_mixstyle)
yield
finally:
model.apply(deactivate_mixstyle)
class MixStyle(nn.Module):
"""MixStyle.
Reference:
Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
"""
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
"""
Args:
p (float): probability of using MixStyle.
alpha (float): parameter of the Beta distribution.
eps (float): scaling parameter to avoid numerical issues.
mix (str): how to mix.
"""
super().__init__()
self.p = p
self.beta = torch.distributions.Beta(alpha, alpha)
self.eps = eps
self.alpha = alpha
self.mix = mix
self._activated = True
def __repr__(self):
return (
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
)
def set_activation_status(self, status=True):
self._activated = status
def update_mix_method(self, mix="random"):
self.mix = mix
def forward(self, x):
if not self.training or not self._activated:
return x
if random.random() > self.p:
return x
B = x.size(0)
mu = x.mean(dim=[2, 3], keepdim=True)
var = x.var(dim=[2, 3], keepdim=True)
sig = (var + self.eps).sqrt()
mu, sig = mu.detach(), sig.detach()
x_normed = (x-mu) / sig
lmda = self.beta.sample((B, 1, 1, 1))
lmda = lmda.to(x.device)
if self.mix == "random":
# random shuffle
perm = torch.randperm(B)
elif self.mix == "crossdomain":
# split into two halves and swap the order
perm = torch.arange(B - 1, -1, -1) # inverse index
perm_b, perm_a = perm.chunk(2)
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
perm = torch.cat([perm_b, perm_a], 0)
else:
raise NotImplementedError
mu2, sig2 = mu[perm], sig[perm]
mu_mix = mu*lmda + mu2 * (1-lmda)
sig_mix = sig*lmda + sig2 * (1-lmda)
return x_normed*sig_mix + mu_mix