-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fixup * Remove from tracing shape check * Add notebook with example * Update readme examples * Fixup
- Loading branch information
Showing
3 changed files
with
226 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# to make onnx export work\n", | ||
"!pip install onnx onnxruntime" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"See complete tutorial in Pytorch docs:\n", | ||
" - https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import onnx\n", | ||
"import onnxruntime\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"import torch\n", | ||
"import segmentation_models_pytorch as smp" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Create random model (or load your own model)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = smp.Unet(\"resnet34\", encoder_weights=\"imagenet\", classes=1)\n", | ||
"model = model.eval()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Export the model to ONNX" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# dynamic_axes is used to specify the variable length axes. it can be just batch size\n", | ||
"dynamic_axes = {0: \"batch_size\", 2: \"height\", 3: \"width\"}\n", | ||
"\n", | ||
"onnx_model_name = \"unet_resnet34.onnx\"\n", | ||
"\n", | ||
"onnx_model = torch.onnx.export(\n", | ||
" model, # model being run\n", | ||
" torch.randn(1, 3, 224, 224), # model input\n", | ||
" onnx_model_name, # where to save the model (can be a file or file-like object)\n", | ||
" export_params=True, # store the trained parameter weights inside the model file\n", | ||
" opset_version=17, # the ONNX version to export\n", | ||
" do_constant_folding=True, # whether to execute constant folding for optimization\n", | ||
" input_names=[\"input\"], # the model's input names\n", | ||
" output_names=[\"output\"], # the model's output names\n", | ||
" dynamic_axes={ # variable length axes\n", | ||
" \"input\": dynamic_axes,\n", | ||
" \"output\": dynamic_axes,\n", | ||
" },\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# check with onnx first\n", | ||
"onnx_model = onnx.load(onnx_model_name)\n", | ||
"onnx.checker.check_model(onnx_model)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Run with onnxruntime" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[array([[[[-1.41701847e-01, -4.63768840e-03, 1.21411584e-01, ...,\n", | ||
" 5.22197843e-01, 3.40217263e-01, 8.52423906e-02],\n", | ||
" [-2.29843616e-01, 2.19401851e-01, 3.53053480e-01, ...,\n", | ||
" 2.79466838e-01, 3.20288718e-01, -2.22393833e-02],\n", | ||
" [-3.12503517e-01, -3.66358161e-02, 1.19251609e-02, ...,\n", | ||
" -5.48991561e-02, 3.71140465e-02, -1.82842150e-01],\n", | ||
" ...,\n", | ||
" [-3.02772015e-01, -4.22928065e-01, -1.49621412e-01, ...,\n", | ||
" -1.42241001e-01, -9.90390778e-02, -1.33311331e-01],\n", | ||
" [-1.08293816e-01, -1.28070369e-01, -5.43620177e-02, ...,\n", | ||
" -8.64556879e-02, -1.74177170e-01, 6.03154302e-03],\n", | ||
" [-1.29619062e-01, -2.96604559e-02, -2.86361389e-03, ...,\n", | ||
" -1.91345289e-01, -1.82653710e-01, 1.17175849e-02]]],\n", | ||
" \n", | ||
" \n", | ||
" [[[-6.16237633e-02, 1.12350248e-01, 1.59193069e-01, ...,\n", | ||
" 4.03313845e-01, 2.26862252e-01, 7.33022243e-02],\n", | ||
" [-1.60109222e-01, 1.21696621e-01, 1.84655115e-01, ...,\n", | ||
" 1.20978586e-01, 2.45723248e-01, 1.00066036e-01],\n", | ||
" [-2.11992145e-01, 1.71708465e-02, -1.57656223e-02, ...,\n", | ||
" -1.11918494e-01, -1.64519548e-01, -1.73958957e-01],\n", | ||
" ...,\n", | ||
" [-2.79706120e-01, -2.87421644e-01, -5.19880295e-01, ...,\n", | ||
" -8.30744207e-02, -3.48939300e-02, 1.26617640e-01],\n", | ||
" [-2.62198627e-01, -2.91804910e-01, -2.82318443e-01, ...,\n", | ||
" 1.81179233e-02, 2.32534595e-02, 1.85002953e-01],\n", | ||
" [-9.28771719e-02, -5.16399741e-05, -9.53909755e-03, ...,\n", | ||
" -2.28582099e-02, -5.09671569e-02, 2.05268264e-02]]]],\n", | ||
" dtype=float32)]" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# create sample with different batch size, height and width\n", | ||
"# from what we used in export above\n", | ||
"sample = torch.randn(2, 3, 512, 512)\n", | ||
"\n", | ||
"ort_session = onnxruntime.InferenceSession(\n", | ||
" onnx_model_name, providers=[\"CPUExecutionProvider\"]\n", | ||
")\n", | ||
"\n", | ||
"# compute ONNX Runtime output prediction\n", | ||
"ort_inputs = {\"input\": sample.numpy()}\n", | ||
"ort_outputs = ort_session.run(output_names=None, input_feed=ort_inputs)\n", | ||
"ort_outputs" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Verify it's the same as for pytorch model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Exported model has been tested with ONNXRuntime, and the result looks good!\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# compute PyTorch output prediction\n", | ||
"with torch.no_grad():\n", | ||
" torch_out = model(sample)\n", | ||
"\n", | ||
"# compare ONNX Runtime and PyTorch results\n", | ||
"np.testing.assert_allclose(torch_out.numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)\n", | ||
"\n", | ||
"print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters