From 979a716c9d79c7d365ccc56e540ff8bb41cd7a10 Mon Sep 17 00:00:00 2001 From: Mathijs de Boer Date: Fri, 1 Dec 2023 17:09:42 +0100 Subject: [PATCH] Switch from dynamo export to regular export Dynamo seems to be a bit buggy still --- nnunetv2/model_sharing/entry_points.py | 8 +++++ nnunetv2/model_sharing/onnx_export.py | 50 +++++++++++++++++++++----- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/nnunetv2/model_sharing/entry_points.py b/nnunetv2/model_sharing/entry_points.py index a7ce98cb5..8d356adc1 100644 --- a/nnunetv2/model_sharing/entry_points.py +++ b/nnunetv2/model_sharing/entry_points.py @@ -160,6 +160,13 @@ def export_pretrained_model_onnx_entry(): required=False, help="Set this to export the cross-validation predictions as well", ) + parser.add_argument( + "-v", + action="store_false", + default=False, + required=False, + help="Set this to get verbose output", + ) args = parser.parse_args() print("######################################################") @@ -191,4 +198,5 @@ def export_pretrained_model_onnx_entry(): folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk, + verbose=args.v, ) diff --git a/nnunetv2/model_sharing/onnx_export.py b/nnunetv2/model_sharing/onnx_export.py index 872e3937c..b20379393 100644 --- a/nnunetv2/model_sharing/onnx_export.py +++ b/nnunetv2/model_sharing/onnx_export.py @@ -3,6 +3,9 @@ from pathlib import Path from typing import Tuple, Union +import numpy as np +import onnx +import onnxruntime import torch from torch._dynamo import OptimizedModule @@ -27,6 +30,7 @@ def export_onnx_model( strict: bool = True, save_checkpoints: Tuple[str, ...] = ("checkpoint_final.pth",), output_names: tuple[str, ...] = None, + verbose: bool = False, ) -> None: if not output_names: output_names = (f"{checkpoint[:-4]}.onnx" for checkpoint in save_checkpoints) @@ -71,14 +75,6 @@ def export_onnx_model( network.eval() - export_options = torch.onnx.ExportOptions(dynamic_shapes=True) - rand_input = torch.rand((1, 1, *config.patch_size)) - traced_model = torch.onnx.dynamo_export( - network, - rand_input, - export_options=export_options, - ) - curr_output_dir = output_dir / c / f"fold_{fold}" if not curr_output_dir.exists(): curr_output_dir.mkdir(parents=True) @@ -88,7 +84,43 @@ def export_onnx_model( f"Output directory {curr_output_dir} is not empty" ) - traced_model.save(str(curr_output_dir / output_name)) + rand_input = torch.rand((1, 1, *config.patch_size)) + torch_output = network(rand_input) + + torch.onnx.export( + network, + rand_input, + curr_output_dir / output_name, + export_params=True, + verbose=verbose, + input_names=["input"], + output_names=["output"], + dynamic_axes={ + "input": {0: "batch_size"}, + "output": {0: "batch_size"}, + }, + ) + + onnx_model = onnx.load(curr_output_dir / output_name) + onnx.checker.check_model(onnx_model) + + ort_session = onnxruntime.InferenceSession( + curr_output_dir / output_name, providers=["CPUExecutionProvider"] + ) + ort_inputs = {ort_session.get_inputs()[0].name: rand_input.numpy()} + ort_outs = ort_session.run(None, ort_inputs) + + np.testing.assert_allclose( + torch_output.detach().cpu().numpy(), + ort_outs[0], + rtol=1e-03, + atol=1e-05, + ) + + print( + f"Successfully exported and verified {curr_output_dir / output_name}" + ) + with open(curr_output_dir / "config.json", "w") as f: json.dump( {