-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_loader.py
101 lines (81 loc) · 3.77 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
from utils import text_helper
class VqaDataset(data.Dataset):
def __init__(self, input_dir, input_vqa, max_qst_length=30, max_num_ans=10, transform=None):
self.input_dir = input_dir
self.vqa = np.load(input_dir+'/'+input_vqa)
self.qst_vocab = text_helper.VocabDict(input_dir+'/vocab_questions.txt')
self.ans_vocab = text_helper.VocabDict(input_dir+'/vocab_answers.txt')
self.max_qst_length = max_qst_length
self.max_num_ans = max_num_ans
self.load_ans = ('valid_answers' in self.vqa[0]) and (self.vqa[0]['valid_answers'] is not None)
self.transform = transform
def __getitem__(self, idx):
vqa = self.vqa
qst_vocab = self.qst_vocab
ans_vocab = self.ans_vocab
max_qst_length = self.max_qst_length
max_num_ans = self.max_num_ans
transform = self.transform
load_ans = self.load_ans
image = vqa[idx]['image_path']
image = Image.open(image).convert('RGB')
qst2idc = np.array([qst_vocab.word2idx('<pad>')] * max_qst_length) # padded with '<pad>' in 'ans_vocab'
qst2idc[:len(vqa[idx]['question_tokens'])] = [qst_vocab.word2idx(w) for w in vqa[idx]['question_tokens']]
sample = {'image': image, 'question': qst2idc}
if load_ans:
ans2idc = [ans_vocab.word2idx(w) for w in vqa[idx]['valid_answers']]
ans2idx = np.random.choice(ans2idc)
sample['answer_label'] = ans2idx # for training
sample['wrong_image'] = self.find_wrong_image(ans2idx)
#print('right classes ', ans2idc)
mul2idc = list([-1] * max_num_ans) # padded with -1 (no meaning) not used in 'ans_vocab'
mul2idc[:len(ans2idc)] = ans2idc # our model should not predict -1
sample['answer_multi_choice'] = mul2idc # for evaluation metric of 'multiple choice'
if transform:
sample['image'] = transform(sample['image'])
sample['wrong_image'] = transform(sample['wrong_image'])
return sample
def __len__(self):
return len(self.vqa)
def find_wrong_image(self, ans2idx):
idx = np.random.randint(len(self.vqa))
ans2idc = [self.ans_vocab.word2idx(w) for w in self.vqa[idx]['valid_answers']]
if ans2idx not in ans2idc:
#print('wrong classes ', ans2idc, ' right class ', ans2idx)
image = self.vqa[idx]['image_path']
image = Image.open(image).convert('RGB')
return image
return self.find_wrong_image(ans2idx)
def get_loader(input_dir, input_vqa_train, input_vqa_valid, max_qst_length, max_num_ans, batch_size, num_workers):
transform = {
phase: transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
for phase in ['train', 'valid']}
vqa_dataset = {
'train': VqaDataset(
input_dir=input_dir,
input_vqa=input_vqa_train,
max_qst_length=max_qst_length,
max_num_ans=max_num_ans,
transform=transform['train']),
'valid': VqaDataset(
input_dir=input_dir,
input_vqa=input_vqa_valid,
max_qst_length=max_qst_length,
max_num_ans=max_num_ans,
transform=transform['valid'])}
data_loader = {
phase: torch.utils.data.DataLoader(
dataset=vqa_dataset[phase],
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
for phase in ['train', 'valid']}
return data_loader