-
Notifications
You must be signed in to change notification settings - Fork 9
/
min_max_quantization.py
50 lines (40 loc) · 1.1 KB
/
min_max_quantization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn import functional as F
from torch.nn import Parameter
import math
import numpy as np
import os
import matplotlib.pyplot as plt
class RoundFunction(Function):
@staticmethod
def forward(ctx, x):
return torch.round(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output
def min_max_quantize(x, k):
n = 2 ** k
a = torch.min(x)
b = torch.max(x)
s = (b - a) / (n - 1)
x = torch.clamp(x, float(a), float(b))
x = (x - a) / s
x = RoundFunction.apply(x)
x = x * s + a
return x
def min_max_quantize2(input, bits):
assert bits >= 1, bits
if bits == 1:
return torch.sign(input) - 1
min_val, max_val = input.min(), input.max()
input_rescale = (input - min_val) / (max_val - min_val)
n = math.pow(2.0, bits) - 1
v = torch.floor(input_rescale * n + 0.5) / n
v = v * (max_val - min_val) + min_val
return v
if __name__ == "__main__":
x = torch.Tensor([-4, 0.222 ,0.5, 0.489, 11, 1])
b = min_max_quantize(x, 2)
print(b)