Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bert fix + a bunch of refactoring #359

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 79 additions & 22 deletions alt_e2eshark/e2e_testing/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,34 @@
import abc
import os
from pathlib import Path
from typing import Union, TypeVar, Tuple, NamedTuple, Dict, Optional, Callable
from typing import List, Union, TypeVar, Tuple, NamedTuple, Dict, Optional, Callable, final
from e2e_testing.storage import TestTensors
from e2e_testing.onnx_utils import *

# This file two types of classes: framework-specific base classes for storing model info, and generic classes for testing infrastructure.

Module = TypeVar("Module")


class OnnxModelInfo:
"""Stores information about an onnx test: the filepath to model.onnx, how to construct/download it, and how to construct sample inputs for a test run."""
"""
Stores information about an onnx test: the filepath to model.onnx, how to construct/download it,
and how to construct sample inputs for a test run.

This class will maintain a onnxruntime.InferenceSession object for the model the first time the session or
the input / output nodes are accessed. The session can be closed via del onnx_model_info.ort_inference_session.
"""

# _xxx properties are meant to be accessed via @property methods
__slots__ = [
"name",
"_model",
"_ort_inference_session",
"_ort_input_nodes",
"_ort_output_nodes",
"opset_version",
"sess_options",
"dim_param_dict",
]

def __init__(
self,
Expand All @@ -27,20 +44,63 @@ def __init__(
opset_version: Optional[int] = None,
):
self.name = name
self.model = os.path.join(onnx_model_path, "model.onnx")
self._model = os.path.join(onnx_model_path, "model.onnx")
self.opset_version = opset_version
self.sess_options = ort.SessionOptions()
self.dim_param_dict = None

# model (path to onnx file)

@property
def model(self):
if not os.path.exists(self._model):
self.construct_model()
return self._model

@model.deleter
def model(self):
del self._model

# inference session

@property
def ort_inference_session(self):
"""
Getter for the onnxruntime.InferenceSession object. If it doesn't exist, it is created and stored.

Also stores the input and output nodes of the model in self._ort_input_nodes and self._ort_output_nodes,
such that they can be accessed via self.ort_input_nodes and self.ort_output_nodes even when the session is deleted.
"""
if not hasattr(self, "_ort_inference_session"):
self.update_sess_options()
self._ort_inference_session = ort.InferenceSession(self.model, self.sess_options)
self._ort_input_nodes = self._ort_inference_session.get_inputs()
self._ort_output_nodes = self._ort_inference_session.get_outputs()
return self._ort_inference_session

@ort_inference_session.deleter
def ort_inference_session(self):
del self._ort_inference_session

# input and output nodes
@property
def ort_input_nodes(self):
if not hasattr(self, "_ort_input_nodes"):
self.ort_inference_session
return self._ort_input_nodes

@property
def ort_output_nodes(self):
if not hasattr(self, "_ort_output_nodes"):
self.ort_inference_session
return self._ort_output_nodes

def forward(self, input: Optional[TestTensors] = None) -> TestTensors:
"""Applies self.model to self.input. Only override if necessary for specific models"""
input = input.to_numpy().data
if not os.path.exists(self.model):
self.construct_model()
self.update_sess_options()
session = ort.InferenceSession(self.model, self.sess_options)
session_inputs = session.get_inputs()
session_outputs = session.get_outputs()
session = self.ort_inference_session
session_inputs = self.ort_input_nodes
session_outputs = self.ort_output_nodes

model_output = session.run(
[output.name for output in session_outputs],
Expand Down Expand Up @@ -68,14 +128,12 @@ def construct_model(self):
f"Model path {self.model} does not exist and no construct_model method is defined."
)

def construct_inputs(self):
def construct_inputs(self) -> TestTensors:
"""can be overridden to generate specific inputs, but a default is provided for convenience"""
if not os.path.exists(self.model):
self.construct_model()
self.update_dim_param_dict()
# print(self.get_signature())
# print(get_op_frequency(self.model))
return get_sample_inputs_for_onnx_model(self.model, self.dim_param_dict)
return get_sample_inputs_for_onnx_model(self.ort_input_nodes, self.dim_param_dict)

def apply_postprocessing(self, output: TestTensors):
"""can be overridden to define post-processing methods for individual models"""
Expand All @@ -86,15 +144,13 @@ def save_processed_output(self, output: TestTensors, save_to: str, name: str):
pass

# the following helper methods aren't meant to be overriden

@final
def get_signature(self, *, from_inputs=True, leave_dynamic=False):
"""Returns the input or output signature of self.model"""
if not os.path.exists(self.model):
self.construct_model()
if not leave_dynamic:
self.update_dim_param_dict()
return get_signature_for_onnx_model(self.model, from_inputs=from_inputs, dim_param_dict=self.dim_param_dict, leave_dynamic=leave_dynamic)

@final
def load_inputs(self, dir_path):
"""computes the input signature of the onnx model and loads inputs from bin files"""
shapes, dtypes = self.get_signature(from_inputs=True)
Expand All @@ -105,22 +161,23 @@ def load_inputs(self, dir_path):
"\tWarning: bin files missing. Generating new inputs. Please re-run this test without --load-inputs to save input bin files."
)
return self.construct_inputs()


