-
Notifications
You must be signed in to change notification settings - Fork 1
/
binarized_modules.py
128 lines (104 loc) · 4.01 KB
/
binarized_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import pdb
import torch.nn as nn
import math
from torch.autograd import Variable
from torch.autograd.function import Function, InplaceFunction
import numpy as np
class Binarize(InplaceFunction):
def forward(ctx,input,quant_mode='det',allow_scale=False,inplace=False):
ctx.inplace = inplace
if ctx.inplace:
ctx.mark_dirty(input)
output = input
else:
output = input.clone()
scale= output.abs().max() if allow_scale else 1
if quant_mode=='det':
return output.div(scale).sign().mul(scale)
else:
return output.div(scale).add_(1).div_(2).add_(torch.rand(output.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1).mul(scale)
def backward(ctx,grad_output):
#STE
grad_input=grad_output
return grad_input,None,None,None
class Quantize(InplaceFunction):
def forward(ctx,input,quant_mode='det',numBits=4,inplace=False):
ctx.inplace = inplace
if ctx.inplace:
ctx.mark_dirty(input)
output = input
else:
output = input.clone()
scale=(2**numBits-1)/(output.max()-output.min())
output = output.mul(scale).clamp(-2**(numBits-1)+1,2**(numBits-1))
if quant_mode=='det':
output=output.round().div(scale)
else:
output=output.round().add(torch.rand(output.size()).add(-0.5)).div(scale)
return output
def backward(grad_output):
#STE
grad_input=grad_output
return grad_input,None,None
def binarized(input,quant_mode='det'):
return Binarize.apply(input,quant_mode)
def quantize(input,quant_mode,numBits):
return Quantize.apply(input,quant_mode,numBits)
class HingeLoss(nn.Module):
def __init__(self):
super(HingeLoss,self).__init__()
self.margin=1.0
def hinge_loss(self,input,target):
#import pdb; pdb.set_trace()
output=self.margin-input.mul(target)
output[output.le(0)]=0
return output.mean()
def forward(self, input, target):
return self.hinge_loss(input,target)
class SqrtHingeLossFunction(Function):
def __init__(self):
super(SqrtHingeLossFunction,self).__init__()
self.margin=1.0
def forward(self, input, target):
output=self.margin-input.mul(target)
output[output.le(0)]=0
self.save_for_backward(input, target)
loss=output.mul(output).sum(0).sum(1).div(target.numel())
return loss
def backward(self,grad_output):
input, target = self.saved_tensors
output=self.margin-input.mul(target)
output[output.le(0)]=0
grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)
grad_output.mul_(output.ne(0).float())
grad_output.div_(input.numel())
return grad_output,grad_output
class BinarizeLinear(nn.Linear):
def __init__(self, input_width, layer_width, quant_in):
super(BinarizeLinear, self).__init__(input_width, layer_width)
self.quant_in = quant_in
def forward(self, input):
if self.quant_in:
input = binarized(input)
weight_b = binarized(self.weight)
out = nn.functional.linear(input, weight_b)
if not self.bias is None:
self.bias.org=self.bias.data.clone()
out += self.bias.view(1, -1).expand_as(out)
return out
class BinarizeConv2d(nn.Conv2d):
def __init__(self, *kargs, **kwargs):
super(BinarizeConv2d, self).__init__(*kargs, **kwargs)
def forward(self, input):
if input.size(1) != 3:
input_b = binarized(input)
else:
input_b=input
weight_b=binarized(self.weight)
out = nn.functional.conv2d(input_b, weight_b, None, self.stride,
self.padding, self.dilation, self.groups)
if not self.bias is None:
self.bias.org=self.bias.data.clone()
out += self.bias.view(1, -1, 1, 1).expand_as(out)
return out