Skip to content

Commit

Permalink
[#69025] pipeline_manager: Add support for RuntimeBuilder
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Torhan <[email protected]>
  • Loading branch information
m-torhan committed Nov 25, 2024
1 parent 54609e9 commit 250adf5
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
2 changes: 1 addition & 1 deletion kenning/pipeline_manager/flow_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def flush_graph(self):
to_ = self.nodes[to_id]
if local_to in to_["inputs"]:
raise RuntimeError(
f"Input {local_to} has more than one " f"connection"
f"Input {local_to} has more than one connection"
)
if local_from in from_["outputs"]:
conn_id = from_["outputs"][local_from]
Expand Down
25 changes: 21 additions & 4 deletions kenning/pipeline_manager/pipeline_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pipeline_manager import specification_builder

from kenning.core.model import ModelWrapper
from kenning.core.protocol import Protocol
from kenning.core.runtimebuilder import RuntimeBuilder
from kenning.pipeline_manager.core import (
SPECIFICATION_VERSION,
BaseDataflowHandler,
Expand Down Expand Up @@ -105,7 +105,13 @@ def add_block(kenning_block: Dict) -> Dict:

node_ids = {}

block_names = ["dataset", "model_wrapper", "runtime", "protocol"]
block_names = [
"dataset",
"model_wrapper",
"runtime",
"runtime_builder",
"protocol",
]
supported_blocks = block_names + ["optimizers"]
for name, block in pipeline.items():
if name not in supported_blocks:
Expand Down Expand Up @@ -198,6 +204,7 @@ def get_nodes(
"kenning.datasets",
"kenning.modelwrappers",
"kenning.protocols",
"kenning.runtimebuilders",
"kenning.runtimes",
"kenning.optimizers",
]
Expand All @@ -213,7 +220,7 @@ def get_nodes(
for _, base_type in base_classes
}
base_type_names[ModelWrapper] = "model_wrapper"
base_type_names[Protocol] = "protocol"
base_type_names[RuntimeBuilder] = "runtime_builder"
for base_module, base_type in base_classes:
classes = get_all_subclasses(base_module, base_type)
for kenning_class in classes:
Expand Down Expand Up @@ -290,6 +297,10 @@ def get_nodes(
],
"outputs": [],
},
"runtime_builder": {
"inputs": [],
"outputs": [],
},
"protocol": {
"inputs": [],
"outputs": [
Expand Down Expand Up @@ -367,7 +378,13 @@ def flush_graph(self) -> Dict:
)

pipeline = {}
types = ["model_wrapper", "runtime", "dataset", "protocol"]
types = [
"model_wrapper",
"runtime",
"runtime_builder",
"dataset",
"protocol",
]
for type_ in types:
if type_ in self.type_to_id:
pipeline[type_] = self.nodes[self.type_to_id[type_]]
Expand Down
7 changes: 7 additions & 0 deletions kenning/runtimebuilders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) 2020-2023 Antmicro <www.antmicro.com>
#
# SPDX-License-Identifier: Apache-2.0

"""
Contains implementations for runtimes' builders.
"""
5 changes: 4 additions & 1 deletion kenning/utils/class_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from kenning.core.protocol import Protocol
from kenning.core.runner import Runner
from kenning.core.runtime import Runtime
from kenning.core.runtimebuilder import RuntimeBuilder
from kenning.utils.logger import KLogger

OPTIMIZERS = "optimizers"
Expand All @@ -34,6 +35,7 @@
MODEL_WRAPPERS = "modelwrappers"
ONNX_CONVERSIONS = "onnxconversions"
OUTPUT_COLLECTORS = "outputcollectors"
RUNTIME_BUILDERS = "runtimebuilders"
RUNTIME_PROTOCOLS = "protocols"
RUNTIMES = "runtimes"

Expand All @@ -57,6 +59,7 @@ def get_base_classes_dict() -> Dict[str, Tuple[str, Type]]:
MODEL_WRAPPERS: ("kenning.modelwrappers", ModelWrapper),
ONNX_CONVERSIONS: ("kenning.onnxconverters", ONNXConversion),
OUTPUT_COLLECTORS: ("kenning.outputcollectors", OutputCollector),
RUNTIME_BUILDERS: ("kenning.runtimebuilders", RuntimeBuilder),
RUNTIME_PROTOCOLS: ("kenning.protocols", Protocol),
RUNTIMES: ("kenning.runtimes", Runtime),
}
Expand Down Expand Up @@ -321,7 +324,7 @@ def get_command(argv: List[str] = None, with_slash: bool = True) -> List[str]:
result = [f"python -m {modulename}"]
first_flag = 1
else:
result = [f'kenning {" ".join(command[1:first_flag])}']
result = [f"kenning {' '.join(command[1:first_flag])}"]

if len(command) > 1:
result[0] = f"{result[0]} " + ("\\" if with_slash else "")
Expand Down

0 comments on commit 250adf5

Please sign in to comment.