@final
def load_outputs(self, dir_path):
"""computes the input signature of the onnx model and loads outputs from bin files"""
shapes, dtypes = self.get_signature(from_inputs=False)
return TestTensors.load_from(shapes, dtypes, dir_path, "output")


@final
def load_golden_outputs(self, dir_path):
"""computes the input signature of the onnx model and loads golden outputs from bin files"""
shapes, dtypes = self.get_signature(from_inputs=False)
return TestTensors.load_from(shapes, dtypes, dir_path, "golden_output")

@final
def update_opset_version_and_overwrite(self):
if not self.opset_version:
return
if not os.path.exists(self.model):
self.construct_model()
og_model = onnx.load(self.model)
if og_model.opset_import[0].version >= self.opset_version:
return
Expand Down
29 changes: 19 additions & 10 deletions alt_e2eshark/e2e_testing/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ def dtype_from_ort_node(node):
return torch.bool
raise NotImplementedError(f"Unhandled dtype string found: {dtypestr}")


def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg, dim_param_dict: Optional[dict[str, int]] = None):
"""A convenience function for generating sample inputs for an onnxruntime node"""
def get_node_shape_from_dim_param_dict(node: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg, dim_param_dict: Optional[dict[str, int]] = None):
"""Get the shape of a node, replacing any string dims with values from a dim_param_dict"""
int_dims = []
for dim in node.shape:
if isinstance(dim, str) and dim_param_dict:
Expand All @@ -46,7 +45,21 @@ def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.N
raise ValueError(
f"input node '{node.name}' has a non-positive dim: {dim}. Consider setting cutsom inputs for this test."
)
int_dims.append(dim)
return int_dims


def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg, dim_param_dict: Optional[dict[str, int]] = None):
"""
Generate a random input tensor for a given node

Args:
node: an onnx node
dim_param_dict: a dictionary mapping onnx string dims to int values
"""


int_dims = get_node_shape_from_dim_param_dict(node, dim_param_dict)

rng = numpy.random.default_rng(19)
if node.type == "tensor(float)":
return rng.random(int_dims).astype(numpy.float32)
Expand All @@ -61,14 +74,10 @@ def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.N
raise NotImplementedError(f"Found an unhandled dtype: {node.type}.")


def get_sample_inputs_for_onnx_model(model_path, dim_param_dict = None):
def get_sample_inputs_for_onnx_model(input_nodes, dim_param_dict = None) -> TestTensors:
"""A convenience function for generating sample inputs for an onnx model"""
opt = onnxruntime.SessionOptions()
opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
s = onnxruntime.InferenceSession(model_path, opt)
inputs = s.get_inputs()
sample_inputs = TestTensors(
tuple([generate_input_from_node(node, dim_param_dict) for node in inputs])
tuple([generate_input_from_node(node, dim_param_dict) for node in input_nodes])
)
return sample_inputs

Expand Down
28 changes: 14 additions & 14 deletions alt_e2eshark/onnx_tests/helper_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
"""This file contains several helpful child classes of OnnxModelInfo."""

class AzureDownloadableModel(OnnxModelInfo):

# slots
__slots__ = [
"cache_dir"
]


"""This class can be used for models in our azure storage (both private and public)."""
def __init__(self, name: str, onnx_model_path: str):
# TODO: Extract opset version from onnx.version.opset
Expand All @@ -37,7 +44,7 @@ def construct_model(self):
# if that fails, try to download and setup from azure, then search again for a .onnx file

# TODO: make the zip file structure more uniform so we don't need to search for extracted files
model_dir = str(Path(self.model).parent)
model_dir = str(Path(self._model).parent)

def find_models(model_dir):
# search for a .onnx file in the ./test-run/testname/ dir
Expand All @@ -52,16 +59,16 @@ def find_models(model_dir):

