-
Notifications
You must be signed in to change notification settings - Fork 21
/
losses.py
94 lines (74 loc) · 3.01 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import print_function
import torch
import torch.nn as nn
class SupConLoss(nn.Module):
def __init__(self, temperature=0.01):
super(SupConLoss, self).__init__()
self.temperature = temperature
def forward(self, features, labels):
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
batch_size = features.shape[0]
labels = labels.contiguous().view(-1, 1)
mask = torch.eq(labels, labels.T).float().to(device)
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(features, features.T),
self.temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# loss
loss = - mean_log_prob_pos
loss = loss.mean()
return loss
# clear those instances that have no positive instances to avoid training error
class SupConLoss_clear(nn.Module):
def __init__(self, temperature=0.07):
super(SupConLoss_clear, self).__init__()
self.temperature = temperature
def forward(self, features, labels):
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
batch_size = features.shape[0]
labels = labels.contiguous().view(-1, 1)
mask = torch.eq(labels, labels.T).float().to(device)
anchor_dot_contrast = torch.div(
torch.matmul(features, features.T),
self.temperature)
# normalize the logits for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
single_samples = (mask.sum(1) == 0).float()
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# compute mean of log-likelihood over positive
# invoid to devide the zero
mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1)+single_samples)
# loss
# filter those single sample
loss = - mean_log_prob_pos*(1-single_samples)
loss = loss.sum()/(loss.shape[0]-single_samples.sum())
return loss