-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_qat.py
140 lines (115 loc) · 4.96 KB
/
test_qat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python3
###################################################################################################
#
# Copyright (C) 2020-2021 Maxim Integrated Products, Inc. All Rights Reserved.
#
# Maxim Integrated Products, Inc. Default Copyright Notice:
# https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
#
###################################################################################################
"""
Test routine for QAT
"""
import copy
import torch
import ai8x
def create_input_data(num_channels):
'''
Creates random data
'''
inp = (2.0 * torch.rand(1, num_channels, 8, 8) - 1.0) # pylint: disable=no-member
inp_int = torch.clamp(torch.round(128 * inp), min=-128, max=127.) # pylint: disable=no-member
inp = inp_int / 128.
return inp, inp_int
def create_conv2d_layer(in_channels, out_channels, kernel_size, wide, activation):
'''
Creates randomly initialized layer
'''
ai8x.set_device(device=85, simulate=False, round_avg=False, verbose=False)
fp_layer = ai8x.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
bias=False,
wide=wide,
activation=activation)
fp_layer.op.weight = torch.nn.Parameter(
(2.0 * torch.rand(out_channels, # pylint: disable=no-member
in_channels,
kernel_size,
kernel_size) - 1.0)
)
return fp_layer
def quantize_fp_layer(fp_layer, wide, activation, num_bits):
'''
Creates layer with quantized leveled fp32 weights from a fp32 weighted layer
'''
ai8x.set_device(device=85, simulate=False, round_avg=False, verbose=False)
in_channels = fp_layer.op.weight.shape[1]
out_channels = fp_layer.op.weight.shape[0]
kernel_size = fp_layer.op.weight.shape[2:]
q_fp_layer = ai8x.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
bias=False,
wide=wide,
activation=activation,
weight_bits=num_bits,
bias_bits=8,
quantize_activation=True)
q_fp_layer.op.weight = copy.deepcopy(fp_layer.op.weight)
return q_fp_layer
def quantize_layer(q_fp_layer, wide, activation, num_bits):
'''
Quantizes layer
'''
ai8x.set_device(device=85, simulate=True, round_avg=False, verbose=False)
in_channels = q_fp_layer.op.weight.shape[1]
out_channels = q_fp_layer.op.weight.shape[0]
kernel_size = q_fp_layer.op.weight.shape[2:]
q_int_layer = ai8x.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
bias=False,
wide=wide,
activation=activation,
weight_bits=num_bits,
bias_bits=8,
quantize_activation=True)
out_shift = q_fp_layer.calc_out_shift(q_fp_layer.op.weight.detach(),
q_fp_layer.output_shift.detach())
weight_scale = q_fp_layer.calc_weight_scale(out_shift)
ai8x.set_device(device=85, simulate=False, round_avg=False, verbose=False)
weight = q_fp_layer.clamp_weight(q_fp_layer.quantize_weight(weight_scale *
q_fp_layer.op.weight))
q_int_weight = (2**(num_bits-1)) * weight
q_int_layer.output_shift = torch.nn.Parameter(
-torch.log2(weight_scale) # pylint: disable=no-member
)
q_int_layer.op.weight = torch.nn.Parameter(q_int_weight)
return q_int_layer
def test():
'''
Main test function
'''
wide_opts = [False, True]
act_opts = [None, 'ReLU']
bit_opts = [8, 4, 2, 1]
inp, inp_int = create_input_data(512)
for bit in bit_opts:
for act in act_opts:
for wide in wide_opts:
if wide and (act is not None):
continue
print(f'Testing for bits:{bit}, wide:{wide}, activation:{act} ...', end=' ')
fp_layer = create_conv2d_layer(512, 16, 3, wide, act)
q_fp_layer = quantize_fp_layer(fp_layer, wide, act, bit)
q_int_layer = quantize_layer(q_fp_layer, wide, act, bit)
ai8x.set_device(device=85, simulate=False, round_avg=False, verbose=False)
q_fp_out = q_fp_layer(inp)
ai8x.set_device(device=85, simulate=True, round_avg=False, verbose=False)
q_int_out = q_int_layer(inp_int)
assert ((128. * q_fp_out) == q_int_out).all(), 'FAIL!!'
print('PASS')
print('\nSUCCESS!!')
if __name__ == "__main__":
test()