diff --git a/onnxslim/argparser.py b/onnxslim/argparser.py index b5a3a1a..769b909 100644 --- a/onnxslim/argparser.py +++ b/onnxslim/argparser.py @@ -1,4 +1,6 @@ +import sys import argparse +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter import dataclasses from dataclasses import dataclass, field from typing import List, Optional, Type @@ -109,8 +111,11 @@ class CheckerArguments: verbose: bool = field(default=False, metadata={"help": "verbose mode, default False."}) -class ArgumentParser: - def __init__(self, *argument_dataclasses: Type): +class OnnxSlimArgumentParser(ArgumentParser): + def __init__(self, *argument_dataclasses: Type, **kwargs): + if "formatter_class" not in kwargs: + kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter + super().__init__(**kwargs) self.argument_dataclasses = argument_dataclasses self.parser = argparse.ArgumentParser( description="OnnxSlim: A Toolkit to Help Optimizer Onnx Model", diff --git a/onnxslim/cli/_main.py b/onnxslim/cli/_main.py index 6bcbbd2..f02460f 100644 --- a/onnxslim/cli/_main.py +++ b/onnxslim/cli/_main.py @@ -123,14 +123,14 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs): def main(): """Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function.""" from onnxslim.argparser import ( - ArgumentParser, + OnnxSlimArgumentParser, CheckerArguments, ModelArguments, ModificationArguments, OptimizationArguments, ) - argument_parser = ArgumentParser(ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments) + argument_parser = OnnxSlimArgumentParser(ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments) model_args, optimization_args, modification_args, checker_args = argument_parser.parse_args_into_dataclasses() if checker_args.inspect and model_args.output_model: