Skip to content

Commit

Permalink
Add onnx tutorial (#990)
Browse files Browse the repository at this point in the history
* fixup

* Remove from tracing shape check

* Add notebook with example

* Update readme examples

* Fixup
  • Loading branch information
qubvel authored Nov 29, 2024
1 parent 076f684 commit 737b24f
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
- Training model for cars segmentation on CamVid dataset [here](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/cars%20segmentation%20(camvid).ipynb).
- Training SMP model with [Catalyst](https://github.com/catalyst-team/catalyst) (high-level framework for PyTorch), [TTAch](https://github.com/qubvel/ttach) (TTA library for PyTorch) and [Albumentations](https://github.com/albu/albumentations) (fast image augmentation library) - [here](https://github.com/catalyst-team/catalyst/blob/v21.02rc0/examples/notebooks/segmentation-tutorial.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/v21.02rc0/examples/notebooks/segmentation-tutorial.ipynb)
- Training SMP model with [Pytorch-Lightning](https://pytorch-lightning.readthedocs.io) framework - [here](https://github.com/ternaus/cloths_segmentation) (clothes binary segmentation by [@ternaus](https://github.com/ternaus)).
- Export trained model to ONNX - [notebook](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb)

### 📦 Models <a name="models"></a>

Expand Down
223 changes: 223 additions & 0 deletions examples/convert_to_onnx.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# to make onnx export work\n",
"!pip install onnx onnxruntime"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See complete tutorial in Pytorch docs:\n",
" - https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import onnx\n",
"import onnxruntime\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import segmentation_models_pytorch as smp"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create random model (or load your own model)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"model = smp.Unet(\"resnet34\", encoder_weights=\"imagenet\", classes=1)\n",
"model = model.eval()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Export the model to ONNX"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# dynamic_axes is used to specify the variable length axes. it can be just batch size\n",
"dynamic_axes = {0: \"batch_size\", 2: \"height\", 3: \"width\"}\n",
"\n",
"onnx_model_name = \"unet_resnet34.onnx\"\n",
"\n",
"onnx_model = torch.onnx.export(\n",
" model, # model being run\n",
" torch.randn(1, 3, 224, 224), # model input\n",
" onnx_model_name, # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=17, # the ONNX version to export\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=[\"input\"], # the model's input names\n",
" output_names=[\"output\"], # the model's output names\n",
" dynamic_axes={ # variable length axes\n",
" \"input\": dynamic_axes,\n",
" \"output\": dynamic_axes,\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# check with onnx first\n",
"onnx_model = onnx.load(onnx_model_name)\n",
"onnx.checker.check_model(onnx_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run with onnxruntime"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array([[[[-1.41701847e-01, -4.63768840e-03, 1.21411584e-01, ...,\n",
" 5.22197843e-01, 3.40217263e-01, 8.52423906e-02],\n",
" [-2.29843616e-01, 2.19401851e-01, 3.53053480e-01, ...,\n",
" 2.79466838e-01, 3.20288718e-01, -2.22393833e-02],\n",
" [-3.12503517e-01, -3.66358161e-02, 1.19251609e-02, ...,\n",
" -5.48991561e-02, 3.71140465e-02, -1.82842150e-01],\n",
" ...,\n",
" [-3.02772015e-01, -4.22928065e-01, -1.49621412e-01, ...,\n",
" -1.42241001e-01, -9.90390778e-02, -1.33311331e-01],\n",
" [-1.08293816e-01, -1.28070369e-01, -5.43620177e-02, ...,\n",
" -8.64556879e-02, -1.74177170e-01, 6.03154302e-03],\n",
" [-1.29619062e-01, -2.96604559e-02, -2.86361389e-03, ...,\n",
" -1.91345289e-01, -1.82653710e-01, 1.17175849e-02]]],\n",
" \n",
" \n",
" [[[-6.16237633e-02, 1.12350248e-01, 1.59193069e-01, ...,\n",
" 4.03313845e-01, 2.26862252e-01, 7.33022243e-02],\n",
" [-1.60109222e-01, 1.21696621e-01, 1.84655115e-01, ...,\n",
" 1.20978586e-01, 2.45723248e-01, 1.00066036e-01],\n",
" [-2.11992145e-01, 1.71708465e-02, -1.57656223e-02, ...,\n",
" -1.11918494e-01, -1.64519548e-01, -1.73958957e-01],\n",
" ...,\n",
" [-2.79706120e-01, -2.87421644e-01, -5.19880295e-01, ...,\n",
" -8.30744207e-02, -3.48939300e-02, 1.26617640e-01],\n",
" [-2.62198627e-01, -2.91804910e-01, -2.82318443e-01, ...,\n",
" 1.81179233e-02, 2.32534595e-02, 1.85002953e-01],\n",
" [-9.28771719e-02, -5.16399741e-05, -9.53909755e-03, ...,\n",
" -2.28582099e-02, -5.09671569e-02, 2.05268264e-02]]]],\n",
" dtype=float32)]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# create sample with different batch size, height and width\n",
"# from what we used in export above\n",
"sample = torch.randn(2, 3, 512, 512)\n",
"\n",
"ort_session = onnxruntime.InferenceSession(\n",
" onnx_model_name, providers=[\"CPUExecutionProvider\"]\n",
")\n",
"\n",
"# compute ONNX Runtime output prediction\n",
"ort_inputs = {\"input\": sample.numpy()}\n",
"ort_outputs = ort_session.run(output_names=None, input_feed=ort_inputs)\n",
"ort_outputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Verify it's the same as for pytorch model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exported model has been tested with ONNXRuntime, and the result looks good!\n"
]
}
],
"source": [
"# compute PyTorch output prediction\n",
"with torch.no_grad():\n",
" torch_out = model(sample)\n",
"\n",
"# compare ONNX Runtime and PyTorch results\n",
"np.testing.assert_allclose(torch_out.numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)\n",
"\n",
"print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 2 additions & 1 deletion segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def check_input_shape(self, x):
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""

self.check_input_shape(x)
if not torch.jit.is_tracing():
self.check_input_shape(x)

features = self.encoder(x)
decoder_output = self.decoder(*features)
Expand Down

0 comments on commit 737b24f

Please sign in to comment.