diff --git a/README.md b/README.md index 5af8c4b..233c5ad 100644 --- a/README.md +++ b/README.md @@ -46,15 +46,15 @@ python test.py [imagenet-folder with train and val folders] train RepVGG-B2-trai You may convert a trained model into the inference-time structure with ``` -python convert.py [weights file of the training-time model to load] [path to save] -a [model name] -``` -For example, -``` -python convert.py RepVGG-B2-train.pth RepVGG-B2-deploy.pth -a RepVGG-B2 +model = get_RepVGG_func_by_name("RepVGG-A0")() +model.eval() +model.deploy(mode=True,show_error=False) + +### torch.onnx.export(model,test_in,"model.onnx") ``` -Then you may test the inference-time model by +You may test the inference-time model by ``` -python test.py [imagenet-folder with train and val folders] deploy RepVGG-B2-deploy.pth -a RepVGG-B2 +python test.py [imagenet-folder with train and val folders] deploy RepVGG-B2-train.pth -a RepVGG-B2 ``` Note that the argument "deploy" builds an inference-time model. diff --git a/repvgg.py b/repvgg.py index 9cd2668..9de08c9 100644 --- a/repvgg.py +++ b/repvgg.py @@ -2,6 +2,16 @@ import numpy as np import torch + +def deploy(self, mode=False,show_error = False): + self.deploying = mode + for module in self.children(): + if hasattr(module, 'deploying'): + module.deploy(mode,show_error) +nn.Sequential.deploying = False +nn.Sequential.show_error = False +nn.Sequential.deploy = deploy + def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): result = nn.Sequential() result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, @@ -12,9 +22,11 @@ def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): class RepVGGBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, - stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False): + stride=1, padding=0, groups=1): super(RepVGGBlock, self).__init__() - self.deploy = deploy + self.deploying = False + self.show_error = False + self.groups = groups self.in_channels = in_channels @@ -25,29 +37,44 @@ def __init__(self, in_channels, out_channels, kernel_size, self.nonlinearity = nn.ReLU() - if deploy: - self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode) + self.in_channels = in_channels + self.in_channels = in_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.groups = groups + + self.register_buffer('fused_weight', None) + self.register_buffer('fused_bias', None) - else: - self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None - self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) - self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups) - print('RepVGG Block, identity = ', self.rbr_identity) + self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None + self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) + self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups) + print('RepVGG Block, identity = ', self.rbr_identity) def forward(self, inputs): - if hasattr(self, 'rbr_reparam'): - return self.nonlinearity(self.rbr_reparam(inputs)) + if self.deploying : + if self.fused_weight is None or self.fused_bias is None: + self.fused_weight,self.fused_bias = self.get_equivalent_kernel_bias() + + fused_out = self.nonlinearity(torch.nn.functional.conv2d( + inputs,self.fused_weight,self.fused_bias,self.stride,self.padding,1,self.groups)) + + if not self.show_error: + return fused_out if self.rbr_identity is None: id_out = 0 else: id_out = self.rbr_identity(inputs) + out = self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out) - return self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out) - + if self.deploying and self.show_error: + print(torch.max(torch.abs(fused_out - out)).item()) + return fused_out + return out # This func derives the equivalent kernel and bias in a DIFFERENTIABLE way. # You can get the equivalent kernel and bias at any time and do whatever you want, @@ -93,27 +120,28 @@ def _fuse_bn_tensor(self, branch): t = (gamma / std).reshape(-1, 1, 1, 1) return kernel * t, beta - running_mean * gamma / std - def repvgg_convert(self): - kernel, bias = self.get_equivalent_kernel_bias() - return kernel.detach().cpu().numpy(), bias.detach().cpu().numpy(), - + def deploy(self,mode = False,show_error = False): + self.deploying = mode + self.show_error = show_error class RepVGG(nn.Module): - def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False): + def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None): super(RepVGG, self).__init__() assert len(width_multiplier) == 4 - self.deploy = deploy + self.deploying = deploy + self.show_error = False + self.override_groups_map = override_groups_map or dict() assert 0 not in self.override_groups_map self.in_planes = min(64, int(64 * width_multiplier[0])) - self.stage0 = RepVGGBlock(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy) + self.stage0 = RepVGGBlock(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1) self.cur_layer_idx = 1 self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=2) self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2) @@ -129,11 +157,18 @@ def _make_stage(self, planes, num_blocks, stride): for stride in strides: cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1) blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3, - stride=stride, padding=1, groups=cur_groups, deploy=self.deploy)) + stride=stride, padding=1, groups=cur_groups)) self.in_planes = planes self.cur_layer_idx += 1 return nn.Sequential(*blocks) + def deploy(self,mode = False,show_error = False): + self.deploying = mode + for module in self.children(): + if hasattr(module,'deploying'): + module.deploy(mode,show_error) + + def forward(self, x): out = self.stage0(x) out = self.stage1(out) @@ -146,63 +181,65 @@ def forward(self, x): return out + + optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26] g2_map = {l: 2 for l in optional_groupwise_layers} g4_map = {l: 4 for l in optional_groupwise_layers} -def create_RepVGG_A0(deploy=False): +def create_RepVGG_A0(): return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000, - width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy) + width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None) -def create_RepVGG_A1(deploy=False): +def create_RepVGG_A1(): return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000, - width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy) + width_multiplier=[1, 1, 1, 2.5], override_groups_map=None) -def create_RepVGG_A2(deploy=False): +def create_RepVGG_A2(): return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000, - width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy) + width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None) -def create_RepVGG_B0(deploy=False): +def create_RepVGG_B0(): return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, - width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy) + width_multiplier=[1, 1, 1, 2.5], override_groups_map=None) -def create_RepVGG_B1(deploy=False): +def create_RepVGG_B1(): return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, - width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy) + width_multiplier=[2, 2, 2, 4], override_groups_map=None) -def create_RepVGG_B1g2(deploy=False): +def create_RepVGG_B1g2(): return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, - width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy) + width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map) -def create_RepVGG_B1g4(deploy=False): +def create_RepVGG_B1g4(): return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, - width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy) + width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map) -def create_RepVGG_B2(deploy=False): +def create_RepVGG_B2(): return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, - width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy) + width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None) -def create_RepVGG_B2g2(deploy=False): +def create_RepVGG_B2g2(): return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, - width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy) + width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map) -def create_RepVGG_B2g4(deploy=False): +def create_RepVGG_B2g4(): return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, - width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy) + width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map) -def create_RepVGG_B3(deploy=False): +def create_RepVGG_B3(): return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, - width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy) + width_multiplier=[3, 3, 3, 5], override_groups_map=None) -def create_RepVGG_B3g2(deploy=False): +def create_RepVGG_B3g2(): return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, - width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy) + width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map) -def create_RepVGG_B3g4(deploy=False): +def create_RepVGG_B3g4(): return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, - width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy) + width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map) func_dict = { @@ -224,69 +261,18 @@ def get_RepVGG_func_by_name(name): return func_dict[name] +if __name__ == '__main__': + model = get_RepVGG_func_by_name("RepVGG-A0")() + model.eval() + + test_in = torch.rand([1,3,224,224]) + + y = model(test_in) + model.deploy(mode=True,show_error=False) + + fused_y = model(test_in) + + print("final error :", torch.max(torch.abs(fused_y - y )).item()) + + torch.onnx.export(model,test_in,"model.onnx") -# Use this for converting a customized model with RepVGG as one of its components (e.g., the backbone of a semantic segmentation model) -# The use case will be like -# 1. Build train_model. For example, build a PSPNet with a training-time RepVGG as backbone -# 2. Train train_model or do whatever you want -# 3. Build deploy_model. In the above example, that will be a PSPNet with an inference-time RepVGG as backbone -# 4. Call this func -# ====================== the pseudo code will be like -# train_backbone = create_RepVGG_B2(deploy=False) -# train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth')) -# train_pspnet = build_pspnet(backbone=train_backbone) -# segmentation_train(train_pspnet) -# deploy_backbone = create_RepVGG_B2(deploy=True) -# deploy_pspnet = build_pspnet(backbone=deploy_backbone) -# whole_model_convert(train_pspnet, deploy_pspnet) -# segmentation_test(deploy_pspnet) -def whole_model_convert(train_model:torch.nn.Module, deploy_model:torch.nn.Module, save_path=None): - all_weights = {} - for name, module in train_model.named_modules(): - if hasattr(module, 'repvgg_convert'): - kernel, bias = module.repvgg_convert() - all_weights[name + '.rbr_reparam.weight'] = kernel - all_weights[name + '.rbr_reparam.bias'] = bias - print('convert RepVGG block') - else: - for p_name, p_tensor in module.named_parameters(): - full_name = name + '.' + p_name - if full_name not in all_weights: - all_weights[full_name] = p_tensor.detach().cpu().numpy() - for p_name, p_tensor in module.named_buffers(): - full_name = name + '.' + p_name - if full_name not in all_weights: - all_weights[full_name] = p_tensor.cpu().numpy() - - deploy_model.load_state_dict(all_weights) - if save_path is not None: - torch.save(deploy_model.state_dict(), save_path) - - return deploy_model - - -# Use this when converting a RepVGG without customized structures. -# train_model = create_RepVGG_A0(deploy=False) -# train train_model -# deploy_model = repvgg_convert(train_model, create_RepVGG_A0, save_path='repvgg_deploy.pth') -def repvgg_model_convert(model:torch.nn.Module, build_func, save_path=None): - converted_weights = {} - for name, module in model.named_modules(): - if hasattr(module, 'repvgg_convert'): - kernel, bias = module.repvgg_convert() - converted_weights[name + '.rbr_reparam.weight'] = kernel - converted_weights[name + '.rbr_reparam.bias'] = bias - elif isinstance(module, torch.nn.Linear): - converted_weights[name + '.weight'] = module.weight.detach().cpu().numpy() - converted_weights[name + '.bias'] = module.bias.detach().cpu().numpy() - del model - - deploy_model = build_func(deploy=True) - for name, param in deploy_model.named_parameters(): - print('deploy param: ', name, param.size(), np.mean(converted_weights[name])) - param.data = torch.from_numpy(converted_weights[name]).float() - - if save_path is not None: - torch.save(deploy_model.state_dict(), save_path) - - return deploy_model \ No newline at end of file diff --git a/test.py b/test.py index 90ec7be..0dea949 100644 --- a/test.py +++ b/test.py @@ -30,7 +30,10 @@ def test(): repvgg_build_func = get_RepVGG_func_by_name(args.arch) - model = repvgg_build_func(deploy=args.mode=='deploy') + model = repvgg_build_func() + + model.deploy(args.mode=='deploy') + if not torch.cuda.is_available(): print('using CPU, this will be slow') diff --git a/train.py b/train.py index 458ac5a..256cfd0 100644 --- a/train.py +++ b/train.py @@ -145,7 +145,8 @@ def main_worker(gpu, ngpus_per_node, args): repvgg_build_func = get_RepVGG_func_by_name(args.arch) - model = repvgg_build_func(deploy=False) + model = repvgg_build_func() + model.deploy(mode=False) if not torch.cuda.is_available(): print('using CPU, this will be slow')