diff --git a/tests/test_pattern_matcher.py b/tests/test_pattern_matcher.py index dacda48..49cb31f 100644 --- a/tests/test_pattern_matcher.py +++ b/tests/test_pattern_matcher.py @@ -36,7 +36,7 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename) - summary = summarize_model(slim(filename)) + summary = summarize_model(slim(filename, model_check=True)) print_model_info_as_table(request.node.name, [summary]) def test_pad_conv(self, request): @@ -69,7 +69,7 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename) - summary = summarize_model(slim(filename)) + summary = summarize_model(slim(filename, model_check=True)) print_model_info_as_table(request.node.name, [summary]) assert summary["op_type_counts"]["Conv"] == 2 @@ -98,7 +98,7 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename, do_constant_folding=False) - summary = summarize_model(slim(filename)) + summary = summarize_model(slim(filename, model_check=True)) print_model_info_as_table(request.node.name, [summary]) assert summary["op_type_counts"]["Conv"] == 1 @@ -125,7 +125,7 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename) - summary = summarize_model(slim(filename)) + summary = summarize_model(slim(filename, model_check=True)) print_model_info_as_table(request.node.name, [summary]) assert summary["op_type_counts"]["Slice"] == 1 @@ -148,7 +148,7 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename) - summary = summarize_model(slim(filename)) + summary = summarize_model(slim(filename, model_check=True)) print_model_info_as_table(request.node.name, [summary]) assert summary["op_type_counts"]["Reshape"] == 1 @@ -174,7 +174,7 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename) - summary = summarize_model(slim(filename)) + summary = summarize_model(slim(filename, model_check=True)) print_model_info_as_table(request.node.name, [summary]) assert summary["op_type_counts"]["Gemm"] == 1 @@ -203,7 +203,7 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename, opset_version=11) - summary = summarize_model(slim(filename)) + summary = summarize_model(slim(filename, model_check=True)) print_model_info_as_table(request.node.name, [summary]) assert summary["op_type_counts"]["ReduceSum"] == 1 @@ -234,7 +234,7 @@ def forward(self, x): filename = f"{directory}/{request.node.name}.onnx" torch.onnx.export(m, input, filename, opset_version=opset) - summary = summarize_model(slim(filename)) + summary = summarize_model(slim(filename, model_check=True)) print_model_info_as_table(request.node.name, [summary]) assert summary["op_type_counts"]["Unsqueeze"] == 1