if len(found_models) == 0:
azutils.pre_test_onnx_model_azure_download(
self.name, self.cache_dir, self.model
self.name, self.cache_dir, self._model
)
found_models = find_models(model_dir)
if len(found_models) == 1:
self.model = found_models[0]
self._model = found_models[0]
return
if len(found_models) > 1:
print(f'Found multiple model.onnx files: {found_models}')
print(f'Picking the first model found to use: {found_models[0]}')
self.model = found_models[0]
self._model = found_models[0]
return
raise OSError(f"No onnx model could be found, downloaded, or extracted to {model_dir}")

Expand All @@ -75,11 +82,6 @@ def __init__(self, og_model_info_class: type, og_name: str, *args, **kwargs):
run_dir = Path(self.model).parents[1]
og_model_path = os.path.join(run_dir, og_name)
self.sibling_inst = og_model_info_class(og_name, og_model_path)

def construct_model(self):
if not os.path.exists(self.sibling_inst.model):
self.sibling_inst.construct_model()
self.model = self.sibling_inst.model

def update_dim_param_dict(self):
self.sibling_inst.update_dim_param_dict()
Expand Down Expand Up @@ -125,8 +127,6 @@ def __init__(self, n: int, op_type: str, *args, **kwargs):
super().__init__(*args, **kwargs)

def construct_model(self):
if not os.path.exists(self.sibling_inst.model):
self.sibling_inst.construct_model()
og_model = onnx.load(self.sibling_inst.model)
inf_model = onnx.shape_inference.infer_shapes(og_model, data_prop=True)
output_node = (
Expand All @@ -135,9 +135,9 @@ def construct_model(self):
else find_node(inf_model, self.n, self.op_type)
)
new_model = modify_model_output(inf_model, output_node)
onnx.save(new_model, self.model)
onnx.save(new_model, self._model)
from e2e_testing.onnx_utils import get_op_frequency
print(get_op_frequency(self.model))
print(get_op_frequency(self._model))


def get_trucated_constructor(truncated_class, og_constructor, og_name):
Expand Down Expand Up @@ -231,4 +231,4 @@ def construct_model(self):
self.construct_initializers()
graph = make_graph(self.node_list, "main", self.input_vi, self.output_vi, self.initializers)
model = make_model(graph)
onnx.save(model, self.model)
onnx.save(model, self._model)
22 changes: 22 additions & 0 deletions alt_e2eshark/onnx_tests/models/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from ..helper_classes import AzureDownloadableModel
from e2e_testing.registry import register_test
from e2e_testing.storage import load_test_txt_file
from e2e_testing.onnx_utils import get_node_shape_from_dim_param_dict
from e2e_testing.storage import TestTensors
import onnxruntime as ort
import numpy
from typing import Optional
import os

this_file = Path(__file__)
lists_dir = (this_file.parent).joinpath("external_lists")
Expand All @@ -28,6 +34,22 @@ def __init__(self, *args, **kwargs):
def update_dim_param_dict(self):
self.dim_param_dict = dim_param_dict

def construct_inputs(self):
"""Overrides the parent class method to construct sample inputs with the correct dimensions."""
default_inputs = super().construct_inputs()

tensors = list(default_inputs.data)

self.update_sess_options()

# nlp specific overrides
for i, node in enumerate(self.ort_input_nodes):
if node.name == "token_type_ids":
rng = numpy.random.default_rng(19)
int_dims = get_node_shape_from_dim_param_dict(node, self.dim_param_dict)
tensors[i] = rng.integers(0, 2, size=int_dims, dtype=numpy.int64)
default_sample_inputs = TestTensors(tuple(tensors))
return default_sample_inputs
return AzureWithDimParams

# Default dimension parameters for NLP models
Expand Down
2 changes: 1 addition & 1 deletion alt_e2eshark/onnx_tests/models/opt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def construct_model(self):
torch.onnx.export(
self.pytorch_model,
(self.encoding["input_ids"], self.encoding["attention_mask"]),
self.model,
self._model,
opset_version=20,
)

Expand Down
3 changes: 1 addition & 2 deletions alt_e2eshark/onnx_tests/operators/generate_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def __init__(self, *args, **kwargs):
self.model = onnx_node_tests_dir + self.name + "/model.onnx"

def construct_inputs(self):
model = onnx.load(self.model)
inputs = model.graph.input
inputs = len(self.ort_input_nodes)
num_inputs = len(inputs)
input_list = []
for i in range(num_inputs):
Expand Down
Loading