-
Notifications
You must be signed in to change notification settings - Fork 171
/
adda.py
85 lines (66 loc) · 2.78 KB
/
adda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import copy
import torch
import torch.nn as nn
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import check_isfile, count_num_param, open_specified_layers
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.modeling import build_head
@TRAINER_REGISTRY.register()
class ADDA(TrainerXU):
"""Adversarial Discriminative Domain Adaptation.
https://arxiv.org/abs/1702.05464.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.open_layers = ["backbone"]
if isinstance(self.model.head, nn.Module):
self.open_layers.append("head")
self.source_model = copy.deepcopy(self.model)
self.source_model.eval()
for param in self.source_model.parameters():
param.requires_grad_(False)
self.build_critic()
self.bce = nn.BCEWithLogitsLoss()
def check_cfg(self, cfg):
assert check_isfile(
cfg.MODEL.INIT_WEIGHTS
), "The weights of source model must be provided"
def build_critic(self):
cfg = self.cfg
print("Building critic network")
fdim = self.model.fdim
critic_body = build_head(
"mlp",
verbose=cfg.VERBOSE,
in_features=fdim,
hidden_layers=[fdim, fdim // 2],
activation="leaky_relu",
)
self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1))
print("# params: {:,}".format(count_num_param(self.critic)))
self.critic.to(self.device)
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
def forward_backward(self, batch_x, batch_u):
open_specified_layers(self.model, self.open_layers)
input_x, _, input_u = self.parse_batch_train(batch_x, batch_u)
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
_, feat_x = self.source_model(input_x, return_feature=True)
_, feat_u = self.model(input_u, return_feature=True)
logit_xd = self.critic(feat_x)
logit_ud = self.critic(feat_u.detach())
loss_critic = self.bce(logit_xd, domain_x)
loss_critic += self.bce(logit_ud, domain_u)
self.model_backward_and_update(loss_critic, "critic")
logit_ud = self.critic(feat_u)
loss_model = self.bce(logit_ud, 1 - domain_u)
self.model_backward_and_update(loss_model, "model")
loss_summary = {
"loss_critic": loss_critic.item(),
"loss_model": loss_model.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary