-
Notifications
You must be signed in to change notification settings - Fork 1
/
lidc.py
executable file
·129 lines (117 loc) · 5.5 KB
/
lidc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import numpy as np
import torchvision
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from unet import Unet
from LIDC.load_LIDC import LIDC
# from cityscape.city_dataset import CityDataset
import torch.nn.functional as F
from prob import ProbabilisticUnet
from utils import l2_regularisation
def iou_1class(outputs: torch.Tensor, labels: torch.Tensor):
# outputs = outputs.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
SMOOTH = 1e-6
labels = labels / labels.max()
outputs = outputs / outputs.max()
outidx = (outputs != 0)
labelidx = (labels != 0)
intersection = outidx[labelidx].long().sum()
union = outidx.long().sum() + (labelidx.long().sum()) - intersection
iou = (float(intersection) + SMOOTH) / (float(union) + SMOOTH)
return iou # Or thresholded.mean() if you are interested in average across the batch
torch.random.manual_seed(42)
def train(batch_size, epochs, gpu, val_after=100, lr=1e-4):
device = torch.device('cuda' if (torch.cuda.is_available() and gpu) else 'cpu')
train_dataset = LIDC('train', transform=torchvision.transforms.ToTensor())
val_dataset = LIDC('val', transform=torchvision.transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
print("Number of training/validation patches:", (len(train_dataset), len(val_dataset)))
net = ProbabilisticUnet(input_channels=1, num_classes=1, latent_dim=2)
# net = Unet(3, 34, [32, 64, 128, 192], {'w': 'he_normal', 'b': 'normal'},
# apply_last_layer=True, padding=True)
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=0)
max_val_loss = 100000
current_step = 0
training_losses = []
val_losses = []
total_step = len(train_loader) * epochs
for epoch in range(epochs):
print(f"Epoch {epoch + 1}")
for step, (patch, mask) in enumerate(train_loader):
current_step += 1
if current_step % 25 == 0:
print(f"Step {current_step} of {total_step}")
net.train()
patch = patch.to(device)
mask = mask.to(device)
# out = net.forward(patch, False)
# criterion = torch.nn.CrossEntropyLoss()
# loss = criterion(out, mask.long().squeeze(1))
net.forward(patch, mask, training=True)
elbo = net.elbo(mask)
reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb.layers)
loss = -elbo + 1e-5 * reg_loss
if current_step % 25 == 0:
print(f"Training Loss = {loss.item()}")
if current_step % val_after == 0:
print(f"Recording Loss")
training_losses.append(loss.detach())
optimizer.zero_grad()
loss.backward()
optimizer.step()
if current_step % val_after == 0:
print("------------------")
print("Start validating")
net.eval()
val_loss = 0
with torch.no_grad():
for step, (val_patch, val_mask) in enumerate(val_loader):
val_patch = val_patch.to(device)
val_mask = val_mask.to(device)
# val_out = net.forward(val_patch, False)
# val_criterion = torch.nn.CrossEntropyLoss()
# val_loss += val_criterion(val_out,
# val_mask.long().squeeze(1))
net.forward(val_patch, None)
val_elbo = net.elbo(val_mask, analytic_kl=False)
val_loss += -val_elbo + 1e-5 * reg_loss
val_loss = val_loss / len(val_loader)
print(f"Validation Loss = {val_loss}")
val_losses.append(val_loss.detach())
if val_loss < max_val_loss:
print("Saving model")
max_val_loss = val_loss.detach()
torch.save(net.state_dict(), "./probmodel_lidc")
print("-------------------")
np.save('loss_train_lidc', training_losses)
np.save('loss_val_lidc', val_losses)
def test(gpu):
device = torch.device('cuda' if (torch.cuda.is_available() and gpu) else 'cpu')
test_dataset = LIDC('test', transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=1)
net = ProbabilisticUnet(input_channels=1, num_classes=1, latent_dim=2)
net.load_state_dict(torch.load("./probmodel_lidc"))
net.to(device)
iou = 0
empty = 0
cnt = 0
print("Start testing")
print("------------------")
for step, (patch, mask) in enumerate(test_loader):
with torch.no_grad():
patch = patch.to(device)
mask = mask.to(device)
net.forward(patch, None, training=False)
out = net.sample(testing=True)
# out = net.forward(patch, False)
if out.shape[1] == 1:
out = torch.cat((out, 1 - out), dim=1)
out = torch.max(out, 1, True).indices
iou += iou_1class(out, mask)
if mask.sum() == 0:
empty += 1
cnt += 1
print(f"Mean IoU = {iou / (cnt - empty) * 100}%")