-
Notifications
You must be signed in to change notification settings - Fork 93
/
rnn_train.py
executable file
·654 lines (558 loc) · 21.3 KB
/
rnn_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
#!/usr/bin python3
# Copyright 2021 Seonghun Noh
import argparse
import logging
import os
import io
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
import h5py
import argparse
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import torch.nn.utils.rnn as rnn_utils
from collections import defaultdict
import yaml
import glob
from tqdm import tqdm
import PIL.Image
from torchvision.transforms import ToTensor
plt.switch_backend('agg')
class CppRawListDataset(Dataset):
def __init__(self, filelist_path, train_length_size=500):
self.train_length_size = train_length_size
self.filelist_path = filelist_path
self.x_dim = 70
self.y_dim = 68
with open(filelist_path, "r") as f:
self.filelist = [filepath.rstrip('\n') for filepath in f.readlines()]
self.nb_sequences = len(self.filelist)
print(self.nb_sequences, ' sequences')
def __len__(self):
return self.nb_sequences
def __getitem__(self, index):
with open(self.filelist[index], 'rb') as cpp_out:
all_data = np.fromfile(cpp_out, np.float32)
all_data = np.reshape(all_data, (self.train_length_size,138))
#make it band energy 30 times bigger for compansating low energy
all_data[:,:68] = all_data[:,:68]*30
x = all_data[:,:self.x_dim]
y = all_data[:,self.x_dim:]
return (x,y)
class h5DirDataset(Dataset):
def __init__(self, h5_dir_path, train_length_size=500):
self.train_length_size = train_length_size
self.h5_dir_path = h5_dir_path
self.x_dim = 70
self.y_dim = 68
self.h5_filelist = glob.glob(os.path.join(h5_dir_path, "*.h5"))
self.nb_sequences = len(self.h5_filelist)
print(self.nb_sequences, ' sequences')
def __len__(self):
return self.nb_sequences
def __getitem__(self, index):
with h5py.File(self.h5_filelist[index], 'r') as hf:
all_data = hf['data'][:]
x = all_data[:,:self.x_dim]
y = all_data[:,self.x_dim:]
return (x,y)
class h5Dataset(Dataset):
def __init__(self, h5_filename="training.h5", window_size=500):
self.window_size = window_size
self.h5_filename = h5_filename
self.x_dim = 70
self.y_dim = 68
#read h5file
with h5py.File(self.h5_filename, 'r') as hf:
all_data = hf['data'][:]
self.nb_sequences = len(all_data)//window_size
print(self.nb_sequences, ' sequences')
x_train = all_data[:self.nb_sequences*self.window_size, :self.x_dim]
self.x_train = np.reshape(x_train, (self.nb_sequences, self.window_size, self.x_dim))
#pad 3 for each batch .. not sure it's right
#self.x_train = np.pad(self.x_train,[(0,0),(3,3),(0,0)],'constant')
y_train = np.copy(all_data[:self.nb_sequences*self.window_size, self.x_dim:self.x_dim+self.y_dim])
self.y_train = np.reshape(y_train, (self.nb_sequences, self.window_size, self.y_dim))
def __len__(self):
return self.nb_sequences
def __getitem__(self, index):
return (self.x_train[index], self.y_train[index])
class PercepNet(nn.Module):
def __init__(self, input_dim=70):
super(PercepNet, self).__init__()
#self.hidden_dim = hidden_dim
#self.n_layers = n_layers
self.fc = nn.Sequential(nn.Linear(input_dim, 128), nn.ReLU())
self.conv1 = nn.Sequential(nn.Conv1d(128, 512, 5, stride=1, padding=4), nn.ReLU())#padding for align with c++ dnn
self.conv2 = nn.Sequential(nn.Conv1d(512, 512, 3, stride=1, padding=2), nn.Tanh())
#self.gru = nn.GRU(512, 512, 3, batch_first=True)
self.gru1 = nn.GRU(512, 512, 1, batch_first=True)
self.gru2 = nn.GRU(512, 512, 1, batch_first=True)
self.gru3 = nn.GRU(512, 512, 1, batch_first=True)
self.gru_gb = nn.GRU(512, 512, 1, batch_first=True)
self.gru_rb = nn.GRU(1024, 128, 1, batch_first=True)
self.fc_gb = nn.Sequential(nn.Linear(512*5, 34), nn.Sigmoid())
self.fc_rb = nn.Sequential(nn.Linear(128, 34), nn.Sigmoid())
def forward(self, x):
x = self.fc(x)
x = x.permute([0,2,1]) # B, D, T
x = self.conv1(x)
x = x[:,:,:-4]
convout = self.conv2(x)
convout = convout[:,:,:-2]#align with c++ dnn
convout = convout.permute([0,2,1]) # B, T, D
gru1_out, gru1_state = self.gru1(convout)
gru2_out, gru2_state = self.gru2(gru1_out)
gru3_out, gru3_state = self.gru3(gru2_out)
gru_gb_out, gru_gb_state = self.gru_gb(gru3_out)
concat_gb_layer = torch.cat((convout,gru1_out,gru2_out,gru3_out,gru_gb_out),-1)
gb = self.fc_gb(concat_gb_layer)
#concat rb need fix
concat_rb_layer = torch.cat((gru3_out,convout),-1)
rnn_rb_out, gru_rb_state = self.gru_rb(concat_rb_layer)
rb = self.fc_rb(rnn_rb_out)
output = torch.cat((gb,rb),-1)
return output
def test():
model = PercepNet()
x = torch.randn(20, 8, 70)
out = model(x)
print(out.shape)
class CustomLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(CustomLoss, self).__init__()
def forward(self, outputs, targets):
gamma = 0.5
C4 = 10
epsi = 1e-10
gb_hat = outputs[:,:,:34]
rb_hat = outputs[:,:,34:68]
gb = targets[:,:,:34]
rb = targets[:,:,34:68]
'''
total_loss=0
for i in range(500):
total_loss += (torch.sum(torch.pow((torch.pow(gb[:,i,:],gamma) - torch.pow(gb_hat[:,i,:],gamma)),2))) \
+ C4*torch.sum(torch.pow(torch.pow(gb[:,i,:],gamma) - torch.pow(gb_hat[:,i,:],gamma),4)) \
+ torch.sum(torch.pow(torch.pow((1-rb[:,i,:]),gamma)-torch.pow((1-rb_hat[:,i,:]),gamma),2))
return total_loss
'''
return (torch.mean(torch.pow((torch.pow(gb,gamma) - torch.pow(gb_hat,gamma)),2))) \
+ C4*torch.mean(torch.pow(torch.pow(gb,gamma) - torch.pow(gb_hat,gamma),4)) \
+ torch.mean(torch.pow(torch.pow((1-rb),gamma)-torch.pow((1-rb_hat),gamma),2))
def train():
parser = argparse.ArgumentParser()
writer = SummaryWriter()
UseCustomLoss = True
dataset = h5Dataset("training.h5")
trainset_ratio = 1 # 1 - validation set ration
train_size = int(trainset_ratio * len(dataset))
test_size = len(dataset) - train_size
batch_size=10
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#validation_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
model = PercepNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
if UseCustomLoss:
#CustomLoss cause Nan error need fix
criterion = CustomLoss()
else:
criterion = nn.MSELoss()
num_epochs = 10000
for epoch in range(num_epochs): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, targets = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
#outputs = torch.cat(outputs,-1)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
# for testing
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, loss.item()))
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
model.eval()
tmp_output = model(torch.tensor(dataset[0][0]).unsqueeze(0))
model.train()
fig = plt.figure()
plt.plot(tmp_output[0].squeeze(0).T.detach().numpy())
writer.add_figure('output gb', fig, global_step=epoch)
fig = plt.figure()
plt.plot(dataset[0][1][:,:].T)
writer.add_figure('target gb', fig, global_step=epoch)
writer.add_scalar('loss', loss.item(), global_step=epoch)
print('Finished Training')
print('save model')
writer.close()
torch.save(model.state_dict(), 'model.pt')
def gen_plot(y, y_hat):
# Create a figure to contain the plot.
plt.figure(figsize=(10,5))
# Start next subplot.
plt.subplot(1, 2, 1)
plt.imshow(y_hat.T,interpolation='none',cmap=plt.cm.jet,origin='lower',aspect='auto')
plt.subplot(1, 2, 2)
plt.imshow(y.T,interpolation='none',cmap=plt.cm.jet,origin='lower',aspect='auto')
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
return buf
class Trainer(object):
"""Customized trainer module for PercepNet training."""
def __init__(
self,
steps,
epochs,
data_loader,
sampler,
model,
criterion,
optimizer,
args,
config,
device=torch.device("cpu"),
):
"""Initialize trainer.
Args:
steps (int): Initial global steps.
epochs (int): Initial global epochs.
data_loader (dict): Dict of data loaders. It must contrain "train" and "dev" loaders.
model (nn.Module): Model. Instance of nn.Module
criterion (nn.Module): criterions.
optimizer (torch.optim): optimizers.
args (parser.parse_args()): Instance of argparse parse_args()
device (torch.deive): Pytorch device instance.
"""
self.steps = steps
self.epochs = epochs
self.data_loader = data_loader
self.sampler = sampler
self.model = model
self.criterion = criterion
self.args = args
self.optimizer = optimizer
self.device = device
self.config = config
self.writer = SummaryWriter(config["out_dir"])
self.finish_train = False
self.total_train_loss = defaultdict(float)
self.total_eval_loss = defaultdict(float)
def run(self):
"""Run training."""
self.tqdm = tqdm(
initial=self.steps, total=self.args.train_max_steps, desc="[train]"
)
while True:
# train one epoch
self._train_epoch()
# check whether training is finished
if self.finish_train:
break
self.tqdm.close()
logging.info("Finished training.")
def save_checkpoint(self, checkpoint_path):
"""Save checkpoint."""
torch.save(self.model.state_dict(), checkpoint_path)
def load_checkpoint(self, checkpoint_path):
"""Load checkpoint.
Args:
checkpoint_path (str): Checkpoint path to be loaded.
"""
state_dict = torch.load(checkpoint_path, map_location="cpu")
if self.args.distributed:
self.model.module.load_state_dict(state_dict)
else:
self.model.load_state_dict(state_dict)
def _train_step(self, batch):
"""Train model one step."""
# get the inputs; data is a list of [inputs, labels]
inputs, targets = batch
inputs = inputs.to(self.device)
targets = targets.to(self.device)
# zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
outputs = self.model(inputs)
#outputs = torch.cat(outputs,-1)
loss = self.criterion(outputs, targets)
loss.backward()
self.optimizer.step()
self.total_train_loss["train/total_loss"] += loss.item()
# update counts
self.steps += 1
self.tqdm.update(1)
self._check_train_finish()
def _train_epoch(self):
"""Train model one epoch."""
for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1):
# train one step
self._train_step(batch)
# check interval
if self.args.rank == 0:
self._check_log_interval()
self._check_eval_interval()
self._check_save_interval()
# check whether training is finished
if self.finish_train:
return
# update
self.epochs += 1
self.train_steps_per_epoch = train_steps_per_epoch
logging.info(
f"(Steps: {self.steps}) Finished {self.epochs} epoch training "
f"({self.train_steps_per_epoch} steps per epoch)."
)
@torch.no_grad()
def _eval_step(self, batch):
"""Evaluate model one step."""
# parse batch
inputs, targets = batch
inputs = inputs.to(self.device)
targets = targets.to(self.device)
# forward + backward + optimize
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
self.total_eval_loss["eval/total_loss"] += loss.item()
def _eval_epoch(self):
"""Evaluate model one epoch."""
logging.info(f"(Steps: {self.steps}) Start evaluation.")
# change mode
self.model.eval()
# calculate loss for each batch
for eval_steps_per_epoch, batch in enumerate(
tqdm(self.data_loader["dev"], desc="[eval]"), 1
):
# eval one step
self._eval_step(batch)
# save intermediate result
if eval_steps_per_epoch == 1:
self._genearete_and_save_intermediate_result(batch)
logging.info(
f"(Steps: {self.steps}) Finished evaluation "
f"({eval_steps_per_epoch} steps per epoch)."
)
# average loss
for key in self.total_eval_loss.keys():
self.total_eval_loss[key] /= eval_steps_per_epoch
logging.info(
f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}."
)
# record
self._write_to_tensorboard(self.total_eval_loss)
# reset
self.total_eval_loss = defaultdict(float)
# restore mode
self.model.train()
@torch.no_grad()
def _genearete_and_save_intermediate_result(self, batch):
"""Generate and save intermediate result."""
# delayed import to avoid error related backend error
import matplotlib.pyplot as plt
# generate
x_batch, y_batch = batch
x_batch = x_batch.to(self.device)
y_batch = y_batch.to(self.device)
y_batch_ = self.model(x_batch)
# check directory
#dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")
#if not os.path.exists(dirname):
# os.makedirs(dirname)
for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 1):
if idx==1:
# convert to ndarray
y, y_ = y.cpu().numpy(), y_.cpu().numpy()
plot_buf=gen_plot(y, y_)
image = PIL.Image.open(plot_buf)
image = ToTensor()(image)
self.writer.add_image('rb,gb hat', image, self.steps)
print("writeimage")
def _write_to_tensorboard(self, loss):
"""Write to tensorboard."""
for key, value in loss.items():
self.writer.add_scalar(key, value, self.steps)
def _check_save_interval(self):
if self.steps % self.config["save_interval_steps"] == 0:
self.save_checkpoint(
os.path.join(self.config["out_dir"], f"checkpoint-{self.steps}steps.pkl")
)
logging.info(f"Successfully saved checkpoint @ {self.steps} steps.")
def _check_eval_interval(self):
if self.steps % self.config["eval_interval_steps"] == 0:
self._eval_epoch()
def _check_log_interval(self):
if self.steps % self.config["log_interval_steps"] == 0:
for key in self.total_train_loss.keys():
self.total_train_loss[key] /= self.config["log_interval_steps"]
logging.info(
f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}."
)
self._write_to_tensorboard(self.total_train_loss)
# reset
self.total_train_loss = defaultdict(float)
def _check_train_finish(self):
if self.steps >= self.config["train_max_steps"]:
self.finish_train = True
def main():
"""Run training process."""
parser = argparse.ArgumentParser(
description="Train PercepNet (See detail in rnn_train.py)."
)
parser.add_argument(
"--train_length_size",
default=2000,
type=int,
help="RNN network train length size.",
)
parser.add_argument(
"--train_max_steps",
default=100000,
type=int,
help="max train steps.",
)
parser.add_argument(
"--train_filelist_path",
type=str,
required=True,
help="cpp generated feature train filelist path.",
)
parser.add_argument(
"--dev_filelist_path",
type=str,
required=True,
help="cpp generated feature dev filelist path",
)
parser.add_argument(
"--pretrain",
default="",
type=str,
nargs="?",
help='checkpoint file path to load pretrained params. (default="")',
)
parser.add_argument(
"--rank",
"--local_rank",
default=0,
type=int,
help="rank for distributed training. no need to explictly specify.",
)
parser.add_argument(
"--out_dir",
type=str,
required=True,
help="directory to save checkpoints.",
)
parser.add_argument(
"--config",
type=str,
required=True,
help="yaml format configuration file.",
)
args = parser.parse_args()
args.distributed = False
if not torch.cuda.is_available():
device = torch.device("cpu")
else:
device = torch.device("cuda")
# effective when using fixed size inputs
# see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
torch.backends.cudnn.benchmark = True
torch.cuda.set_device(args.rank)
# setup for distributed training
# see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
if "WORLD_SIZE" in os.environ:
args.world_size = int(os.environ["WORLD_SIZE"])
args.distributed = args.world_size > 1
if args.distributed:
torch.distributed.init_process_group(backend="nccl", init_method="env://")
# load and save config
with open(args.config) as f:
config = yaml.load(f, Loader=yaml.Loader)
config.update(vars(args))
with open(os.path.join(args.out_dir, "config.yml"), "w") as f:
yaml.dump(config, f, Dumper=yaml.Dumper)
for key, value in config.items():
logging.info(f"{key} = {value}")
model = PercepNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = CustomLoss()
train_dataset = CppRawListDataset(
args.train_filelist_path, train_length_size=args.train_length_size)
dev_dataset = CppRawListDataset(
args.dev_filelist_path, train_length_size=args.train_length_size)
logging.info(f"The number of training files = {len(train_dataset)}.")
logging.info(f"The number of training files = {len(dev_dataset)}.")
dataset = {
"train": train_dataset,
"dev": dev_dataset,
}
sampler = {"train": None, "dev": None}
if args.distributed:
# setup sampler for distributed training
from torch.utils.data.distributed import DistributedSampler
sampler["train"] = DistributedSampler(
dataset=dataset["train"],
num_replicas=args.world_size,
rank=args.rank,
shuffle=True,
)
sampler["dev"] = DistributedSampler(
dataset=dataset["dev"],
num_replicas=args.world_size,
rank=args.rank,
shuffle=False,
)
data_loader = {
"train" : torch.utils.data.DataLoader(
dataset["train"],
batch_size=config["batch_size"],
num_workers=config["num_workers"],
shuffle=True
),
"dev": torch.utils.data.DataLoader(
dataset["dev"],
batch_size=config["batch_size"],
num_workers=config["num_workers"],
shuffle=False
)
}
# define trainer
trainer = Trainer(
steps=0,
epochs=0,
model=model,
data_loader=data_loader,
criterion=criterion,
optimizer=optimizer,
config=config,
args=args,
sampler=sampler,
device=device,
)
# load pretrained parameters from checkpoint
if len(args.pretrain) != 0:
trainer.load_checkpoint(args.pretrain)
logging.info(f"Successfully load parameters from {args.pretrain}.")
# run training loop
try:
trainer.run()
finally:
trainer.save_checkpoint(
os.path.join(config["out_dir"], f"checkpoint-{trainer.steps}steps.pt")
)
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
if __name__ == '__main__':
main()
#train()