-
Notifications
You must be signed in to change notification settings - Fork 37
/
data_utils_torch.py
40 lines (36 loc) · 1.43 KB
/
data_utils_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
from torch.utils.data import TensorDataset,DataLoader
import torch
import os,pickle
def load_dataset(batch_size,dir='data',n_workers=0,test_size=16384,total_size=None):
print ("Loading data...")
data = np.load(os.path.join(dir,'captcha.npz'))
image = data['img'].astype(np.float32)/127.5-1
text = data['text']
print ("Loading dictionary...")
vocab = pickle.load(open(os.path.join(dir,'captcha.vocab_dict'),'rb'),encoding='utf8')
print ("Convert to tensor...")
if total_size is None:
image = torch.Tensor(image).permute(0,3,1,2)
text = torch.LongTensor(text)
else:
image = torch.Tensor(image[:total_size]).permute(0,3,1,2)
text = torch.LongTensor(text[:total_size])
image_train = image[:-test_size]
image_test = image[-test_size:]
text_train = text[:-test_size]
text_test = text[-test_size:]
print ("Build dataset...")
dataset_train = TensorDataset(image_train,text_train)
dataset_test = TensorDataset(image_test,text_test)
if torch.cuda.is_available():
pm = True
else:
pm = False
print ("Build dataloader...")
dataloader_train = DataLoader(dataset_train,batch_size,True,num_workers=n_workers)
dataloader_test = DataLoader(dataset_test,batch_size,shuffle=False)
print ("data ready!")
return dataloader_train,dataloader_test,vocab
if __name__=='__main__':
dl_train,dl_test,vocab = load_dataset(32)