From 7e11bd7296a6caf2f3e422a6873718e4ac244cbd Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Tue, 1 Oct 2024 17:26:55 +0000 Subject: [PATCH 1/3] nlp patch via overriding some tensor generation --- alt_e2eshark/e2e_testing/framework.py | 2 +- alt_e2eshark/e2e_testing/onnx_utils.py | 23 ++++++++++++++++++----- alt_e2eshark/onnx_tests/models/nlp.py | 23 +++++++++++++++++++++++ 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/alt_e2eshark/e2e_testing/framework.py b/alt_e2eshark/e2e_testing/framework.py index c8329a67..981c4098 100644 --- a/alt_e2eshark/e2e_testing/framework.py +++ b/alt_e2eshark/e2e_testing/framework.py @@ -68,7 +68,7 @@ def construct_model(self): f"Model path {self.model} does not exist and no construct_model method is defined." ) - def construct_inputs(self): + def construct_inputs(self) -> TestTensors: """can be overridden to generate specific inputs, but a default is provided for convenience""" if not os.path.exists(self.model): self.construct_model() diff --git a/alt_e2eshark/e2e_testing/onnx_utils.py b/alt_e2eshark/e2e_testing/onnx_utils.py index 122f5526..57ac2831 100644 --- a/alt_e2eshark/e2e_testing/onnx_utils.py +++ b/alt_e2eshark/e2e_testing/onnx_utils.py @@ -27,9 +27,8 @@ def dtype_from_ort_node(node): return torch.bool raise NotImplementedError(f"Unhandled dtype string found: {dtypestr}") - -def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg, dim_param_dict: Optional[dict[str, int]] = None): - """A convenience function for generating sample inputs for an onnxruntime node""" +def get_node_shape_from_dim_param_dict(node: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg, dim_param_dict: Optional[dict[str, int]] = None): + """Get the shape of a node, replacing any string dims with values from a dim_param_dict""" int_dims = [] for dim in node.shape: if isinstance(dim, str) and dim_param_dict: @@ -46,7 +45,21 @@ def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.N raise ValueError( f"input node '{node.name}' has a non-positive dim: {dim}. Consider setting cutsom inputs for this test." ) - int_dims.append(dim) + return int_dims + + +def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg, dim_param_dict: Optional[dict[str, int]] = None): + """ + Generate a random input tensor for a given node + + Args: + node: an onnx node + dim_param_dict: a dictionary mapping onnx string dims to int values + """ + + + int_dims = get_node_shape_from_dim_param_dict(node, dim_param_dict) + rng = numpy.random.default_rng(19) if node.type == "tensor(float)": return rng.random(int_dims).astype(numpy.float32) @@ -61,7 +74,7 @@ def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.N raise NotImplementedError(f"Found an unhandled dtype: {node.type}.") -def get_sample_inputs_for_onnx_model(model_path, dim_param_dict = None): +def get_sample_inputs_for_onnx_model(model_path, dim_param_dict = None) -> TestTensors: """A convenience function for generating sample inputs for an onnx model""" opt = onnxruntime.SessionOptions() opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL diff --git a/alt_e2eshark/onnx_tests/models/nlp.py b/alt_e2eshark/onnx_tests/models/nlp.py index c1325608..f63d2576 100644 --- a/alt_e2eshark/onnx_tests/models/nlp.py +++ b/alt_e2eshark/onnx_tests/models/nlp.py @@ -8,6 +8,12 @@ from ..helper_classes import AzureDownloadableModel from e2e_testing.registry import register_test from e2e_testing.storage import load_test_txt_file +from e2e_testing.onnx_utils import get_node_shape_from_dim_param_dict +from e2e_testing.storage import TestTensors +import onnxruntime as ort +import numpy +from typing import Optional +import os this_file = Path(__file__) lists_dir = (this_file.parent).joinpath("external_lists") @@ -28,6 +34,23 @@ def __init__(self, *args, **kwargs): def update_dim_param_dict(self): self.dim_param_dict = dim_param_dict + def construct_inputs(self): + """Overrides the parent class method to construct sample inputs with the correct dimensions.""" + default_inputs = super().construct_inputs() + + tensors = list(default_inputs.data) + + self.update_sess_options() + session = ort.InferenceSession(self.model, self.sess_options) + + # nlp specific overrides + for i, node in enumerate(session.get_inputs()): + if node.name == "token_type_ids": + rng = numpy.random.default_rng(19) + int_dims = get_node_shape_from_dim_param_dict(node, self.dim_param_dict) + tensors[i] = rng.integers(0, 2, size=int_dims, dtype=numpy.int64) + default_sample_inputs = TestTensors(tuple(tensors)) + return default_sample_inputs return AzureWithDimParams # Default dimension parameters for NLP models From 99ec3299c8deba44deb9314199f04fa39ab3c652 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Wed, 2 Oct 2024 17:34:16 +0000 Subject: [PATCH 2/3] Refactor OnnxModelInfo for performance and maintainability - Implement lazy loading for ONNX models and inference sessions - Add properties for model, session, and I/O nodes with caching - Optimize helper classes to reduce redundant operations - Update related functions for compatibility with refactored classes - Improve -t/--test-filter argument description --- alt_e2eshark/e2e_testing/framework.py | 99 +++++++++++++++---- alt_e2eshark/e2e_testing/onnx_utils.py | 8 +- alt_e2eshark/onnx_tests/helper_classes.py | 28 +++--- alt_e2eshark/onnx_tests/models/nlp.py | 3 +- alt_e2eshark/onnx_tests/models/opt_models.py | 2 +- .../onnx_tests/operators/generate_node.py | 3 +- alt_e2eshark/run.py | 2 +- 7 files changed, 98 insertions(+), 47 deletions(-) diff --git a/alt_e2eshark/e2e_testing/framework.py b/alt_e2eshark/e2e_testing/framework.py index 981c4098..ee420f05 100644 --- a/alt_e2eshark/e2e_testing/framework.py +++ b/alt_e2eshark/e2e_testing/framework.py @@ -8,7 +8,7 @@ import abc import os from pathlib import Path -from typing import Union, TypeVar, Tuple, NamedTuple, Dict, Optional, Callable +from typing import List, Union, TypeVar, Tuple, NamedTuple, Dict, Optional, Callable, final from e2e_testing.storage import TestTensors from e2e_testing.onnx_utils import * @@ -16,9 +16,26 @@ Module = TypeVar("Module") - class OnnxModelInfo: - """Stores information about an onnx test: the filepath to model.onnx, how to construct/download it, and how to construct sample inputs for a test run.""" + """ + Stores information about an onnx test: the filepath to model.onnx, how to construct/download it, + and how to construct sample inputs for a test run. + + This class will maintain a onnxruntime.InferenceSession object for the model the first time the session or + the input / output nodes are accessed. The session can be closed via del onnx_model_info.ort_inference_session. + """ + + # _xxx properties are meant to be accessed via @property methods + __slots__ = [ + "name", + "_model", + "_ort_inference_session", + "_ort_input_nodes", + "_ort_output_nodes", + "opset_version", + "sess_options", + "dim_param_dict", + ] def __init__( self, @@ -27,20 +44,63 @@ def __init__( opset_version: Optional[int] = None, ): self.name = name - self.model = os.path.join(onnx_model_path, "model.onnx") + self._model = os.path.join(onnx_model_path, "model.onnx") self.opset_version = opset_version self.sess_options = ort.SessionOptions() self.dim_param_dict = None + # model (path to onnx file) + + @property + def model(self): + if not os.path.exists(self._model): + self.construct_model() + return self._model + + @model.deleter + def model(self): + del self._model + + # inference session + + @property + def ort_inference_session(self): + """ + Getter for the onnxruntime.InferenceSession object. If it doesn't exist, it is created and stored. + + Also stores the input and output nodes of the model in self._ort_input_nodes and self._ort_output_nodes, + such that they can be accessed via self.ort_input_nodes and self.ort_output_nodes even when the session is deleted. + """ + if not hasattr(self, "_ort_inference_session"): + self.update_sess_options() + self._ort_inference_session = ort.InferenceSession(self.model, self.sess_options) + self._ort_input_nodes = self._ort_inference_session.get_inputs() + self._ort_output_nodes = self._ort_inference_session.get_outputs() + return self._ort_inference_session + + @ort_inference_session.deleter + def ort_inference_session(self): + del self._ort_inference_session + + # input and output nodes + @property + def ort_input_nodes(self): + if not hasattr(self, "_ort_input_nodes"): + self.ort_inference_session + return self._ort_input_nodes + + @property + def ort_output_nodes(self): + if not hasattr(self, "_ort_output_nodes"): + self.ort_inference_session + return self._ort_output_nodes + def forward(self, input: Optional[TestTensors] = None) -> TestTensors: """Applies self.model to self.input. Only override if necessary for specific models""" input = input.to_numpy().data - if not os.path.exists(self.model): - self.construct_model() - self.update_sess_options() - session = ort.InferenceSession(self.model, self.sess_options) - session_inputs = session.get_inputs() - session_outputs = session.get_outputs() + session = self.ort_inference_session + session_inputs = self.ort_input_nodes + session_outputs = self.ort_output_nodes model_output = session.run( [output.name for output in session_outputs], @@ -70,12 +130,10 @@ def construct_model(self): def construct_inputs(self) -> TestTensors: """can be overridden to generate specific inputs, but a default is provided for convenience""" - if not os.path.exists(self.model): - self.construct_model() self.update_dim_param_dict() # print(self.get_signature()) # print(get_op_frequency(self.model)) - return get_sample_inputs_for_onnx_model(self.model, self.dim_param_dict) + return get_sample_inputs_for_onnx_model(self.ort_input_nodes, self.dim_param_dict) def apply_postprocessing(self, output: TestTensors): """can be overridden to define post-processing methods for individual models""" @@ -86,15 +144,13 @@ def save_processed_output(self, output: TestTensors, save_to: str, name: str): pass # the following helper methods aren't meant to be overriden - + @final def get_signature(self, *, from_inputs=True, leave_dynamic=False): """Returns the input or output signature of self.model""" - if not os.path.exists(self.model): - self.construct_model() if not leave_dynamic: self.update_dim_param_dict() return get_signature_for_onnx_model(self.model, from_inputs=from_inputs, dim_param_dict=self.dim_param_dict, leave_dynamic=leave_dynamic) - + @final def load_inputs(self, dir_path): """computes the input signature of the onnx model and loads inputs from bin files""" shapes, dtypes = self.get_signature(from_inputs=True) @@ -105,22 +161,23 @@ def load_inputs(self, dir_path): "\tWarning: bin files missing. Generating new inputs. Please re-run this test without --load-inputs to save input bin files." ) return self.construct_inputs() - + + @final def load_outputs(self, dir_path): """computes the input signature of the onnx model and loads outputs from bin files""" shapes, dtypes = self.get_signature(from_inputs=False) return TestTensors.load_from(shapes, dtypes, dir_path, "output") - + + @final def load_golden_outputs(self, dir_path): """computes the input signature of the onnx model and loads golden outputs from bin files""" shapes, dtypes = self.get_signature(from_inputs=False) return TestTensors.load_from(shapes, dtypes, dir_path, "golden_output") + @final def update_opset_version_and_overwrite(self): if not self.opset_version: return - if not os.path.exists(self.model): - self.construct_model() og_model = onnx.load(self.model) if og_model.opset_import[0].version >= self.opset_version: return diff --git a/alt_e2eshark/e2e_testing/onnx_utils.py b/alt_e2eshark/e2e_testing/onnx_utils.py index 57ac2831..e57b48b0 100644 --- a/alt_e2eshark/e2e_testing/onnx_utils.py +++ b/alt_e2eshark/e2e_testing/onnx_utils.py @@ -74,14 +74,10 @@ def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.N raise NotImplementedError(f"Found an unhandled dtype: {node.type}.") -def get_sample_inputs_for_onnx_model(model_path, dim_param_dict = None) -> TestTensors: +def get_sample_inputs_for_onnx_model(input_nodes, dim_param_dict = None) -> TestTensors: """A convenience function for generating sample inputs for an onnx model""" - opt = onnxruntime.SessionOptions() - opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - s = onnxruntime.InferenceSession(model_path, opt) - inputs = s.get_inputs() sample_inputs = TestTensors( - tuple([generate_input_from_node(node, dim_param_dict) for node in inputs]) + tuple([generate_input_from_node(node, dim_param_dict) for node in input_nodes]) ) return sample_inputs diff --git a/alt_e2eshark/onnx_tests/helper_classes.py b/alt_e2eshark/onnx_tests/helper_classes.py index 13e82c28..d52c4c90 100644 --- a/alt_e2eshark/onnx_tests/helper_classes.py +++ b/alt_e2eshark/onnx_tests/helper_classes.py @@ -18,6 +18,13 @@ """This file contains several helpful child classes of OnnxModelInfo.""" class AzureDownloadableModel(OnnxModelInfo): + + # slots + __slots__ = [ + "cache_dir" + ] + + """This class can be used for models in our azure storage (both private and public).""" def __init__(self, name: str, onnx_model_path: str): # TODO: Extract opset version from onnx.version.opset @@ -37,7 +44,7 @@ def construct_model(self): # if that fails, try to download and setup from azure, then search again for a .onnx file # TODO: make the zip file structure more uniform so we don't need to search for extracted files - model_dir = str(Path(self.model).parent) + model_dir = str(Path(self._model).parent) def find_models(model_dir): # search for a .onnx file in the ./test-run/testname/ dir @@ -52,16 +59,16 @@ def find_models(model_dir): if len(found_models) == 0: azutils.pre_test_onnx_model_azure_download( - self.name, self.cache_dir, self.model + self.name, self.cache_dir, self._model ) found_models = find_models(model_dir) if len(found_models) == 1: - self.model = found_models[0] + self._model = found_models[0] return if len(found_models) > 1: print(f'Found multiple model.onnx files: {found_models}') print(f'Picking the first model found to use: {found_models[0]}') - self.model = found_models[0] + self._model = found_models[0] return raise OSError(f"No onnx model could be found, downloaded, or extracted to {model_dir}") @@ -75,11 +82,6 @@ def __init__(self, og_model_info_class: type, og_name: str, *args, **kwargs): run_dir = Path(self.model).parents[1] og_model_path = os.path.join(run_dir, og_name) self.sibling_inst = og_model_info_class(og_name, og_model_path) - - def construct_model(self): - if not os.path.exists(self.sibling_inst.model): - self.sibling_inst.construct_model() - self.model = self.sibling_inst.model def update_dim_param_dict(self): self.sibling_inst.update_dim_param_dict() @@ -125,8 +127,6 @@ def __init__(self, n: int, op_type: str, *args, **kwargs): super().__init__(*args, **kwargs) def construct_model(self): - if not os.path.exists(self.sibling_inst.model): - self.sibling_inst.construct_model() og_model = onnx.load(self.sibling_inst.model) inf_model = onnx.shape_inference.infer_shapes(og_model, data_prop=True) output_node = ( @@ -135,9 +135,9 @@ def construct_model(self): else find_node(inf_model, self.n, self.op_type) ) new_model = modify_model_output(inf_model, output_node) - onnx.save(new_model, self.model) + onnx.save(new_model, self._model) from e2e_testing.onnx_utils import get_op_frequency - print(get_op_frequency(self.model)) + print(get_op_frequency(self._model)) def get_trucated_constructor(truncated_class, og_constructor, og_name): @@ -231,4 +231,4 @@ def construct_model(self): self.construct_initializers() graph = make_graph(self.node_list, "main", self.input_vi, self.output_vi, self.initializers) model = make_model(graph) - onnx.save(model, self.model) + onnx.save(model, self._model) diff --git a/alt_e2eshark/onnx_tests/models/nlp.py b/alt_e2eshark/onnx_tests/models/nlp.py index f63d2576..217d16f8 100644 --- a/alt_e2eshark/onnx_tests/models/nlp.py +++ b/alt_e2eshark/onnx_tests/models/nlp.py @@ -41,10 +41,9 @@ def construct_inputs(self): tensors = list(default_inputs.data) self.update_sess_options() - session = ort.InferenceSession(self.model, self.sess_options) # nlp specific overrides - for i, node in enumerate(session.get_inputs()): + for i, node in enumerate(self.ort_input_nodes): if node.name == "token_type_ids": rng = numpy.random.default_rng(19) int_dims = get_node_shape_from_dim_param_dict(node, self.dim_param_dict) diff --git a/alt_e2eshark/onnx_tests/models/opt_models.py b/alt_e2eshark/onnx_tests/models/opt_models.py index a151d45e..ad76d9d7 100644 --- a/alt_e2eshark/onnx_tests/models/opt_models.py +++ b/alt_e2eshark/onnx_tests/models/opt_models.py @@ -83,7 +83,7 @@ def construct_model(self): torch.onnx.export( self.pytorch_model, (self.encoding["input_ids"], self.encoding["attention_mask"]), - self.model, + self._model, opset_version=20, ) diff --git a/alt_e2eshark/onnx_tests/operators/generate_node.py b/alt_e2eshark/onnx_tests/operators/generate_node.py index 6f4c6fdb..02d6f646 100644 --- a/alt_e2eshark/onnx_tests/operators/generate_node.py +++ b/alt_e2eshark/onnx_tests/operators/generate_node.py @@ -40,8 +40,7 @@ def __init__(self, *args, **kwargs): self.model = onnx_node_tests_dir + self.name + "/model.onnx" def construct_inputs(self): - model = onnx.load(self.model) - inputs = model.graph.input + inputs = len(self.ort_input_nodes) num_inputs = len(inputs) input_list = [] for i in range(num_inputs): diff --git a/alt_e2eshark/run.py b/alt_e2eshark/run.py index 37cae113..0b266649 100644 --- a/alt_e2eshark/run.py +++ b/alt_e2eshark/run.py @@ -392,7 +392,7 @@ def _get_argparse(): parser.add_argument( "-t", "--test-filter", - help="Run given specific test(s) only", + help="Run tests matching regex filter only.", ) parser.add_argument( "--testsfile", From faa45839ac46dfb735122525a0b93af6fe0d6fdd Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Wed, 2 Oct 2024 20:41:39 +0000 Subject: [PATCH 3/3] add option to run one test of each class --- alt_e2eshark/run.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/alt_e2eshark/run.py b/alt_e2eshark/run.py index 0b266649..2b07efb9 100644 --- a/alt_e2eshark/run.py +++ b/alt_e2eshark/run.py @@ -11,6 +11,7 @@ import argparse import re from typing import List, Literal, Optional +import random # append alt_e2eshark dir to path to allow importing without explicit pythonpath management TEST_DIR = str(Path(__file__).parent) @@ -88,6 +89,24 @@ def get_tests(groups: Literal["all", "combinations", "operators"], test_filter: return test_list +def select_tests_for_infrastructure(test_list, category): + if not category: + return test_list + + category_dict = {} + for test in test_list: + if category == 'class': + key = test.model_constructor.__name__ + else: + raise ValueError(f"Unknown category: {category}") + + if key not in category_dict: + category_dict[key] = [] + category_dict[key].append(test) + random.seed(a=19, version=2) + selected_tests = [random.choice(tests) for tests in category_dict.values()] + return selected_tests + def main(args): """Sets up config and test list based on CL args, then runs the tests""" @@ -114,6 +133,9 @@ def main(args): test_list = get_tests(args.groups, args.test_filter, args.testsfile) test_list.sort() + if args.test_infrastructure_by_category: + test_list = select_tests_for_infrastructure(test_list, args.test_infrastructure_by_category) + #setup test stages stages = ALL_STAGES if args.benchmark else DEFAULT_STAGES @@ -444,6 +466,12 @@ def _get_argparse(): default="report.md", help="output filename for the report summary.", ) + parser.add_argument( + "--test-infrastructure-by-category", + choices=['class'], # TODO: unique name and so on + help="Run one random test from each specified category to test the infrastructure", + ) + # parser.add_argument( # "-d", # "--todtype",