-
Notifications
You must be signed in to change notification settings - Fork 638
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #65 from manthan3C273/master
Inference with onnx (Image matting)
- Loading branch information
Showing
5 changed files
with
455 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Inference with onnxruntime | ||
|
||
Please try MODNet image matting onnx-inference demo with [Colab Notebook](https://colab.research.google.com/drive/1P3cWtg8fnmu9karZHYDAtmm1vj1rgA-f?usp=sharing) | ||
|
||
Download [modnet.onnx](https://drive.google.com/file/d/1cgycTQlYXpTh26gB9FTnthE7AvruV8hd/view?usp=sharing) | ||
|
||
### 1. Export onnx model | ||
|
||
Run the following command: | ||
```shell | ||
python export_modnet_onnx.py \ | ||
--ckpt-path=pretrained/modnet_photographic_portrait_matting.ckpt \ | ||
--output-path=modnet.onnx | ||
``` | ||
|
||
|
||
### 2. Inference | ||
|
||
Run the following command: | ||
```shell | ||
python inference_onnx.py \ | ||
--image-path=PATH_TO_IMAGE \ | ||
--output-path=matte.png \ | ||
--model-path=modnet.onnx | ||
``` | ||
|
55 changes: 55 additions & 0 deletions
55
demo/image_matting/Inference_with_ONNX/export_modnet_onnx.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
""" | ||
Export onnx model | ||
Arguments: | ||
--ckpt-path --> Path of last checkpoint to load | ||
--output-path --> path of onnx model to be saved | ||
example: | ||
python export_modnet_onnx.py \ | ||
--ckpt-path=modnet_photographic_portrait_matting.ckpt \ | ||
--output-path=modnet.onnx | ||
output: | ||
ONNX model with dynamic input shape: (batch_size, 3, height, width) & | ||
output shape: (batch_size, 1, height, width) | ||
""" | ||
import os | ||
import argparse | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
from src.models.onnx_modnet import MODNet | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
# define cmd arguments | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--ckpt-path', type=str, required=True, help='path of pre-trained MODNet') | ||
parser.add_argument('--output-path', type=str, required=True, help='path of output onnx model') | ||
args = parser.parse_args() | ||
|
||
# check input arguments | ||
if not os.path.exists(args.ckpt_path): | ||
print('Cannot find checkpoint path: {0}'.format(args.ckpt_path)) | ||
exit() | ||
|
||
# define model & load checkpoint | ||
modnet = MODNet(backbone_pretrained=False) | ||
modnet = nn.DataParallel(modnet).cuda() | ||
state_dict = torch.load(args.ckpt_path) | ||
modnet.load_state_dict(state_dict) | ||
modnet.eval() | ||
|
||
# prepare dummy_input | ||
batch_size = 1 | ||
height = 512 | ||
width = 512 | ||
dummy_input = Variable(torch.randn(batch_size, 3, height, width)).cuda() | ||
|
||
# export to onnx model | ||
torch.onnx.export(modnet.module, dummy_input, args.output_path, export_params = True, opset_version=11, | ||
input_names = ['input'], output_names = ['output'], | ||
dynamic_axes = {'input': {0:'batch_size', 2:'height', 3:'width'}, | ||
'output': {0: 'batch_size', 2: 'height', 3: 'width'}}) |
116 changes: 116 additions & 0 deletions
116
demo/image_matting/Inference_with_ONNX/inference_onnx.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
""" | ||
Inference with onnxruntime | ||
Arguments: | ||
--image-path --> path to single input image | ||
--output-path --> paht to save generated matte | ||
--model-path --> path to onnx model file | ||
example: | ||
python inference_onnx.py \ | ||
--image-path=demo.jpg \ | ||
--output-path=matte.png \ | ||
--model-path=modnet.onnx | ||
Optional: | ||
Generate transparent image without background | ||
""" | ||
import os | ||
import argparse | ||
import cv2 | ||
import numpy as np | ||
import onnx | ||
import onnxruntime | ||
from onnx import helper | ||
from PIL import Image | ||
|
||
if __name__ == '__main__': | ||
# define cmd arguments | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--image-path', type=str, help='path of input image') | ||
parser.add_argument('--output-path', type=str, help='path of output image') | ||
parser.add_argument('--model-path', type=str, help='path of onnx model') | ||
args = parser.parse_args() | ||
|
||
# check input arguments | ||
if not os.path.exists(args.image_path): | ||
print('Cannot find input path: {0}'.format(args.image_path)) | ||
exit() | ||
if not os.path.exists(args.model_path): | ||
print('Cannot find model path: {0}'.format(args.model_path)) | ||
exit() | ||
|
||
ref_size = 512 | ||
|
||
# Get x_scale_factor & y_scale_factor to resize image | ||
def get_scale_factor(im_h, im_w, ref_size): | ||
|
||
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: | ||
if im_w >= im_h: | ||
im_rh = ref_size | ||
im_rw = int(im_w / im_h * ref_size) | ||
elif im_w < im_h: | ||
im_rw = ref_size | ||
im_rh = int(im_h / im_w * ref_size) | ||
else: | ||
im_rh = im_h | ||
im_rw = im_w | ||
|
||
im_rw = im_rw - im_rw % 32 | ||
im_rh = im_rh - im_rh % 32 | ||
|
||
x_scale_factor = im_rw / im_w | ||
y_scale_factor = im_rh / im_h | ||
|
||
return x_scale_factor, y_scale_factor | ||
|
||
############################################## | ||
# Main Inference part | ||
############################################## | ||
|
||
# read image | ||
im = cv2.imread(args.image_path) | ||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) | ||
|
||
# unify image channels to 3 | ||
if len(im.shape) == 2: | ||
im = im[:, :, None] | ||
if im.shape[2] == 1: | ||
im = np.repeat(im, 3, axis=2) | ||
elif im.shape[2] == 4: | ||
im = im[:, :, 0:3] | ||
|
||
# normalize values to scale it between -1 to 1 | ||
im = (im - 127.5) / 127.5 | ||
|
||
im_h, im_w, im_c = im.shape | ||
x, y = get_scale_factor(im_h, im_w, ref_size) | ||
|
||
# resize image | ||
im = cv2.resize(im, None, fx = x, fy = y, interpolation = cv2.INTER_AREA) | ||
|
||
# prepare input shape | ||
im = np.transpose(im) | ||
im = np.swapaxes(im, 1, 2) | ||
im = np.expand_dims(im, axis = 0).astype('float32') | ||
|
||
# Initialize session and get prediction | ||
session = onnxruntime.InferenceSession(args.model_path, None) | ||
input_name = session.get_inputs()[0].name | ||
output_name = session.get_outputs()[0].name | ||
result = session.run([output_name], {input_name: im}) | ||
|
||
# refine matte | ||
matte = (np.squeeze(result[0]) * 255).astype('uint8') | ||
matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation = cv2.INTER_AREA) | ||
|
||
cv2.imwrite(args.output_path, matte) | ||
|
||
############################################## | ||
# Optional - save png image without background | ||
############################################## | ||
|
||
# im_PIL = Image.open(args.image_path) | ||
# matte = Image.fromarray(matte) | ||
# im_PIL.putalpha(matte) # add alpha channel to keep transparency | ||
# im_PIL.save('without_background.png') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
onnx==1.8.1 | ||
onnxruntime==1.6.0 | ||
opencv-python==4.5.1.48 | ||
torch==1.7.1 |
Oops, something went wrong.