Skip to content

Commit

Permalink
enable model_check
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Jul 31, 2024
1 parent 34efce2 commit 1ab514e
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tests/test_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename)

summary = summarize_model(slim(filename))
summary = summarize_model(slim(filename, model_check=True))
print_model_info_as_table(request.node.name, [summary])

def test_pad_conv(self, request):
Expand Down Expand Up @@ -69,7 +69,7 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename)

summary = summarize_model(slim(filename))
summary = summarize_model(slim(filename, model_check=True))
print_model_info_as_table(request.node.name, [summary])

assert summary["op_type_counts"]["Conv"] == 2
Expand Down Expand Up @@ -98,7 +98,7 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename, do_constant_folding=False)

summary = summarize_model(slim(filename))
summary = summarize_model(slim(filename, model_check=True))
print_model_info_as_table(request.node.name, [summary])
assert summary["op_type_counts"]["Conv"] == 1

Expand All @@ -125,7 +125,7 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename)

summary = summarize_model(slim(filename))
summary = summarize_model(slim(filename, model_check=True))
print_model_info_as_table(request.node.name, [summary])
assert summary["op_type_counts"]["Slice"] == 1

Expand All @@ -148,7 +148,7 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename)

summary = summarize_model(slim(filename))
summary = summarize_model(slim(filename, model_check=True))
print_model_info_as_table(request.node.name, [summary])
assert summary["op_type_counts"]["Reshape"] == 1

Expand All @@ -174,7 +174,7 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename)

summary = summarize_model(slim(filename))
summary = summarize_model(slim(filename, model_check=True))
print_model_info_as_table(request.node.name, [summary])
assert summary["op_type_counts"]["Gemm"] == 1

Expand Down Expand Up @@ -203,7 +203,7 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename, opset_version=11)

summary = summarize_model(slim(filename))
summary = summarize_model(slim(filename, model_check=True))
print_model_info_as_table(request.node.name, [summary])
assert summary["op_type_counts"]["ReduceSum"] == 1

Expand Down Expand Up @@ -234,7 +234,7 @@ def forward(self, x):
filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, filename, opset_version=opset)

summary = summarize_model(slim(filename))
summary = summarize_model(slim(filename, model_check=True))
print_model_info_as_table(request.node.name, [summary])
assert summary["op_type_counts"]["Unsqueeze"] == 1

Expand Down

0 comments on commit 1ab514e

Please sign in to comment.