-
Notifications
You must be signed in to change notification settings - Fork 0
/
process.py
105 lines (98 loc) · 3.95 KB
/
process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os, pathlib
import numpy as np
import medpy.io
from monai.inferers import sliding_window_inference
from monai.networks.layers import Norm
from monai.networks.nets import UNet
import monai.transforms as tf
from monai.utils import set_determinism
import torch
from settings import loader_settings
class Seg:
def __init__(self):
# super().__init__(
# validators=dict(
# input_image=(
# UniqueImagesValidator(),
# UniquePathIndicesValidator(),
# )
# ),
# )
return
def process(self):
inp_path = loader_settings["InputPath"] # Path for the input
out_path = loader_settings["OutputPath"] # Path for the output
file_list = os.listdir(inp_path) # List of files in the input
file_list = [os.path.join(inp_path, f) for f in file_list]
for fil in file_list:
dat, hdr = medpy.io.load(fil) # dat is a numpy array
im_shape = dat.shape
dat = dat.reshape(1, 1, *im_shape) # reshape to Pytorch standard
# Convert 'dat' to Tensor, or as appropriate for your model.
###########
### Replace this section with the call to your code.
transform = tf.Compose(
[
tf.NormalizeIntensityd(keys=["image"], channel_wise=True),
tf.CopyItemsd(keys=["image"], times=1, names=["flipped_image"]),
tf.Flipd(keys=["flipped_image"], spatial_axis=0),
tf.ConcatItemsd(keys=["image", "flipped_image"], name="image"),
tf.ToTensord(keys=["image"]),
]
)
device = torch.device("cuda:0")
model = UNet(
dimensions=3,
in_channels=2,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
dropout=0.2,
).to(device)
set_determinism(seed=0)
model.load_state_dict(
torch.load(
"best_metric_model.pth"
) # TODO copy .pth file to main working dir
)
model.eval()
roi_size = (96,) * 3
sw_batch_size = 4
post_pred = tf.AsDiscrete(argmax=True, to_onehot=2)
img = dat[0] # squeeze batch dimension for now
img_dict = {"image": img} # build dict to feed to MONAI transforms
img_dict = transform(img_dict) # apply MONAI transforms
img = img_dict["image"] # extract tensor from dict
img = img.unsqueeze(0) # add batch dimension manually
img = img.to(device) # cast to device
amp = True # TODO toggle if AMP causes problems
if amp:
with torch.cuda.amp.autocast():
prediction = sliding_window_inference(
img, roi_size, sw_batch_size, model
) # prediction has batch dim
else:
prediction = sliding_window_inference(
img, roi_size, sw_batch_size, model
) # prediction has batch dim
prediction = post_pred(prediction[0])[1]
dat = prediction
dat = dat.cpu().detach().numpy()
# mean_dat = np.mean(dat)
# dat[dat > mean_dat] = 1
# dat[dat <= mean_dat] = 0
###
###########
# dat = dat.reshape(*im_shape)
out_name = os.path.basename(fil)
out_filepath = os.path.join(out_path, out_name)
print(f"=== saving {out_filepath} from {fil} ===")
medpy.io.save(dat, out_filepath, hdr=hdr)
return
if __name__ == "__main__":
pathlib.Path("/output/images/stroke-lesion-segmentation/").mkdir(
parents=True, exist_ok=True
)
Seg().process()