Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add backend.prepare Handle Large Models #1074

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion onnx_tf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
from __future__ import print_function
from __future__ import unicode_literals

import os

try:
from itertools import izip as zip
except ImportError: # will be 3.x series
pass

from onnx import defs
from onnx import numpy_helper
from onnx import load
from onnx.backend.base import Backend
from onnx.backend.base import namedtupledict
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt
Expand Down Expand Up @@ -51,7 +54,9 @@ def prepare(cls,
of the computational graph called TensorflowRep and returns
the converted representation.

:param model: The ONNX model to be converted.
:param model: The ONNX model or the file path to the ONNX model to be converted.
If a file path is provided for a model larger than 2GB, the model will be loaded directly
from the path to handle large file sizes.
:param device: The device to execute this model on. It can be either CPU (default) or CUDA.
:param strict: Whether to enforce semantic equivalence between the original model
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
Expand All @@ -70,6 +75,9 @@ def prepare(cls,
common.sys_config.auto_cast = auto_cast
common.sys_config.device = device

if isinstance(model, (str, os.PathLike)):
model = load(model)

return cls.onnx_model_to_tensorflow_rep(model, strict, **kwargs)

@classmethod
Expand Down