-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
60 lines (40 loc) · 2.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from config.model_config import Device_config, Image_config, Model_config
from source.dataloader import AmosDataLoader
from source.model import UNet3D
from source.train import train_model, validate_model
class Run_Segmentation():
def __init__(self, input_paths, target_paths, input_paths_validation, target_paths_validation) -> None:
self.input_paths = input_paths
self.target_paths = target_paths
self.input_paths_validation = input_paths_validation
self.target_paths_validation = target_paths_validation
self.batch_size = Model_config['BATCH_SIZE']
self.num_class = Model_config['NUM_CLASS']
self.input_chan = Model_config['INPUT_DIM']
self.output_dim = Model_config['OUTPUT_CHANNEL']
self.num_epochs = Model_config['EPOCHS']
self.device = Device_config['device']
def run_train_model(self):
data = AmosDataLoader(self.input_paths, self.target_paths)
train_loader= DataLoader(data, batch_size = self.batch_size, drop_last= True, collate_fn=data.collate_fn)
model = UNet3D(self.input_chan, self.num_class).to(self.device)
n_epoch = self.num_epochs
criteria = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
train = train_model(model, optimizer, criteria, n_epoch, train_loader)
train.train_model()
def run_validation(self):
model = UNet3D(self.input_chan, self.num_class).to(self.device)
data = AmosDataLoader(self.input_paths_validation, self.target_paths_validation)
loss = validate_model(data, model)
print(loss)
# import os
# from glob import glob
# path = '/Users/arshad_221b/Downloads/Projects/create_shorts-main/MedSeg/AMOS/amos22/'
# input_paths = sorted(glob(os.path.join(path, "imagesVa","*.nii.gz")))
# target_paths = sorted(glob(os.path.join(path, "labelsVa","*.nii.gz")))
# r = Run_Segmentation(input_paths, target_paths)
# r.run_train_model()