-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_svm.py
119 lines (96 loc) · 3.85 KB
/
train_svm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# SPDX-License-Identifier: EUPL-1.2
# Copyright (c) 2020, Martynas Janonis
# Licensed under the EUPL-1.2-or-later
import torch
import sys
import numpy as np
from torchvision.models import resnext101_32x8d, densenet201
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, cohen_kappa_score
from networks import EmbeddingNet, TripletNet
from datasets import XRayParcels
from joblib import dump, load
cuda = torch.cuda.is_available()
device = torch.device("cuda") if cuda else torch.device("cpu")
xray_train_dataset = XRayParcels("svm_train.csv", train=True, transform=True)
xray_test_dataset = XRayParcels("svm_test.csv", train=False, transform=False)
batch_size = 64
kwargs = {"num_workers": 1, "pin_memory": True} if cuda else {}
xray_train_loader = torch.utils.data.DataLoader(
xray_train_dataset, batch_size=batch_size, shuffle=True, **kwargs
)
xray_test_loader = torch.utils.data.DataLoader(
xray_test_dataset, batch_size=batch_size, shuffle=False, **kwargs
)
# Initialize and load the embedding network
embedding_net = EmbeddingNet(densenet201())
model = TripletNet(embedding_net)
model.load_state_dict(torch.load("triplet_densenet201_m2.pth", map_location=device))
model.to(device)
model.eval()
# Initialize the SVM
svm = SGDClassifier(
loss="hinge", verbose=0, class_weight={0: 1, 1: 50}, warm_start=True, average=True
)
n_epochs = 50
highest_f1 = 0
for epoch in range(n_epochs):
print("Starting epoch {}".format(epoch))
# Train stage
# Generate #batch_size vectors to pass as a dataset to the SVM
for batch_idx, (data, target) in enumerate(xray_train_loader):
target = target if len(target) > 0 else None
if not type(data) in (tuple, list):
data = (data,)
data = tuple(d.to(device) for d in data)
if target is not None:
target = target.to(device)
with torch.no_grad():
vectors = model.embedding_net(*data)
# Convert from PyTorch tensors to NumPy arrays
vectors = vectors.detach().cpu().numpy()
target = target.detach().cpu().numpy()
# Do one epoch of SGD for the SVM
svm.partial_fit(vectors, target, classes=[0, 1])
message = "Train: [{}/{} ({:.0f}%)]".format(
batch_idx * len(data[0]),
len(xray_train_loader.dataset),
100.0 * batch_idx / len(xray_train_loader),
)
sys.stdout.write("\x1b[2K") # Clear to the end of line
sys.stdout.write("\r" + message)
sys.stdout.flush()
print()
print("Starting validation")
# Test stage
# Generate #batch_size vectors to pass as a dataset to the SVM
y_pred = []
y_true = []
for batch_idx, (data, target) in enumerate(xray_test_loader):
target = target if len(target) > 0 else None
if not type(data) in (tuple, list):
data = (data,)
data = tuple(d.to(device) for d in data)
if target is not None:
target = target.to(device)
with torch.no_grad():
vectors = model.embedding_net(*data)
# Convert from PyTorch tensors to NumPy arrays
vectors = vectors.detach().cpu().numpy()
y_true = np.append(y_true, target.detach().cpu().numpy())
y_pred = np.append(y_pred, svm.predict(vectors))
print(
"Epoch {}/{}. Validation set: Avg. accuracy: {:.4f}, avg. F1 score: {:.4f}, avg. AUC: {:.4f}, avg. Kappa {:.4f}".format(
epoch,
n_epochs,
accuracy_score(y_true, y_pred),
f1_score(y_true, y_pred),
roc_auc_score(y_true, y_pred),
cohen_kappa_score(y_true, y_pred),
)
)
# Save the model if F1 is larger
if f1_score(y_true, y_pred) > highest_f1:
print("F1 score increased, saving model")
dump(svm, "svm.joblib")
highest_f1 = f1_score(y_true, y_pred)