From 28a9eea2842fbe9e2a21cfc10a86819cf680686a Mon Sep 17 00:00:00 2001 From: inisis Date: Fri, 21 Jun 2024 17:27:09 +0000 Subject: [PATCH] update test --- tests/test_onnx_nets.py | 2 +- tests/test_pattern_generator.py | 5 ++++- tests/test_pattern_matcher.py | 36 ++++++++++++++++++++++++++------- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/tests/test_onnx_nets.py b/tests/test_onnx_nets.py index f45d4d5..9daf160 100644 --- a/tests/test_onnx_nets.py +++ b/tests/test_onnx_nets.py @@ -82,4 +82,4 @@ def test_timm(self, request, model_name): if __name__ == "__main__": warnings.filterwarnings("ignore") - pytest.main(["-p", "no:warnings", "-n", "10", "-v", "tests/test_onnx_nets.py"]) + pytest.main(["-p", "no:warnings", "-v", "tests/test_onnx_nets.py"]) diff --git a/tests/test_pattern_generator.py b/tests/test_pattern_generator.py index 9fd61c6..fc43169 100644 --- a/tests/test_pattern_generator.py +++ b/tests/test_pattern_generator.py @@ -72,7 +72,10 @@ def rewrite(self): raise Exception("Pattern Matched") register_fusion_pattern(GeluMatcher(pattern, 1)) - slim(model_filename, f"{directory}/{request.node.name}_slim.onnx") + with pytest.raises(Exception) as excinfo: + slim(model_filename, f"{directory}/{request.node.name}_slim.onnx") + + assert str(excinfo.value) == "Pattern Matched" if __name__ == "__main__": diff --git a/tests/test_pattern_matcher.py b/tests/test_pattern_matcher.py index 9e58618..19787db 100644 --- a/tests/test_pattern_matcher.py +++ b/tests/test_pattern_matcher.py @@ -5,6 +5,7 @@ import torch.nn as nn from onnxslim import slim +from onnxslim.utils import summarize_model, print_model_info_as_table class TestPatternMatcher: @@ -35,7 +36,9 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename) - slim(filename, filename) + summary = summarize_model(slim(filename)) + print_model_info_as_table(request.node.name, [summary]) + def test_pad_conv(self, request): """Test padding followed by 2D convolution within a neural network module.""" @@ -67,7 +70,12 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename) - slim(filename, filename) + summary = summarize_model(slim(filename)) + print_model_info_as_table(request.node.name, [summary]) + + assert summary["op_type_counts"]['Conv'] == 2 + assert summary["op_type_counts"]['Add'] == 1 + def test_conv_bn(self, request): """Test the convolutional layer followed by batch normalization export and re-import via ONNX.""" @@ -92,7 +100,10 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename, do_constant_folding=False) - slim(filename, filename) + summary = summarize_model(slim(filename)) + print_model_info_as_table(request.node.name, [summary]) + assert summary["op_type_counts"]['Conv'] == 1 + def test_consecutive_slice(self, request): """Tests consecutive slicing operations on a model by exporting it to ONNX format and then slimming the ONNX @@ -117,7 +128,10 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename) - slim(filename, filename) + summary = summarize_model(slim(filename)) + print_model_info_as_table(request.node.name, [summary]) + assert summary["op_type_counts"]['Slice'] == 1 + def test_consecutive_reshape(self, request): """Test the functionality of consecutive reshape operations in a model and export it to ONNX format.""" @@ -138,7 +152,10 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename) - slim(filename, filename) + summary = summarize_model(slim(filename)) + print_model_info_as_table(request.node.name, [summary]) + assert summary["op_type_counts"]['Reshape'] == 1 + def test_matmul_add(self, request): """Tests matrix multiplication followed by an addition operation within a neural network model.""" @@ -162,7 +179,10 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename) - slim(filename, filename) + summary = summarize_model(slim(filename)) + print_model_info_as_table(request.node.name, [summary]) + assert summary["op_type_counts"]['Gemm'] == 1 + def test_reduce(self, request): """Tests model reduction by exporting a PyTorch model to ONNX format, slimming it, and saving to a specified @@ -189,7 +209,9 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename, opset_version=11) - slim(filename, filename) + summary = summarize_model(slim(filename)) + print_model_info_as_table(request.node.name, [summary]) + assert summary["op_type_counts"]['ReduceSum'] == 1 if __name__ == "__main__":