forked from singh-hrituraj/PixelCNN-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMaskedCNN.py
37 lines (26 loc) · 859 Bytes
/
MaskedCNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
'''
Code by Hrituraj Singh
Indian Institute of Technology Roorkee
'''
from torch import nn
class MaskedCNN(nn.Conv2d):
"""
Implementation of Masked CNN Class as explained in A Oord et. al.
Taken from https://github.com/jzbontar/pixelcnn-pytorch
"""
def __init__(self, mask_type, *args, **kwargs):
self.mask_type = mask_type
assert mask_type in ['A', 'B'], "Unknown Mask Type"
super(MaskedCNN, self).__init__(*args, **kwargs)
self.register_buffer('mask', self.weight.data.clone())
_, depth, height, width = self.weight.size()
self.mask.fill_(1)
if mask_type =='A':
self.mask[:,:,height//2,width//2:] = 0
self.mask[:,:,height//2+1:,:] = 0
else:
self.mask[:,:,height//2,width//2+1:] = 0
self.mask[:,:,height//2+1:,:] = 0
def forward(self, x):
self.weight.data*=self.mask
return super(MaskedCNN, self).forward(x)