-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorchtools.py
151 lines (128 loc) · 4.95 KB
/
pytorchtools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import numpy as np
import torch
import torch.nn as nn
from typing import Optional
######## UTILS
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience"""
def __init__(self, patience=7, verbose=False, delta=0, mode='higher'):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.best_val_metric = np.Inf # The lower the better
self.delta = delta
self.save_checkpoint = False
if mode == 'higher':
self.mode = 1
elif mode == 'lower':
self.mode = -1
else:
raise ValueError("Bad mode type, please choose between 'higher' and 'lower'")
def __call__(self, train_metric, val_metric, model=None):
val_score = self.mode * val_metric
train_score = self.mode * train_metric
if not torch.isnan(torch.tensor(val_score)):
if self.best_score is None:
self.best_score = val_score
self.save_checkpoint = True
self.best_val_metric = val_metric
elif val_score < self.best_score + self.delta and train_score > val_score + self.delta: # apply patience only if train is better than val scores
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
self.save_checkpoint = False
elif val_score > self.best_score:
self.best_score = val_score
self.save_checkpoint = True
self.best_val_metric = val_metric
self.counter = 0
else:
self.save_checkpoint = False
else:
self.save_checkpoint = False
######## ACTIVATIONS
class Mish(torch.nn.Module):
"""
Applies the mish activation function element-wise:
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
Shape:
- Input: (N, *) where * means, any number of additional
dimensions
- Output: (N, *), same shape as the input
Examples:
>>> m = Mish()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def __init__(self):
"""
Init method.
"""
super().__init__()
@staticmethod
def mish(input):
'''
Applies the mish function element-wise:
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
See additional documentation for mish class.
'''
return input * torch.tanh(torch.nn.functional.softplus(input))
def forward(self, input):
""""
Forward pass of the function.
"""
return self.mish(input)
######### CUSTOM LAYERS
class Flatten(torch.nn.Module):
"""
The flatten layer to build sequential models
"""
def forward(self, input):
return input.view(input.size(0), -1)
class SiameseBlock(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x):
x = torch.cat([self.module(x_i) for x_i in x.transpose(1, 0)], 1)
return x
class ChiaBlock(torch.nn.Module):
def __init__(self, module, axis=-1):
super().__init__()
self.module = module
self.axis = axis
def forward(self, x):
x = torch.stack([self.module(x_i) for x_i in x.transpose(1, 0)], self.axis)
x = torch.mean(x, self.axis)
# x = torch.logsumexp(x, self.axis) / x.size(self.axis)
return x
class AdaptiveConcatPool2d(torch.nn.Module):
"Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`."
def __init__(self, sz: Optional[int] = None):
super().__init__()
"Output will be 2*sz or 2 if sz is None"
self.output_size = sz or 1
self.ap = nn.AdaptiveAvgPool2d(self.output_size)
self.mp = nn.AdaptiveMaxPool2d(self.output_size)
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
############# LOSSES
class BinnedBCE(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.loss = nn.BCEWithLogitsLoss(*args, **kwargs)
self.binning = lambda x: torch.cat([torch.ones(x), torch.zeros(5-x)])
def forward(self, output, target: torch.Tensor):
bin_target = torch.stack([self.binning(t) for t in target.tolist()], dim=0).to(target.device)
return self.loss(output, bin_target)