-
Notifications
You must be signed in to change notification settings - Fork 118
/
Copy pathbn_restore.py
107 lines (85 loc) · 3.7 KB
/
bn_restore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import typing
import torch
import torch.nn as nn
import copy
from tinynn.util.util import get_logger
log = get_logger(__name__)
support_conv_cls = (torch.nn.Conv1d, torch.nn.Conv2d)
bn_restore_cls = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d)
class ConvBnForward(nn.Module):
def __init__(self, conv):
super().__init__()
assert isinstance(conv, support_conv_cls), "not a supported conv type"
self.conv = conv
self.bn = bn_restore_cls[support_conv_cls.index(type(conv))](conv.out_channels)
def forward(self, input_0):
# Forward bn in train mode, but do not use its output so to get running_stat of conv at train_set.
conv = self.conv(input_0)
self.bn(conv)
return conv
class ConvBnTrain(nn.Module):
def __init__(self, convbn):
super().__init__()
assert isinstance(convbn, ConvBnForward), "not a ConvBnForward"
self.conv = convbn.conv
self.bn = convbn.bn
# Use the forward running_stat to get weight and bias of restored bn, weight calculation ignores eps.
weight = self.bn.running_var**0.5
bias = self.bn.running_mean
self.bn.weight.data.copy_(weight)
self.bn.bias.data.copy_(bias)
def forward(self, input_0):
conv = self.conv(input_0)
bn = self.bn(conv)
return bn
def restore_bn(origin_model, layer_fused_bn):
model = copy.deepcopy(origin_model)
for name, mod in model.named_modules():
if isinstance(mod, support_conv_cls) and name in layer_fused_bn:
mod, mod_parent, name_part = get_submodule_with_parent_from_name(model, name)
setattr(mod_parent, name_part, ConvBnForward(mod))
return model
def restore_bn_set_param(origin_model):
model = copy.deepcopy(origin_model)
for name, mod in model.named_modules():
if isinstance(mod, ConvBnForward):
mod, mod_parent, name_part = get_submodule_with_parent_from_name(model, name)
setattr(mod_parent, name_part, ConvBnTrain(mod))
return model
def model_restore_bn(
model: nn.Module, device, calibrate_func, *params, layers_fused_bn: typing.List[str] = None
) -> nn.Module:
r"""High API to restore BN for a bn_fused model(e.g.MobileOne),
the `calibrate_func` should be done in full train_set in train mode.
Args:
model (nn.Module): The model which need restore bn.
device (torch.device): Specifies the device of the model.
calibrate_func: The function used to do train_set calibrate in training mode.
params: The params of calibrate function except model.
layers_fused_bn: The name list of the conv which need to be bn_restored. Defaults to all conv of model.
Return:
The bn_restored model.
"""
log.info("start to do BN Restore.")
if layers_fused_bn is None:
layers_fused_bn = [name for name, mod in model.named_modules() if isinstance(mod, support_conv_cls)]
restore_bn_model = restore_bn(model, layers_fused_bn)
restore_bn_model = restore_bn_model.train()
restore_bn_model.to(device)
calibrate_func(restore_bn_model, *params)
bn_model = restore_bn_set_param(restore_bn_model)
return bn_model
def get_submodule_with_parent_from_name(model, module_name):
"""Gets the submodule with its parent and sub_name using the name given"""
module_name_parts = module_name.split('.')
cur_obj = model
last_obj = None
for ns in module_name_parts:
last_obj = cur_obj
if type(cur_obj) is nn.ModuleList:
cur_obj = cur_obj[int(ns)]
elif type(cur_obj) is nn.ModuleDict:
cur_obj = cur_obj[ns]
else:
cur_obj = getattr(cur_obj, ns)
return cur_obj, last_obj, module_name_parts[-1]