forked from Algolzw/WeatherClassification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cutmix.py
52 lines (35 loc) · 1.28 KB
/
cutmix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import torch
import torch.nn as nn
def cutmix(batch, alpha):
data, targets = batch
indices = torch.randperm(data.size(0))
shuffled_data = data[indices]
shuffled_targets = targets[indices]
lam = np.random.beta(alpha, alpha)
image_h, image_w = data.shape[2:]
cx = np.random.uniform(0, image_w)
cy = np.random.uniform(0, image_h)
w = image_w * np.sqrt(1 - lam)
h = image_h * np.sqrt(1 - lam)
x0 = int(np.round(max(cx - w / 2, 0)))
x1 = int(np.round(min(cx + w / 2, image_w)))
y0 = int(np.round(max(cy - h / 2, 0)))
y1 = int(np.round(min(cy + h / 2, image_h)))
data[:, :, y0:y1, x0:x1] = shuffled_data[:, :, y0:y1, x0:x1]
targets = (targets, shuffled_targets, lam)
return data, targets
class CutMixCollator:
def __init__(self, alpha):
self.alpha = alpha
def __call__(self, batch):
batch = torch.utils.data.dataloader.default_collate(batch)
batch = cutmix(batch, self.alpha)
return batch
class CutMixCriterion:
def __init__(self, criterion):
self.criterion = criterion
def __call__(self, preds, targets):
targets1, targets2, lam = targets
return lam * self.criterion(
preds, targets1) + (1 - lam) * self.criterion(preds, targets2)