Skip to content

Commit

Permalink
split variable and operators name generation
Browse files Browse the repository at this point in the history
This allows for a compatibility with sklearn2onnx.
Moreover we now explicitely set names to operators.
  • Loading branch information
MainRo committed Oct 25, 2024
1 parent 68ec987 commit 8e5039b
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 92 deletions.
35 changes: 35 additions & 0 deletions ebm2onnx/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import NamedTuple, Callable


class Context(NamedTuple):
generate_variable_name: Callable[[], str]
generate_operator_name: Callable[[], str]


def create_name_generator() -> Callable[[str], str]:
state = {}

def _generate_unique_name(name: str) -> str:
""" Generates a new globaly unique name in the graph
"""
if name in state:
state[name] += 1
else:
state[name] = 0

return "{}_{}".format(name, state[name])

return _generate_unique_name


def create(
generate_variable_name=None,
generate_operator_name=None,
) -> Context:
generate_variable_name = generate_variable_name or create_name_generator()
generate_operator_name = generate_operator_name or create_name_generator()

return Context(
generate_variable_name=generate_variable_name,
generate_operator_name=generate_operator_name,
)
3 changes: 2 additions & 1 deletion ebm2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def to_onnx(model, dtype, name="ebm",
prediction_name="prediction",
probabilities_name="probabilities",
explain_name="scores",
context=None,
):
"""Converts an EBM model to ONNX.
Expand All @@ -93,7 +94,7 @@ def to_onnx(model, dtype, name="ebm",
An ONNX model.
"""
target_opset = target_opset or get_latest_opset_version()
root = graph.create_graph()
root = graph.create_graph(context=context)

inputs = [None for _ in model.feature_names_in_]
parts = []
Expand Down
2 changes: 1 addition & 1 deletion ebm2onnx/ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_bin_index_on_continuous_value(bin_edges):
def _get_bin_index_on_continuous_value(g):
bin_count = len(bin_edges)
index_range = list(range(bin_count))

init_bin_index_range = graph.create_initializer(g, "bin_index_range", onnx.TensorProto.FLOAT, [bin_count], index_range)
init_bin_edges = graph.create_initializer(g, "bin_edges", onnx.TensorProto.DOUBLE, [bin_count], bin_edges)

Expand Down
63 changes: 30 additions & 33 deletions ebm2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,18 @@
from ebm2onnx import __version__
from .utils import get_latest_opset_version

from . import context as _context


class Graph(NamedTuple):
generate_name: Callable[[], str]
context: _context.Context
inputs: List[onnx.ValueInfoProto] = []
outputs: List[onnx.ValueInfoProto] = []
transients: List[onnx.ValueInfoProto] = []
nodes: List[onnx.NodeProto] = []
initializers: List[onnx.TensorProto] = []


def create_name_generator() -> Callable[[str], str]:
state = {}

def _generate_unique_name(name: str) -> str:
""" Generates a new globaly unique name in the graph
"""
if name in state:
state[name] += 1
else:
state[name] = 0

return "{}_{}".format(name, state[name])

return _generate_unique_name


def extend(i, val):
"""Extends a list as a copy
"""
Expand All @@ -42,14 +28,16 @@ def pipe(*args):
pass


def create_graph() -> Graph:
def create_graph(context=None) -> Graph:
"""Creates a new graph object.
Returns:
A Graph object.
"""
if context is None:
context = _context.create()
return Graph(
generate_name=create_name_generator()
context=context
)


Expand All @@ -65,7 +53,7 @@ def from_onnx(model) -> Graph:
A Graph object.
"""
return Graph(
generate_name=create_name_generator(),
context=_context.create(),
inputs=[n for n in model.graph.input],
outputs=[n for n in model.graph.output],
nodes=[n for n in model.graph.node],
Expand Down Expand Up @@ -97,7 +85,7 @@ def to_onnx(
graph = onnx.helper.make_graph(
nodes=graph.nodes,
name=name,
inputs=graph.inputs,
inputs=graph.inputs,
outputs=graph.outputs,
initializer=graph.initializers,
)
Expand Down Expand Up @@ -134,7 +122,7 @@ def to_onnx(
def create_input(graph, name, type, shape):
input = onnx.helper.make_tensor_value_info(name , type, shape)
return Graph(
generate_name=graph.generate_name,
context=graph.context,
inputs=[input],
transients=[input],
)
Expand All @@ -148,40 +136,49 @@ def add_output(graph, name, type, shape):


def create_initializer(graph, name, type, shape, value):
initializer = onnx.helper.make_tensor(graph.generate_name(name) , type, shape, value)
initializer = onnx.helper.make_tensor(graph.context.generate_variable_name(name) , type, shape, value)
return Graph(
generate_name=graph.generate_name,
context=graph.context,
initializers=[initializer],
transients=[initializer],
)


def create_transient_by_name(g, name, type, shape):
def create_transient_by_name(graph, name, type, shape):
input = onnx.helper.make_tensor_value_info(name, type, shape)
return Graph(
generate_name=g.generate_name,
context=graph.context,
transients=[input],
)


def add_transient_by_name(g, name, type=onnx.TensorProto.UNDEFINED, shape=[]):
def add_transient_by_name(graph, name, type=onnx.TensorProto.UNDEFINED, shape=[]):
tname = [
o
for n in g.nodes
for n in graph.nodes
for o in n.output
if o == name
][0]
]

if len(tname) == 0:
tname = [
name
for n in graph.initializers
if n.name == name
]

tname = tname[0]
t = onnx.helper.make_tensor_value_info(tname, type, shape)
return g._replace(
transients=extend(g.transients, [t])
)
return graph._replace(
transients=extend(graph.transients, [t])
)


def strip_to_transients(graph):
""" Returns only the transients of a graph
"""
return Graph(
generate_name=graph.generate_name,
context=graph.context,
transients=graph.transients,
)

Expand Down
Loading

0 comments on commit 8e5039b

Please sign in to comment.