Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a easy way to deploy #26

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ python test.py [imagenet-folder with train and val folders] train RepVGG-B2-trai

You may convert a trained model into the inference-time structure with
```
python convert.py [weights file of the training-time model to load] [path to save] -a [model name]
```
For example,
```
python convert.py RepVGG-B2-train.pth RepVGG-B2-deploy.pth -a RepVGG-B2
model = get_RepVGG_func_by_name("RepVGG-A0")()
model.eval()
model.deploy(mode=True,show_error=False)

### torch.onnx.export(model,test_in,"model.onnx")
```
Then you may test the inference-time model by
You may test the inference-time model by
```
python test.py [imagenet-folder with train and val folders] deploy RepVGG-B2-deploy.pth -a RepVGG-B2
python test.py [imagenet-folder with train and val folders] deploy RepVGG-B2-train.pth -a RepVGG-B2
```
Note that the argument "deploy" builds an inference-time model.

Expand Down
212 changes: 99 additions & 113 deletions repvgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
import numpy as np
import torch


def deploy(self, mode=False,show_error = False):
self.deploying = mode
for module in self.children():
if hasattr(module, 'deploying'):
module.deploy(mode,show_error)
nn.Sequential.deploying = False
nn.Sequential.show_error = False
nn.Sequential.deploy = deploy

def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
result = nn.Sequential()
result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
Expand All @@ -12,9 +22,11 @@ def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
class RepVGGBlock(nn.Module):

def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False):
stride=1, padding=0, groups=1):
super(RepVGGBlock, self).__init__()
self.deploy = deploy
self.deploying = False
self.show_error = False

self.groups = groups
self.in_channels = in_channels

Expand All @@ -25,29 +37,44 @@ def __init__(self, in_channels, out_channels, kernel_size,

self.nonlinearity = nn.ReLU()

if deploy:
self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
self.in_channels = in_channels
self.in_channels = in_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups

self.register_buffer('fused_weight', None)
self.register_buffer('fused_bias', None)

else:
self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
print('RepVGG Block, identity = ', self.rbr_identity)
self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
print('RepVGG Block, identity = ', self.rbr_identity)


def forward(self, inputs):
if hasattr(self, 'rbr_reparam'):
return self.nonlinearity(self.rbr_reparam(inputs))
if self.deploying :
if self.fused_weight is None or self.fused_bias is None:
self.fused_weight,self.fused_bias = self.get_equivalent_kernel_bias()

fused_out = self.nonlinearity(torch.nn.functional.conv2d(
inputs,self.fused_weight,self.fused_bias,self.stride,self.padding,1,self.groups))

if not self.show_error:
return fused_out

if self.rbr_identity is None:
id_out = 0
else:
id_out = self.rbr_identity(inputs)
out = self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)

return self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)

if self.deploying and self.show_error:
print(torch.max(torch.abs(fused_out - out)).item())
return fused_out

return out

# This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
# You can get the equivalent kernel and bias at any time and do whatever you want,
Expand Down Expand Up @@ -93,27 +120,28 @@ def _fuse_bn_tensor(self, branch):
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std

def repvgg_convert(self):
kernel, bias = self.get_equivalent_kernel_bias()
return kernel.detach().cpu().numpy(), bias.detach().cpu().numpy(),

def deploy(self,mode = False,show_error = False):
self.deploying = mode
self.show_error = show_error


class RepVGG(nn.Module):

def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False):
def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None):
super(RepVGG, self).__init__()

assert len(width_multiplier) == 4

self.deploy = deploy
self.deploying = deploy
self.show_error = False

self.override_groups_map = override_groups_map or dict()

assert 0 not in self.override_groups_map

self.in_planes = min(64, int(64 * width_multiplier[0]))

self.stage0 = RepVGGBlock(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy)
self.stage0 = RepVGGBlock(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1)
self.cur_layer_idx = 1
self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=2)
self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2)
Expand All @@ -129,11 +157,18 @@ def _make_stage(self, planes, num_blocks, stride):
for stride in strides:
cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3,
stride=stride, padding=1, groups=cur_groups, deploy=self.deploy))
stride=stride, padding=1, groups=cur_groups))
self.in_planes = planes
self.cur_layer_idx += 1
return nn.Sequential(*blocks)

def deploy(self,mode = False,show_error = False):
self.deploying = mode
for module in self.children():
if hasattr(module,'deploying'):
module.deploy(mode,show_error)


def forward(self, x):
out = self.stage0(x)
out = self.stage1(out)
Expand All @@ -146,63 +181,65 @@ def forward(self, x):
return out




optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
g2_map = {l: 2 for l in optional_groupwise_layers}
g4_map = {l: 4 for l in optional_groupwise_layers}

def create_RepVGG_A0(deploy=False):
def create_RepVGG_A0():
return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy)
width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None)

def create_RepVGG_A1(deploy=False):
def create_RepVGG_A1():
return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None)

def create_RepVGG_A2(deploy=False):
def create_RepVGG_A2():
return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy)
width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None)

def create_RepVGG_B0(deploy=False):
def create_RepVGG_B0():
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None)

