-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
113 lines (83 loc) · 3.61 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os.path as osp
import cv2
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class GolfDB(Dataset):
def __init__(self, data_file, vid_dir, seq_length, transform=None, train=True):
self.df = pd.read_pickle(data_file)
self.vid_dir = vid_dir
self.seq_length = seq_length
self.transform = transform
self.train = train
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
a = self.df.loc[idx, :] # annotation info
events = a['events']
events -= events[0] # now frame #s correspond to frames in preprocessed video clips
images, labels = [], []
cap = cv2.VideoCapture(osp.join(self.vid_dir, '{}.mp4'.format(a['id'])))
if self.train:
# random starting position, sample 'seq_length' frames
start_frame = np.random.randint(events[-1] + 1)
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
pos = start_frame
while len(images) < self.seq_length:
ret, img = cap.read()
if ret:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
images.append(img)
if pos in events[1:-1]:
labels.append(np.where(events[1:-1] == pos)[0][0])
else:
labels.append(8)
pos += 1
else:
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
pos = 0
cap.release()
else:
# full clip
for pos in range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))):
_, img = cap.read()
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
images.append(img)
if pos in events[1:-1]:
labels.append(np.where(events[1:-1] == pos)[0][0])
else:
labels.append(8)
cap.release()
sample = {'images':np.asarray(images), 'labels':np.asarray(labels)}
if self.transform:
sample = self.transform(sample)
return sample
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
images, labels = sample['images'], sample['labels']
images = images.transpose((0, 3, 1, 2))
return {'images': torch.from_numpy(images).float().div(255.),
'labels': torch.from_numpy(labels).long()}
class Normalize(object):
def __init__(self, mean, std):
self.mean = torch.tensor(mean, dtype=torch.float32)
self.std = torch.tensor(std, dtype=torch.float32)
def __call__(self, sample):
images, labels = sample['images'], sample['labels']
images.sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None])
return {'images': images, 'labels': labels}
if __name__ == '__main__':
norm = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet mean and std (RGB)
dataset = GolfDB(data_file='data/train_split_1.pkl',
vid_dir='data/videos_160/',
seq_length=64,
transform=transforms.Compose([ToTensor(), norm]),
train=False)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=6, drop_last=False)
for i, sample in enumerate(data_loader):
images, labels = sample['images'], sample['labels']
events = np.where(labels.squeeze() < 8)[0]
print('{} events: {}'.format(len(events), events))