Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Jun 21, 2024
1 parent d59f900 commit 28a9eea
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
2 changes: 1 addition & 1 deletion tests/test_onnx_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,4 @@ def test_timm(self, request, model_name):

if __name__ == "__main__":
warnings.filterwarnings("ignore")
pytest.main(["-p", "no:warnings", "-n", "10", "-v", "tests/test_onnx_nets.py"])
pytest.main(["-p", "no:warnings", "-v", "tests/test_onnx_nets.py"])
5 changes: 4 additions & 1 deletion tests/test_pattern_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def rewrite(self):
raise Exception("Pattern Matched")

register_fusion_pattern(GeluMatcher(pattern, 1))
slim(model_filename, f"{directory}/{request.node.name}_slim.onnx")
with pytest.raises(Exception) as excinfo:
slim(model_filename, f"{directory}/{request.node.name}_slim.onnx")

assert str(excinfo.value) == "Pattern Matched"


if __name__ == "__main__":
Expand Down
36 changes: 29 additions & 7 deletions tests/test_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn as nn

from onnxslim import slim
from onnxslim.utils import summarize_model, print_model_info_as_table


class TestPatternMatcher:
Expand Down Expand Up @@ -35,7 +36,9 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename)

slim(filename, filename)
summary = summarize_model(slim(filename))
print_model_info_as_table(request.node.name, [summary])


def test_pad_conv(self, request):
"""Test padding followed by 2D convolution within a neural network module."""
Expand Down Expand Up @@ -67,7 +70,12 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename)

slim(filename, filename)
summary = summarize_model(slim(filename))
print_model_info_as_table(request.node.name, [summary])

assert summary["op_type_counts"]['Conv'] == 2
assert summary["op_type_counts"]['Add'] == 1


def test_conv_bn(self, request):
"""Test the convolutional layer followed by batch normalization export and re-import via ONNX."""
Expand All @@ -92,7 +100,10 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename, do_constant_folding=False)

slim(filename, filename)
summary = summarize_model(slim(filename))
print_model_info_as_table(request.node.name, [summary])
assert summary["op_type_counts"]['Conv'] == 1


def test_consecutive_slice(self, request):
"""Tests consecutive slicing operations on a model by exporting it to ONNX format and then slimming the ONNX
Expand All @@ -117,7 +128,10 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename)

slim(filename, filename)
summary = summarize_model(slim(filename))
print_model_info_as_table(request.node.name, [summary])
assert summary["op_type_counts"]['Slice'] == 1


def test_consecutive_reshape(self, request):
"""Test the functionality of consecutive reshape operations in a model and export it to ONNX format."""
Expand All @@ -138,7 +152,10 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename)

slim(filename, filename)
summary = summarize_model(slim(filename))
print_model_info_as_table(request.node.name, [summary])
assert summary["op_type_counts"]['Reshape'] == 1


def test_matmul_add(self, request):
"""Tests matrix multiplication followed by an addition operation within a neural network model."""
Expand All @@ -162,7 +179,10 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename)

slim(filename, filename)
summary = summarize_model(slim(filename))
print_model_info_as_table(request.node.name, [summary])
assert summary["op_type_counts"]['Gemm'] == 1


def test_reduce(self, request):
"""Tests model reduction by exporting a PyTorch model to ONNX format, slimming it, and saving to a specified
Expand All @@ -189,7 +209,9 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename, opset_version=11)

slim(filename, filename)
summary = summarize_model(slim(filename))
print_model_info_as_table(request.node.name, [summary])
assert summary["op_type_counts"]['ReduceSum'] == 1


if __name__ == "__main__":
Expand Down

0 comments on commit 28a9eea

Please sign in to comment.