Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TorchScript version #78

Merged
merged 2 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions TorchScript/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
## Usage:

```shell

python export_torchscript.py \
--ckpt-path pretrained/modnet_photographic_portrait_matting.ckpt\
--out-dir scripted_model
```

## Official TorchScript model:

[BaiduCloudDisk](https://pan.baidu.com/s/1kOmmmbG7lSZiSmDdE7CaRw), extract_code=dm9e
Empty file added TorchScript/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions TorchScript/export_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from . import modnet_torchscript

if __name__ == '__main__':

parser = argparse.ArgumentParser()
parser.add_argument('--ckpt-path', type=str, help='path of pre-trained MODNet')
parser.add_argument('--out-dir', type=str, required=True, help='path for saving the TorchScript model')
args = parser.parse_args()

# check input arguments
if not os.path.exists(args.ckpt_path):
print('Cannot find checkpoint path: {0}'.format(args.ckpt_path))
exit()

if not os.path.exists(args.out_dir):
os.mkdir(args.out_dir)

# create MODNet and load the pre-trained ckpt
modnet = MODNet(backbone_pretrained=True)
# modnet = nn.DataParallel(modnet).cuda()
modnet = modnet.cuda()
ckpt = torch.load(args.ckpt)

# if use more than one GPU
if 'module.' in ckpt.keys():
ckpt = OrderedDict()
for k, v in ckpt.items():
k = k.replace('module.', '')
ckpt[k] = v

modnet.load_state_dict(ckpt)
modnet.eval()

scripted_model = torch.jit.script(modnet)
torch.jit.save(scripted_model, os.path.join(args.out_dir,'modnet.pt'))

275 changes: 275 additions & 0 deletions TorchScript/modnet_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

# from .backbones import SUPPORTED_BACKBONES
from .backbones import SUPPORTED_BACKBONES


#------------------------------------------------------------------------------
# MODNet Basic Modules
#------------------------------------------------------------------------------

class IBNorm(nn.Module):
""" Combine Instance Norm and Batch Norm into One Layer
对一半channel做BN,一半做IN
"""

def __init__(self, in_channels):
super(IBNorm, self).__init__()
in_channels = in_channels
self.bnorm_channels = int(in_channels / 2)
self.inorm_channels = in_channels - self.bnorm_channels

self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)

def forward(self, x):
bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())

return torch.cat((bn_x, in_x), 1)


class Conv2dIBNormRelu(nn.Module):
""" Convolution + IBNorm + ReLu
"""

def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, groups=1, bias=True,
with_ibn=True, with_relu=True):
super(Conv2dIBNormRelu, self).__init__()

layers = [
nn.Conv2d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
]

if with_ibn:
layers.append(IBNorm(out_channels))
if with_relu:
layers.append(nn.ReLU(inplace=True))

self.layers = nn.Sequential(*layers)

def forward(self, x):
return self.layers(x)


class SEBlock(nn.Module):
""" SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
通道 Attention
"""

def __init__(self, in_channels, out_channels, reduction=1):
super(SEBlock, self).__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, int(in_channels // reduction), bias=False),
nn.ReLU(inplace=True),
nn.Linear(int(in_channels // reduction), out_channels, bias=False),
nn.Sigmoid()
)

def forward(self, x):
b, c, _, _ = x.size()
w = self.pool(x).view(b, c)
w = self.fc(w).view(b, c, 1, 1)

return x * w.expand_as(x)


#------------------------------------------------------------------------------
# MODNet Branches
#------------------------------------------------------------------------------

class LRBranch(nn.Module):
""" Low Resolution Branch of MODNet
"""

def __init__(self, backbone):
super(LRBranch, self).__init__()

enc_channels = backbone.enc_channels
# ==> self.enc_channels = [16, 24, 32, 96, 1280]

self.backbone = backbone
self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)

def forward(self, img, inference):
enc_features = self.backbone.forward(img)
enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]

# 对最后一层进行通道注意力
enc32x = self.se_block(enc32x)
# 再上采样4倍
lr16x = F.interpolate(enc32x, scale_factor=2.0, mode='bilinear', align_corners=False)
lr16x = self.conv_lr16x(lr16x)
lr8x = F.interpolate(lr16x, scale_factor=2.0, mode='bilinear', align_corners=False)
lr8x = self.conv_lr8x(lr8x)

pred_semantic = torch.tensor([]) # None
if not inference:
lr = self.conv_lr(lr8x)
pred_semantic = torch.sigmoid(lr)

return pred_semantic, lr8x, [enc2x, enc4x]


class HRBranch(nn.Module):
""" High Resolution Branch of MODNet
"""

