diff --git a/onnxslim/core/pattern/fusion/__init__.py b/onnxslim/core/pattern/fusion/__init__.py index 15b9d17..411c211 100644 --- a/onnxslim/core/pattern/fusion/__init__.py +++ b/onnxslim/core/pattern/fusion/__init__.py @@ -3,3 +3,4 @@ from .gemm import * from .padconv import * from .reduce import * +from .convadd import * diff --git a/onnxslim/core/pattern/fusion/convadd.py b/onnxslim/core/pattern/fusion/convadd.py new file mode 100644 index 0000000..94234ca --- /dev/null +++ b/onnxslim/core/pattern/fusion/convadd.py @@ -0,0 +1,66 @@ +import numpy as np + +import onnxslim.third_party.onnx_graphsurgeon as gs +from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users +from onnxslim.core.pattern.registry import register_fusion_pattern + + +class ConvAddMatcher(PatternMatcher): + def __init__(self, priority): + """Initializes the ConvAddMatcher for fusing Conv and Add layers in an ONNX graph.""" + pattern = Pattern( + """ + input input 0 1 conv_0 + Conv conv_0 1+ 1 input bn_0 + Add add_0 2 1 conv_0 ? output + output output 1 0 add_0 + """ + ) + super().__init__(pattern, priority) + + @property + def name(self): + """Returns the name of the FusionConvAdd pattern.""" + return "FusionConvAdd" + + def rewrite(self, opset=11): + match_case = {} + conv_node = self.conv_0 + conv_weight = list(conv_node.inputs)[1] + conv_node_users = get_node_users(conv_node) + node = self.add_0 + if len(conv_node_users) == 1 and isinstance(node.inputs[1], gs.Constant) and isinstance(conv_weight, gs.Constant) and node.inputs[1].values.squeeze().ndim == 1 and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[0]: + add_node = node + if len(conv_node.inputs) == 2: + conv_bias = node.inputs[1].values.squeeze() + else: + conv_bias = conv_node.inputs[2].values + node.inputs[1].values.squeeze() + + inputs = [] + inputs.append(list(conv_node.inputs)[0]) + inputs.append(conv_weight) + weight_name = list(conv_node.inputs)[1].name + if weight_name.endswith("weight"): + bias_name = f"{weight_name[:-6]}bias" + else: + bias_name = f"{weight_name}_bias" + inputs.append(gs.Constant(bias_name, values=conv_bias)) + outputs = list(add_node.outputs) + + conv_node.outputs.clear() + add_node.inputs.clear() + add_node.outputs.clear() + + match_case[conv_node.name] = { + "op": conv_node.op, + "inputs": inputs, + "outputs": outputs, + "name": conv_node.name, + "attrs": conv_node.attrs, + "domain": None, + } + + return match_case + + +register_fusion_pattern(ConvAddMatcher(1))