-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathpatch_swd.py
64 lines (52 loc) · 2.15 KB
/
patch_swd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn.functional as F
class PatchSWDLoss(torch.nn.Module):
def __init__(self, patch_size=7, stride=1, num_proj=256, c=3, l2=False):
super(PatchSWDLoss, self).__init__()
self.patch_size = patch_size
self.stride = stride
self.num_proj = num_proj
self.l2 = l2
self.c = c
self.sample_projections()
def sample_projections(self):
# Sample random normalized projections
rand = torch.randn(self.num_proj, self.c*self.patch_size**2) # (slice_size**2*ch)
rand = rand / torch.norm(rand, dim=1, keepdim=True) # noramlize to unit directions
self.rand = rand.reshape(self.num_proj, self.c, self.patch_size, self.patch_size)
def forward(self, x, y, reset_projections=True):
if reset_projections:
self.sample_projections()
self.rand = self.rand.to(x.device)
# Project patches
projx = F.conv2d(x, self.rand).transpose(1,0).reshape(self.num_proj, -1)
projy = F.conv2d(y, self.rand).transpose(1,0).reshape(self.num_proj, -1)
# Duplicate patches if number does not equal
projx, projy = duplicate_to_match_lengths(projx, projy)
# Sort and compute L1 loss
projx, _ = torch.sort(projx, dim=1)
projy, _ = torch.sort(projy, dim=1)
if self.l2:
loss = ((projx - projy)**2).mean()
else:
loss = torch.abs(projx - projy).mean()
return loss
def duplicate_to_match_lengths(arr1, arr2):
"""
Duplicates randomly selected entries from the smaller array to match its size to the bigger one
:param arr1: (r, n) torch tensor
:param arr2: (r, m) torch tensor
:return: (r,max(n,m)) torch tensor
"""
if arr1.shape[1] == arr2.shape[1]:
return arr1, arr2
elif arr1.shape[1] < arr2.shape[1]:
tmp = arr1
arr1 = arr2
arr2 = tmp
b = arr1.shape[1] // arr2.shape[1]
arr2 = torch.cat([arr2] * b, dim=1)
if arr1.shape[1] > arr2.shape[1]:
indices = torch.randperm(arr2.shape[1])[:arr1.shape[1] - arr2.shape[1]]
arr2 = torch.cat([arr2, arr2[:, indices]], dim=1)
return arr1, arr2