This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathVideoMAE_transforms.py
executable file
·173 lines (157 loc) · 5.85 KB
/
VideoMAE_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#!/usr/bin/env python3
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
Based on: https://github.com/MCG-NJU/VideoMAE/blob/e97e551a678d522b055d23e8a9dc15ccc4c14a9b/ssv2.py
"""
import videoMAE_video_transforms as video_transforms
from torchvision import transforms
from random_erasing import RandomErasing
import torch
class aug(object):
def __init__(self, crop_size, aa=None, reprob=0, remode=0, recount=0):
self.crop_size = crop_size
self.aug_transform = video_transforms.create_random_augment(
input_size=(self.crop_size, self.crop_size),
auto_augment=aa,
interpolation='bicubic',
)
if reprob > 0:
self.erase_transform = RandomErasing(
reprob,
mode=remode,
max_count=recount,
num_splits=recount,
device="cpu",
)
else:
self.erase_transform = None
def __call__(self, buffer):
try:
buffer = [
transforms.ToPILImage()(frame) for frame in buffer
]
except:
pass
buffer = self.aug_transform(buffer)
buffer = [transforms.ToTensor()(img) for img in buffer]
buffer = torch.stack(buffer) # T C H W
buffer = buffer.permute(0, 2, 3, 1) # T H W C
# T H W C
buffer = tensor_normalize(buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# T H W C -> C T H W.
buffer = buffer.permute(3, 0, 1, 2)
# Perform data augmentation.
scl, asp = (
[0.08, 1.0],
[0.75, 1.3333],
)
buffer = spatial_sampling(
buffer,
spatial_idx=-1,
min_scale=256,
max_scale=320,
crop_size=self.crop_size,
random_horizontal_flip=False,
inverse_uniform_sampling=False,
aspect_ratio=asp,
scale=scl,
motion_shift=False
)
# C, T, H, W -> T, C, H, W
buffer = buffer.permute(1, 0, 2, 3)
if self.erase_transform is not None:
buffer = self.erase_transform(buffer)
return buffer
def tensor_normalize(tensor, mean, std):
"""
Normalize a given tensor by subtracting the mean and dividing the std.
Args:
tensor (tensor): tensor to normalize.
mean (tensor or list): mean value to subtract.
std (tensor or list): std to divide.
"""
if tensor.dtype == torch.uint8:
tensor = tensor.float()
tensor = tensor / 255.0
if type(mean) == list:
mean = torch.tensor(mean)
if type(std) == list:
std = torch.tensor(std)
tensor = tensor - mean
tensor = tensor / std
return tensor
def spatial_sampling(
frames,
spatial_idx=-1,
min_scale=256,
max_scale=320,
crop_size=224,
random_horizontal_flip=True,
inverse_uniform_sampling=False,
aspect_ratio=None,
scale=None,
motion_shift=False,
):
"""
Perform spatial sampling on the given video frames. If spatial_idx is
-1, perform random scale, random crop, and random flip on the given
frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
with the given spatial_idx.
Args:
frames (tensor): frames of images sampled from the video. The
dimension is `num frames` x `height` x `width` x `channel`.
spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
or 2, perform left, center, right crop if width is larger than
height, and perform top, center, buttom crop if height is larger
than width.
min_scale (int): the minimal size of scaling.
max_scale (int): the maximal size of scaling.
crop_size (int): the size of height and width used to crop the
frames.
inverse_uniform_sampling (bool): if True, sample uniformly in
[1 / max_scale, 1 / min_scale] and take a reciprocal to get the
scale. If False, take a uniform sample from [min_scale,
max_scale].
aspect_ratio (list): Aspect ratio range for resizing.
scale (list): Scale range for resizing.
motion_shift (bool): Whether to apply motion shift for resizing.
Returns:
frames (tensor): spatially sampled frames.
"""
assert spatial_idx in [-1, 0, 1, 2]
if spatial_idx == -1:
if aspect_ratio is None and scale is None:
frames, _ = video_transforms.random_short_side_scale_jitter(
images=frames,
min_size=min_scale,
max_size=max_scale,
inverse_uniform_sampling=inverse_uniform_sampling,
)
frames, _ = video_transforms.random_crop(frames, crop_size)
else:
transform_func = (
video_transforms.random_resized_crop_with_shift
if motion_shift
else video_transforms.random_resized_crop
)
frames = transform_func(
images=frames,
target_height=crop_size,
target_width=crop_size,
scale=scale,
ratio=aspect_ratio,
)
if random_horizontal_flip:
frames, _ = video_transforms.horizontal_flip(0.5, frames)
else:
# The testing is deterministic and no jitter should be performed.
# min_scale, max_scale, and crop_size are expect to be the same.
assert len({min_scale, max_scale, crop_size}) == 1
frames, _ = video_transforms.random_short_side_scale_jitter(
frames, min_scale, max_scale
)
frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx)
return frames