-
Notifications
You must be signed in to change notification settings - Fork 0
/
doc4_Metrics_in_epoch_end_functions.py
70 lines (49 loc) · 1.68 KB
/
doc4_Metrics_in_epoch_end_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
Basic LightningModule structure must be as
class LightningNet(pl.LightningModule):
def __init__(...):
...
def forward(...):
...
def training_step(...):
...
def configure_optimizers(...):
...
def training_epoch_end(...):
...
def validation_epoch_end(...):
...
"""
"""
METHOD-1 ==> This method calculate metrics value of only current train epoch
"""
def training_epoch_end(self, outputs):
Tacc = 0
Tprec = 0
Trec = 0
Tf1_score = 0
for pred in outputs:
acc = float(self.accuracy(pred['scores'], pred['y']))
prec = float(self.precision_(pred['scores'], pred['y']))
rec = float(self.recall(pred['scores'], pred['y']))
f1_score = float(self.f1(pred['scores'], pred['y']))
Tacc = Tacc+acc
Tprec = Tprec+prec
Trec = Trec+rec
Tf1_score = Tf1_score+f1_score
data = [[f"Train[Epoch: {self.epoch}]", Tacc / len(outputs), Tprec / len(outputs), Trec / len(outputs),
Tf1_score / len(outputs)]]
headers = ["Type", 'Accuracy', 'Precision', 'Recall', 'F1 Score']
self.epoch += 1
print(tabulate(data, headers=headers))
print("\n")
"""
METHOD-2 ==> This method calculate average of metrics value of passed train epochs
"""
def training_epoch_end(self, outputs):
data = [[f"Train[Epoch: {self.epoch}]", self.accuracy.compute(), self.precision_.compute(), self.recall.compute(),
self.f1.compute()]]
headers = ["Type", 'Accuracy', 'Precision', 'Recall', 'F1 Score']
self.epoch += 1
print(tabulate(data, headers=headers))
print("\n")