-
Notifications
You must be signed in to change notification settings - Fork 2
/
collate_fn.py
116 lines (95 loc) · 4.59 KB
/
collate_fn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import math
import random
import numpy as np
from utils import get_msg_mgr
class CollateFn(object):
def __init__(self, label_set, sample_config):
self.label_set = label_set
sample_type = sample_config['sample_type']
sample_type = sample_type.split('_')
self.sampler = sample_type[0]
self.ordered = sample_type[1]
# self.return_length = sample_type[2] if len(sample_type)>=3 else 'no'
if self.sampler not in ['fixed', 'unfixed', 'all']:
raise ValueError
if self.ordered not in ['ordered', 'unordered']:
raise ValueError
self.ordered = sample_type[1] == 'ordered'
# fixed cases
if self.sampler == 'fixed':
self.frames_num_fixed = sample_config['frames_num_fixed']
# unfixed cases
if self.sampler == 'unfixed':
self.frames_num_max = sample_config['frames_num_max']
self.frames_num_min = sample_config['frames_num_min']
if self.sampler != 'all' and self.ordered:
self.frames_skip_num = sample_config['frames_skip_num']
self.frames_all_limit = -1
if self.sampler == 'all' and 'frames_all_limit' in sample_config:
self.frames_all_limit = sample_config['frames_all_limit']
def __call__(self, batch):
batch_size = len(batch)
# currently, the functionality of feature_num is not fully supported yet, it refers to 1 now. We are supposed to make our framework support multiple source of input data, such as silhouette, or skeleton.
feature_num = len(batch[0][0])
seqs_batch, labs_batch, typs_batch, vies_batch = [], [], [], []
for bt in batch:
seqs_batch.append(bt[0])
labs_batch.append(self.label_set.index(bt[1][0]))
typs_batch.append(bt[1][1])
vies_batch.append(bt[1][2])
global count
count = 0
def sample_frames(seqs):
global count
sampled_fras = [[] for i in range(feature_num)]
seq_len = len(seqs[0])
indices = list(range(seq_len))
if self.sampler in ['fixed', 'unfixed']:
if self.sampler == 'fixed':
frames_num = self.frames_num_fixed
else:
frames_num = random.choice(
list(range(self.frames_num_min, self.frames_num_max+1)))
if self.ordered:
fs_n = frames_num + self.frames_skip_num
if seq_len < fs_n:
it = math.ceil(fs_n / seq_len)
seq_len = seq_len * it
indices = indices * it
start = random.choice(list(range(0, seq_len - fs_n + 1)))
end = start + fs_n
idx_lst = list(range(seq_len))
idx_lst = idx_lst[start:end]
idx_lst = sorted(np.random.choice(
idx_lst, frames_num, replace=False))
indices = [indices[i] for i in idx_lst]
else:
replace = seq_len < frames_num
if seq_len == 0:
get_msg_mgr().log_debug('Find no frames in the sequence %s-%s-%s.'
% (str(labs_batch[count]), str(typs_batch[count]), str(vies_batch[count])))
count += 1
indices = np.random.choice(
indices, frames_num, replace=replace)
for i in range(feature_num):
for j in indices[:self.frames_all_limit] if self.frames_all_limit > -1 and len(indices) > self.frames_all_limit else indices:
sampled_fras[i].append(seqs[i][j])
return sampled_fras
# f: feature_num
# b: batch_size
# p: batch_size_per_gpu
# g: gpus_num
fras_batch = [sample_frames(seqs) for seqs in seqs_batch] # [b, f]
batch = [fras_batch, labs_batch, typs_batch, vies_batch, None]
if self.sampler == "fixed":
fras_batch = [[np.asarray(fras_batch[i][j]) for i in range(batch_size)]
for j in range(feature_num)] # [f, b]
else:
seqL_batch = [[len(fras_batch[i][0])
for i in range(batch_size)]] # [1, p]
def my_cat(k): return np.concatenate(
[fras_batch[i][k] for i in range(batch_size)], 0)
fras_batch = [[my_cat(k)] for k in range(feature_num)] # [f, g]
batch[-1] = np.asarray(seqL_batch)
batch[0] = fras_batch
return batch