From a3798ecdcb4971a7b31e89fcfadf6189d9e979e2 Mon Sep 17 00:00:00 2001 From: Judyxujj Date: Fri, 22 Dec 2023 15:54:04 +0100 Subject: [PATCH] extern data size configurable --- tools/torch_export_to_onnx.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tools/torch_export_to_onnx.py b/tools/torch_export_to_onnx.py index 5e54b58d8..9e106c44f 100644 --- a/tools/torch_export_to_onnx.py +++ b/tools/torch_export_to_onnx.py @@ -182,6 +182,13 @@ def main(): parser.add_argument("out_onnx_filename", type=str, help="Filename of the final ONNX model.") parser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)") parser.add_argument("--device", type=str, default="cpu", help="'cpu' (default) or 'gpu'.") + parser.add_argument( + "--dyn_dim_min_sizes", type=dict, default=None, help="Specify min sizes for dim tags with dynamic sizes" + ) + parser.add_argument( + "--dyn_dim_max_sizes", type=dict, default=None, help="Specify max sizes for dim tags with dynamic sizes" + ) + args = parser.parse_args() init(config_filename=args.config, checkpoint=args.checkpoint, log_verbosity=args.verbosity, device=args.device) @@ -223,7 +230,9 @@ def main(): if not v.available_for_inference: del extern_data.data[k] - tensor_dict_fill_random_numpy_(extern_data) + tensor_dict_fill_random_numpy_( + extern_data, dyn_dim_max_sizes=args.dyn_dim_max_sizes, dyn_dim_min_sizes=args.dyn_dim_min_sizes + ) tensor_dict_numpy_to_torch_(extern_data) extern_data_raw = extern_data.as_raw_tensor_dict(include_scalar_dyn_sizes=False, exclude_duplicate_dims=True) model_outputs_raw_keys = _get_model_outputs_raw_keys()