def __init__(self, hr_channels, enc_channels):
super(HRBranch, self).__init__()

self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)

self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)

self.conv_hr4x = nn.Sequential(
Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
)

self.conv_hr2x = nn.Sequential(
Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
)

self.conv_hr = nn.Sequential(
Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
)

def forward(self, img, enc2x, enc4x, lr8x, inference):
img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)
img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)

enc2x = self.tohr_enc2x(enc2x)
# 把原图叠加到通道上
hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))

# 把两个 featmap 连接
enc4x = self.tohr_enc4x(enc4x)
hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))

lr4x = F.interpolate(lr8x, scale_factor=2.0, mode='bilinear', align_corners=False)
hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))

hr2x = F.interpolate(hr4x, scale_factor=2.0, mode='bilinear', align_corners=False)
hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))

pred_detail = torch.tensor([]) # None
if not inference:
hr = F.interpolate(hr2x, scale_factor=2.0, mode='bilinear', align_corners=False)
hr = self.conv_hr(torch.cat((hr, img), dim=1))
pred_detail = torch.sigmoid(hr)

return pred_detail, hr2x


class FusionBranch(nn.Module):
""" Fusion Branch of MODNet
"""

def __init__(self, hr_channels, enc_channels):
super(FusionBranch, self).__init__()
self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)

self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
self.conv_f = nn.Sequential(
Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
)

def forward(self, img, lr8x, hr2x):
lr4x = F.interpolate(lr8x, scale_factor=2.0, mode='bilinear', align_corners=False)
lr4x = self.conv_lr4x(lr4x)
lr2x = F.interpolate(lr4x, scale_factor=2.0, mode='bilinear', align_corners=False)

f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
f = F.interpolate(f2x, scale_factor=2.0, mode='bilinear', align_corners=False)
f = self.conv_f(torch.cat((f, img), dim=1))
pred_matte = torch.sigmoid(f)

return pred_matte


#------------------------------------------------------------------------------
# MODNet
#------------------------------------------------------------------------------

class MODNet(nn.Module):
""" Architecture of MODNet
"""

def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
super(MODNet, self).__init__()

self.in_channels = in_channels
self.hr_channels = hr_channels
self.backbone_arch = backbone_arch
self.backbone_pretrained = backbone_pretrained

self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)

self.lr_branch = LRBranch(self.backbone)
self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)

for m in self.modules():
if isinstance(m, nn.Conv2d):
self._init_conv(m)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
self._init_norm(m)

if self.backbone_pretrained:
self.backbone.load_pretrained_ckpt()

def forward(self, img, inference):
pred_semantic = self.lr_branch(img, inference)[0]
lr8x = self.lr_branch(img, inference)[1]
enc2x = self.lr_branch(img, inference)[2][0]
enc4x = self.lr_branch(img, inference)[2][1]

pred_detail = self.hr_branch(img, enc2x, enc4x, lr8x, inference)[0]
hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)[1]

pred_matte = self.f_branch(img, lr8x, hr2x)

return pred_semantic, pred_detail, pred_matte

def freeze_norm(self):
norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
for m in self.modules():
for n in norm_types:
if isinstance(m, n):
m.eval()
continue

def _init_conv(self, conv):
nn.init.kaiming_uniform_(
conv.weight, a=0, mode='fan_in', nonlinearity='relu')
if conv.bias is not None:
nn.init.constant_(conv.bias, 0)

def _init_norm(self, norm):
if norm.weight is not None:
nn.init.constant_(norm.weight, 1)
nn.init.constant_(norm.bias, 0)


if __name__ == "__main__":
IbNorm = IBNorm(20)
out = IbNorm(torch.randn((1,3,224,224)))
print(out.shape)
26 changes: 20 additions & 6 deletions src/models/backbones/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,31 @@ def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
# Initialize weights
self._init_weights()

def forward(self, x, feature_names=None):
def forward(self, x):
# Stage1
x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x)
x = self.features[0](x)
x = self.features[1](x)
# Stage2
x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x)
x = self.features[2](x)
x = self.features[3](x)
# Stage3
x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x)
x = self.features[4](x)
x = self.features[5](x)
x = self.features[6](x)
# Stage4
x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x)
x = self.features[7](x)
x = self.features[8](x)
x = self.features[9](x)
x = self.features[10](x)
x = self.features[11](x)
x = self.features[12](x)
x = self.features[13](x)
# Stage5
x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x)
x = self.features[14](x)
x = self.features[15](x)
x = self.features[16](x)
x = self.features[17](x)
x = self.features[18](x)

# Classification
if self.num_classes is not None:
Expand Down
Loading