diff --git a/onnxslim/utils.py b/onnxslim/utils.py index 8ca6f23..409fe2d 100644 --- a/onnxslim/utils.py +++ b/onnxslim/utils.py @@ -185,8 +185,8 @@ def format_model_info(model_name: str, model_info_list: List[Dict], elapsed_time ) final_op_info.extend( ( - [SEPARATING_LINE], - ["Model Info", "Original Model"] + ["Slimmed Model"] * (len(model_info_list) - 1), + [SEPARATING_LINE] * (len(model_info_list) + 1), + ["Model Info"] + [model_info_list[0].get("tag", "Original Model")] + [item.get("tag", "Slimmed Model") for item in model_info_list[1:]], [SEPARATING_LINE] * (len(model_info_list) + 1), ) ) @@ -318,13 +318,15 @@ def get_opset(model: onnx.ModelProto) -> int: return None -def summarize_model(model: Union[str, onnx.ModelProto]) -> Dict: +def summarize_model(model: Union[str, onnx.ModelProto], tag=None) -> Dict: """Generates a summary of the ONNX model, including model size, operations, and tensor shapes.""" if isinstance(model, str): model = onnx.load(model) logger.debug("Start summarizing model.") model_info = {} + if tag != None: + model_info["tag"] = tag model_size = model.ByteSize() model_info["model_size"] = model_size