diff --git a/examples/common_subexpression_elimination/cse_demo.py b/examples/common_subexpression_elimination/cse_demo.py index 0e97a5f..3805f00 100644 --- a/examples/common_subexpression_elimination/cse_demo.py +++ b/examples/common_subexpression_elimination/cse_demo.py @@ -4,9 +4,11 @@ class Model(torch.nn.Module): + """A PyTorch model applying LayerNorm to input tensors for normalization in neural network layers.""" + def __init__(self): """Initializes the Model class with a single LayerNorm layer of embedding dimension 10.""" - super(Model, self).__init__() + super().__init__() embedding_dim = 10 self.layer_norm = nn.LayerNorm(embedding_dim) diff --git a/onnxslim/argparser.py b/onnxslim/argparser.py index b5a3a1a..59d6a4b 100644 --- a/onnxslim/argparser.py +++ b/onnxslim/argparser.py @@ -110,7 +110,10 @@ class CheckerArguments: class ArgumentParser: + """Parses command-line arguments into specified dataclasses for ONNX model optimization and modification tasks.""" + def __init__(self, *argument_dataclasses: Type): + """Initializes the ArgumentParser with dataclass types for parsing ONNX model optimization arguments.""" self.argument_dataclasses = argument_dataclasses self.parser = argparse.ArgumentParser( description="OnnxSlim: A Toolkit to Help Optimizer Onnx Model", @@ -119,6 +122,7 @@ def __init__(self, *argument_dataclasses: Type): self._add_arguments() def _add_arguments(self): + """Adds command-line arguments to the parser based on provided dataclass fields and their metadata.""" for dataclass_type in self.argument_dataclasses: for field_name, field_def in dataclass_type.__dataclass_fields__.items(): arg_type = field_def.type @@ -150,6 +154,7 @@ def _add_arguments(self): self.parser.add_argument("-v", "--version", action="version", version=onnxslim.__version__) def parse_args_into_dataclasses(self): + """Parses command-line arguments into specified dataclass instances for structured configuration.""" args = self.parser.parse_args() args_dict = vars(args) diff --git a/onnxslim/cli/_main.py b/onnxslim/cli/_main.py index 4c9c3d7..c39ecfd 100644 --- a/onnxslim/cli/_main.py +++ b/onnxslim/cli/_main.py @@ -4,6 +4,7 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs): + """Slims an ONNX model by optimizing and modifying its structure, inputs, and outputs for improved performance.""" import os import time from pathlib import Path @@ -11,9 +12,9 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs): from onnxslim.core import ( convert_data_format, freeze, + input_modification, input_shape_modification, optimize, - input_modification, output_modification, shape_infer, ) @@ -29,20 +30,20 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs): summarize_model, ) - output_model = args[0] if len(args) > 0 else kwargs.get('output_model', None) - model_check = kwargs.get('model_check', False) - input_shapes = kwargs.get('input_shapes', None) - inputs = kwargs.get('inputs', None) - outputs = kwargs.get('outputs', None) - no_shape_infer = kwargs.get('no_shape_infer', False) - no_constant_folding = kwargs.get('no_constant_folding', False) - dtype = kwargs.get('dtype', None) - skip_fusion_patterns = kwargs.get('skip_fusion_patterns', None) - inspect = kwargs.get('inspect', False) - dump_to_disk = kwargs.get('dump_to_disk', False) - save_as_external_data = kwargs.get('save_as_external_data', False) - model_check_inputs = kwargs.get('model_check_inputs', None) - verbose = kwargs.get('verbose', False) + output_model = args[0] if args else kwargs.get("output_model", None) + model_check = kwargs.get("model_check", False) + input_shapes = kwargs.get("input_shapes", None) + inputs = kwargs.get("inputs", None) + outputs = kwargs.get("outputs", None) + no_shape_infer = kwargs.get("no_shape_infer", False) + no_constant_folding = kwargs.get("no_constant_folding", False) + dtype = kwargs.get("dtype", None) + skip_fusion_patterns = kwargs.get("skip_fusion_patterns", None) + inspect = kwargs.get("inspect", False) + dump_to_disk = kwargs.get("dump_to_disk", False) + save_as_external_data = kwargs.get("save_as_external_data", False) + model_check_inputs = kwargs.get("model_check_inputs", None) + verbose = kwargs.get("verbose", False) logger = init_logging(verbose) diff --git a/onnxslim/core/optimization/weight_tying.py b/onnxslim/core/optimization/weight_tying.py index 20c8f1a..24aed85 100644 --- a/onnxslim/core/optimization/weight_tying.py +++ b/onnxslim/core/optimization/weight_tying.py @@ -31,10 +31,7 @@ def replace_constant_references(existing_constant, to_be_removed_constant): for i, constant_tensor in enumerate(constant_tensors): if keep_constants[i]: for j in range(i + 1, len(constant_tensors)): - if keep_constants[j]: - if constant_tensor == constant_tensors[j]: - keep_constants[j] = False - replace_constant_references(constant_tensor, constant_tensors[j]) - logger.debug( - f"Constant {constant_tensors[j].name} can be replaced by {constant_tensor.name}" - ) + if keep_constants[j] and constant_tensor == constant_tensors[j]: + keep_constants[j] = False + replace_constant_references(constant_tensor, constant_tensors[j]) + logger.debug(f"Constant {constant_tensors[j].name} can be replaced by {constant_tensor.name}") diff --git a/onnxslim/core/pattern/__init__.py b/onnxslim/core/pattern/__init__.py index 32bf0d5..fa9991e 100644 --- a/onnxslim/core/pattern/__init__.py +++ b/onnxslim/core/pattern/__init__.py @@ -44,6 +44,8 @@ def get_name(name): class NodeDescriptor: + """Represents a node in a computational graph, detailing its operation type, inputs, and outputs.""" + def __init__(self, node_spec): """Initialize NodeDescriptor with node_spec list requiring at least 4 elements.""" if not isinstance(node_spec, list): @@ -87,6 +89,8 @@ def __dict__(self): class Pattern: + """Parses and matches ONNX graph patterns into NodeDescriptor objects for model optimization tasks.""" + def __init__(self, pattern): """Initialize the Pattern class with a given pattern and parse its nodes.""" self.pattern = pattern @@ -109,6 +113,8 @@ def __repr__(self): class PatternMatcher: + """Matches computational graph nodes to predefined patterns for optimization and transformation tasks.""" + def __init__(self, pattern, priority): """Initialize the PatternMatcher with a given pattern and priority, and prepare node references and output names. @@ -184,6 +190,8 @@ def parameter_check(self): class PatternGenerator: + """Generates pattern templates from an ONNX model by processing its graph structure and node connections.""" + def __init__(self, onnx_model): """Initialize the PatternGenerator class with an ONNX model and process its graph.""" self.graph = gs.import_onnx(onnx_model) diff --git a/onnxslim/core/pattern/elimination/reshape.py b/onnxslim/core/pattern/elimination/reshape.py index 90eeeae..20bb1d6 100644 --- a/onnxslim/core/pattern/elimination/reshape.py +++ b/onnxslim/core/pattern/elimination/reshape.py @@ -6,6 +6,8 @@ class ReshapePatternMatcher(PatternMatcher): + """Matches and optimizes nested reshape operations in computational graphs to eliminate redundancy.""" + def __init__(self, priority): """Initializes the ReshapePatternMatcher with a priority and a specific pattern for detecting nested reshape operations. diff --git a/onnxslim/core/pattern/elimination/slice.py b/onnxslim/core/pattern/elimination/slice.py index 14bec9c..e303cad 100644 --- a/onnxslim/core/pattern/elimination/slice.py +++ b/onnxslim/core/pattern/elimination/slice.py @@ -6,6 +6,8 @@ class SlicePatternMatcher(PatternMatcher): + """Matches and optimizes nested slice operations in ONNX graphs to improve computational efficiency.""" + def __init__(self, priority): """Initializes the SlicePatternMatcher with a specified priority using a predefined graph pattern.""" pattern = Pattern( diff --git a/onnxslim/core/pattern/elimination/unsqueeze.py b/onnxslim/core/pattern/elimination/unsqueeze.py index c4fcf8d..e69962d 100644 --- a/onnxslim/core/pattern/elimination/unsqueeze.py +++ b/onnxslim/core/pattern/elimination/unsqueeze.py @@ -6,6 +6,8 @@ class UnsqueezePatternMatcher(PatternMatcher): + """Matches and optimizes nested unsqueeze patterns in ONNX graphs to improve computational efficiency.""" + def __init__(self, priority): """Initializes the UnsqueezePatternMatcher with a specified priority using a predefined graph pattern.""" pattern = Pattern( @@ -29,53 +31,60 @@ def rewrite(self, opset=11): node_unsqueeze_0 = self.unsqueeze_0 users_node_unsqueeze_0 = get_node_users(node_unsqueeze_0) node_unsqueeze_1 = self.unsqueeze_1 - if len(users_node_unsqueeze_0) == 1 and node_unsqueeze_0.inputs[0].shape and node_unsqueeze_1.inputs[0].shape: - if opset < 13 or ( - isinstance(node_unsqueeze_0.inputs[1], gs.Constant) - and isinstance(node_unsqueeze_1.inputs[1], gs.Constant) - ): + if ( + len(users_node_unsqueeze_0) == 1 + and node_unsqueeze_0.inputs[0].shape + and node_unsqueeze_1.inputs[0].shape + and ( + opset < 13 + or ( + isinstance(node_unsqueeze_0.inputs[1], gs.Constant) + and isinstance(node_unsqueeze_1.inputs[1], gs.Constant) + ) + ) + ): - def get_unsqueeze_axes(unsqueeze_node, opset): - dim = len(unsqueeze_node.inputs[0].shape) - if opset < 13: - axes = unsqueeze_node.attrs["axes"] - else: - axes = unsqueeze_node.inputs[1].values - return [axis + dim + len(axes) if axis < 0 else axis for axis in axes] + def get_unsqueeze_axes(unsqueeze_node, opset): + dim = len(unsqueeze_node.inputs[0].shape) + if opset < 13: + axes = unsqueeze_node.attrs["axes"] + else: + axes = unsqueeze_node.inputs[1].values + return [axis + dim + len(axes) if axis < 0 else axis for axis in axes] - axes_node_unsqueeze_0 = get_unsqueeze_axes(node_unsqueeze_0, opset) - axes_node_unsqueeze_1 = get_unsqueeze_axes(node_unsqueeze_1, opset) + axes_node_unsqueeze_0 = get_unsqueeze_axes(node_unsqueeze_0, opset) + axes_node_unsqueeze_1 = get_unsqueeze_axes(node_unsqueeze_1, opset) - axes_node_unsqueeze_0 = [ - axis + sum(1 for axis_ in axes_node_unsqueeze_1 if axis_ <= axis) for axis in axes_node_unsqueeze_0 - ] + axes_node_unsqueeze_0 = [ + axis + sum(bool(axis_ <= axis) for axis_ in axes_node_unsqueeze_1) for axis in axes_node_unsqueeze_0 + ] - inputs = [node_unsqueeze_0.inputs[0]] - outputs = list(node_unsqueeze_1.outputs) - node_unsqueeze_0.inputs.clear() - node_unsqueeze_0.outputs.clear() - node_unsqueeze_1.inputs.clear() - node_unsqueeze_1.outputs.clear() + inputs = [node_unsqueeze_0.inputs[0]] + outputs = list(node_unsqueeze_1.outputs) + node_unsqueeze_0.inputs.clear() + node_unsqueeze_0.outputs.clear() + node_unsqueeze_1.inputs.clear() + node_unsqueeze_1.outputs.clear() - if opset < 13: - attrs = {"axes": axes_node_unsqueeze_0 + axes_node_unsqueeze_1} - else: - attrs = None - inputs.append( - gs.Constant( - name=f"{node_unsqueeze_0.name}_axes", - values=np.array(axes_node_unsqueeze_0 + axes_node_unsqueeze_1, dtype=np.int64), - ) + if opset < 13: + attrs = {"axes": axes_node_unsqueeze_0 + axes_node_unsqueeze_1} + else: + attrs = None + inputs.append( + gs.Constant( + name=f"{node_unsqueeze_0.name}_axes", + values=np.array(axes_node_unsqueeze_0 + axes_node_unsqueeze_1, dtype=np.int64), ) + ) - match_case[node_unsqueeze_0.name] = { - "op": "Unsqueeze", - "inputs": inputs, - "outputs": outputs, - "name": node_unsqueeze_0.name, - "attrs": attrs, - "domain": None, - } + match_case[node_unsqueeze_0.name] = { + "op": "Unsqueeze", + "inputs": inputs, + "outputs": outputs, + "name": node_unsqueeze_0.name, + "attrs": attrs, + "domain": None, + } return match_case diff --git a/onnxslim/core/pattern/fusion/convbn.py b/onnxslim/core/pattern/fusion/convbn.py index f4aa0e0..d6bba3a 100644 --- a/onnxslim/core/pattern/fusion/convbn.py +++ b/onnxslim/core/pattern/fusion/convbn.py @@ -6,6 +6,8 @@ class ConvBatchNormMatcher(PatternMatcher): + """Fuses Conv and BatchNormalization layers in an ONNX graph to optimize model performance and inference speed.""" + def __init__(self, priority): """Initializes the ConvBatchNormMatcher for fusing Conv and BatchNormalization layers in an ONNX graph.""" pattern = Pattern( diff --git a/onnxslim/core/pattern/fusion/gelu.py b/onnxslim/core/pattern/fusion/gelu.py index 5efc87e..2157941 100644 --- a/onnxslim/core/pattern/fusion/gelu.py +++ b/onnxslim/core/pattern/fusion/gelu.py @@ -2,6 +2,8 @@ class GeluPatternMatcher(PatternMatcher): + """Matches and fuses GELU patterns in computational graphs for optimization purposes.""" + def __init__(self, priority): """Initializes a `GeluPatternMatcher` to identify and fuse GELU patterns in a computational graph.""" pattern = Pattern( diff --git a/onnxslim/core/pattern/fusion/gemm.py b/onnxslim/core/pattern/fusion/gemm.py index 9a80912..f6eec15 100644 --- a/onnxslim/core/pattern/fusion/gemm.py +++ b/onnxslim/core/pattern/fusion/gemm.py @@ -5,6 +5,8 @@ class MatMulAddPatternMatcher(PatternMatcher): + """Matches and fuses MatMul and Add operations in ONNX graphs to optimize computational efficiency.""" + def __init__(self, priority): """Initializes a matcher for fusing MatMul and Add operations in ONNX graph optimization.""" pattern = Pattern( diff --git a/onnxslim/core/pattern/fusion/padconv.py b/onnxslim/core/pattern/fusion/padconv.py index ef304cc..344a768 100644 --- a/onnxslim/core/pattern/fusion/padconv.py +++ b/onnxslim/core/pattern/fusion/padconv.py @@ -4,6 +4,8 @@ class PadConvMatcher(PatternMatcher): + """Matches and optimizes Pad-Conv patterns in ONNX graphs by ensuring padding parameters are constants.""" + def __init__(self, priority): """Initializes the PadConvMatcher with a specified priority and defines its matching pattern.""" pattern = Pattern( diff --git a/onnxslim/core/pattern/fusion/reduce.py b/onnxslim/core/pattern/fusion/reduce.py index 29f31d0..9b4d606 100644 --- a/onnxslim/core/pattern/fusion/reduce.py +++ b/onnxslim/core/pattern/fusion/reduce.py @@ -3,6 +3,8 @@ class ReducePatternMatcher(PatternMatcher): + """Optimizes ONNX graph patterns with ReduceSum and Unsqueeze operations for improved model performance.""" + def __init__(self, priority): """Initializes the ReducePatternMatcher with a specified pattern matching priority level.""" pattern = Pattern( diff --git a/onnxslim/misc/tabulate.py b/onnxslim/misc/tabulate.py index 55aa9fc..514f385 100644 --- a/onnxslim/misc/tabulate.py +++ b/onnxslim/misc/tabulate.py @@ -222,7 +222,7 @@ def make_header_line(is_header, colwidths, colaligns): alignment = {"left": "<", "right": ">", "center": "^", "decimal": ">"} # use the column widths generated by tabulate for the asciidoc column width specifiers asciidoc_alignments = zip(colwidths, [alignment[colalign] for colalign in colaligns]) - asciidoc_column_specifiers = ["{:d}{}".format(width, align) for width, align in asciidoc_alignments] + asciidoc_column_specifiers = [f"{width:d}{align}" for width, align in asciidoc_alignments] header_list = ['cols="' + (",".join(asciidoc_column_specifiers)) + '"'] # generate the list of options (currently only "header") @@ -2484,7 +2484,7 @@ def _wrap_chunks(self, chunks): """ lines = [] if self.width <= 0: - raise ValueError("invalid width %r (must be > 0)" % self.width) + raise ValueError(f"invalid width {self.width!r} (must be > 0)") if self.max_lines is not None: indent = self.subsequent_indent if self.max_lines > 1 else self.initial_indent if self._len(indent) + self._len(self.placeholder.lstrip()) > self.width: diff --git a/onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py b/onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py index 17b7563..7b75664 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +++ b/onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py @@ -18,7 +18,9 @@ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph -class BaseExporter(object): +class BaseExporter: + """BaseExporter provides a static method to export ONNX graphs to a specified destination format.""" + @staticmethod def export_graph(graph: Graph): """ diff --git a/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py b/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py index 20c4ffb..3591a71 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +++ b/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py @@ -52,12 +52,7 @@ def check_duplicate_node_names(nodes: Sequence[Node], level=G_LOGGER.WARNING): if not node.name: continue if node.name in name_map: - msg = "Found distinct Nodes that share the same name:\n[id: {:}]:\n {:}---\n[id: {:}]:\n {:}\n".format( - id(name_map[node.name]), - name_map[node.name], - id(node), - node, - ) + msg = f"Found distinct Nodes that share the same name:\n[id: {id(name_map[node.name])}]:\n {name_map[node.name]}---\n[id: {id(node)}]:\n {node}\n" G_LOGGER.log(msg, level) else: name_map[node.name] = node @@ -110,6 +105,8 @@ def np_float32_to_bf16_as_uint16(arr): class OnnxExporter(BaseExporter): + """Exports internal graph structures to ONNX format for model interoperability.""" + @staticmethod def export_tensor_proto(tensor: Constant) -> onnx.TensorProto: # Do *not* load LazyValues into an intermediate numpy array - instead, use @@ -146,9 +143,7 @@ def export_value_info_proto(tensor: Tensor, do_type_check: bool) -> onnx.ValueIn """Creates an ONNX ValueInfoProto from a Tensor, optionally checking for dtype information.""" if do_type_check and tensor.dtype is None: G_LOGGER.critical( - "Graph input and output tensors must include dtype information. Please set the dtype attribute for: {:}".format( - tensor - ) + f"Graph input and output tensors must include dtype information. Please set the dtype attribute for: {tensor}" ) if tensor.dtype is None: diff --git a/onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py b/onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py index 32e11ec..414b53c 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +++ b/onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py @@ -104,7 +104,7 @@ def __init__(self) -> None: self.op = None # op (str) self.check_func = None # callback function for single node # pattern node name -> GraphPattern nodes(single or subpattern) - self.nodes: Dict[str, "GraphPattern"] = {} + self.nodes: Dict[str, GraphPattern] = {} # pattern node name -> input tensors self.node_inputs: Dict[str, List[int]] = {} # pattern node name -> output tensors @@ -119,6 +119,7 @@ def __init__(self) -> None: """Assigns a unique tensor ID, tracks its input node if provided, and initializes output node tracking.""" def _add_tensor(self, input_node=None) -> int: + """Assigns a unique tensor ID, tracks its input node if provided, and initializes output node tracking.""" tensor_id = self.num_tensors self.tensor_inputs[tensor_id] = [] if input_node is not None: @@ -239,15 +240,13 @@ def _single_node_match(self, onnx_node: Node) -> bool: with G_LOGGER.indent(): if self.op != onnx_node.op: G_LOGGER.info( - "No match because: Op did not match. Node op was: {:} but pattern op was: {:}.".format( - onnx_node.op, self.op - ) + f"No match because: Op did not match. Node op was: {onnx_node.op} but pattern op was: {self.op}." ) return False if self.check_func is not None and not self.check_func(onnx_node): G_LOGGER.info("No match because: check_func returned false.") return False - G_LOGGER.info("Single node is matched: {:}, {:}".format(self.op, onnx_node.name)) + G_LOGGER.info(f"Single node is matched: {self.op}, {onnx_node.name}") return True def _get_tensor_index_for_node(self, node: str, tensor_id: int, is_node_input: bool): @@ -330,7 +329,7 @@ def _match_node( ) -> bool: """Matches ONNX nodes to the graph pattern starting from a specific node and tensor context.""" with G_LOGGER.indent(): - G_LOGGER.info("Checking node: {:} against pattern node: {:}.".format(onnx_node.name, node_name)) + G_LOGGER.info(f"Checking node: {onnx_node.name} against pattern node: {node_name}.") tensor_index_for_node = self._get_tensor_index_for_node(node_name, from_tensor, is_node_input=from_inbound) subgraph_mapping = self.nodes[node_name].match( onnx_node, diff --git a/onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py b/onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py index 50279ac..2b9b59c 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +++ b/onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py @@ -18,7 +18,9 @@ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph -class BaseImporter(object): +class BaseImporter: + """BaseImporter provides functionality to import and convert source graphs into onnx-graphsurgeon Graph objects.""" + @staticmethod def import_graph(graph) -> Graph: """ diff --git a/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py b/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py index 9cbb15e..3d634ac 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +++ b/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py @@ -189,6 +189,8 @@ def get_onnx_tensor_type(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProt class OnnxImporter(BaseImporter): + """Imports ONNX models, functions, and tensors into internal representations for further processing.""" + @staticmethod def get_opset(model_or_func: Union[onnx.ModelProto, onnx.FunctionProto]): """Return the ONNX opset version for the given ONNX model or function, or None if the information is @@ -277,14 +279,10 @@ def process_attr(attr_str: str): if attr_str in ONNX_PYTHON_ATTR_MAPPING: attr_dict[attr.name] = process_attr(attr_str) else: - G_LOGGER.warning( - "Attribute of type {:} is currently unsupported. Skipping attribute.".format(attr_str) - ) + G_LOGGER.warning(f"Attribute of type {attr_str} is currently unsupported. Skipping attribute.") else: G_LOGGER.warning( - "Attribute type: {:} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute.".format( - attr.type - ) + f"Attribute type: {attr.type} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute." ) return attr_dict @@ -315,9 +313,7 @@ def get_tensor(name: str, check_outer_graph=True): return Variable.empty() G_LOGGER.verbose( - "Tensor: {:} was not generated during shape inference, or shape inference was not run on this model. Creating a new Tensor.".format( - name - ) + f"Tensor: {name} was not generated during shape inference, or shape inference was not run on this model. Creating a new Tensor." ) subgraph_tensor_map[name] = Variable(name) return subgraph_tensor_map[name] diff --git a/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py b/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py index 355a031..68ca230 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +++ b/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py @@ -28,7 +28,9 @@ from onnxslim.third_party.onnx_graphsurgeon.util import misc -class NodeIDAdder(object): +class NodeIDAdder: + """Assigns unique IDs to graph nodes on entry and removes them on exit for context management.""" + def __init__(self, graph): """Initializes NodeIDAdder with a specified graph.""" self.graph = graph @@ -45,7 +47,7 @@ def __exit__(self, exc_type, exc_value, traceback): del node.id -class Graph(object): +class Graph: """Represents a graph containing nodes and tensors.""" DEFAULT_OPSET = 11 @@ -83,8 +85,8 @@ def register_func(func): """ if hasattr(Graph, func.__name__): G_LOGGER.warning( - "Registered function: {:} is hidden by a Graph attribute or function with the same name. " - "This function will never be called!".format(func.__name__) + f"Registered function: {func.__name__} is hidden by a Graph attribute or function with the same name. " + "This function will never be called!" ) # Default behavior is to register functions for all opsets. @@ -143,7 +145,7 @@ def __init__( self._merge_subgraph_functions() # Printing graphs can be very expensive - G_LOGGER.ultra_verbose(lambda: "Created Graph: {:}".format(self)) + G_LOGGER.ultra_verbose(lambda: f"Created Graph: {self}") def __getattr__(self, name): """Dynamically handles attribute access, falling back to superclass attribute retrieval if not found.""" @@ -248,8 +250,8 @@ def _get_node_id(self, node): return node.id except AttributeError: G_LOGGER.critical( - "Encountered a node not in the graph:\n{:}.\n\n" - "To fix this, please append the node to this graph's `nodes` attribute.".format(node) + f"Encountered a node not in the graph:\n{node}.\n\n" + "To fix this, please append the node to this graph's `nodes` attribute." ) # A tensor is local if it is produced in this graph, or is explicitly a graph input. @@ -290,7 +292,7 @@ def _get_used_node_ids(self): """Returns a dictionary of tensors that are used by node IDs in the current subgraph.""" local_tensors = self._local_tensors() - class IgnoreDupAndForeign(object): + class IgnoreDupAndForeign: def __init__(self, initial_tensors=None): """Initialize IgnoreDupAndForeign with an optional list of initial tensors.""" tensors = misc.default_value(initial_tensors, []) @@ -421,7 +423,7 @@ def cleanup_subgraphs(): recurse_functions=False, # No infinite recursion ) - G_LOGGER.verbose("Cleaning up {:}".format(self.name)) + G_LOGGER.verbose(f"Cleaning up {self.name}") with self.node_ids(): # Graph input producers must be removed first so used_node_ids is correct. @@ -435,7 +437,7 @@ def cleanup_subgraphs(): if inp in used_tensors or not remove_unused_graph_inputs: inputs.append(inp) else: - G_LOGGER.ultra_verbose("Removing unused input: {:}".format(inp)) + G_LOGGER.ultra_verbose(f"Removing unused input: {inp}") self.inputs = inputs nodes = [] @@ -446,7 +448,7 @@ def cleanup_subgraphs(): else: node.inputs.clear() node.outputs.clear() - G_LOGGER.ultra_verbose("Removing unused node: {:}".format(node)) + G_LOGGER.ultra_verbose(f"Removing unused node: {node}") # Remove any hanging tensors - tensors without outputs if remove_unused_node_outputs: @@ -514,11 +516,11 @@ def toposort( for subgraph in self.subgraphs(): subgraph.toposort(recurse_subgraphs=True, recurse_functions=False, mode="nodes") - G_LOGGER.debug("Topologically sorting {:}".format(self.name)) + G_LOGGER.debug(f"Topologically sorting {self.name}") # Keeps track of a node and its level in the graph hierarchy. # 0 corresponds to an input node, N corresponds to a node with N layers of inputs. - class HierarchyDescriptor(object): + class HierarchyDescriptor: def __init__(self, node_or_func, level=None): """Initializes a HierarchyDescriptor with a node or function and an optional level in the graph hierarchy. @@ -638,18 +640,8 @@ def add_to_tensor_map(tensor): """Add a tensor to the tensor_map if it is not empty and ensure no duplicate tensor names exist.""" if not tensor.is_empty(): if tensor.name in tensor_map and tensor_map[tensor.name] is not tensor: - msg = "Found distinct tensors that share the same name:\n[id: {:}] {:}\n[id: {:}] {:}\n".format( - id(tensor_map[tensor.name]), - tensor_map[tensor.name], - id(tensor), - tensor, - ) - msg += ( - "Note: Producer node(s) of first tensor:\n{:}\nProducer node(s) of second tensor:\n{:}".format( - tensor_map[tensor.name].inputs, - tensor.inputs, - ) - ) + msg = f"Found distinct tensors that share the same name:\n[id: {id(tensor_map[tensor.name])}] {tensor_map[tensor.name]}\n[id: {id(tensor)}] {tensor}\n" + msg += f"Note: Producer node(s) of first tensor:\n{tensor_map[tensor.name].inputs}\nProducer node(s) of second tensor:\n{tensor.inputs}" if check_duplicates: G_LOGGER.critical(msg) @@ -756,10 +748,10 @@ def should_exclude_node(node): PARTITIONING_MODES = [None, "basic", "recursive"] if partitioning not in PARTITIONING_MODES: - G_LOGGER.critical("Argument for parameter 'partitioning' must be one of: {:}".format(PARTITIONING_MODES)) + G_LOGGER.critical(f"Argument for parameter 'partitioning' must be one of: {PARTITIONING_MODES}") ORT_PROVIDERS = ["CPUExecutionProvider"] - G_LOGGER.debug("Folding constants in {:}".format(self.name)) + G_LOGGER.debug(f"Folding constants in {self.name}") # We apply constant folding in 5 passes: # Pass 1 lowers 'Constant' nodes into Constant tensors. @@ -892,7 +884,7 @@ def run_cast_elision(node): if fold_shapes: # Perform shape tensor cast elision prior to most other folding - G_LOGGER.debug("Performing shape tensor cast elision in {:}".format(self.name)) + G_LOGGER.debug(f"Performing shape tensor cast elision in {self.name}") try: with self.node_ids(): for node in self.nodes: @@ -1077,13 +1069,13 @@ def fold_shape_slice(tensor): shape_of = shape_fold_func(tensor) if shape_of is not None: - G_LOGGER.ultra_verbose("Folding shape tensor: {:} to: {:}".format(tensor.name, shape_of)) + G_LOGGER.ultra_verbose(f"Folding shape tensor: {tensor.name} to: {shape_of}") graph_constants[tensor.name] = tensor.to_constant(shape_of) graph_constants[tensor.name].inputs.clear() except Exception as err: if not error_ok: raise err - G_LOGGER.warning("'{:}' routine failed with:\n{:}".format(shape_fold_func.__name__, err)) + G_LOGGER.warning(f"'{shape_fold_func.__name__}' routine failed with:\n{err}") else: graph_constants = update_foldable_outputs(graph_constants) @@ -1112,7 +1104,7 @@ def get_out_node_ids(): part = subgraph.copy() out_node = part.nodes[index] part.outputs = out_node.outputs - part.name = "Folding: {:}".format([out.name for out in part.outputs]) + part.name = f"Folding: {[out.name for out in part.outputs]}" part.cleanup(remove_unused_graph_inputs=True) names = [out.name for out in part.outputs] @@ -1126,7 +1118,7 @@ def get_out_node_ids(): ) values = sess.run(names, {}) except Exception as err: - G_LOGGER.warning("Inference failed for subgraph: {:}. Note: Error was:\n{:}".format(part.name, err)) + G_LOGGER.warning(f"Inference failed for subgraph: {part.name}. Note: Error was:\n{err}") if partitioning == "recursive": G_LOGGER.verbose("Attempting to recursively partition subgraph") # Partition failed, peel off last node. @@ -1168,7 +1160,7 @@ def should_eval_foldable(tensor): return non_const and (is_graph_output or has_non_foldable_outputs) and not exceeds_size_threshold graph_clone.outputs = [t for t in graph_constants.values() if should_eval_foldable(t)] - G_LOGGER.debug("Folding tensors: {:}".format(graph_clone.outputs)) + G_LOGGER.debug(f"Folding tensors: {graph_clone.outputs}") graph_clone.cleanup(remove_unused_graph_inputs=True, recurse_functions=False) # Using ._values avoids a deep copy of the values. @@ -1214,16 +1206,16 @@ def should_eval_foldable(tensor): except Exception as err: G_LOGGER.warning( "Inference failed. You may want to try enabling partitioning to see better results. " - "Note: Error was:\n{:}".format(err) + f"Note: Error was:\n{err}" ) - G_LOGGER.verbose("Note: Graph was:\n{:}".format(graph_clone)) + G_LOGGER.verbose(f"Note: Graph was:\n{graph_clone}") if not error_ok: raise elif not constant_values: G_LOGGER.debug( - "Could not find any nodes in this graph ({:}) that can be folded. " + f"Could not find any nodes in this graph ({self.name}) that can be folded. " "This could mean that constant folding has already been run on this graph. " - "Skipping.".format(self.name) + "Skipping." ) # Finally, replace the Variables in the original graph with constants. @@ -1238,9 +1230,7 @@ def should_eval_foldable(tensor): if size_threshold is not None and values.nbytes > size_threshold: G_LOGGER.debug( - "Will not fold: '{:}' since its size in bytes ({:}) exceeds the size threshold ({:})".format( - name, values.nbytes, size_threshold - ) + f"Will not fold: '{name}' since its size in bytes ({values.nbytes}) exceeds the size threshold ({size_threshold})" ) continue elif size_threshold is None and values.nbytes > (1 << 20): @@ -1251,12 +1241,12 @@ def should_eval_foldable(tensor): if large_tensors: large_tensors_mib = { - tensor_name: "{:} MiB".format(value // (1 << 20)) for tensor_name, value in large_tensors.items() + tensor_name: f"{value // (1 << 20)} MiB" for tensor_name, value in large_tensors.items() } G_LOGGER.warning( "It looks like this model contains foldable nodes that produce large outputs.\n" "In order to avoid bloating the model, you may want to set a constant-folding size threshold.\n" - "Note: Large tensors and their corresponding sizes were: {:}".format(large_tensors_mib), + f"Note: Large tensors and their corresponding sizes were: {large_tensors_mib}", mode=LogMode.ONCE, ) @@ -1283,12 +1273,12 @@ def fold_subgraphs(): while index < len(self.nodes): node = self.nodes[index] if node.op == "If" and isinstance(node.inputs[0], Constant): - G_LOGGER.debug("Flattening conditional: {:}".format(node.name)) + G_LOGGER.debug(f"Flattening conditional: {node.name}") cond = get_scalar_value(node.inputs[0]) subgraph = node.attrs["then_branch"] if cond else node.attrs["else_branch"] # Need to add a suffix to subgraph tensors so they don't collide with outer graph tensors for tensor in subgraph._local_tensors().values(): - tensor.name += "_subg_{:}_{:}".format(index, subgraph.name) + tensor.name += f"_subg_{index}_{subgraph.name}" # The subgraph outputs correspond to the If node outputs. Only the latter are visible # in the parent graph, so we rebind the producer nodes of the subgraph outputs to point @@ -1397,9 +1387,9 @@ def process_io(io, existing_names): new_io.append(Constant(name=name, values=arr)) else: G_LOGGER.critical( - "Unrecognized type passed to Graph.layer: {:}.\n" + f"Unrecognized type passed to Graph.layer: {elem}.\n" "\tHint: Did you forget to unpack a list with `*`?\n" - "\tPlease use Tensors, strings, or NumPy arrays.".format(elem) + "\tPlease use Tensors, strings, or NumPy arrays." ) if new_io[-1].name: existing_names.add(new_io[-1].name) diff --git a/onnxslim/third_party/onnx_graphsurgeon/ir/node.py b/onnxslim/third_party/onnx_graphsurgeon/ir/node.py index 88437ed..4fc7419 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/ir/node.py +++ b/onnxslim/third_party/onnx_graphsurgeon/ir/node.py @@ -24,7 +24,9 @@ from onnxslim.third_party.onnx_graphsurgeon.util import misc -class Node(object): +class Node: + """Represents an operation node in a computational graph, managing inputs, outputs, and attributes.""" + @dataclass class AttributeRef: """ @@ -177,24 +179,24 @@ def copy( def __str__(self): """Return a string representation of the object showing its name and operation.""" - ret = "{:} ({:})".format(self.name, self.op) + ret = f"{self.name} ({self.op})" def add_io(name, io): """Add the input or output operations and their names to the string representation of the object.""" nonlocal ret - ret += "\n\t{:}: [".format(name) + ret += f"\n\t{name}: [" for elem in io: - ret += "\n\t\t{:}".format(elem) + ret += f"\n\t\t{elem}" ret += "\n\t]" add_io("Inputs", self.inputs) add_io("Outputs", self.outputs) if self.attrs: - ret += "\nAttributes: {:}".format(self.attrs) + ret += f"\nAttributes: {self.attrs}" if self.domain: - ret += "\nDomain: {:}".format(self.domain) + ret += f"\nDomain: {self.domain}" return ret @@ -204,7 +206,7 @@ def __repr__(self): def __eq__(self, other): """Check whether two nodes are equal by comparing name, attributes, op, inputs, and outputs.""" - G_LOGGER.verbose("Comparing node: {:} with {:}".format(self.name, other.name)) + G_LOGGER.verbose(f"Comparing node: {self.name} with {other.name}") attrs_match = self.name == other.name and self.op == other.op and self.attrs == other.attrs if not attrs_match: return False diff --git a/onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py b/onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py index d68270a..51b3477 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +++ b/onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py @@ -23,7 +23,7 @@ from onnxslim.third_party.onnx_graphsurgeon.util import misc -class Tensor(object): +class Tensor: """Abstract base class for tensors in a graph.""" DYNAMIC = -1 @@ -155,7 +155,7 @@ def o(self, consumer_idx=0, tensor_idx=0): def __str__(self): """Returns a string representation of the object including its type, name, shape, and data type.""" - return "{:} ({:}): (shape={:}, dtype={:})".format(type(self).__name__, self.name, self.shape, self.dtype) + return f"{type(self).__name__} ({self.name}): (shape={self.shape}, dtype={self.dtype})" def __repr__(self): # Hack to make logging output pretty. """Returns a string representation of the object for logging output.""" @@ -191,6 +191,8 @@ def is_output(self, is_output: bool = False): class Variable(Tensor): + """Represents a tensor with unknown values until inference-time, supporting dynamic shapes and data types.""" + @staticmethod def empty(): """Create and return an empty Variable tensor with an empty name.""" @@ -258,7 +260,7 @@ def __eq__(self, other): return name_match and inputs_match and outputs_match and dtype_match and shape_match and type_match -class LazyValues(object): +class LazyValues: """A special object that represents constant tensor values that should be lazily loaded.""" def __init__(self, tensor): @@ -304,7 +306,7 @@ def load(self): def __str__(self): """Returns a formatted string representation of the LazyValues object indicating its shape and dtype.""" - return "LazyValues (shape={:}, dtype={:})".format(self.shape, self.dtype) + return f"LazyValues (shape={self.shape}, dtype={self.dtype})" def __repr__(self): # Hack to make logging output pretty. """Returns an unambiguous string representation of the LazyValues object for logging purposes.""" @@ -369,10 +371,12 @@ def load(self): def __str__(self): """Return a string representation of the SparseValues object with its shape and data type.""" - return "SparseValues (shape={:}, dtype={:})".format(self.shape, self.dtype) + return f"SparseValues (shape={self.shape}, dtype={self.dtype})" class Constant(Tensor): + """Represents a tensor with known constant values, supporting lazy loading and export data type specification.""" + def __init__( self, name: str, @@ -407,7 +411,7 @@ def __init__( G_LOGGER.critical( "Provided `values` argument is not a NumPy array, a LazyValues instance or a" "SparseValues instance. Please provide a NumPy array or LazyValues instance " - "to construct a Constant. Note: Provided `values` parameter was: {:}".format(values) + f"to construct a Constant. Note: Provided `values` parameter was: {values}" ) self._values = values self.data_location = data_location @@ -470,7 +474,7 @@ def export_dtype(self, export_dtype): def __repr__(self): # Hack to make logging output pretty. """Return a string representation of the object, including its values, for improved logging readability.""" ret = self.__str__() - ret += "\n{:}".format(self._values) + ret += f"\n{self._values}" return ret def __eq__(self, other): diff --git a/onnxslim/third_party/onnx_graphsurgeon/logger/logger.py b/onnxslim/third_party/onnx_graphsurgeon/logger/logger.py index 03d0d62..f4f735b 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +++ b/onnxslim/third_party/onnx_graphsurgeon/logger/logger.py @@ -27,7 +27,9 @@ # Context manager to apply indentation to messages -class LoggerIndent(object): +class LoggerIndent: + """Context manager for temporarily setting indentation levels in logger messages.""" + def __init__(self, logger, indent): """Initialize the LoggerIndent context manager with the specified logger and indentation level.""" self.logger = logger @@ -45,7 +47,9 @@ def __exit__(self, exc_type, exc_value, traceback): # Context manager to suppress messages -class LoggerSuppress(object): +class LoggerSuppress: + """Suppress logger messages below a specified severity level within a context.""" + def __init__(self, logger, severity): """Initialize a LoggerSuppress object with a logger and severity level.""" self.logger = logger @@ -63,11 +67,15 @@ def __exit__(self, exc_type, exc_value, traceback): class LogMode(enum.IntEnum): + """Enumerates logging modes for controlling message frequency in the Ultralytics library.""" + EACH = 0 # Log the message each time ONCE = 1 # Log the message only once. The same message will not be logged again. -class Logger(object): +class Logger: + """Manages logging with configurable severity, indentation, and formatting for debugging and monitoring.""" + ULTRA_VERBOSE = -10 VERBOSE = 0 DEBUG = 10 @@ -173,7 +181,7 @@ def get_line_info(): # If the file is not located in trt_smeagol, use its basename instead. if os.pardir in filename: filename = os.path.basename(filename) - return "[{:}:{:}] ".format(filename, sys._getframe(stack_depth).f_lineno) + return f"[{filename}:{sys._getframe(stack_depth).f_lineno}] " prefix = "" if self.letter: @@ -207,7 +215,7 @@ def apply_color(message): prefix = get_prefix() message = apply_indentation(message) - return apply_color("{:}{:}".format(prefix, message)) + return apply_color(f"{prefix}{message}") def should_log(message): """Determines if a message should be logged based on the severity level and logging mode.""" diff --git a/onnxslim/third_party/onnx_graphsurgeon/util/misc.py b/onnxslim/third_party/onnx_graphsurgeon/util/misc.py index 1ac98cd..af03b45 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/util/misc.py +++ b/onnxslim/third_party/onnx_graphsurgeon/util/misc.py @@ -158,6 +158,8 @@ def convert_to_onnx_attr_type(any_type): # So, in the example above, we can make n.inputs a synchronized list whose field_name is set to "outputs". # See test_ir.TestNodeIO for functional tests class SynchronizedList(list): + """Synchronizes list operations with a specified attribute of elements to maintain bidirectional consistency.""" + def __init__(self, parent_obj, field_name, initial): """Initialize a SynchronizedList with a parent object, a field name, and an initial set of elements.""" self.parent_obj = parent_obj diff --git a/onnxslim/third_party/symbolic_shape_infer.py b/onnxslim/third_party/symbolic_shape_infer.py index aff28c7..6aace36 100644 --- a/onnxslim/third_party/symbolic_shape_infer.py +++ b/onnxslim/third_party/symbolic_shape_infer.py @@ -138,6 +138,8 @@ def sympy_reduce_product(x): class SymbolicShapeInference: + """Performs symbolic shape inference on ONNX models to deduce tensor shapes using symbolic computation.""" + def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): """Initializes the SymbolicShapeInference class with configuration parameters for symbolic shape inference.""" self.dispatcher_ = { @@ -2397,6 +2399,7 @@ def _infer_PackedMultiHeadAttention(self, node): # noqa: N802 vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_MultiScaleDeformableAttnTRT(self, node): + """Infers output shape and type for MultiScaleDeformableAttnTRT node using input shapes.""" shape_value = self._try_get_shape(node, 0) sampling_locations = self._try_get_shape(node, 3) output_shape = shape_value @@ -2835,8 +2838,8 @@ def get_prereq(node): ): sorted_known_vi.update(node.output) sorted_nodes.append(node) - if old_sorted_nodes_len == len(sorted_nodes) and not all( - o.name in sorted_known_vi for o in self.out_mp_.graph.output + if old_sorted_nodes_len == len(sorted_nodes) and any( + o.name not in sorted_known_vi for o in self.out_mp_.graph.output ): raise Exception("Invalid model with cyclic graph") @@ -2936,11 +2939,7 @@ def get_prereq(node): out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED if self.verbose_ > 2: logger.debug( - " {}: {} {}".format( - node.output[i_o], - str(out_shape), - onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type), - ) + f" {node.output[i_o]}: {str(out_shape)} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}" ) if node.output[i_o] in self.sympy_data_: logger.debug(" Sympy Data: " + str(self.sympy_data_[node.output[i_o]])) @@ -3043,17 +3042,11 @@ def get_prereq(node): if self.verbose_ > 0: if is_unknown_op: logger.debug( - "Possible unknown op: {} node: {}, guessing {} shape".format( - node.op_type, node.name, vi.name - ) + f"Possible unknown op: {node.op_type} node: {node.name}, guessing {vi.name} shape" ) if self.verbose_ > 2: logger.debug( - " {}: {} {}".format( - node.output[i_o], - str(new_shape), - vi.type.tensor_type.elem_type, - ) + f" {node.output[i_o]}: {str(new_shape)} {vi.type.tensor_type.elem_type}" ) self.run_ = True continue # continue the inference after guess, no need to stop as no merge is needed diff --git a/onnxslim/utils.py b/onnxslim/utils.py index ffcdead..b02ff60 100644 --- a/onnxslim/utils.py +++ b/onnxslim/utils.py @@ -56,7 +56,7 @@ def format_bytes(size: Union[int, Tuple[int, ...]]) -> str: size_in_bytes /= 1024 unit_index += 1 - formatted_size = "{:.2f} {}".format(size_in_bytes, units[unit_index]) + formatted_size = f"{size_in_bytes:.2f} {units[unit_index]}" formatted_sizes.append(formatted_size) if len(formatted_sizes) == 1: @@ -579,6 +579,7 @@ def check_onnx_compatibility(): def get_max_tensor(model, topk=5): + """Identify and print the top-k largest constant tensors in an ONNX model based on their size.""" graph = gs.import_onnx(model) tensor_map = graph.tensors() diff --git a/setup.py b/setup.py index b760e5f..43f0454 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import find_packages, setup -with open("VERSION", "r") as f: +with open("VERSION") as f: version = f.read().strip() with open("onnxslim/version.py", "w") as f: @@ -10,7 +10,7 @@ name="onnxslim", version=version, description="OnnxSlim: A Toolkit to Help Optimize Large Onnx Model", - long_description=open("README.md", "r", encoding="utf-8").read(), + long_description=open("README.md", encoding="utf-8").read(), long_description_content_type="text/markdown", url="https://github.com/inisis/OnnxSlim", author="inisis", diff --git a/tests/test_modelzoo.py b/tests/test_modelzoo.py index f5d5933..c479f15 100644 --- a/tests/test_modelzoo.py +++ b/tests/test_modelzoo.py @@ -12,7 +12,10 @@ class TestModelZoo: + """Tests ONNX models from the model zoo using slimming techniques for validation.""" + def test_silero_vad(self, request): + """Test the Silero VAD model by slimming its ONNX file and running inference with dummy input data.""" name = request.node.originalname[len("test_") :] filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" @@ -27,6 +30,7 @@ def test_silero_vad(self, request): ort_sess.run(None, {"input": input, "sr": sr, "state": state}) def test_decoder_with_past_model(self, request): + """Test the ONNX model decoder with past states using a slimmed model and validate inference execution.""" name = request.node.originalname[len("test_") :] filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" @@ -40,6 +44,7 @@ def test_decoder_with_past_model(self, request): ort_sess.run(None, {"input_ids": input_ids, "encoder_hidden_states": encoder_hidden_states}) def test_tiny_en_decoder(self, request): + """Tests the functionality of a slimmed tiny English encoder-decoder model using ONNX Runtime for inference.""" name = request.node.originalname[len("test_") :] filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" @@ -47,6 +52,7 @@ def test_tiny_en_decoder(self, request): slim(filename, os.path.join(tempdir, f"{name}_slim.onnx"), model_check=True) def test_transformer_encoder(self, request): + """Tests the transformer encoder model from the model zoo by verifying the operation count after slimming.""" name = request.node.originalname[len("test_") :] filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" summary = summarize_model(slim(filename)) @@ -55,6 +61,7 @@ def test_transformer_encoder(self, request): assert summary["op_type_counts"]["Div"] == 53 def test_uiex(self, request): + """Summarize the UIEX model and verify absence of 'Range' and 'Floor' operators.""" name = request.node.originalname[len("test_") :] filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" summary = summarize_model(slim(filename)) diff --git a/tests/test_onnx_nets.py b/tests/test_onnx_nets.py index 9daf160..907ae10 100644 --- a/tests/test_onnx_nets.py +++ b/tests/test_onnx_nets.py @@ -24,6 +24,7 @@ class TestTorchVisionClass: models.googlenet, ), ) + """Tests TorchVision models by exporting them to ONNX format and verifying the process with random input tensors.""" def test_torchvision(self, request, model, shape=(1, 3, 224, 224)): """Test various TorchVision models with random input tensors of a specified shape.""" model = model(pretrained=PRETRAINED) @@ -47,6 +48,7 @@ def test_torchvision(self, request, model, shape=(1, 3, 224, 224)): class TestTimmClass: + """Tests TIMM models for successful ONNX export and slimming using random input tensors.""" @pytest.fixture(params=timm.list_models()) def model_name(self, request): """Yields names of models available in TIMM (https://github.com/rwightman/pytorch-image-models) for pytest fixture parameterization.""" diff --git a/tests/test_onnxslim.py b/tests/test_onnxslim.py index 8af9738..3d977cd 100644 --- a/tests/test_onnxslim.py +++ b/tests/test_onnxslim.py @@ -12,6 +12,8 @@ class TestFunctional: + """Tests the functionality of the 'slim' function for optimizing ONNX models using temporary directories.""" + def test_basic(self, request): """Test the basic functionality of the slim function.""" with tempfile.TemporaryDirectory() as tempdir: @@ -30,6 +32,8 @@ def test_basic(self, request): class TestFeature: + """Tests ONNX model modifications like input shape, precision conversion, and input/output adjustments.""" + def test_input_shape_modification(self, request): """Test the modification of input shapes.""" summary = summarize_model(slim(FILENAME, input_shapes=["input:1,3,224,224"])) @@ -76,7 +80,9 @@ def test_output_modification(self, request): def test_input_modification(self, request): """Tests input modification.""" - summary = summarize_model(slim(FILENAME, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"])) + summary = summarize_model( + slim(FILENAME, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]) + ) print_model_info_as_table(request.node.name, summary) assert "/maxpool/MaxPool_output_0" in summary["op_input_info"] assert "/layer1/layer1.0/relu/Relu_output_0" in summary["op_input_info"] diff --git a/tests/test_pattern_generator.py b/tests/test_pattern_generator.py index ae0aca9..a3c9f63 100644 --- a/tests/test_pattern_generator.py +++ b/tests/test_pattern_generator.py @@ -10,12 +10,14 @@ class TestPatternGenerator: + """Generates and tests ONNX fusion patterns for neural network models using the GELU activation function.""" + def test_gelu(self, request): """Test the GELU activation function within the PatternModel class.""" class PatternModel(nn.Module): def __init__(self): - super(PatternModel, self).__init__() + super().__init__() self.gelu = nn.GELU() def forward(self, x): @@ -26,7 +28,7 @@ def forward(self, x): class Model(nn.Module): def __init__(self): """Initializes the Model class with ReLU and PatternModel components.""" - super(Model, self).__init__() + super().__init__() self.relu0 = nn.ReLU() self.pattern = PatternModel() self.relu1 = nn.ReLU() diff --git a/tests/test_pattern_matcher.py b/tests/test_pattern_matcher.py index 70dc2bf..dd22fa2 100644 --- a/tests/test_pattern_matcher.py +++ b/tests/test_pattern_matcher.py @@ -9,12 +9,14 @@ class TestPatternMatcher: + """Tests various neural network operations by exporting PyTorch models to ONNX and analyzing them with onnxslim.""" + def test_gelu(self, request): """Test the GELU activation function in a neural network model using an instance of nn.Module.""" class Model(nn.Module): def __init__(self): - super(Model, self).__init__() + super().__init__() self.relu0 = nn.ReLU() self.gelu = nn.GELU() self.relu1 = nn.ReLU() @@ -44,7 +46,7 @@ def test_pad_conv(self, request): class Model(nn.Module): def __init__(self): - super(Model, self).__init__() + super().__init__() self.pad_0 = nn.ConstantPad2d(3, 0) self.conv_0 = nn.Conv2d(1, 1, 3) @@ -80,7 +82,7 @@ def test_conv_bn(self, request): class Model(nn.Module): def __init__(self): - super(Model, self).__init__() + super().__init__() self.conv = nn.Conv2d(1, 1, 3) self.bn = nn.BatchNorm2d(1) @@ -109,7 +111,7 @@ def test_consecutive_slice(self, request): class Model(nn.Module): def __init__(self): - super(Model, self).__init__() + super().__init__() self.conv = nn.Conv2d(1, 1, 3) self.bn = nn.BatchNorm2d(1) @@ -134,7 +136,7 @@ def test_consecutive_reshape(self, request): class Model(nn.Module): def __init__(self): - super(Model, self).__init__() + super().__init__() def forward(self, x): """Reshape tensor sequentially to (2, 6) and then to (12, 1).""" @@ -157,7 +159,7 @@ def test_matmul_add(self, request): class Model(nn.Module): def __init__(self): - super(Model, self).__init__() + super().__init__() self.data = torch.randn(4, 3) def forward(self, x): @@ -185,7 +187,7 @@ def test_reduce(self, request): class Model(nn.Module): def __init__(self): - super(Model, self).__init__() + super().__init__() def forward(self, x): """Performs a reduction summing over the last dimension of the input tensor and then unsqueezes the @@ -215,9 +217,11 @@ def forward(self, x): ), ) def test_consecutive_unsqueeze(self, request, opset): + """Tests consecutive unsqueeze operations in a model by exporting to ONNX and summarizing the slimmed model.""" + class Model(nn.Module): def __init__(self): - super(Model, self).__init__() + super().__init__() def forward(self, x): x = x.unsqueeze(-1) diff --git a/tests/utils.py b/tests/utils.py index f885345..86fdffe 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,7 +23,7 @@ from tqdm import tqdm except ImportError: # fake tqdm if it's not installed - class tqdm(object): # type: ignore[no-redef] + class tqdm: # type: ignore[no-redef] def __init__( self, total=None, @@ -44,9 +44,9 @@ def update(self, n): self.n += n if self.total is None: - sys.stderr.write("\r{0:.1f} bytes".format(self.n)) + sys.stderr.write(f"\r{self.n:.1f} bytes") else: - sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) + sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%") sys.stderr.flush() def close(self):