Skip to content

Commit

Permalink
fix argparser
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Nov 14, 2024
1 parent 9ad1d37 commit c492ffd
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
9 changes: 7 additions & 2 deletions onnxslim/argparser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import sys
import argparse
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
import dataclasses
from dataclasses import dataclass, field
from typing import List, Optional, Type
Expand Down Expand Up @@ -109,8 +111,11 @@ class CheckerArguments:
verbose: bool = field(default=False, metadata={"help": "verbose mode, default False."})


class ArgumentParser:
def __init__(self, *argument_dataclasses: Type):
class OnnxSlimArgumentParser(ArgumentParser):
def __init__(self, *argument_dataclasses: Type, **kwargs):
if "formatter_class" not in kwargs:
kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
super().__init__(**kwargs)
self.argument_dataclasses = argument_dataclasses
self.parser = argparse.ArgumentParser(
description="OnnxSlim: A Toolkit to Help Optimizer Onnx Model",
Expand Down
4 changes: 2 additions & 2 deletions onnxslim/cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs):
def main():
"""Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function."""
from onnxslim.argparser import (
ArgumentParser,
OnnxSlimArgumentParser,
CheckerArguments,
ModelArguments,
ModificationArguments,
OptimizationArguments,
)

argument_parser = ArgumentParser(ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments)
argument_parser = OnnxSlimArgumentParser(ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments)
model_args, optimization_args, modification_args, checker_args = argument_parser.parse_args_into_dataclasses()

if checker_args.inspect and model_args.output_model:
Expand Down

0 comments on commit c492ffd

Please sign in to comment.