def create_RepVGG_B1(deploy=False):
def create_RepVGG_B1():
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy)
width_multiplier=[2, 2, 2, 4], override_groups_map=None)

def create_RepVGG_B1g2(deploy=False):
def create_RepVGG_B1g2():
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy)
width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map)

def create_RepVGG_B1g4(deploy=False):
def create_RepVGG_B1g4():
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy)
width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map)


def create_RepVGG_B2(deploy=False):
def create_RepVGG_B2():
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy)
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None)

def create_RepVGG_B2g2(deploy=False):
def create_RepVGG_B2g2():
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy)
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map)

def create_RepVGG_B2g4(deploy=False):
def create_RepVGG_B2g4():
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy)
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map)


def create_RepVGG_B3(deploy=False):
def create_RepVGG_B3():
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy)
width_multiplier=[3, 3, 3, 5], override_groups_map=None)

def create_RepVGG_B3g2(deploy=False):
def create_RepVGG_B3g2():
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy)
width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map)

def create_RepVGG_B3g4(deploy=False):
def create_RepVGG_B3g4():
return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy)
width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map)


func_dict = {
Expand All @@ -224,69 +261,18 @@ def get_RepVGG_func_by_name(name):
return func_dict[name]


if __name__ == '__main__':
model = get_RepVGG_func_by_name("RepVGG-A0")()
model.eval()

test_in = torch.rand([1,3,224,224])

y = model(test_in)
model.deploy(mode=True,show_error=False)

fused_y = model(test_in)

print("final error :", torch.max(torch.abs(fused_y - y )).item())

torch.onnx.export(model,test_in,"model.onnx")

# Use this for converting a customized model with RepVGG as one of its components (e.g., the backbone of a semantic segmentation model)
# The use case will be like
# 1. Build train_model. For example, build a PSPNet with a training-time RepVGG as backbone
# 2. Train train_model or do whatever you want
# 3. Build deploy_model. In the above example, that will be a PSPNet with an inference-time RepVGG as backbone
# 4. Call this func
# ====================== the pseudo code will be like
# train_backbone = create_RepVGG_B2(deploy=False)
# train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth'))
# train_pspnet = build_pspnet(backbone=train_backbone)
# segmentation_train(train_pspnet)
# deploy_backbone = create_RepVGG_B2(deploy=True)
# deploy_pspnet = build_pspnet(backbone=deploy_backbone)
# whole_model_convert(train_pspnet, deploy_pspnet)
# segmentation_test(deploy_pspnet)
def whole_model_convert(train_model:torch.nn.Module, deploy_model:torch.nn.Module, save_path=None):
all_weights = {}
for name, module in train_model.named_modules():
if hasattr(module, 'repvgg_convert'):
kernel, bias = module.repvgg_convert()
all_weights[name + '.rbr_reparam.weight'] = kernel
all_weights[name + '.rbr_reparam.bias'] = bias
print('convert RepVGG block')
else:
for p_name, p_tensor in module.named_parameters():
full_name = name + '.' + p_name
if full_name not in all_weights:
all_weights[full_name] = p_tensor.detach().cpu().numpy()
for p_name, p_tensor in module.named_buffers():
full_name = name + '.' + p_name
if full_name not in all_weights:
all_weights[full_name] = p_tensor.cpu().numpy()

deploy_model.load_state_dict(all_weights)
if save_path is not None:
torch.save(deploy_model.state_dict(), save_path)

return deploy_model


# Use this when converting a RepVGG without customized structures.
# train_model = create_RepVGG_A0(deploy=False)
# train train_model
# deploy_model = repvgg_convert(train_model, create_RepVGG_A0, save_path='repvgg_deploy.pth')
def repvgg_model_convert(model:torch.nn.Module, build_func, save_path=None):
converted_weights = {}
for name, module in model.named_modules():
if hasattr(module, 'repvgg_convert'):
kernel, bias = module.repvgg_convert()
converted_weights[name + '.rbr_reparam.weight'] = kernel
converted_weights[name + '.rbr_reparam.bias'] = bias
elif isinstance(module, torch.nn.Linear):
converted_weights[name + '.weight'] = module.weight.detach().cpu().numpy()
converted_weights[name + '.bias'] = module.bias.detach().cpu().numpy()
del model

deploy_model = build_func(deploy=True)
for name, param in deploy_model.named_parameters():
print('deploy param: ', name, param.size(), np.mean(converted_weights[name]))
param.data = torch.from_numpy(converted_weights[name]).float()

if save_path is not None:
torch.save(deploy_model.state_dict(), save_path)

return deploy_model
5 changes: 4 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def test():

repvgg_build_func = get_RepVGG_func_by_name(args.arch)

model = repvgg_build_func(deploy=args.mode=='deploy')
model = repvgg_build_func()

model.deploy(args.mode=='deploy')


if not torch.cuda.is_available():
print('using CPU, this will be slow')
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def main_worker(gpu, ngpus_per_node, args):

repvgg_build_func = get_RepVGG_func_by_name(args.arch)

model = repvgg_build_func(deploy=False)
model = repvgg_build_func()
model.deploy(mode=False)

if not torch.cuda.is_available():
print('using CPU, this will be slow')
Expand Down