diff --git a/onnx_tf/backend.py b/onnx_tf/backend.py index 06574425..ce795c98 100644 --- a/onnx_tf/backend.py +++ b/onnx_tf/backend.py @@ -7,6 +7,8 @@ from __future__ import print_function from __future__ import unicode_literals +import os + try: from itertools import izip as zip except ImportError: # will be 3.x series @@ -14,6 +16,7 @@ from onnx import defs from onnx import numpy_helper +from onnx import load from onnx.backend.base import Backend from onnx.backend.base import namedtupledict from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt @@ -51,7 +54,9 @@ def prepare(cls, of the computational graph called TensorflowRep and returns the converted representation. - :param model: The ONNX model to be converted. + :param model: The ONNX model or the file path to the ONNX model to be converted. + If a file path is provided for a model larger than 2GB, the model will be loaded directly + from the path to handle large file sizes. :param device: The device to execute this model on. It can be either CPU (default) or CUDA. :param strict: Whether to enforce semantic equivalence between the original model and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence). @@ -70,6 +75,9 @@ def prepare(cls, common.sys_config.auto_cast = auto_cast common.sys_config.device = device + if isinstance(model, (str, os.PathLike)): + model = load(model) + return cls.onnx_model_to_tensorflow_rep(model, strict, **kwargs) @classmethod