-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
52 lines (45 loc) · 1.54 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import time
import torch
import numpy as np
import random
from utils.func import print_options
from arch.spectral_upsample import Spectral_upsample
from config import args
from Data_loader import Dataset
from model import sr_model
# 设置固定的输入值
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
Spectral_up_net = Spectral_upsample(args, args.msi_channel, args.hsi_channel, init_type='normal', init_gain=0.02,
initializer=False)
#store the training configuration in opt.txt
#setting
setup_seed(2) #seed is set to 2
print_options(args)
train_dataset=Dataset(args)
down_model = sr_model.DownModel(train_dataset)
# stage one train
down_model()
from utils.func import save_net
save_net(args,down_model.Spectral_down_net)
save_net(args,down_model.Spatial_down_net)
# begin stage 2
up_model = sr_model.UpModel(Spectral_up_net, down_model)
# stage two train
up_model()
from utils.func import save_hhsi
#from utils.func import save_net
##save trained three module
save_net(args,up_model.Spectral_up_net)
est_hhsi = up_model.Spectral_up_net(train_dataset[0]["hmsi"].unsqueeze(0).to(args.device))
###save estimated HHSI
hrhsi=est_hhsi.data.cpu().float().numpy()[0].transpose(1,2,0)
save_hhsi(args,est_hhsi)
print(args)
print('all done')
print("end")