-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathfocalloss.py
38 lines (30 loc) · 1.45 KB
/
focalloss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
####################################################
##### This is focal loss class for multi class #####
####################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
# I refered https://github.com/c0nn3r/RetinaNet/blob/master/focal_loss.py
class FocalLoss2d(nn.modules.loss._WeightedLoss):
def __init__(self, gamma=2, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction='mean', balance_param=0.25):
super(FocalLoss2d, self).__init__(weight, size_average, reduce, reduction)
self.gamma = gamma
self.weight = weight
self.size_average = size_average
self.ignore_index = ignore_index
self.balance_param = balance_param
def forward(self, input, target):
# inputs and targets are assumed to be BatchxClasses
assert len(input.shape) == len(target.shape)
assert input.size(0) == target.size(0)
assert input.size(1) == target.size(1)
weight = Variable(self.weight)
# compute the negative likelyhood
logpt = - F.binary_cross_entropy_with_logits(input, target, pos_weight=weight, reduction=self.reduction)
pt = torch.exp(logpt)
# compute the loss
focal_loss = -( (1-pt)**self.gamma ) * logpt
balanced_focal_loss = self.balance_param * focal_loss
return balanced_focal_loss