From c492ffd624f2e498b751522f27fddde714b7e226 Mon Sep 17 00:00:00 2001 From: inisis Date: Fri, 15 Nov 2024 00:27:58 +0800 Subject: [PATCH] fix argparser --- onnxslim/argparser.py | 9 +++++++-- onnxslim/cli/_main.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/onnxslim/argparser.py b/onnxslim/argparser.py index b5a3a1a..769b909 100644 --- a/onnxslim/argparser.py +++ b/onnxslim/argparser.py @@ -1,4 +1,6 @@ +import sys import argparse +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter import dataclasses from dataclasses import dataclass, field from typing import List, Optional, Type @@ -109,8 +111,11 @@ class CheckerArguments: verbose: bool = field(default=False, metadata={"help": "verbose mode, default False."}) -class ArgumentParser: - def __init__(self, *argument_dataclasses: Type): +class OnnxSlimArgumentParser(ArgumentParser): + def __init__(self, *argument_dataclasses: Type, **kwargs): + if "formatter_class" not in kwargs: + kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter + super().__init__(**kwargs) self.argument_dataclasses = argument_dataclasses self.parser = argparse.ArgumentParser( description="OnnxSlim: A Toolkit to Help Optimizer Onnx Model", diff --git a/onnxslim/cli/_main.py b/onnxslim/cli/_main.py index 6bcbbd2..f02460f 100644 --- a/onnxslim/cli/_main.py +++ b/onnxslim/cli/_main.py @@ -123,14 +123,14 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs): def main(): """Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function.""" from onnxslim.argparser import ( - ArgumentParser, + OnnxSlimArgumentParser, CheckerArguments, ModelArguments, ModificationArguments, OptimizationArguments, ) - argument_parser = ArgumentParser(ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments) + argument_parser = OnnxSlimArgumentParser(ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments) model_args, optimization_args, modification_args, checker_args = argument_parser.parse_args_into_dataclasses() if checker_args.inspect and model_args.output_model: