diff --git a/demo/image_matting/Inference_with_ONNX/README.md b/demo/image_matting/Inference_with_ONNX/README.md new file mode 100644 index 0000000..0403fb9 --- /dev/null +++ b/demo/image_matting/Inference_with_ONNX/README.md @@ -0,0 +1,26 @@ +# Inference with onnxruntime + +Please try MODNet image matting onnx-inference demo with [Colab Notebook](https://colab.research.google.com/drive/1P3cWtg8fnmu9karZHYDAtmm1vj1rgA-f?usp=sharing) + +Download [modnet.onnx](https://drive.google.com/file/d/1cgycTQlYXpTh26gB9FTnthE7AvruV8hd/view?usp=sharing) + +### 1. Export onnx model + +Run the following command: +```shell +python export_modnet_onnx.py \ + --ckpt-path=pretrained/modnet_photographic_portrait_matting.ckpt \ + --output-path=modnet.onnx +``` + + +### 2. Inference + +Run the following command: +```shell +python inference_onnx.py \ + --image-path=PATH_TO_IMAGE \ + --output-path=matte.png \ + --model-path=modnet.onnx +``` + diff --git a/demo/image_matting/Inference_with_ONNX/export_modnet_onnx.py b/demo/image_matting/Inference_with_ONNX/export_modnet_onnx.py new file mode 100644 index 0000000..fe5ced7 --- /dev/null +++ b/demo/image_matting/Inference_with_ONNX/export_modnet_onnx.py @@ -0,0 +1,55 @@ +""" +Export onnx model + +Arguments: + --ckpt-path --> Path of last checkpoint to load + --output-path --> path of onnx model to be saved + +example: +python export_modnet_onnx.py \ + --ckpt-path=modnet_photographic_portrait_matting.ckpt \ + --output-path=modnet.onnx + +output: +ONNX model with dynamic input shape: (batch_size, 3, height, width) & + output shape: (batch_size, 1, height, width) +""" +import os +import argparse +import torch +import torch.nn as nn +from torch.autograd import Variable +from src.models.onnx_modnet import MODNet + + + +if __name__ == '__main__': + # define cmd arguments + parser = argparse.ArgumentParser() + parser.add_argument('--ckpt-path', type=str, required=True, help='path of pre-trained MODNet') + parser.add_argument('--output-path', type=str, required=True, help='path of output onnx model') + args = parser.parse_args() + + # check input arguments + if not os.path.exists(args.ckpt_path): + print('Cannot find checkpoint path: {0}'.format(args.ckpt_path)) + exit() + + # define model & load checkpoint + modnet = MODNet(backbone_pretrained=False) + modnet = nn.DataParallel(modnet).cuda() + state_dict = torch.load(args.ckpt_path) + modnet.load_state_dict(state_dict) + modnet.eval() + + # prepare dummy_input + batch_size = 1 + height = 512 + width = 512 + dummy_input = Variable(torch.randn(batch_size, 3, height, width)).cuda() + + # export to onnx model + torch.onnx.export(modnet.module, dummy_input, args.output_path, export_params = True, opset_version=11, + input_names = ['input'], output_names = ['output'], + dynamic_axes = {'input': {0:'batch_size', 2:'height', 3:'width'}, + 'output': {0: 'batch_size', 2: 'height', 3: 'width'}}) diff --git a/demo/image_matting/Inference_with_ONNX/inference_onnx.py b/demo/image_matting/Inference_with_ONNX/inference_onnx.py new file mode 100644 index 0000000..cccfa23 --- /dev/null +++ b/demo/image_matting/Inference_with_ONNX/inference_onnx.py @@ -0,0 +1,116 @@ +""" +Inference with onnxruntime + +Arguments: + --image-path --> path to single input image + --output-path --> paht to save generated matte + --model-path --> path to onnx model file + +example: +python inference_onnx.py \ + --image-path=demo.jpg \ + --output-path=matte.png \ + --model-path=modnet.onnx + +Optional: +Generate transparent image without background +""" +import os +import argparse +import cv2 +import numpy as np +import onnx +import onnxruntime +from onnx import helper +from PIL import Image + +if __name__ == '__main__': + # define cmd arguments + parser = argparse.ArgumentParser() + parser.add_argument('--image-path', type=str, help='path of input image') + parser.add_argument('--output-path', type=str, help='path of output image') + parser.add_argument('--model-path', type=str, help='path of onnx model') + args = parser.parse_args() + + # check input arguments + if not os.path.exists(args.image_path): + print('Cannot find input path: {0}'.format(args.image_path)) + exit() + if not os.path.exists(args.model_path): + print('Cannot find model path: {0}'.format(args.model_path)) + exit() + + ref_size = 512 + + # Get x_scale_factor & y_scale_factor to resize image + def get_scale_factor(im_h, im_w, ref_size): + + if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: + if im_w >= im_h: + im_rh = ref_size + im_rw = int(im_w / im_h * ref_size) + elif im_w < im_h: + im_rw = ref_size + im_rh = int(im_h / im_w * ref_size) + else: + im_rh = im_h + im_rw = im_w + + im_rw = im_rw - im_rw % 32 + im_rh = im_rh - im_rh % 32 + + x_scale_factor = im_rw / im_w + y_scale_factor = im_rh / im_h + + return x_scale_factor, y_scale_factor + + ############################################## + # Main Inference part + ############################################## + + # read image + im = cv2.imread(args.image_path) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + + # unify image channels to 3 + if len(im.shape) == 2: + im = im[:, :, None] + if im.shape[2] == 1: + im = np.repeat(im, 3, axis=2) + elif im.shape[2] == 4: + im = im[:, :, 0:3] + + # normalize values to scale it between -1 to 1 + im = (im - 127.5) / 127.5 + + im_h, im_w, im_c = im.shape + x, y = get_scale_factor(im_h, im_w, ref_size) + + # resize image + im = cv2.resize(im, None, fx = x, fy = y, interpolation = cv2.INTER_AREA) + + # prepare input shape + im = np.transpose(im) + im = np.swapaxes(im, 1, 2) + im = np.expand_dims(im, axis = 0).astype('float32') + + # Initialize session and get prediction + session = onnxruntime.InferenceSession(args.model_path, None) + input_name = session.get_inputs()[0].name + output_name = session.get_outputs()[0].name + result = session.run([output_name], {input_name: im}) + + # refine matte + matte = (np.squeeze(result[0]) * 255).astype('uint8') + matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation = cv2.INTER_AREA) + + cv2.imwrite(args.output_path, matte) + + ############################################## + # Optional - save png image without background + ############################################## + + # im_PIL = Image.open(args.image_path) + # matte = Image.fromarray(matte) + # im_PIL.putalpha(matte) # add alpha channel to keep transparency + # im_PIL.save('without_background.png') \ No newline at end of file diff --git a/demo/image_matting/Inference_with_ONNX/requirements.txt b/demo/image_matting/Inference_with_ONNX/requirements.txt new file mode 100644 index 0000000..3dfd20d --- /dev/null +++ b/demo/image_matting/Inference_with_ONNX/requirements.txt @@ -0,0 +1,4 @@ +onnx==1.8.1 +onnxruntime==1.6.0 +opencv-python==4.5.1.48 +torch==1.7.1 \ No newline at end of file diff --git a/src/models/onnx_modnet.py b/src/models/onnx_modnet.py new file mode 100644 index 0000000..6fa5a41 --- /dev/null +++ b/src/models/onnx_modnet.py @@ -0,0 +1,254 @@ +""" +This file is a modified version of the original file modnet.py without +"pred_semantic" and "pred_details" as these both returns None when "inference = True" + +And it does not contain "inference" argument which will make it easier to +convert checkpoint into onnx model. + +Refer: 'demo/image_matting/inference_with_ONNX/export_modnet_onnx.py' to export model. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbones import SUPPORTED_BACKBONES + + +#------------------------------------------------------------------------------ +# MODNet Basic Modules +#------------------------------------------------------------------------------ + +class IBNorm(nn.Module): + """ Combine Instance Norm and Batch Norm into One Layer + """ + + def __init__(self, in_channels): + super(IBNorm, self).__init__() + in_channels = in_channels + self.bnorm_channels = int(in_channels / 2) + self.inorm_channels = in_channels - self.bnorm_channels + + self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) + self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) + + def forward(self, x): + bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) + in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) + + return torch.cat((bn_x, in_x), 1) + + +class Conv2dIBNormRelu(nn.Module): + """ Convolution + IBNorm + ReLu + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True, + with_ibn=True, with_relu=True): + super(Conv2dIBNormRelu, self).__init__() + + layers = [ + nn.Conv2d(in_channels, out_channels, kernel_size, + stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + ] + + if with_ibn: + layers.append(IBNorm(out_channels)) + if with_relu: + layers.append(nn.ReLU(inplace=True)) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class SEBlock(nn.Module): + """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf + """ + + def __init__(self, in_channels, out_channels, reduction=1): + super(SEBlock, self).__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(in_channels, int(in_channels // reduction), bias=False), + nn.ReLU(inplace=True), + nn.Linear(int(in_channels // reduction), out_channels, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + w = self.pool(x).view(b, c) + w = self.fc(w).view(b, c, 1, 1) + + return x * w.expand_as(x) + + +#------------------------------------------------------------------------------ +# MODNet Branches +#------------------------------------------------------------------------------ + +class LRBranch(nn.Module): + """ Low Resolution Branch of MODNet + """ + + def __init__(self, backbone): + super(LRBranch, self).__init__() + + enc_channels = backbone.enc_channels + + self.backbone = backbone + self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) + self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) + self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2) + self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False) + + def forward(self, img): + enc_features = self.backbone.forward(img) + enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4] + + enc32x = self.se_block(enc32x) + lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False) + lr16x = self.conv_lr16x(lr16x) + lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False) + lr8x = self.conv_lr8x(lr8x) + + return lr8x, [enc2x, enc4x] + + +class HRBranch(nn.Module): + """ High Resolution Branch of MODNet + """ + + def __init__(self, hr_channels, enc_channels): + super(HRBranch, self).__init__() + + self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0) + self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1) + + self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0) + self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) + + self.conv_hr4x = nn.Sequential( + Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), + ) + + self.conv_hr2x = nn.Sequential( + Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), + ) + + self.conv_hr = nn.Sequential( + Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False), + ) + + def forward(self, img, enc2x, enc4x, lr8x): + img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False) + img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False) + + enc2x = self.tohr_enc2x(enc2x) + hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1)) + + enc4x = self.tohr_enc4x(enc4x) + hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1)) + + lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) + hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1)) + + hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False) + hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1)) + + return hr2x + + +class FusionBranch(nn.Module): + """ Fusion Branch of MODNet + """ + + def __init__(self, hr_channels, enc_channels): + super(FusionBranch, self).__init__() + self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) + + self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) + self.conv_f = nn.Sequential( + Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), + Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False), + ) + + def forward(self, img, lr8x, hr2x): + lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) + lr4x = self.conv_lr4x(lr4x) + lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False) + + f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1)) + f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False) + f = self.conv_f(torch.cat((f, img), dim=1)) + pred_matte = torch.sigmoid(f) + + return pred_matte + + +#------------------------------------------------------------------------------ +# MODNet +#------------------------------------------------------------------------------ + +class MODNet(nn.Module): + """ Architecture of MODNet + """ + + def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True): + super(MODNet, self).__init__() + + self.in_channels = in_channels + self.hr_channels = hr_channels + self.backbone_arch = backbone_arch + self.backbone_pretrained = backbone_pretrained + + self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels) + + self.lr_branch = LRBranch(self.backbone) + self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels) + self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + self._init_conv(m) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): + self._init_norm(m) + + if self.backbone_pretrained: + self.backbone.load_pretrained_ckpt() + + def forward(self, img): + lr8x, [enc2x, enc4x] = self.lr_branch(img) + hr2x = self.hr_branch(img, enc2x, enc4x, lr8x) + pred_matte = self.f_branch(img, lr8x, hr2x) + + return pred_matte + + def freeze_norm(self): + norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d] + for m in self.modules(): + for n in norm_types: + if isinstance(m, n): + m.eval() + continue + + def _init_conv(self, conv): + nn.init.kaiming_uniform_( + conv.weight, a=0, mode='fan_in', nonlinearity='relu') + if conv.bias is not None: + nn.init.constant_(conv.bias, 0) + + def _init_norm(self, norm): + if norm.weight is not None: + nn.init.constant_(norm.weight, 1) + nn.init.constant_(norm.bias, 0)