-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_all.py
86 lines (61 loc) · 2.48 KB
/
run_all.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import time
from matplotlib import pyplot as plt
from train import Model
if __name__ == "__main__":
bert_resnet_with_concat = Model(model=0)
start_time = time.time()
bert_resnet_with_concat.train()
end_time = time.time()
print("Training time:", end_time - start_time)
bert_resnet_with_attention = Model(model=1)
start_time = time.time()
bert_resnet_with_attention.train()
end_time = time.time()
print("Training time:", end_time - start_time)
bert_densenet_with_concat = Model(model=2)
start_time = time.time()
bert_densenet_with_concat.train()
end_time = time.time()
print("Training time:", end_time - start_time)
bert_densenet_with_attention = Model(model=3)
start_time = time.time()
bert_densenet_with_attention.train()
end_time = time.time()
print("Training time:", end_time - start_time)
resnet = Model(model=0, ablate=1)
start_time = time.time()
resnet.train()
end_time = time.time()
print("Training time:", end_time - start_time)
bert = Model(model=0, ablate=2)
start_time = time.time()
bert.train()
end_time = time.time()
print("Training time:", end_time - start_time)
densenet = Model(model=2, ablate=1)
start_time = time.time()
densenet.train()
end_time = time.time()
print("Training time:", end_time - start_time)
plt.plot(bert_resnet_with_concat.train_loss, label="BertResnetWithConcat")
plt.plot(bert_resnet_with_attention.train_loss, label="BertResnetWithAttention")
plt.plot(bert_densenet_with_concat.train_loss, label="BertDensenetWithConcat")
plt.plot(bert_densenet_with_attention.train_loss, label="BertDensenetWithAttention")
plt.plot(resnet.train_loss, label="Resnet Only")
plt.plot(bert.train_loss, label="Bert Only")
plt.plot(densenet.train_loss, label="Densenet Only")
plt.title("Train Loss")
plt.legend()
plt.show()
plt.plot(bert_resnet_with_concat.val_accuracy, label="BertResnetWithConcat")
plt.plot(bert_resnet_with_attention.val_accuracy, label="BertResnetWithAttention")
plt.plot(bert_densenet_with_concat.val_accuracy, label="BertDensenetWithConcat")
plt.plot(
bert_densenet_with_attention.val_accuracy, label="BertDensenetWithAttention"
)
plt.plot(resnet.val_accuracy, label="Resnet Only")
plt.plot(bert.val_accuracy, label="Bert Only")
plt.plot(densenet.val_accuracy, label="Densenet Only")
plt.title("Validation Accuracy")
plt.legend()
plt.show()