diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..23b60fc --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,40 @@ +exclude: | + (?x)^( + docs/.* + )$ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: debug-statements + - id: check-merge-conflict + - repo: https://github.com/asottile/pyupgrade + rev: v3.3.1 + hooks: + - id: pyupgrade + args: ["--py38-plus"] + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + language_version: python3 + - repo: https://github.com/pycqa/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + - repo: https://github.com/pycqa/isort + rev: 5.11.5 + hooks: + - id: isort + - repo: https://github.com/humitos/mirrors-autoflake.git + rev: v1.1 + hooks: + - id: autoflake + args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable'] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.1.1 + hooks: + - id: mypy + additional_dependencies: + - types-filelock + - types-setuptools diff --git a/README.md b/README.md index d7a4cb7..4067be3 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ within Numba. The code in this repository is an implementation of the CFG restructuring algorithms in Bahmann2015, specifically those from section 4.1 and 4.2: namely "loop restructuring" and "branch restructuring". These are interesting for -Numba because they serve to clearly identify regions withing the Python +Numba because they serve to clearly identify regions within the Python bytecode. ## dependencies @@ -42,8 +42,8 @@ numba_rvsdg ├── core │   ├── datastructures │   │   ├── basic_block.py # BasicBlock implementation -│   │   ├── block_map.py # BlockMap implementation, maps labels to blocks -│   │   ├── byte_flow.py # ByteFlow implementation, BlockMap + bytecode +│   │   ├── scfg.py # SCFG implementation, maps labels to blocks +│   │   ├── byte_flow.py # ByteFlow implementation, SCFG + bytecode │   │   ├── flow_info.py # Converts program to ByteFlow │   │   └── labels.py # Collection of Label classes │   ├── transformations.py # Algorithms @@ -51,7 +51,7 @@ numba_rvsdg ├── networkx_vendored │   └── scc.py # Strongly Connected Componets (loop detection) ├── rendering -│   └── rendering.py # Graphivz based rendering of BlockMaps +│   └── rendering.py # Graphivz based rendering of SCFGs ├── tests │   ├── simulator.py # Simulator utility for running SCFGs │   ├── test_byteflow.py # Testung ByteFlow and others diff --git a/numba_rvsdg/core/datastructures/basic_block.py b/numba_rvsdg/core/datastructures/basic_block.py index 9a92375..e62cd6e 100644 --- a/numba_rvsdg/core/datastructures/basic_block.py +++ b/numba_rvsdg/core/datastructures/basic_block.py @@ -1,47 +1,31 @@ import dis -from collections import ChainMap -from typing import Tuple, Dict, List -from dataclasses import dataclass, field, replace +from typing import Dict, List +from dataclasses import dataclass, field, InitVar -from numba_rvsdg.core.datastructures.labels import Label +from numba_rvsdg.core.datastructures.labels import Label, NameGenerator, BlockName from numba_rvsdg.core.utils import _next_inst_offset @dataclass(frozen=True) -class BasicBlock: - label: Label - """The corresponding Label for this block. """ - - _jump_targets: Tuple[Label] = tuple() - """Jump targets (branch destinations) for this block""" - - backedges: Tuple[Label] = tuple() - """Backedges for this block.""" +class Block: + name_gen: InitVar[NameGenerator] + """Block Name Generator associated with this Block. + Note: This is an initialization only argument and not + a class attribute.""" - @property - def is_exiting(self) -> bool: - return not self.jump_targets + label: Label + """The corresponding Label for this block.""" - @property - def fallthrough(self) -> bool: - return len(self._jump_targets) == 1 - @property - def jump_targets(self) -> Tuple[Label]: - acc = [] - for j in self._jump_targets: - if j not in self.backedges: - acc.append(j) - return tuple(acc) +@dataclass(frozen=True) +class BasicBlock(Block): - def replace_backedge(self, target: Label) -> "BasicBlock": - if target in self.jump_targets: - assert not self.backedges - return replace(self, backedges=(target,)) - return self + block_name: BlockName = field(init=False) + """Unique name identifier for this block""" - def replace_jump_targets(self, jump_targets: Tuple) -> "BasicBlock": - return replace(self, _jump_targets=jump_targets) + def __post_init__(self, name_gen: NameGenerator): + block_name = name_gen.new_block_name(label=self.label) + object.__setattr__(self, "block_name", block_name) @dataclass(frozen=True) @@ -83,44 +67,25 @@ class BranchBlock(BasicBlock): variable: str = None branch_value_table: dict = None - def replace_jump_targets(self, jump_targets: Tuple) -> "BasicBlock": - fallthrough = len(jump_targets) == 1 - old_branch_value_table = self.branch_value_table - new_branch_value_table = {} - for target in self.jump_targets: - if target not in jump_targets: - # ASSUMPTION: only one jump_target is being updated - diff = set(jump_targets).difference(self.jump_targets) - assert len(diff) == 1 - new_target = next(iter(diff)) - for k, v in old_branch_value_table.items(): - if v == target: - new_branch_value_table[k] = new_target - else: - # copy all old values - for k, v in old_branch_value_table.items(): - if v == target: - new_branch_value_table[k] = v - - return replace( - self, - _jump_targets=jump_targets, - branch_value_table=new_branch_value_table, - ) +# Maybe we can register new blocks over here instead of static lists +block_types = { + "basic": BasicBlock, + "python_bytecode": PythonBytecodeBlock, + "control_variable": ControlVariableBlock, + "branch": BranchBlock, +} -@dataclass(frozen=True) -class RegionBlock(BasicBlock): - kind: str = None - headers: Dict[Label, BasicBlock] = None - """The header of the region""" - subregion: "BlockMap" = None - """The subgraph excluding the headers - """ - exit: Label = None - """The exit node. - """ - def get_full_graph(self): - graph = ChainMap(self.subregion.graph, self.headers) - return graph +def get_block_class(block_type_string: str): + if block_type_string in block_types: + return block_types[block_type_string] + else: + raise TypeError(f"Block Type {block_type_string} not recognized.") + +def get_block_class_str(basic_block: BasicBlock): + for key, value in block_types.items(): + if isinstance(basic_block, value): + return key + else: + raise TypeError(f"Block Type of {basic_block} not recognized.") diff --git a/numba_rvsdg/core/datastructures/block_map.py b/numba_rvsdg/core/datastructures/block_map.py deleted file mode 100644 index 26daf5c..0000000 --- a/numba_rvsdg/core/datastructures/block_map.py +++ /dev/null @@ -1,379 +0,0 @@ -import dis -import yaml -from textwrap import dedent -from typing import Set, Tuple, Dict, List, Iterator -from dataclasses import dataclass, field - -from numba_rvsdg.core.datastructures.basic_block import ( - BasicBlock, - ControlVariableBlock, - BranchBlock, - RegionBlock, -) -from numba_rvsdg.core.datastructures.labels import ( - Label, - ControlLabelGenerator, - SynthenticAssignment, - SyntheticExit, - SyntheticTail, - SyntheticReturn, - ControlLabel -) - - -@dataclass(frozen=True) -class BlockMap: - """Map of Labels to Blocks.""" - - graph: Dict[Label, BasicBlock] = field(default_factory=dict) - clg: ControlLabelGenerator = field( - default_factory=ControlLabelGenerator, compare=False - ) - - def __getitem__(self, index): - return self.graph[index] - - def __contains__(self, index): - return index in self.graph - - def __iter__(self): - """Graph Iterator""" - # initialise housekeeping datastructures - to_visit, seen = [self.find_head()], [] - while to_visit: - # get the next label on the list - label = to_visit.pop(0) - # if we have visited this, we skip it - if label in seen: - continue - else: - seen.append(label) - # get the corresponding block for the label - block = self[label] - # yield the label, block combo - yield (label, block) - # if this is a region, recursively yield everything from that region - if type(block) == RegionBlock: - for i in block.subregion: - yield i - # finally add any jump_targets to the list of labels to visit - to_visit.extend(block.jump_targets) - - def exclude_blocks(self, exclude_blocks: Set[Label]) -> Iterator[Label]: - """Iterator over all nodes not in exclude_blocks.""" - for block in self.graph: - if block not in exclude_blocks: - yield block - - def find_head(self) -> Label: - """Find the head block of the CFG. - - Assuming the CFG is closed, this will find the block - that no other blocks are pointing to. - - """ - heads = set(self.graph.keys()) - for label in self.graph.keys(): - block = self.graph[label] - for jt in block.jump_targets: - heads.discard(jt) - assert len(heads) == 1 - return next(iter(heads)) - - def compute_scc(self) -> List[Set[Label]]: - """ - Strongly-connected component for detecting loops. - """ - from numba_rvsdg.networkx_vendored.scc import scc - - class GraphWrap: - def __init__(self, graph): - self.graph = graph - - def __getitem__(self, vertex): - out = self.graph[vertex].jump_targets - # Exclude node outside of the subgraph - return [k for k in out if k in self.graph] - - def __iter__(self): - return iter(self.graph.keys()) - - return list(scc(GraphWrap(self.graph))) - - def compute_scc_subgraph(self, subgraph) -> List[Set[Label]]: - """ - Strongly-connected component for detecting loops inside a subgraph. - """ - from numba_rvsdg.networkx_vendored.scc import scc - - class GraphWrap: - def __init__(self, graph, subgraph): - self.graph = graph - self.subgraph = subgraph - - def __getitem__(self, vertex): - out = self.graph[vertex].jump_targets - # Exclude node outside of the subgraph - return [k for k in out if k in subgraph] - - def __iter__(self): - return iter(self.graph.keys()) - - return list(scc(GraphWrap(self.graph, subgraph))) - - def find_headers_and_entries( - self, subgraph: Set[Label] - ) -> Tuple[Set[Label], Set[Label]]: - """Find entries and headers in a given subgraph. - - Entries are blocks outside the subgraph that have an edge pointing to - the subgraph headers. Headers are blocks that are part of the strongly - connected subset and that have incoming edges from outside the - subgraph. Entries point to headers and headers are pointed to by - entries. - - """ - outside: Label - entries: Set[Label] = set() - headers: Set[Label] = set() - - for outside in self.exclude_blocks(subgraph): - nodes_jump_in_loop = subgraph.intersection(self.graph[outside].jump_targets) - headers.update(nodes_jump_in_loop) - if nodes_jump_in_loop: - entries.add(outside) - # If the loop has no headers or entries, the only header is the head of - # the CFG. - if not headers: - headers = {self.find_head()} - return headers, entries - - def find_exiting_and_exits( - self, subgraph: Set[Label] - ) -> Tuple[Set[Label], Set[Label]]: - """Find exiting and exit blocks in a given subgraph. - - Existing blocks are blocks inside the subgraph that have edges to - blocks outside of the subgraph. Exit blocks are blocks outside the - subgraph that have incoming edges from within the subgraph. Exiting - blocks point to exits and exits and pointed to by exiting blocks. - - """ - inside: Label - exiting: Set[Label] = set() - exits: Set[Label] = set() - for inside in subgraph: - # any node inside that points outside the loop - for jt in self.graph[inside].jump_targets: - if jt not in subgraph: - exiting.add(inside) - exits.add(jt) - # any returns - if self.graph[inside].is_exiting: - exiting.add(inside) - return exiting, exits - - def is_reachable_dfs(self, begin: Label, end: Label): # -> TypeGuard: - """Is end reachable from begin.""" - seen = set() - to_vist = list(self.graph[begin].jump_targets) - while True: - if to_vist: - block = to_vist.pop() - else: - return False - - if block in seen: - continue - elif block == end: - return True - elif block not in seen: - seen.add(block) - if block in self.graph: - to_vist.extend(self.graph[block].jump_targets) - - def add_block(self, basicblock: BasicBlock): - self.graph[basicblock.label] = basicblock - - def remove_blocks(self, labels: Set[Label]): - for label in labels: - del self.graph[label] - - def insert_block( - self, new_label: Label, predecessors: Set[Label], successors: Set[Label] - ): - # TODO: needs a diagram and documentaion - # initialize new block - new_block = BasicBlock( - label=new_label, _jump_targets=successors, backedges=set() - ) - # add block to self - self.add_block(new_block) - # Replace any arcs from any of predecessors to any of successors with - # an arc through the inserted block instead. - for label in predecessors: - block = self.graph.pop(label) - jt = list(block.jump_targets) - if successors: - for s in successors: - if s in jt: - if new_label not in jt: - jt[jt.index(s)] = new_label - else: - jt.pop(jt.index(s)) - else: - jt.append(new_label) - self.add_block(block.replace_jump_targets(jump_targets=tuple(jt))) - - def insert_block_and_control_blocks( - self, new_label: Label, predecessors: Set[Label], successors: Set[Label] - ): - # TODO: needs a diagram and documentaion - # name of the variable for this branching assignment - branch_variable = self.clg.new_variable() - # initial value of the assignment - branch_variable_value = 0 - # store for the mapping from variable value to label - branch_value_table = {} - # Replace any arcs from any of predecessors to any of successors with - # an arc through the to be inserted block instead. - for label in predecessors: - block = self.graph[label] - jt = list(block.jump_targets) - # Need to create synthetic assignments for each arc from a - # predecessors to a successor and insert it between the predecessor - # and the newly created block - for s in set(jt).intersection(successors): - synth_assign = SynthenticAssignment(str(self.clg.new_index())) - variable_assignment = {} - variable_assignment[branch_variable] = branch_variable_value - synth_assign_block = ControlVariableBlock( - label=synth_assign, - _jump_targets=(new_label,), - backedges=(), - variable_assignment=variable_assignment, - ) - # add block - self.add_block(synth_assign_block) - # update branching table - branch_value_table[branch_variable_value] = s - # update branching variable - branch_variable_value += 1 - # replace previous successor with synth_assign - jt[jt.index(s)] = synth_assign - # finally, replace the jump_targets - self.add_block( - self.graph.pop(label).replace_jump_targets(jump_targets=tuple(jt)) - ) - # initialize new block, which will hold the branching table - new_block = BranchBlock( - label=new_label, - _jump_targets=tuple(successors), - backedges=set(), - variable=branch_variable, - branch_value_table=branch_value_table, - ) - # add block to self - self.add_block(new_block) - - def join_returns(self): - """Close the CFG. - - A closed CFG is a CFG with a unique entry and exit node that have no - predescessors and no successors respectively. - """ - # for all nodes that contain a return - return_nodes = [node for node in self.graph if self.graph[node].is_exiting] - # close if more than one is found - if len(return_nodes) > 1: - return_solo_label = SyntheticReturn(str(self.clg.new_index())) - self.insert_block(return_solo_label, return_nodes, tuple()) - - def join_tails_and_exits(self, tails: Set[Label], exits: Set[Label]): - if len(tails) == 1 and len(exits) == 1: - # no-op - solo_tail_label = next(iter(tails)) - solo_exit_label = next(iter(exits)) - return solo_tail_label, solo_exit_label - - if len(tails) == 1 and len(exits) == 2: - # join only exits - solo_tail_label = next(iter(tails)) - solo_exit_label = SyntheticExit(str(self.clg.new_index())) - self.insert_block(solo_exit_label, tails, exits) - return solo_tail_label, solo_exit_label - - if len(tails) >= 2 and len(exits) == 1: - # join only tails - solo_tail_label = SyntheticTail(str(self.clg.new_index())) - solo_exit_label = next(iter(exits)) - self.insert_block(solo_tail_label, tails, exits) - return solo_tail_label, solo_exit_label - - if len(tails) >= 2 and len(exits) >= 2: - # join both tails and exits - solo_tail_label = SyntheticTail(str(self.clg.new_index())) - solo_exit_label = SyntheticExit(str(self.clg.new_index())) - self.insert_block(solo_tail_label, tails, exits) - self.insert_block(solo_exit_label, set((solo_tail_label,)), exits) - return solo_tail_label, solo_exit_label - - @staticmethod - def bcmap_from_bytecode(bc: dis.Bytecode): - return {inst.offset: inst for inst in bc} - - @staticmethod - def from_yaml(yaml_string): - data = yaml.safe_load(yaml_string) - return BlockMap.from_dict(data) - - @staticmethod - def from_dict(graph_dict): - block_map_graph = {} - clg = ControlLabelGenerator() - for index, attributes in graph_dict.items(): - jump_targets = attributes["jt"] - backedges = attributes.get("be", ()) - label = ControlLabel(str(clg.new_index())) - block = BasicBlock( - label=label, - backedges=wrap_id(backedges), - _jump_targets=wrap_id(jump_targets), - ) - block_map_graph[label] = block - return BlockMap(block_map_graph, clg=clg) - - def to_yaml(self): - # Convert to yaml - block_map_graph = self.graph - yaml_string = """""" - - for key, value in block_map_graph.items(): - jump_targets = [f"{i.index}" for i in value._jump_targets] - jump_targets = str(jump_targets).replace("\'", "\"") - back_edges = [f"{i.index}" for i in value.backedges] - jump_target_str= f""" - "{str(key.index)}": - jt: {jump_targets}""" - - if back_edges: - back_edges = str(back_edges).replace("\'", "\"") - jump_target_str += f""" - be: {back_edges}""" - yaml_string += dedent(jump_target_str) - - return yaml_string - - def to_dict(self): - block_map_graph = self.graph - graph_dict = {} - for key, value in block_map_graph.items(): - curr_dict = {} - curr_dict["jt"] = [f"{i.index}" for i in value._jump_targets] - if value.backedges: - curr_dict["be"] = [f"{i.index}" for i in value.backedges] - graph_dict[str(key.index)] = curr_dict - return graph_dict - -def wrap_id(indices: Set[Label]): - return tuple([ControlLabel(i) for i in indices]) diff --git a/numba_rvsdg/core/datastructures/byte_flow.py b/numba_rvsdg/core/datastructures/byte_flow.py index b77859c..0fe5ea6 100644 --- a/numba_rvsdg/core/datastructures/byte_flow.py +++ b/numba_rvsdg/core/datastructures/byte_flow.py @@ -1,19 +1,22 @@ import dis -from copy import deepcopy from dataclasses import dataclass -from numba_rvsdg.core.datastructures.block_map import BlockMap -from numba_rvsdg.core.datastructures.basic_block import RegionBlock +from numba_rvsdg.core.datastructures.scfg import SCFG from numba_rvsdg.core.datastructures.flow_info import FlowInfo from numba_rvsdg.core.utils import _logger, _LogWrap -from numba_rvsdg.core.transformations import restructure_loop, restructure_branch +from numba_rvsdg.core.transformations import ( + restructure_loop, + restructure_loop_recursive, + restructure_branch, + join_returns, +) @dataclass(frozen=True) class ByteFlow: bc: dis.Bytecode - bbmap: "BlockMap" + scfg: SCFG @staticmethod def from_bytecode(code) -> "ByteFlow": @@ -21,45 +24,26 @@ def from_bytecode(code) -> "ByteFlow": _logger.debug("Bytecode\n%s", _LogWrap(lambda: bc.dis())) flowinfo = FlowInfo.from_bytecode(bc) - bbmap = flowinfo.build_basicblocks() - return ByteFlow(bc=bc, bbmap=bbmap) + scfg = flowinfo.build_basicblocks() + return ByteFlow(bc=bc, scfg=scfg) def _join_returns(self): - bbmap = deepcopy(self.bbmap) - bbmap.join_returns() - return ByteFlow(bc=self.bc, bbmap=bbmap) + join_returns(self.scfg) def _restructure_loop(self): - bbmap = deepcopy(self.bbmap) - restructure_loop(bbmap) - for region in _iter_subregions(bbmap): - restructure_loop(region.subregion) - return ByteFlow(bc=self.bc, bbmap=bbmap) + restructure_loop_recursive(self.scfg) def _restructure_branch(self): - bbmap = deepcopy(self.bbmap) - restructure_branch(bbmap) - for region in _iter_subregions(bbmap): - restructure_branch(region.subregion) - return ByteFlow(bc=self.bc, bbmap=bbmap) + restructure_branch(self.scfg) def restructure(self): - bbmap = deepcopy(self.bbmap) # close - bbmap.join_returns() + join_returns(self.scfg) # handle loop - restructure_loop(bbmap) - for region in _iter_subregions(bbmap): - restructure_loop(region.subregion) + restructure_loop(self.scfg) # handle branch - restructure_branch(bbmap) - for region in _iter_subregions(bbmap): - restructure_branch(region.subregion) - return ByteFlow(bc=self.bc, bbmap=bbmap) + restructure_branch(self.scfg) - -def _iter_subregions(bbmap: "BlockMap"): - for node in bbmap.graph.values(): - if isinstance(node, RegionBlock): - yield node - yield from _iter_subregions(node.subregion) + @staticmethod + def bcmap_from_bytecode(bc: dis.Bytecode): + return {inst.offset: inst for inst in bc} diff --git a/numba_rvsdg/core/datastructures/flow_info.py b/numba_rvsdg/core/datastructures/flow_info.py index 3caa41b..dbc3cb9 100644 --- a/numba_rvsdg/core/datastructures/flow_info.py +++ b/numba_rvsdg/core/datastructures/flow_info.py @@ -1,13 +1,11 @@ import dis - -from typing import Set, Tuple, Dict, Sequence +from typing import Set, Tuple, Dict, Sequence, List from dataclasses import dataclass, field from numba_rvsdg.core.datastructures.basic_block import PythonBytecodeBlock -from numba_rvsdg.core.datastructures.block_map import BlockMap +from numba_rvsdg.core.datastructures.scfg import SCFG from numba_rvsdg.core.datastructures.labels import ( - Label, - ControlLabelGenerator, + BlockName, PythonBytecodeLabel, ) from numba_rvsdg.core.utils import ( @@ -35,10 +33,6 @@ class FlowInfo: """Offset of the last bytecode instruction. """ - clg: ControlLabelGenerator = field( - default_factory=ControlLabelGenerator, compare=False - ) - def _add_jump_inst(self, offset: int, targets: Sequence[int]): """Add jump instruction to FlowInfo.""" for off in targets: @@ -71,33 +65,36 @@ def from_bytecode(bc: dis.Bytecode) -> "FlowInfo": flowinfo.last_offset = inst.offset return flowinfo - def build_basicblocks(self: "FlowInfo", end_offset=None) -> "BlockMap": + def build_basicblocks(self: "FlowInfo", end_offset=None) -> "SCFG": """ Build a graph of basic-blocks """ offsets = sorted(self.block_offsets) - # enumerate labels - labels = dict( - (offset, PythonBytecodeLabel(self.clg.new_index())) for offset in offsets - ) + scfg = SCFG() + + names = {} if end_offset is None: end_offset = _next_inst_offset(self.last_offset) - bbmap = BlockMap(graph={}, clg=self.clg) + + for begin, end in zip(offsets, [*offsets[1:], end_offset]): + names[begin] = scfg.add_block( + block_type="python_bytecode", + block_label=PythonBytecodeLabel(), + begin=begin, + end=end, + ) + for begin, end in zip(offsets, [*offsets[1:], end_offset]): - label = labels[begin] - targets: Tuple[Label, ...] + targets: List[BlockName] term_offset = _prev_inst_offset(end) if term_offset not in self.jump_insts: # implicit jump - targets = (labels[end],) + targets = [names[end],] else: - targets = tuple(labels[o] for o in self.jump_insts[term_offset]) - block = PythonBytecodeBlock( - label=label, - begin=begin, - end=end, - _jump_targets=targets, - backedges=(), - ) - bbmap.add_block(block) - return bbmap + targets = [names[o] for o in self.jump_insts[term_offset]] + + block_name = names[begin] + scfg.add_connections(block_name, targets) + + scfg.check_graph() + return scfg diff --git a/numba_rvsdg/core/datastructures/labels.py b/numba_rvsdg/core/datastructures/labels.py index 648f941..98c4393 100644 --- a/numba_rvsdg/core/datastructures/labels.py +++ b/numba_rvsdg/core/datastructures/labels.py @@ -1,9 +1,11 @@ from dataclasses import dataclass +from typing import List @dataclass(frozen=True, order=True) class Label: - index: int + info: List[str] = None + """Any Block specific information we want to add can go here""" ... @@ -17,6 +19,11 @@ class ControlLabel(Label): pass +@dataclass(frozen=True, order=True) +class RegionLabel(Label): + pass + + @dataclass(frozen=True, order=True) class SyntheticBranch(ControlLabel): pass @@ -57,17 +64,89 @@ class SynthenticAssignment(ControlLabel): pass -class ControlLabelGenerator: - def __init__(self, index=0, variable=97): - self.index = index - self.variable = variable +@dataclass(frozen=True, order=True) +class LoopRegionLabel(RegionLabel): + pass + + +@dataclass(frozen=True, order=True) +class MetaRegionLabel(RegionLabel): + pass + + +# Maybe we can register new labels over here instead of static lists +label_types = { + "label": Label, + "python_bytecode": PythonBytecodeLabel, + "control": ControlLabel, + "synth_branch": SyntheticBranch, + "synth_tail": SyntheticTail, + "synth_exit": SyntheticExit, + "synth_head": SyntheticHead, + "synth_return": SyntheticReturn, + "synth_latch": SyntheticLatch, + "synth_exit_latch": SyntheticExitingLatch, + "synth_assign": SynthenticAssignment, +} + + +def get_label_class(label_type_string): + if label_type_string in label_types: + return label_types[label_type_string] + else: + raise TypeError(f"Block Type {label_type_string} not recognized.") + + +@dataclass(frozen=True, order=True) +class Name: + name: str + + def __repr__(self): + return self.name + + def __str__(self): + return self.name + + +@dataclass(frozen=True, order=True) +class BlockName(Name): + pass + + +@dataclass(frozen=True, order=True) +class RegionName(Name): + pass - def new_index(self): - ret = self.index - self.index += 1 - return ret - def new_variable(self): - ret = chr(self.variable) - self.variable += 1 - return ret +@dataclass +class NameGenerator: + """Name generator for various element names. + + Attributes + ---------- + + block_index : int + The starting index for blocks + variable_index: int + The starting index for control variables + region_index : int + The starting index for regions + """ + block_index: int = 0 + variable_index: int = 97 # Variables start at lowercase 'a' + region_index: int = 0 + + def new_block_name(self, label: str) -> BlockName: + ret = self.block_index + self.block_index += 1 + return BlockName(str(label).lower().split("(")[0] + "_" + str(ret)) + + def new_region_name(self, label: str) -> RegionName: + ret = self.region_index + self.region_index += 1 + return RegionName(str(label).lower().split("(")[0] + "_" + str(ret)) + + def new_var_name(self) -> str: + variable_name = chr(self.variable_index) + self.variable_index += 1 + return str(variable_name) diff --git a/numba_rvsdg/core/datastructures/region.py b/numba_rvsdg/core/datastructures/region.py new file mode 100644 index 0000000..114802b --- /dev/null +++ b/numba_rvsdg/core/datastructures/region.py @@ -0,0 +1,47 @@ +from dataclasses import dataclass, field +from numba_rvsdg.core.datastructures.basic_block import Block + +from numba_rvsdg.core.datastructures.labels import NameGenerator, RegionName, BlockName + + +@dataclass(frozen=True) +class Region(Block): + region_name: RegionName = field(init=False) + """Unique name identifier for this region""" + + def __post_init__(self, name_gen: NameGenerator): + region_name = name_gen.new_region_name(self.label) + object.__setattr__(self, "region_name", region_name) + + +@dataclass(frozen=True) +class LoopRegion(Region): + header: BlockName + exiting: BlockName + ... + + +@dataclass(frozen=True) +class MetaRegion(Region): + ... + + +# TODO: Register new regions over here +region_types = { + "loop": LoopRegion +} + + +def get_region_class(region_type_string: str): + if region_type_string in region_types: + return region_types[region_type_string] + else: + raise TypeError(f"Region Type {region_type_string} not recognized.") + + +def get_region_class_str(region: Region): + for key, value in region_types.items(): + if isinstance(region, value): + return key + else: + raise TypeError(f"Region Type of {region} not recognized.") diff --git a/numba_rvsdg/core/datastructures/scfg.py b/numba_rvsdg/core/datastructures/scfg.py new file mode 100644 index 0000000..8848754 --- /dev/null +++ b/numba_rvsdg/core/datastructures/scfg.py @@ -0,0 +1,514 @@ +import yaml +import itertools + +from textwrap import dedent +from typing import Set, Tuple, Dict, List, Iterator +from dataclasses import dataclass, field +from collections import deque + +from numba_rvsdg.core.datastructures.basic_block import BasicBlock, get_block_class, get_block_class_str +from numba_rvsdg.core.datastructures.region import MetaRegion, Region, LoopRegion, get_region_class +from numba_rvsdg.core.datastructures.labels import ( + Name, + Label, + BlockName, + NameGenerator, + RegionName, + MetaRegionLabel, LoopRegionLabel, RegionLabel, + get_label_class, +) + + +@dataclass(frozen=True) +class SCFG: + """Maps of BlockNames to respective BasicBlocks. + And stores the jump targets and back edges for + blocks within the graph.""" + + blocks: Dict[BlockName, BasicBlock] = field(default_factory=dict, init=False) + + out_edges: Dict[BlockName, List[BlockName]] = field(default_factory=dict, init=False) + back_edges: set[tuple[BlockName, BlockName]] = field(default_factory=set, init=False) + + regions: Dict[RegionName, Region] = field(default_factory=dict, init=False) + meta_region: RegionName = field(init=False) + region_tree: Dict[RegionName, List[RegionName]] = field(default_factory=dict, init=False) + + name_gen: NameGenerator = field(default_factory=NameGenerator, compare=False, init=False) + + def __post_init__(self): + new_region = MetaRegion(name_gen = self.name_gen, label = MetaRegionLabel()) + region_name = new_region.region_name + self.regions[region_name] = new_region + self.region_tree[region_name] = [] + object.__setattr__(self, "meta_region", region_name) + + def __getitem__(self, index: BlockName) -> BasicBlock: + return self.blocks[index] + + def __contains__(self, index: BlockName) -> bool: + return index in self.blocks + + def __iter__(self): + """Graph Iterator""" + # initialise housekeeping datastructures + to_visit, seen = [self.find_head()], set() + while to_visit: + # get the next name on the list + name = to_visit.pop(0) + # if we have visited this, we skip it + if name in seen: + continue + else: + seen.add(name) + # get the corresponding block for the name + if isinstance(name, RegionName): + name = self.regions[name].header + # yield the name, block combo + yield name + # finally add any out_edges to the list of names to visit + to_visit.extend(self.out_edges[name]) + + def exclude_blocks(self, exclude_blocks: Set[BlockName]) -> Iterator[BlockName]: + """Iterator over all nodes not in exclude_blocks.""" + for block in self.blocks: + if block not in exclude_blocks: + yield block + + def find_head(self) -> BlockName: + """Find the head block of the CFG. + + Assuming the CFG is closed, this will find the block + that no other blocks are pointing to. + + """ + heads = set(self.blocks.keys()) + for name in self.blocks.keys(): + for jt in self.out_edges[name]: + heads.discard(jt) + for _, region in self.regions.items(): + if hasattr(region, "header"): + heads.discard(region.header) + assert len(heads) == 1 + return next(iter(heads)) + + def compute_scc(self) -> List[Set[BlockName]]: + """ + Strongly-connected component for detecting loops. + """ + from numba_rvsdg.networkx_vendored.scc import scc + + out_edges = self.out_edges + + class GraphWrap: + def __init__(self, graph): + self.graph = graph + + def __getitem__(self, vertex): + out = out_edges[vertex] + # Exclude node outside of the subgraph + return [k for k in out if k in self.graph] + + def __iter__(self): + return iter(self.graph.keys()) + + return list(scc(GraphWrap(self.blocks))) + + def compute_scc_subgraph(self, subgraph) -> List[Set[BlockName]]: + """ + Strongly-connected component for detecting loops inside a subgraph. + """ + from numba_rvsdg.networkx_vendored.scc import scc + + scfg = self + + class GraphWrap: + def __init__(self, graph: Dict[BlockName, BasicBlock], subgraph): + self.graph = graph + self.subgraph = subgraph + + def __getitem__(self, vertex): + + out = scfg.get_out_edges(vertex, region_view=False) + # Exclude node outside of the subgraph + return [k for k in out if k in subgraph + and not (vertex, k) in scfg.back_edges] + + def __iter__(self): + return iter(self.graph.keys()) + + return list(scc(GraphWrap(self.blocks, subgraph))) + + def find_headers_and_entries( + self, + subgraph: set[BlockName] + ) -> Tuple[list[BlockName], list[BlockName]]: + """Find entries and headers in a given subgraph. + + Entries are blocks outside the subgraph that have an edge pointing to + the subgraph headers. Headers are blocks that are part of the strongly + connected subset and that have incoming edges from outside the + subgraph. Entries point to headers and headers are pointed to by + entries. + + Parameters + ---------- + subgraph: set of BlockName + The subgraph for which to find the headers and entries + + Returns + ------- + headers: list of BlockName + The headers for this subgraph + entries: + The entries for this subgraph + + Notes + ----- + The returned lists of headers and entries are sorted. + """ + outside: BlockName + entries: set[BlockName] = set() + headers: set[BlockName] = set() + # Iterate over all blocks in the graph, excluding any blocks inside the + # subgraph. + for outside in self.exclude_blocks(subgraph): + # Check if the current block points to any blocks that are inside + # the subgraph. + targets_in_loop = subgraph.intersection(self.get_out_edges(outside, + region_view=False)) + # Record both headers and entries + if targets_in_loop: + headers.update(targets_in_loop) + entries.add(outside) + # If the loop has no headers or entries, the only header is the head of + # the CFG. + if not headers: + headers.add(self.find_head()) + return sorted(headers), sorted(entries) + + def find_exiting_and_exits( + self, subgraph: Set[BlockName] + ) -> Tuple[list[BlockName], list[BlockName]]: + """Find exiting and exit blocks in a given subgraph. + + Exiting blocks are blocks inside the subgraph that have edges to + blocks outside of the subgraph. Exit blocks are blocks outside the + subgraph that have incoming edges from within the subgraph. Exiting + blocks point to exits and exits and pointed to by exiting blocks. + + Parameters + ---------- + subgraph: set of BlockName + The subgraph for which to find the exiting and exit blocks. + + Returns + ------- + exiting: list of BlockName + The exiting blocks for this subgraph + exits: + The exit block for this subgraph + + Notes + ----- + The returned lists of exiting and exit blocks are sorted. + + """ + inside: BlockName + # use sets internally to avoid duplicates + exiting: set[BlockName] = set() + exits: set[BlockName] = set() + for inside in subgraph: + # any node inside that points outside the loop + for out_target in self.out_edges[inside]: + if out_target not in subgraph: + exiting.add(inside) + exits.add(out_target) + # any returns + if self.is_exiting(inside): + exiting.add(inside) + # convert to sorted list before return + return sorted(exiting), sorted(exits) + + def is_reachable_dfs(self, begin: BlockName, end: BlockName): # -> TypeGuard: + """Is end reachable from begin.""" + seen = set() + to_vist = list(self.out_edges[begin]) + while True: + if to_vist: + block = to_vist.pop() + else: + return False + + if block in seen: + continue + elif block == end: + return True + elif block not in seen: + seen.add(block) + if block in self.blocks: + to_vist.extend(self.out_edges[block]) + + def is_exiting(self, block_name: BlockName): + return len(self.out_edges[block_name]) == 0 + + def is_fallthrough(self, block_name: BlockName): + return len(self.out_edges[block_name]) == 1 + + def check_graph(self): + pass + + def insert_block_between( + self, + block_name: BlockName, + predecessors: List[BlockName], + successors: List[BlockName] + ): + # Replace any arcs from any of predecessors to any of successors with + # an arc through the inserted block instead. + for pred_name in predecessors: + # For every predecessor + # Add the inserted block as out edge + for idx, _out in enumerate(self.out_edges[pred_name]): + if _out in successors: + self.out_edges[pred_name][idx] = block_name + + if block_name not in self.out_edges[pred_name]: + self.out_edges[pred_name].append(block_name) + + self.out_edges[pred_name] = list(dict.fromkeys(self.out_edges[pred_name])) + + for success_name in successors: + # For every sucessor + # For inserted block, the sucessor in an out-edge + self.out_edges[block_name].append(success_name) + + self.check_graph() + + def add_block( + self, block_type: str = "basic", block_label: Label = Label(), **block_args + ) -> BlockName: + block_type = get_block_class(block_type) + new_block: BasicBlock = block_type(**block_args, label=block_label, name_gen=self.name_gen) + + name = new_block.block_name + self.blocks[name] = new_block + self.out_edges[name] = [] + + return name + + def add_region(self, kind: str, header: BlockName, exiting: BlockName, parent: Region = None, region_label = RegionLabel()): + if parent is None: + parent = self.meta_region + + region_type = get_region_class(kind) + new_region: Region = region_type(name_gen=self.name_gen, label=region_label, header=header, exiting=exiting) + region_name = new_region.region_name + self.regions[region_name] = new_region + self.region_tree[region_name] = [] + + self.region_tree[parent].append(region_name) + + for block, out_edges in self.out_edges.items(): + for idx, edge in enumerate(out_edges): + if edge == header and block is not exiting: + self.out_edges[block][idx] = region_name + + return region_name + + def add_connections(self, block_name, out_edges): + assert self.out_edges[block_name] == [] + self.out_edges[block_name] = out_edges + self.check_graph() + + + @staticmethod + def from_yaml(yaml_string): + data = yaml.safe_load(yaml_string) + return SCFG.from_dict(data) + + @staticmethod + def from_dict(graph_dict: Dict[str, Dict]): + scfg = SCFG() + ref_dict = {} + + for block_ref, block_attrs in graph_dict.items(): + block_class = block_attrs["type"] + block_args = block_attrs.get("block_args", {}) + label_class = get_label_class(block_attrs.get("label_type", "label")) + label_info = block_attrs.get("label_info", None) + block_label = label_class(label_info) + block_name = scfg.add_block(block_class, block_label, **block_args) + ref_dict[block_ref] = block_name + + for block_ref, block_attrs in graph_dict.items(): + out_refs = block_attrs.get("out", list()) + back_refs = block_attrs.get("back", list()) + + block_name = ref_dict[block_ref] + out_edges = list(ref_dict[out_ref] for out_ref in out_refs) + scfg.add_connections(block_name, out_edges) + for _back in back_refs: + scfg.back_edges.add((ref_dict[block_ref], ref_dict[_back])) + + + scfg.check_graph() + return scfg, ref_dict + + def to_yaml(self): + # Convert to yaml + yaml_string = """""" + + for key, value in self.blocks.items(): + out_edges = [] + back_edges = [] + for out_edge in self.out_edges[key]: + out_edges.append(f"{out_edge}") + if (key, out_edge) in self.back_edges: + back_edges.append(f"{out_edge}") + out_edges = str(out_edges).replace("'", '"') + jump_target_str = f""" + "{str(key)}": + type: "{get_block_class_str(value)}" + out: {out_edges}""" + + if back_edges: + back_edges = str(back_edges).replace("'", '"') + jump_target_str += f""" + back: {back_edges}""" + yaml_string += dedent(jump_target_str) + + return yaml_string + + def to_dict(self): + graph_dict = {} + for key, value in self.blocks.items(): + curr_dict = {} + curr_dict["type"] = get_block_class_str(value) + curr_dict["out"] = [] + back_edges = [] + for out_edge in self.out_edges[key]: + curr_dict["out"].append(f"{out_edge}") + if (key, out_edge) in self.back_edges: + back_edges.append(f"{out_edge}") + if back_edges: + curr_dict["back"] = back_edges + graph_dict[str(key)] = curr_dict + + return graph_dict + + def blocks_in_region(self, region_name: RegionName): + """Generator for all blocks in a given region. + + Parameters + ---------- + region_name: RegionName + + Returns + ------- + gen: generator of BlockName + A generator that yields block names + """ + # get correct region from, scfg + region = self.regions[region_name] + # initialize housekeeping datastructures + to_visit, seen = [region.header], set() + while to_visit: + # get the next block_name on the list + name = to_visit.pop(0) + # if we have visited this, we skip it + if name in seen: + continue + # otherwise add it to the list of seen names + else: + seen.add(name) + # unless this is the exiting block of the region + if name is not region.exiting: + out_edges = [e for e in + self.get_out_edges(name, region_view=False) + if not isinstance(e, RegionName) + ] + to_visit.extend(out_edges) + yield name + + def get_out_edges(self, name: Name, region_view: bool = True): + if region_view: + if isinstance(name, RegionName): + if name is self.meta_region: + raise ValueError("Meta Region encompasses all the graph, it cannot have out edges") + name = self.regions[name].exiting + else: + if isinstance(name, RegionName): + if name is self.meta_region: + raise ValueError("Meta Region encompasses all the graph, it cannot have out edges") + name = self.regions[name].header + + return self.out_edges[name] + + def region_iterator(self) -> Iterator[RegionName]: + """A region iterator that is aware of new regions being added. + + This iterator returns `RegionName` from the `regions` dictionary. + Importantly, this iterator has an internal queue that is updated with + new keys from the `regions` dictionary on every iteration. This allows + the `regions` dictionary to be updated while this iterator is running. + + Returns + ------- + iterator: Iterator[ReionName] + the regions of this scfg + + """ + # initialize housekeeping datastructures + queue, seen = deque(), set() + while True: + # extend the queue with any un-seen items + queue.extend([r for r in self.regions if r not in seen]) + # If the queue is now empty exit + if not queue: + break + # Otherwise... + else: + # get the next region name + r = queue.popleft() + if r in seen: + continue + # add it to the set of seen regions + seen.add(r) + # finally, yield the region name + yield r + + + def iterate_region(self, region_name, region_view=False): + region = self.regions[region_name] + """Region Iterator""" + region_head = region.header if region_name is not self.meta_region else self.find_head() + + # initialise housekeeping datastructures + to_visit, seen = [region_head], set() + while to_visit: + # get the next block_name on the list + block_name = to_visit.pop(0) + # if we have visited this, we skip it + if block_name in seen: + continue + else: + seen.add(block_name) + # yield the block_name + yield block_name + if region_name is not self.meta_region and block_name is region.exiting: + continue + # finally add any out_edges to the list of block_names to visit + if isinstance(block_name, RegionName): + outs = self.out_edges[self.regions[block_name].exiting] + else: + outs = self.out_edges[block_name] + + if not region_view: + for idx, _out in enumerate(outs): + if isinstance(_out, RegionName): + to_visit.append(self.regions[_out].header) + else: + to_visit.append(_out) + else: + to_visit.extend(outs) diff --git a/numba_rvsdg/core/transformations.py b/numba_rvsdg/core/transformations.py index 9f13e29..bb28e8b 100644 --- a/numba_rvsdg/core/transformations.py +++ b/numba_rvsdg/core/transformations.py @@ -7,93 +7,121 @@ SyntheticHead, SyntheticExitingLatch, SyntheticExit, + SyntheticReturn, + SyntheticTail, SynthenticAssignment, PythonBytecodeLabel, + BlockName, + RegionLabel ) -from numba_rvsdg.core.datastructures.block_map import BlockMap +from numba_rvsdg.core.datastructures.scfg import SCFG from numba_rvsdg.core.datastructures.basic_block import ( BasicBlock, ControlVariableBlock, BranchBlock, - RegionBlock, ) from numba_rvsdg.core.utils import _logger -def loop_restructure_helper(bbmap: BlockMap, loop: Set[Label]): +def loop_restructure_helper(scfg: SCFG, loop: Set[BlockName]): """Loop Restructuring Applies the algorithm LOOP RESTRUCTURING from section 4.1 of Bahmann2015. - Note that this will modify both the `bbmap` and the `loop` in-place. + Note that this will modify both the `scfg` and the `loop` in-place. Parameters ---------- - bbmap: BlockMap - The BlockMap containing the loop - loop: Set[Label] + scfg: SCFG + The SCFG containing the loop + loop: List[BlockName] The loop (strongly connected components) that is to be restructured """ - headers, entries = bbmap.find_headers_and_entries(loop) - exiting_blocks, exit_blocks = bbmap.find_exiting_and_exits(loop) - #assert len(entries) == 1 + headers, entries = scfg.find_headers_and_entries(loop) + exiting_blocks, exit_blocks = scfg.find_exiting_and_exits(loop) headers_were_unified = False # If there are multiple headers, insert assignment and control blocks, # such that only a single loop header remains. if len(headers) > 1: headers_were_unified = True - solo_head_label = SyntheticHead(str(bbmap.clg.new_index())) - bbmap.insert_block_and_control_blocks(solo_head_label, entries, headers) - loop.add(solo_head_label) - loop_head: Label = solo_head_label + solo_head_label = SyntheticHead() + loop_head: BlockName = insert_block_and_control_blocks( + scfg, + entries, + headers, + block_label=solo_head_label) + loop.add(loop_head) else: - loop_head: Label = next(iter(headers)) + loop_head: BlockName = headers[0] # If there is only a single exiting latch (an exiting block that also has a # backedge to the loop header) we can exit early, since the condition for # SCFG is fullfilled. backedge_blocks = [ - block for block in loop if headers.intersection(bbmap[block].jump_targets) + block for block in loop + if set(headers).intersection(scfg.out_edges[block]) ] - if (len(backedge_blocks) == 1 and len(exiting_blocks) == 1 - and backedge_blocks[0] == next(iter(exiting_blocks))): - bbmap.add_block(bbmap.graph.pop(backedge_blocks[0]).replace_backedge(loop_head)) + if ( + len(backedge_blocks) == 1 + and len(exiting_blocks) == 1 + and backedge_blocks[0] == exiting_blocks[0] + ): + scfg.back_edges.add((backedge_blocks[0], loop_head)) return + doms = _doms(scfg) # The synthetic exiting latch and synthetic exit need to be created # based on the state of the cfg. If there are multiple exits, we need a # SyntheticExit, otherwise we only need a SyntheticExitingLatch - synth_exiting_latch = SyntheticExitingLatch(str(bbmap.clg.new_index())) + # Set a flag, this will determine the variable assignment and block # insertion later on needs_synth_exit = len(exit_blocks) > 1 - if needs_synth_exit: - synth_exit = SyntheticExit(str(bbmap.clg.new_index())) # This sets up the various control variables. # If there were multiple headers, we must re-use the variable that was used # for looping as the exit variable if headers_were_unified: - exit_variable = bbmap[solo_head_label].variable + exit_variable = scfg[loop_head].variable else: - exit_variable = bbmap.clg.new_variable() - # This variable denotes the backedge - backedge_variable = bbmap.clg.new_variable() + exit_variable = scfg.name_gen.new_var_name() + + exit_value_table = dict(((i, j) for i, j in enumerate(exit_blocks))) + if needs_synth_exit: + synth_exit_label = SyntheticExit() + synth_exit = scfg.add_block( + "branch", + block_label=synth_exit_label, + variable=exit_variable, + branch_value_table=exit_value_table) + # Now we setup the lookup tables for the various control variables, # depending on the state of the CFG and what is needed - exit_value_table = dict(((i, j) for i, j in enumerate(exit_blocks))) if needs_synth_exit: - backedge_value_table = dict((i, j) for i, j in enumerate((loop_head, synth_exit))) + backedge_value_table = dict( + (i, j) for i, j in enumerate((loop_head, synth_exit)) + ) else: - backedge_value_table = dict((i, j) for i, j in enumerate((loop_head, next(iter(exit_blocks))))) + backedge_value_table = dict( + (i, j) for i, j in enumerate((loop_head, exit_blocks[0])) + ) if headers_were_unified: - header_value_table = bbmap[solo_head_label].branch_value_table + header_value_table = scfg[loop_head].branch_value_table else: header_value_table = {} + synth_latch_label = SyntheticExitingLatch() + # This variable denotes the backedge + backedge_variable = scfg.name_gen.new_var_name() + synth_exiting_latch = scfg.add_block( + "branch", + block_label=synth_latch_label, + variable=backedge_variable, + branch_value_table=backedge_value_table) + # This does a dictionary reverse lookup, to determine the key for a given # value. def reverse_lookup(d, value): @@ -105,132 +133,133 @@ def reverse_lookup(d, value): # Now that everything is in place, we can start to insert blocks, depending # on what is needed - # All new blocks are recorded for later insertion into the loop set - new_blocks = set() - doms = _doms(bbmap) + # For every block in the loop: - for label in sorted(loop, key=lambda x: x.index): + for _name in sorted(loop): # If the block is an exiting block or a backedge block - if label in exiting_blocks or label in backedge_blocks: - # Copy the jump targets, these will be modified - new_jt = list(bbmap[label].jump_targets) - # For each jump_target in the blockj - for jt in bbmap[label].jump_targets: + if _name in exiting_blocks or _name in backedge_blocks: + # For each jump_target in the block + for out_target in scfg.out_edges[_name]: # If the target is an exit block - if jt in exit_blocks: - # Create a new assignment label and record it - synth_assign = SynthenticAssignment(str(bbmap.clg.new_index())) - new_blocks.add(synth_assign) + if out_target in exit_blocks: + # Create a new assignment name and record it + synth_assign = SynthenticAssignment() + # Setup the table for the variable assignment variable_assignment = {} # Setup the variables in the assignment table to point to # the correct blocks if needs_synth_exit: - variable_assignment[exit_variable] = reverse_lookup(exit_value_table, jt) - variable_assignment[backedge_variable] = reverse_lookup(backedge_value_table, - synth_exit if needs_synth_exit else next(iter(exit_blocks))) - # Create the actual control variable block - synth_assign_block = ControlVariableBlock( - label=synth_assign, - _jump_targets=(synth_exiting_latch,), - backedges=(), - variable_assignment=variable_assignment, + variable_assignment[exit_variable] = reverse_lookup( + exit_value_table, out_target + ) + variable_assignment[backedge_variable] = reverse_lookup( + backedge_value_table, + synth_exit if needs_synth_exit else exit_blocks[0] ) # Insert the assignment to the block map - bbmap.add_block(synth_assign_block) - # Insert the new block into the new jump_targets making - # sure, that it replaces the correct jump_target, order - # matters in this case. - new_jt[new_jt.index(jt)] = synth_assign + synth_assign = scfg.add_block( + "control_variable", + synth_assign, + variable_assignment=variable_assignment) + scfg.add_connections( + synth_assign, + [synth_exiting_latch]) + loop.add(synth_assign) + # Update the edge from the out_target to point to the new + # assignment block + out_edges = scfg.out_edges[_name] + out_edges[out_edges.index(out_target)] = synth_assign # If the target is the loop_head - elif jt in headers and label not in doms[jt]: + elif out_target in headers and _name not in doms[out_target]: # Create the assignment and record it - synth_assign = SynthenticAssignment(str(bbmap.clg.new_index())) - new_blocks.add(synth_assign) + synth_assign = SynthenticAssignment() # Setup the variables in the assignment table to point to # the correct blocks variable_assignment = {} - variable_assignment[backedge_variable] = reverse_lookup(backedge_value_table, loop_head) - if needs_synth_exit: - variable_assignment[exit_variable] = reverse_lookup(header_value_table, jt) - # Update the backedge block - remove any existing backedges - # that point to the headers, no need to add a backedge, - # since it will be contained in the SyntheticExitingLatch - # later on. - block = bbmap.graph.pop(label) - jts = list(block.jump_targets) - for h in headers: - if h in jts: - jts.remove(h) - bbmap.add_block(block.replace_jump_targets(jump_targets=tuple(jts))) - # Setup the assignment block and initialize it with the - # correct jump_targets and variable assignment. - synth_assign_block = ControlVariableBlock( - label=synth_assign, - _jump_targets=(synth_exiting_latch,), - backedges=(), - variable_assignment=variable_assignment, + variable_assignment[backedge_variable] = reverse_lookup( + backedge_value_table, loop_head ) - # Add the new block to the BlockMap - bbmap.add_block(synth_assign_block) - # Update the jump targets again, order matters - new_jt[new_jt.index(jt)] = synth_assign - # finally, replace the jump_targets for this block with the new ones - bbmap.add_block( - bbmap.graph.pop(label).replace_jump_targets(jump_targets=tuple(new_jt)) - ) - # Add any new blocks to the loop. - loop.update(new_blocks) - - # Insert the exiting latch, add it to the loop and to the graph. - synth_exiting_latch_block = BranchBlock( - label=synth_exiting_latch, - _jump_targets=(synth_exit if needs_synth_exit else next(iter(exit_blocks)), loop_head), - backedges=(loop_head,), - variable=backedge_variable, - branch_value_table=backedge_value_table, - ) + if needs_synth_exit: + variable_assignment[exit_variable] = reverse_lookup( + header_value_table, out_target + ) + synth_assign = scfg.add_block( + "control_variable", + synth_assign, + variable_assignment=variable_assignment) + scfg.add_connections( + synth_assign, + [synth_exiting_latch]) + loop.add(synth_assign) + + # Update the edge from the out_target to point to the new + # assignment block + out_edges = scfg.out_edges[_name] + out_edges[out_edges.index(out_target)] = synth_assign + + # Finally, add the synthetic exiting latch to loop loop.add(synth_exiting_latch) - bbmap.add_block(synth_exiting_latch_block) - # If an exit is to be created, we do so too, but only add it to the bbmap, + + # Add the back_edge + scfg.out_edges[synth_exiting_latch].append(loop_head) + scfg.back_edges.add((synth_exiting_latch, loop_head)) + + # If an exit is to be created, we do so too, but only add it to the scfg, # since it isn't part of the loop if needs_synth_exit: - synth_exit_block = BranchBlock( - label=synth_exit, - _jump_targets=tuple(exit_blocks), - backedges=(), - variable=exit_variable, - branch_value_table=exit_value_table, - ) - bbmap.add_block(synth_exit_block) + scfg.insert_block_between( + synth_exit, + [synth_exiting_latch], + exit_blocks) + else: + scfg.out_edges[synth_exiting_latch].append(exit_blocks[0]) -def restructure_loop(bbmap: BlockMap): +def restructure_loop(scfg: SCFG, subgraph: list[BlockName] = None): """Inplace restructuring of the given graph to extract loops using strongly-connected components """ # obtain a List of Sets of Labels, where all labels in each set are strongly # connected, i.e. all reachable from one another by traversing the subset - scc: List[Set[Label]] = bbmap.compute_scc() + if subgraph: + scc: List[Set[SCFG]] = scfg.compute_scc_subgraph(subgraph) + else: + scc: List[Set[SCFG]] = scfg.compute_scc() # loops are defined as strongly connected subsets who have more than a # single label and single label loops that point back to to themselves. - loops: List[Set[Label]] = [ + loops: List[Set[SCFG]] = [ nodes for nodes in scc - if len(nodes) > 1 or next(iter(nodes)) in bbmap[next(iter(nodes))].jump_targets + if len(nodes) > 1 or next(iter(nodes)) in scfg.out_edges[next(iter(nodes))] ] _logger.debug( - "restructure_loop found %d loops in %s", len(loops), bbmap.graph.keys() + "restructure_loop found %d loops in %s", len(loops), scfg.blocks.keys() ) # rotate and extract loop - for loop in loops: - loop_restructure_helper(bbmap, loop) - extract_region(bbmap, loop, "loop") + for l in loops: + loop_restructure_helper(scfg, l) + extract_region(scfg, l, "loop") + + +def restructure_loop_recursive(scfg: SCFG): + # find all top-level loops + restructure_loop(scfg) + # if any loops were found, continue down into regions + for region_name in scfg.region_iterator(): + if region_name is scfg.meta_region: + # skip the meta region, it doesn't have header or exiting + continue + region_blocks = list(scfg.blocks_in_region(region_name)) + #region_blocks = list(scfg.iterate_region(region_name)) + #breakpoint() + breakpoint() + restructure_loop(scfg, subgraph=region_blocks) -def find_head_blocks(bbmap: BlockMap, begin: Label) -> Set[Label]: - head = bbmap.find_head() +def find_head_blocks(scfg: SCFG, begin: BlockName) -> Set[BlockName]: + head = scfg.find_head() head_region_blocks = set() current_block = head # Start at the head block and traverse the graph linearly until @@ -240,23 +269,24 @@ def find_head_blocks(bbmap: BlockMap, begin: Label) -> Set[Label]: if current_block == begin: break else: - jt = bbmap.graph[current_block].jump_targets - assert len(jt) == 1 - current_block = next(iter(jt)) + out_targets = scfg.out_edges[current_block] + assert len(out_targets) == 1 + current_block = out_targets[0] return head_region_blocks -def find_branch_regions(bbmap: BlockMap, begin: Label, end: Label) -> Set[Label]: +def find_branch_regions(scfg: SCFG, begin: BlockName, end: BlockName) -> Set[BlockName]: # identify branch regions - doms = _doms(bbmap) - postdoms = _post_doms(bbmap) + doms = _doms(scfg) + postdoms = _post_doms(scfg) postimmdoms = _imm_doms(postdoms) immdoms = _imm_doms(doms) branch_regions = [] - jump_targets = bbmap.graph[begin].jump_targets - for bra_start in jump_targets: - for jt in jump_targets: - if jt != bra_start and bbmap.is_reachable_dfs(jt, bra_start): + out_targets = scfg.out_edges[begin] + for bra_start in out_targets: + for out_target in out_targets: + if (out_target != bra_start + and scfg.is_reachable_dfs(out_target, bra_start)): branch_regions.append(tuple()) break else: @@ -271,19 +301,17 @@ def find_branch_regions(bbmap: BlockMap, begin: Label, end: Label) -> Set[Label] return branch_regions -def _find_branch_regions(bbmap: BlockMap, begin: Label, end: Label) -> Set[Label]: +def _find_branch_regions(scfg: SCFG, begin: BlockName, end: BlockName) -> Set[BlockName]: # identify branch regions branch_regions = [] - for bra_start in bbmap[begin].jump_targets: + for bra_start in scfg.out_edges[begin]: region = [] region.append(bra_start) return branch_regions -def find_tail_blocks( - bbmap: BlockMap, begin: Set[Label], head_region_blocks, branch_regions -): - tail_subregion = set((b for b in bbmap.graph.keys())) +def find_tail_blocks(scfg: SCFG, begin: Set[BlockName], head_region_blocks, branch_regions): + tail_subregion = set((b for b in scfg.blocks.keys())) tail_subregion.difference_update(head_region_blocks) for reg in branch_regions: if not reg: @@ -297,43 +325,24 @@ def find_tail_blocks( return tail_subregion -def extract_region(bbmap, region_blocks, region_kind): - headers, entries = bbmap.find_headers_and_entries(region_blocks) - exiting_blocks, exit_blocks = bbmap.find_exiting_and_exits(region_blocks) +def extract_region(scfg: SCFG, region_blocks, region_kind, region_label = RegionLabel()): + headers, entries = scfg.find_headers_and_entries(region_blocks) + exiting_blocks, exit_blocks = scfg.find_exiting_and_exits(region_blocks) assert len(headers) == 1 assert len(exiting_blocks) == 1 - region_header = next(iter(headers)) - region_exiting = next(iter(exiting_blocks)) + region_header = headers[0] + region_exiting = exiting_blocks[0] - head_subgraph = BlockMap( - {label: bbmap.graph[label] for label in region_blocks}, clg=bbmap.clg - ) - - if isinstance(bbmap[region_exiting], RegionBlock): - region_exit = bbmap[region_exiting].exit - else: - region_exit = region_exiting - - subregion = RegionBlock( - label=region_header, - _jump_targets=bbmap[region_exiting].jump_targets, - backedges=(), - kind=region_kind, - headers=headers, - subregion=head_subgraph, - exit=region_exit, - ) - bbmap.remove_blocks(region_blocks) - bbmap.graph[region_header] = subregion + scfg.add_region(region_kind, region_label=region_label, header = region_header, exiting = region_exiting) -def restructure_branch(bbmap: BlockMap): - print("restructure_branch", bbmap.graph) - doms = _doms(bbmap) - postdoms = _post_doms(bbmap) +def restructure_branch(scfg: SCFG): + print("restructure_branch", scfg.blocks) + doms = _doms(scfg) + postdoms = _post_doms(scfg) postimmdoms = _imm_doms(postdoms) immdoms = _imm_doms(doms) - regions = [r for r in _iter_branch_regions(bbmap, immdoms, postimmdoms)] + regions = [r for r in _iter_branch_regions(scfg, immdoms, postimmdoms)] # Early exit when no branching regions are found. # TODO: the whole graph should become a linear mono head @@ -342,23 +351,23 @@ def restructure_branch(bbmap: BlockMap): # Compute initial regions. begin, end = regions[0] - head_region_blocks = find_head_blocks(bbmap, begin) - branch_regions = find_branch_regions(bbmap, begin, end) + head_region_blocks = find_head_blocks(scfg, begin) + branch_regions = find_branch_regions(scfg, begin, end) tail_region_blocks = find_tail_blocks( - bbmap, begin, head_region_blocks, branch_regions + scfg, begin, head_region_blocks, branch_regions ) # Unify headers of tail subregion if need be. - headers, entries = bbmap.find_headers_and_entries(tail_region_blocks) + headers, entries = scfg.find_headers_and_entries(tail_region_blocks) if len(headers) > 1: - end = SyntheticHead(bbmap.clg.new_index()) - bbmap.insert_block_and_control_blocks(end, entries, headers) + end = SyntheticHead() + insert_block_and_control_blocks(scfg, entries, headers, end) # Recompute regions. - head_region_blocks = find_head_blocks(bbmap, begin) - branch_regions = find_branch_regions(bbmap, begin, end) + head_region_blocks = find_head_blocks(scfg, begin) + branch_regions = find_branch_regions(scfg, begin, end) tail_region_blocks = find_tail_blocks( - bbmap, begin, head_region_blocks, branch_regions + scfg, begin, head_region_blocks, branch_regions ) # Branch region processing: @@ -369,38 +378,39 @@ def restructure_branch(bbmap: BlockMap): bra_start, inner_nodes = region if inner_nodes: # Insert SyntheticTail - exiting_blocks, _ = bbmap.find_exiting_and_exits(inner_nodes) - tail_headers, _ = bbmap.find_headers_and_entries(tail_region_blocks) - _, _ = bbmap.join_tails_and_exits(exiting_blocks, tail_headers) + exiting_blocks, _ = scfg.find_exiting_and_exits(inner_nodes) + tail_headers, _ = scfg.find_headers_and_entries(tail_region_blocks) + _, _ = join_tails_and_exits(scfg, exiting_blocks, tail_headers) else: # Insert SyntheticBranch - tail_headers, _ = bbmap.find_headers_and_entries(tail_region_blocks) - synthetic_branch_block_label = SyntheticBranch(str(bbmap.clg.new_index())) - bbmap.insert_block(synthetic_branch_block_label, (begin,), tail_headers) + tail_headers, _ = scfg.find_headers_and_entries(tail_region_blocks) + synthetic_branch_block_label = SyntheticBranch() + scfg.add_block(block_label=synthetic_branch_block_label) + scfg.insert_block_between(synthetic_branch_block_label, (begin,), tail_headers) # Recompute regions. - head_region_blocks = find_head_blocks(bbmap, begin) - branch_regions = find_branch_regions(bbmap, begin, end) + head_region_blocks = find_head_blocks(scfg, begin) + branch_regions = find_branch_regions(scfg, begin, end) tail_region_blocks = find_tail_blocks( - bbmap, begin, head_region_blocks, branch_regions + scfg, begin, head_region_blocks, branch_regions ) # extract subregions - extract_region(bbmap, head_region_blocks, "head") + extract_region(scfg, head_region_blocks, "head") for region in branch_regions: if region: bra_start, inner_nodes = region if inner_nodes: - extract_region(bbmap, inner_nodes, "branch") - extract_region(bbmap, tail_region_blocks, "tail") + extract_region(scfg, inner_nodes, "branch") + extract_region(scfg, tail_region_blocks, "tail") def _iter_branch_regions( - bbmap: BlockMap, immdoms: Dict[Label, Label], postimmdoms: Dict[Label, Label] + scfg: SCFG, immdoms: Dict[BlockName, BlockName], postimmdoms: Dict[BlockName, BlockName] ): - for begin, node in [i for i in bbmap.graph.items()]: - if len(node.jump_targets) > 1: + for begin, node in [i for i in scfg.blocks.items()]: + if len(scfg.out_edges[begin]) > 1: # found branch if begin in postimmdoms: end = postimmdoms[begin] @@ -408,7 +418,7 @@ def _iter_branch_regions( yield begin, end -def _imm_doms(doms: Dict[Label, Set[Label]]) -> Dict[Label, Label]: +def _imm_doms(doms: Dict[BlockName, Set[BlockName]]) -> Dict[BlockName, BlockName]: idoms = {k: v - {k} for k, v in doms.items()} changed = True while changed: @@ -428,48 +438,47 @@ def _imm_doms(doms: Dict[Label, Set[Label]]) -> Dict[Label, Label]: return out -def _doms(bbmap: BlockMap): +def _doms(scfg: SCFG): # compute dom entries = set() preds_table = defaultdict(set) succs_table = defaultdict(set) node: BasicBlock - for src, node in bbmap.graph.items(): - for dst in node.jump_targets: + for src, node in scfg.blocks.items(): + for dst in scfg.out_edges[src]: # check dst is in subgraph - if dst in bbmap.graph: + if dst in scfg.blocks: preds_table[dst].add(src) succs_table[src].add(dst) - for k in bbmap.graph: + for k in scfg.blocks: if not preds_table[k]: entries.add(k) return _find_dominators_internal( - entries, list(bbmap.graph.keys()), preds_table, succs_table + entries, list(scfg.blocks.keys()), preds_table, succs_table ) -def _post_doms(bbmap: BlockMap): +def _post_doms(scfg: SCFG): # compute post dom entries = set() - for k, v in bbmap.graph.items(): - targets = set(v.jump_targets) & set(bbmap.graph) + for k in scfg.blocks.keys(): + targets = set(scfg.out_edges[k]) & set(scfg.blocks) if not targets: entries.add(k) preds_table = defaultdict(set) succs_table = defaultdict(set) - node: BasicBlock - for src, node in bbmap.graph.items(): - for dst in node.jump_targets: + for src in scfg.blocks.keys(): + for dst in scfg.out_edges[src]: # check dst is in subgraph - if dst in bbmap.graph: + if dst in scfg.blocks: preds_table[src].add(dst) succs_table[dst].add(src) return _find_dominators_internal( - entries, list(bbmap.graph.keys()), preds_table, succs_table + entries, list(scfg.blocks.keys()), preds_table, succs_table ) @@ -481,13 +490,13 @@ def _find_dominators_internal(entries, nodes, preds_table, succs_table): # in http://pages.cs.wisc.edu/~fischer/cs701.f08/finding.loops.html # if post: - # entries = set(self._exit_points) - # preds_table = self._succs - # succs_table = self._preds + # entries = set(scfg._exit_points) + # preds_table = scfg._succs + # succs_table = scfg._preds # else: - # entries = set([self._entry_point]) - # preds_table = self._preds - # succs_table = self._succs + # entries = set([scfg._entry_point]) + # preds_table = scfg._preds + # succs_table = scfg._succs import functools @@ -517,3 +526,110 @@ def _find_dominators_internal(entries, nodes, preds_table, succs_table): doms[n] = new_doms todo.extend(succs_table[n]) return doms + + +def insert_block_and_control_blocks( + scfg: SCFG, + predecessors: List[BlockName], + successors: List[BlockName], + block_label: Label = Label() +) -> BlockName: + # TODO: needs a diagram and documentaion + # name of the variable for this branching assignment + branch_variable = scfg.name_gen.new_var_name() + # initial value of the assignment + branch_variable_value = 0 + # store for the mapping from variable value to blockname + branch_value_table = {} + # initialize new block, which will hold the branching table + branch_block_name = scfg.add_block( + "branch", + block_label, + variable=branch_variable, + branch_value_table=branch_value_table + ) + + control_blocks = {} + # Replace any arcs from any of predecessors to any of successors with + # an arc through the to be inserted block instead. + + for pred_name in predecessors: + pred_outs = scfg.out_edges[pred_name] + # Need to create synthetic assignments for each arc from a + # predecessors to a successor and insert it between the predecessor + # and the newly created block + for s in successors: + if s in pred_outs: + synth_assign = SynthenticAssignment() + variable_assignment = {} + variable_assignment[branch_variable] = branch_variable_value + + # add block + control_block_name = scfg.add_block( + "control_variable", + synth_assign, + variable_assignment=variable_assignment, + ) + # update branching table + branch_value_table[branch_variable_value] = s + # update branching variable + branch_variable_value += 1 + control_blocks[control_block_name] = pred_name + + scfg.insert_block_between(branch_block_name, predecessors, successors) + + for _synth_assign, _pred in control_blocks.items(): + scfg.insert_block_between(_synth_assign, [_pred], [branch_block_name]) + + return branch_block_name + + +def join_returns(scfg: SCFG): + """Close the CFG. + + A closed CFG is a CFG with a unique entry and exit node that have no + predescessors and no successors respectively. + """ + # for all nodes that contain a return + return_nodes = [node for node in scfg.blocks.keys() if scfg.is_exiting(node)] + # close if more than one is found + if len(return_nodes) > 1: + return_solo_label = SyntheticReturn() + new_block = scfg.add_block(block_label=return_solo_label) + scfg.insert_block_between(new_block, return_nodes, []) + + +def join_tails_and_exits(scfg: SCFG, tails: Set[BlockName], exits: Set[BlockName]): + if len(tails) == 1 and len(exits) == 1: + # no-op + solo_tail_name = next(iter(tails)) + solo_exit_name = next(iter(exits)) + return solo_tail_name, solo_exit_name + + if len(tails) == 1 and len(exits) == 2: + # join only exits + solo_tail_name = next(iter(tails)) + solo_exit_label = SyntheticExit() + solo_exit_name = scfg.add_block(block_label=solo_exit_label) + scfg.insert_block_between(solo_exit_name, tails, exits) + return solo_tail_name, solo_exit_name + + if len(tails) >= 2 and len(exits) == 1: + # join only tails + solo_tail_label = SyntheticTail() + solo_exit_name = next(iter(exits)) + solo_tail_name = scfg.add_block(block_label=solo_tail_label) + scfg.insert_block_between(solo_tail_name, tails, exits) + return solo_tail_name, solo_exit_name + + if len(tails) >= 2 and len(exits) >= 2: + # join both tails and exits + solo_tail_label = SyntheticTail() + solo_exit_label = SyntheticExit() + + solo_tail_name = scfg.add_block(block_label=solo_tail_label) + scfg.insert_block_between(solo_tail_name, tails, exits) + + solo_exit_name = scfg.add_block(block_label=solo_exit_label) + scfg.insert_block_between(solo_exit_name, set((solo_tail_name,)), exits) + return solo_tail_name, solo_exit_name diff --git a/numba_rvsdg/rendering/rendering.py b/numba_rvsdg/rendering/rendering.py index bc3ac02..e43dd13 100644 --- a/numba_rvsdg/rendering/rendering.py +++ b/numba_rvsdg/rendering/rendering.py @@ -1,76 +1,68 @@ import logging from numba_rvsdg.core.datastructures.basic_block import ( BasicBlock, - RegionBlock, PythonBytecodeBlock, ControlVariableBlock, BranchBlock, ) -from numba_rvsdg.core.datastructures.block_map import BlockMap +from numba_rvsdg.core.datastructures.scfg import SCFG from numba_rvsdg.core.datastructures.labels import ( Label, PythonBytecodeLabel, ControlLabel, + BlockName, + RegionName ) +from numba_rvsdg.core.datastructures.region import MetaRegion, LoopRegion from numba_rvsdg.core.datastructures.byte_flow import ByteFlow import dis from typing import Dict class ByteFlowRenderer(object): - def __init__(self): + def __init__(self, byte_flow: ByteFlow): from graphviz import Digraph self.g = Digraph() + self.byte_flow = byte_flow + self.scfg = byte_flow.scfg + self.bcmap_from_bytecode(byte_flow.bc) - def render_region_block( - self, digraph: "Digraph", label: Label, regionblock: RegionBlock - ): - # render subgraph - graph = regionblock.get_full_graph() - with digraph.subgraph(name=f"cluster_{label}") as subg: - color = "blue" - if regionblock.kind == "branch": - color = "green" - if regionblock.kind == "tail": - color = "purple" - if regionblock.kind == "head": - color = "red" - subg.attr(color=color, label=regionblock.kind) - for label, block in graph.items(): - self.render_block(subg, label, block) - # render edges within this region - self.render_edges(graph) - - def render_basic_block(self, digraph: "Digraph", label: Label, block: BasicBlock): - if isinstance(label, PythonBytecodeLabel): + self.rendered_blocks = set() + self.render_region(self.g, None) + self.render_edges() + + def render_basic_block(self, graph, block_name: BlockName): + block = self.scfg[block_name] + + if isinstance(block.label, PythonBytecodeLabel): instlist = block.get_instructions(self.bcmap) - body = label.__class__.__name__ + ": " + str(label.index) + "\l" + body = str(block_name) + "\l" body += "\l".join( [f"{inst.offset:3}: {inst.opname}" for inst in instlist] + [""] ) - elif isinstance(label, ControlLabel): - body = label.__class__.__name__ + ": " + str(label.index) + elif isinstance(block, ControlLabel): + body = str(block_name) else: - raise Exception("Unknown label type: " + label) - digraph.node(str(label), shape="rect", label=body) - - def render_control_variable_block( - self, digraph: "Digraph", label: Label, block: BasicBlock - ): - if isinstance(label, ControlLabel): - body = label.__class__.__name__ + ": " + str(label.index) + "\l" - body += "\l".join( - (f"{k} = {v}" for k, v in block.variable_assignment.items()) - ) + raise Exception("Unknown label type: " + block.label) + graph.node(str(block_name), shape="rect", label=body) + + def render_control_variable_block(self, graph, block_name: BlockName): + block = self.scfg[block_name] + + if isinstance(block.label, ControlLabel): + body = str(block_name) + "\l" + # body += "\l".join( + # (f"{k} = {v}" for k, v in block.variable_assignment.items()) + # ) else: - raise Exception("Unknown label type: " + label) - digraph.node(str(label), shape="rect", label=body) + raise Exception("Unknown label type: " + block.label) + graph.node(str(block_name), shape="rect", label=body) + + def render_branching_block(self, graph, block_name: BlockName): + block = self.scfg[block_name] - def render_branching_block( - self, digraph: "Digraph", label: Label, block: BasicBlock - ): - if isinstance(label, ControlLabel): + if isinstance(block.label, ControlLabel): def find_index(v): if hasattr(v, "offset"): @@ -78,64 +70,80 @@ def find_index(v): if hasattr(v, "index"): return v.index - body = label.__class__.__name__ + ": " + str(label.index) + "\l" - body += f"variable: {block.variable}\l" - body += "\l".join( - (f"{k}=>{find_index(v)}" for k, v in block.branch_value_table.items()) - ) + body = str(block_name) + "\l" + # body += f"variable: {block.variable}\l" + # body += "\l".join( + # (f" {k} => {find_index(v)}" for k, v in block.branch_value_table.items()) + # ) + else: + raise Exception("Unknown label type: " + block.label) + graph.node(str(block_name), shape="rect", label=body) + + def render_region(self, graph, region_name): + # If region name is none, we're in the 'root' region + # that is the graph itself. + if region_name is None: + region_name = self.scfg.meta_region + region = self.scfg.regions[region_name] else: - raise Exception("Unknown label type: " + label) - digraph.node(str(label), shape="rect", label=body) + region = self.scfg.regions[region_name] - def render_block(self, digraph: "Digraph", label: Label, block: BasicBlock): + all_blocks = list(self.scfg.iterate_region(region_name, region_view=True)) + + with graph.subgraph(name=f"cluster_{region_name}") as subg: + if isinstance(region, LoopRegion): + color = "blue" + else: + color = "black" + subg.attr(color=color, label=str(region.label)) + + # If there are no further subregions then we render the blocks + for block_name in all_blocks: + self.render_block(subg, block_name) + + def render_block(self, graph, block_name): + if block_name in self.rendered_blocks: + return + + if isinstance(block_name, RegionName): + self.rendered_blocks.add(block_name) + self.render_region(graph, block_name) + return + + block = self.scfg[block_name] if type(block) == BasicBlock: - self.render_basic_block(digraph, label, block) + self.render_basic_block(graph, block_name) elif type(block) == PythonBytecodeBlock: - self.render_basic_block(digraph, label, block) + self.render_basic_block(graph, block_name) elif type(block) == ControlVariableBlock: - self.render_control_variable_block(digraph, label, block) + self.render_control_variable_block(graph, block_name) elif type(block) == BranchBlock: - self.render_branching_block(digraph, label, block) - elif type(block) == RegionBlock: - self.render_region_block(digraph, label, block) + self.render_branching_block(graph, block_name) else: raise Exception("unreachable") - - def render_edges(self, blocks: Dict[Label, BasicBlock]): - for label, block in blocks.items(): - for dst in block.jump_targets: - if dst in blocks: - if type(block) in ( - PythonBytecodeBlock, - BasicBlock, - ControlVariableBlock, - BranchBlock, - ): - self.g.edge(str(label), str(dst)) - elif type(block) == RegionBlock: - if block.exit is not None: - self.g.edge(str(block.exit), str(dst)) - else: - self.g.edge(str(label), str(dst)) - else: - raise Exception("unreachable") - for dst in block.backedges: - # assert dst in blocks - self.g.edge( - str(label), str(dst), style="dashed", color="grey", constraint="0" - ) - - def render_byteflow(self, byteflow: ByteFlow): - self.bcmap_from_bytecode(byteflow.bc) - - # render nodes - for label, block in byteflow.bbmap.graph.items(): - self.render_block(self.g, label, block) - self.render_edges(byteflow.bbmap.graph) - return self.g + self.rendered_blocks.add(block_name) + + def render_edges(self): + for block_name, out_edges in self.scfg.out_edges.items(): + for out_edge in out_edges: + if isinstance(out_edge, RegionName): + out_edge = self.scfg.regions[out_edge].header + if (block_name, out_edge) in self.scfg.back_edges: + self.g.edge( + str(block_name), + str(out_edge), + style="dashed", + color="grey", + constraint="0", + ) + else: + self.g.edge(str(block_name), str(out_edge)) def bcmap_from_bytecode(self, bc: dis.Bytecode): - self.bcmap: Dict[int, dis.Instruction] = BlockMap.bcmap_from_bytecode(bc) + self.bcmap: Dict[int, dis.Instruction] = ByteFlow.bcmap_from_bytecode(bc) + + def view(self, *args): + self.g.view(*args) logging.basicConfig(level=logging.DEBUG) @@ -143,13 +151,18 @@ def bcmap_from_bytecode(self, bc: dis.Bytecode): def render_func(func): flow = ByteFlow.from_bytecode(func) - ByteFlowRenderer().render_byteflow(flow).view("before") + render_flow(flow) + + +def render_flow(byte_flow): + ByteFlowRenderer(byte_flow).view("before") - cflow = flow._join_returns() - ByteFlowRenderer().render_byteflow(cflow).view("closed") + byte_flow._join_returns() + ByteFlowRenderer(byte_flow).view("closed") - lflow = cflow._restructure_loop() - ByteFlowRenderer().render_byteflow(lflow).view("loop restructured") + byte_flow._restructure_loop() + breakpoint() + ByteFlowRenderer(byte_flow).view("loop restructured") - bflow = lflow._restructure_branch() - ByteFlowRenderer().render_byteflow(bflow).view("branch restructured") + #byte_flow._restructure_branch() + #ByteFlowRenderer(byte_flow).view("branch restructured") diff --git a/numba_rvsdg/tests/simulator.py b/numba_rvsdg/tests/simulator.py index 6d8e837..5b0d897 100644 --- a/numba_rvsdg/tests/simulator.py +++ b/numba_rvsdg/tests/simulator.py @@ -1,11 +1,12 @@ from collections import ChainMap from dis import Instruction from numba_rvsdg.core.datastructures.byte_flow import ByteFlow -from numba_rvsdg.core.datastructures.block_map import BlockMap +from numba_rvsdg.core.datastructures.scfg import SCFG from numba_rvsdg.core.datastructures.basic_block import ( BasicBlock, PythonBytecodeBlock, - RegionBlock, + ControlVariableBlock, + BranchBlock, ) from numba_rvsdg.core.datastructures.labels import ( Label, @@ -17,13 +18,14 @@ SyntheticHead, SyntheticTail, SyntheticReturn, + BlockName, ) import builtins class Simulator: - """BlockMap simulator. + """SCFG simulator. This is a simulator utility to be used for testing. @@ -50,10 +52,8 @@ class Simulator: Control variable map stack: List[Instruction] Instruction stack - region_stack: List[RegionBlocks] - Stack to hold the recusion level for regions - trace: List[Tuple(label, block)] - List of label, block combinations visisted + trace: Set[BlockName] + List of names, block combinations visisted branch: Boolean Flag to be set during execution. return_value: Any @@ -64,19 +64,19 @@ class Simulator: def __init__(self, flow: ByteFlow, globals: dict): self.flow = flow + self.scfg = flow.scfg self.globals = ChainMap(globals, builtins.__dict__) self.bcmap = {inst.offset: inst for inst in flow.bc} self.varmap = dict() self.ctrl_varmap = dict() self.stack = [] - self.region_stack = [] - self.trace = [] + self.trace = set() self.branch = None self.return_value = None - def get_block(self, label:Label): - """Return the BasicBlock object for a give label. + def get_block(self, name: BlockName): + """Return the BasicBlock object for a give name. This method is aware of the recusion level of the `Simulator` into the `region_stack`. That is to say, if we have recursed into regions, the @@ -87,8 +87,8 @@ def get_block(self, label:Label): Parameters ---------- - label: Label - The label for which to fetch the BasicBlock + name: BlockName + The name for which to fetch the BasicBlock Return ------ @@ -96,12 +96,7 @@ def get_block(self, label:Label): The requested block """ - # Recursed into regions, return block from region - if self.region_stack: - return self.region_stack[-1].subregion[label] - # Not recursed into regions, return block from ByteFlow - else: - return self.flow.bbmap[label] + return self.flow.scfg[name] def run(self, args): """Run the given simulator with given args. @@ -118,20 +113,20 @@ def run(self, args): """ self.varmap.update(args) - label = PythonBytecodeLabel(index=0) + name = self.flow.scfg.find_head() while True: - action = self.run_BasicBlock(label) + action = self.run_BasicBlock(name) if "return" in action: return action["return"] - label = action["jumpto"] + name = action["jumpto"] - def run_BasicBlock(self, label: Label): + def run_BasicBlock(self, name: BlockName): """Run a BasicBlock. Paramters --------- - label: Label - The Label of the BasicBlock + name: BlockName + The BlockName of the BasicBlock Returns ------- @@ -140,104 +135,51 @@ def run_BasicBlock(self, label: Label): BasicBlock. """ - print("AT", label) - block = self.get_block(label) - self.trace.append((label, block)) - if isinstance(block, RegionBlock): - return self.run_RegionBlock(label) - - if isinstance(label, ControlLabel): - self.run_synth_block(label) - elif isinstance(label, PythonBytecodeLabel): - self.run_PythonBytecodeBlock(label) - if block.fallthrough: - [label] = block.jump_targets - return {"jumpto": label} - elif len(block._jump_targets) == 2: - [br_false, br_true] = block._jump_targets + print("AT", name) + block = self.get_block(name) + self.trace.add(name) + + if isinstance(block.label, ControlLabel): + self.run_synth_block(name) + elif isinstance(block.label, PythonBytecodeLabel): + self.run_PythonBytecodeBlock(name) + if len(self.scfg.out_edges[name]) == 1: + [name] = self.scfg.out_edges[name] + return {"jumpto": name} + elif len(self.scfg.out_edges[name]) == 2: + [br_false, br_true] = self.scfg.out_edges[name] return {"jumpto": br_true if self.branch else br_false} else: return {"return": self.return_value} - def run_RegionBlock(self, label: Label): - """Run region. - - Execute all BasicBlocks in this region. Stay within the region, only - return the action when we jump out of the region or when we return from - within the region. - - Special attention is directed at the use of the `region_stack` here. - Since the blocks for the subregion are stored in the `region.subregion` - graph, we need to use a region aware `get_blocks` in methods such as - `run_BasicBlock` so that we get the correct `BasicBlock`. The net effect - of placing the `region` onto the `region_stack` is that `run_BasicBlock` - will be able to fetch the correct label from the `region.subregion` - graph, and thus be able to run the correct sequence of blocks. - - Parameters - ---------- - label: Label - The Label for the RegionBlock - - Returns - ------- - action: Dict[Str: Int or Boolean or Any] - The action to be taken as a result of having executed the - BasicBlock. - - """ - # Get the RegionBlock and place it onto the region_stack - region: RegionBlock = self.get_block(label) - self.region_stack.append(region) - while True: - # Execute the first block of the region. - action = self.run_BasicBlock(label) - # If we need to return, break and do so - if "return" in action: - break # break and return action - elif "jumpto" in action: - label = action["jumpto"] - # Otherwise check if we stay in the region and break otherwise - if label in region.subregion.graph: - continue # stay in the region - else: - break # break and return action - else: - assert False, "unreachable" # in case of coding errors - # Pop the region from the region stack again and return the final - # action for this region - popped = self.region_stack.pop() - assert(popped == region) - return action - - def run_PythonBytecodeBlock(self, label: PythonBytecodeLabel): + def run_PythonBytecodeBlock(self, name: BlockName): """Run PythonBytecodeBlock Parameters ---------- - label: PythonBytecodeLabel - The Label for the block. + name: BlockName + The BlockName for the block. """ - block: PythonBytecodeBlock = self.get_block(label) + block: PythonBytecodeBlock = self.get_block(name) assert type(block) is PythonBytecodeBlock for inst in block.get_instructions(self.bcmap): self.run_inst(inst) - def run_synth_block(self, label: ControlLabel): + def run_synth_block(self, name: BlockName): """Run a SyntheticBlock Paramaters ---------- - label: ControlLabel - The Label for the block. + name: BlockName + The BlockName for the block. """ - print("----", label) + print("----", name) print(f"control variable map: {self.ctrl_varmap}") - block = self.get_block(label) - handler = getattr(self, f"synth_{type(label).__name__}") - handler(label, block) + block = self.get_block(name) + handler = getattr(self, f"synth_{type(name).__name__}") + handler(name, block) def run_inst(self, inst: Instruction): """Run a bytecode Instruction @@ -257,44 +199,58 @@ def run_inst(self, inst: Instruction): print(f"stack after: {self.stack}") ### Synthetic Instructions ### - def synth_SynthenticAssignment(self, control_label, block): + def synth_SynthenticAssignment( + self, control_label: BlockName, block: ControlVariableBlock + ): self.ctrl_varmap.update(block.variable_assignment) - def _synth_branch(self, control_label, block): + def _synth_branch(self, control_label: BlockName, block: BranchBlock): jump_target = block.branch_value_table[self.ctrl_varmap[block.variable]] self.branch = bool(block._jump_targets.index(jump_target)) - def synth_SyntheticExitingLatch(self, control_label, block): + def synth_SyntheticExitingLatch( + self, control_label: BlockName, block: ControlVariableBlock + ): self._synth_branch(control_label, block) - def synth_SyntheticHead(self, control_label, block): + def synth_SyntheticHead( + self, control_label: BlockName, block: ControlVariableBlock + ): self._synth_branch(control_label, block) - def synth_SyntheticExit(self, control_label, block): + def synth_SyntheticExit( + self, control_label: BlockName, block: ControlVariableBlock + ): self._synth_branch(control_label, block) - def synth_SyntheticReturn(self, control_label, block): + def synth_SyntheticReturn( + self, control_label: BlockName, block: ControlVariableBlock + ): pass - def synth_SyntheticTail(self, control_label, block): + def synth_SyntheticTail( + self, control_label: BlockName, block: ControlVariableBlock + ): pass - def synth_SyntheticBranch(self, control_label, block): + def synth_SyntheticBranch( + self, control_label: BlockName, block: ControlVariableBlock + ): pass ### Bytecode Instructions ### - def op_LOAD_CONST(self, inst): + def op_LOAD_CONST(self, inst: Instruction): self.stack.append(inst.argval) - def op_COMPARE_OP(self, inst): + def op_COMPARE_OP(self, inst: Instruction): arg1 = self.stack.pop() arg2 = self.stack.pop() self.stack.append(eval(f"{arg2} {inst.argval} {arg1}")) - def op_LOAD_FAST(self, inst): + def op_LOAD_FAST(self, inst: Instruction): self.stack.append(self.varmap[inst.argval]) - def op_LOAD_GLOBAL(self, inst): + def op_LOAD_GLOBAL(self, inst: Instruction): v = self.globals[inst.argval] if inst.argrepr.startswith("NULL"): append_null = True @@ -303,22 +259,22 @@ def op_LOAD_GLOBAL(self, inst): else: raise NotImplementedError - def op_STORE_FAST(self, inst): + def op_STORE_FAST(self, inst: Instruction): val = self.stack.pop() self.varmap[inst.argval] = val - def op_CALL_FUNCTION(self, inst): + def op_CALL_FUNCTION(self, inst: Instruction): args = [self.stack.pop() for _ in range(inst.argval)][::-1] fn = self.stack.pop() res = fn(*args) self.stack.append(res) - def op_GET_ITER(self, inst): + def op_GET_ITER(self, inst: Instruction): val = self.stack.pop() res = iter(val) self.stack.append(res) - def op_FOR_ITER(self, inst): + def op_FOR_ITER(self, inst: Instruction): tos = self.stack[-1] try: ind = next(tos) @@ -329,58 +285,58 @@ def op_FOR_ITER(self, inst): self.branch = False self.stack.append(ind) - def op_INPLACE_ADD(self, inst): + def op_INPLACE_ADD(self, inst: Instruction): rhs = self.stack.pop() lhs = self.stack.pop() lhs += rhs self.stack.append(lhs) - def op_RETURN_VALUE(self, inst): + def op_RETURN_VALUE(self, inst: Instruction): v = self.stack.pop() self.return_value = v - def op_JUMP_ABSOLUTE(self, inst): + def op_JUMP_ABSOLUTE(self, inst: Instruction): pass - def op_JUMP_FORWARD(self, inst): + def op_JUMP_FORWARD(self, inst: Instruction): pass - def op_POP_JUMP_IF_FALSE(self, inst): + def op_POP_JUMP_IF_FALSE(self, inst: Instruction): self.branch = not self.stack.pop() - def op_POP_JUMP_IF_TRUE(self, inst): + def op_POP_JUMP_IF_TRUE(self, inst: Instruction): self.branch = bool(self.stack.pop()) - def op_JUMP_IF_TRUE_OR_POP(self, inst): + def op_JUMP_IF_TRUE_OR_POP(self, inst: Instruction): if self.stack[-1]: self.branch = True else: self.stack.pop() self.branch = False - def op_JUMP_IF_FALSE_OR_POP(self, inst): + def op_JUMP_IF_FALSE_OR_POP(self, inst: Instruction): if not self.stack[-1]: self.branch = True else: self.stack.pop() self.branch = False - def op_POP_TOP(self, inst): + def op_POP_TOP(self, inst: Instruction): self.stack.pop() - def op_RESUME(self, inst): + def op_RESUME(self, inst: Instruction): pass - def op_PRECALL(self, inst): + def op_PRECALL(self, inst: Instruction): pass - def op_CALL_FUNCTION(self, inst): + def op_CALL_FUNCTION(self, inst: Instruction): args = [self.stack.pop() for _ in range(inst.argval)][::-1] fn = self.stack.pop() res = fn(*args) self.stack.append(res) - def op_CALL(self, inst): + def op_CALL(self, inst: Instruction): args = [self.stack.pop() for _ in range(inst.argval)][::-1] first, second = self.stack.pop(), self.stack.pop() if first == None: @@ -390,42 +346,42 @@ def op_CALL(self, inst): res = func(*args) self.stack.append(res) - def op_BINARY_OP(self, inst): + def op_BINARY_OP(self, inst: Instruction): rhs, lhs, op = self.stack.pop(), self.stack.pop(), inst.argrepr op = op if len(op) == 1 else op[0] self.stack.append(eval(f"{lhs} {op} {rhs}")) - def op_JUMP_BACKWARD(self, inst): + def op_JUMP_BACKWARD(self, inst: Instruction): pass - def op_POP_JUMP_FORWARD_IF_TRUE(self, inst): + def op_POP_JUMP_FORWARD_IF_TRUE(self, inst: Instruction): self.branch = self.stack[-1] self.stack.pop() - def op_POP_JUMP_BACKWARD_IF_TRUE(self, inst): + def op_POP_JUMP_BACKWARD_IF_TRUE(self, inst: Instruction): self.branch = self.stack[-1] self.stack.pop() - def op_POP_JUMP_FORWARD_IF_FALSE(self, inst): + def op_POP_JUMP_FORWARD_IF_FALSE(self, inst: Instruction): self.branch = not self.stack[-1] self.stack.pop() - def op_POP_JUMP_BACKWARD_IF_FALSE(self, inst): + def op_POP_JUMP_BACKWARD_IF_FALSE(self, inst: Instruction): self.branch = not self.stack[-1] self.stack.pop() - def op_POP_JUMP_FORWARD_IF_NOT_NONE(self, inst): + def op_POP_JUMP_FORWARD_IF_NOT_NONE(self, inst: Instruction): self.branch = self.stack[-1] is not None self.stack.pop() - def op_POP_JUMP_BACKWARD_IF_NOT_NONE(self, inst): + def op_POP_JUMP_BACKWARD_IF_NOT_NONE(self, inst: Instruction): self.branch = self.stack[-1] is not None self.stack.pop() - def op_POP_JUMP_FORWARD_IF_NONE(self, inst): + def op_POP_JUMP_FORWARD_IF_NONE(self, inst: Instruction): self.branch = self.stack[-1] is None self.stack.pop() - def op_POP_JUMP_BACKWARD_IF_NONE(self, inst): + def op_POP_JUMP_BACKWARD_IF_NONE(self, inst: Instruction): self.branch = self.stack[-1] is None self.stack.pop() diff --git a/numba_rvsdg/tests/test_block_map.py b/numba_rvsdg/tests/test_block_map.py deleted file mode 100644 index ac07420..0000000 --- a/numba_rvsdg/tests/test_block_map.py +++ /dev/null @@ -1,22 +0,0 @@ - -from numba_rvsdg.tests.test_transforms import MapComparator -from numba_rvsdg.core.datastructures.basic_block import BasicBlock -from numba_rvsdg.core.datastructures.block_map import BlockMap -from numba_rvsdg.core.datastructures.labels import ControlLabel - -class TestBlockMapIterator(MapComparator): - - def test_block_map_iter(self): - expected = [ - (ControlLabel("0"), BasicBlock(label=ControlLabel("0"), - _jump_targets=(ControlLabel("1"),))), - (ControlLabel("1"), BasicBlock(label=ControlLabel("1"))), - ] - block_map = BlockMap.from_yaml(""" - "0": - jt: ["1"] - "1": - jt: [] - """) - received = list(block_map) - self.assertEqual(expected, received) diff --git a/numba_rvsdg/tests/test_blockmap.py b/numba_rvsdg/tests/test_blockmap.py deleted file mode 100644 index f9c8a63..0000000 --- a/numba_rvsdg/tests/test_blockmap.py +++ /dev/null @@ -1,104 +0,0 @@ - -from unittest import main -from textwrap import dedent -from numba_rvsdg.core.datastructures.block_map import BlockMap - -from numba_rvsdg.tests.test_utils import MapComparator - - -class TestBlockMapConversion(MapComparator): - - def test_yaml_conversion(self): - # Case # 1: Acyclic graph, no back-edges - cases = [""" - "0": - jt: ["1", "2"] - "1": - jt: ["3"] - "2": - jt: ["4"] - "3": - jt: ["4"] - "4": - jt: []""", - # Case # 2: Cyclic graph, no back edges - """ - "0": - jt: ["1", "2"] - "1": - jt: ["5"] - "2": - jt: ["1", "5"] - "3": - jt: ["0"] - "4": - jt: [] - "5": - jt: ["3", "4"]""", - # Case # 3: Graph with backedges - """ - "0": - jt: ["1"] - "1": - jt: ["2", "3"] - "2": - jt: ["4"] - "3": - jt: [] - "4": - jt: ["2", "3"] - be: ["2"]"""] - - for case in cases: - case = dedent(case) - block_map = BlockMap.from_yaml(case) - self.assertEqual(case, block_map.to_yaml()) - - def test_dict_conversion(self): - # Case # 1: Acyclic graph, no back-edges - cases = [{ - "0": - {"jt": ["1", "2"]}, - "1": - {"jt": ["3"]}, - "2": - {"jt": ["4"]}, - "3": - {"jt": ["4"]}, - "4": - {"jt": []}}, - # Case # 2: Cyclic graph, no back edges - { - "0": - {"jt": ["1", "2"]}, - "1": - {"jt": ["5"]}, - "2": - {"jt": ["1", "5"]}, - "3": - {"jt": ["0"]}, - "4": - {"jt": []}, - "5": - {"jt": ["3", "4"]}}, - # Case # 3: Graph with backedges - { - "0": - {"jt": ["1"]}, - "1": - {"jt": ["2", "3"]}, - "2": - {"jt": ["4"]}, - "3": - {"jt": []}, - "4": - {"jt": ["2", "3"], - "be": ["2"]}}] - - for case in cases: - block_map = BlockMap.from_dict(case) - self.assertEqual(case, block_map.to_dict()) - -if __name__ == "__main__": - main() - diff --git a/numba_rvsdg/tests/test_byteflow.py b/numba_rvsdg/tests/test_byteflow.py index 6f429a6..aff5aba 100644 --- a/numba_rvsdg/tests/test_byteflow.py +++ b/numba_rvsdg/tests/test_byteflow.py @@ -2,11 +2,16 @@ import unittest from numba_rvsdg.core.datastructures.basic_block import PythonBytecodeBlock -from numba_rvsdg.core.datastructures.labels import PythonBytecodeLabel +from numba_rvsdg.core.datastructures.labels import ( + PythonBytecodeLabel, + NameGenerator, + BlockName, + get_label_class, +) from numba_rvsdg.core.datastructures.byte_flow import ByteFlow -from numba_rvsdg.core.datastructures.block_map import BlockMap +from numba_rvsdg.core.datastructures.scfg import SCFG from numba_rvsdg.core.datastructures.flow_info import FlowInfo - +from numba_rvsdg.tests.test_utils import SCFGComparator def fun(): x = 1 @@ -19,7 +24,7 @@ def fun(): class TestBCMapFromBytecode(unittest.TestCase): def test(self): # If the function definition line changes, just change the variable below, rest of it will adjust as long as function remains the same - func_def_line = 11 + func_def_line = 16 expected = { 0: Instruction( opname="RESUME", @@ -102,47 +107,35 @@ def test(self): ), ), } - received = BlockMap.bcmap_from_bytecode(bytecode) + received = ByteFlow.bcmap_from_bytecode(bytecode) self.assertEqual(expected, received) class TestPythonBytecodeBlock(unittest.TestCase): def test_constructor(self): + name_gen = NameGenerator() block = PythonBytecodeBlock( - label=PythonBytecodeLabel(index=0), + label=PythonBytecodeLabel(), begin=0, end=8, - _jump_targets=(), - backedges=(), + name_gen=name_gen, ) - self.assertEqual(block.label, PythonBytecodeLabel(index=0)) + self.assertEqual(block.label, PythonBytecodeLabel()) self.assertEqual(block.begin, 0) self.assertEqual(block.end, 8) - self.assertFalse(block.fallthrough) - self.assertTrue(block.is_exiting) - self.assertEqual(block.jump_targets, ()) - self.assertEqual(block.backedges, ()) - def test_is_jump_target(self): - block = PythonBytecodeBlock( - label=PythonBytecodeLabel(index=0), - begin=0, - end=8, - _jump_targets=(PythonBytecodeLabel(index=1),), - backedges=(), - ) - self.assertEqual(block.jump_targets, (PythonBytecodeLabel(index=1),)) - self.assertFalse(block.is_exiting) + block_name = BlockName("pythonbytecodelabel_0") + self.assertEqual(block.block_name, block_name) def test_get_instructions(self): # If the function definition line changes, just change the variable below, rest of it will adjust as long as function remains the same - func_def_line = 11 + func_def_line = 16 + name_gen = NameGenerator() block = PythonBytecodeBlock( - label=PythonBytecodeLabel(index=0), + label=PythonBytecodeLabel(), begin=0, end=8, - _jump_targets=(), - backedges=(), + name_gen=name_gen, ) expected = [ Instruction( @@ -211,7 +204,7 @@ def test_get_instructions(self): ), ] - received = block.get_instructions(BlockMap.bcmap_from_bytecode(bytecode)) + received = block.get_instructions(ByteFlow.bcmap_from_bytecode(bytecode)) self.assertEqual(expected, received) @@ -229,42 +222,32 @@ def test_from_bytecode(self): self.assertEqual(expected, received) def test_build_basic_blocks(self): - expected = BlockMap( - graph={ - PythonBytecodeLabel(index=0): PythonBytecodeBlock( - label=PythonBytecodeLabel(index=0), - begin=0, - end=10, - _jump_targets=(), - backedges=(), - ) - } - ) + expected = SCFG() + expected.add_block("python_bytecode", get_label_class("python_bytecode")(), + begin=0, end=10) + received = FlowInfo.from_bytecode(bytecode).build_basicblocks() self.assertEqual(expected, received) -class TestByteFlow(unittest.TestCase): +class TestByteFlow(SCFGComparator): def test_constructor(self): byteflow = ByteFlow([], []) self.assertEqual(len(byteflow.bc), 0) - self.assertEqual(len(byteflow.bbmap), 0) + self.assertEqual(len(byteflow.scfg), 0) def test_from_bytecode(self): - bbmap = BlockMap( - graph={ - PythonBytecodeLabel(index=0): PythonBytecodeBlock( - label=PythonBytecodeLabel(index=0), - begin=0, - end=10, - _jump_targets=(), - backedges=(), - ) - } + scfg = SCFG() + + scfg.add_block( + block_type="python_bytecode", + block_label=get_label_class("python_bytecode")(), + begin=0, + end=10, ) - expected = ByteFlow(bc=bytecode, bbmap=bbmap) + expected = ByteFlow(bc=bytecode, scfg=scfg) received = ByteFlow.from_bytecode(fun) - self.assertEqual(expected.bbmap, received.bbmap) + self.assertSCFGEqual(expected.scfg, received.scfg) if __name__ == "__main__": diff --git a/numba_rvsdg/tests/test_fig3.py b/numba_rvsdg/tests/test_fig3.py index e3e2525..bdf4d3f 100644 --- a/numba_rvsdg/tests/test_fig3.py +++ b/numba_rvsdg/tests/test_fig3.py @@ -29,5 +29,5 @@ def make_flow(): dis.Instruction("RETURN_VALUE", 1, None, None, "", 20, None, False), ] flow = FlowInfo.from_bytecode(bc) - bbmap = flow.build_basicblocks() - return ByteFlow(bc=bc, bbmap=bbmap) + scfg = flow.build_basicblocks() + return ByteFlow(bc=bc, scfg=scfg) diff --git a/numba_rvsdg/tests/test_fig4.py b/numba_rvsdg/tests/test_fig4.py index d30636b..605efcc 100644 --- a/numba_rvsdg/tests/test_fig4.py +++ b/numba_rvsdg/tests/test_fig4.py @@ -28,5 +28,5 @@ def make_flow(): dis.Instruction("RETURN_VALUE", 1, None, None, "", 18, None, False), ] flow = FlowInfo.from_bytecode(bc) - bbmap = flow.build_basicblocks() - return ByteFlow(bc=bc, bbmap=bbmap) + scfg = flow.build_basicblocks() + return ByteFlow(bc=bc, scfg=scfg) diff --git a/numba_rvsdg/tests/test_scfg.py b/numba_rvsdg/tests/test_scfg.py new file mode 100644 index 0000000..104bded --- /dev/null +++ b/numba_rvsdg/tests/test_scfg.py @@ -0,0 +1,335 @@ +from unittest import main +from textwrap import dedent +from numba_rvsdg.core.datastructures.scfg import SCFG + +from numba_rvsdg.tests.test_utils import SCFGComparator +from numba_rvsdg.core.datastructures.basic_block import BasicBlock +from numba_rvsdg.core.datastructures.labels import Label, NameGenerator + + +class TestSCFGConversion(SCFGComparator): + def test_yaml_conversion(self): + # Case # 1: Acyclic graph, no back-edges + cases = [ + """ + "0": + type: "basic" + out: ["1", "2"] + "1": + type: "basic" + out: ["3"] + "2": + type: "basic" + out: ["4"] + "3": + type: "basic" + out: ["4"] + "4": + type: "basic" + out: []""", + # Case # 2: Cyclic graph, no back edges + """ + "0": + type: "basic" + out: ["1", "2"] + "1": + type: "basic" + out: ["5"] + "2": + type: "basic" + out: ["1", "5"] + "3": + type: "basic" + out: ["0"] + "4": + type: "basic" + out: [] + "5": + type: "basic" + out: ["3", "4"]""", + # Case # 3: Graph with backedges + """ + "0": + type: "basic" + out: ["1"] + "1": + type: "basic" + out: ["2", "3"] + "2": + type: "basic" + out: ["4"] + "3": + type: "basic" + out: [] + "4": + type: "basic" + out: ["2", "3"] + back: ["2"]""", + ] + + for case in cases: + case = dedent(case) + scfg, ref_dict = SCFG.from_yaml(case) + yaml = scfg.to_yaml() + self.assertYAMLEquals(case, yaml, ref_dict) + + def test_dict_conversion(self): + # Case # 1: Acyclic graph, no back-edges + cases = [ + { + "0": {"type": "basic", "out": ["1", "2"]}, + "1": {"type": "basic", "out": ["3"]}, + "2": {"type": "basic", "out": ["4"]}, + "3": {"type": "basic", "out": ["4"]}, + "4": {"type": "basic", "out": []}, + }, + # Case # 2: Cyclic graph, no back edges + { + "0": {"type": "basic", "out": ["1", "2"]}, + "1": {"type": "basic", "out": ["5"]}, + "2": {"type": "basic", "out": ["1", "5"]}, + "3": {"type": "basic", "out": ["0"]}, + "4": {"type": "basic", "out": []}, + "5": {"type": "basic", "out": ["3", "4"]}, + }, + # Case # 3: Graph with backedges + { + "0": {"type": "basic", "out": ["1"]}, + "1": {"type": "basic", "out": ["2", "3"]}, + "2": {"type": "basic", "out": ["4"]}, + "3": {"type": "basic", "out": []}, + "4": {"type": "basic", "out": ["2", "3"], "back": ["2"]}, + }, + ] + + for case in cases: + scfg = SCFG.from_dict(case) + scfg, ref_dict = SCFG.from_dict(case) + generated_dict = scfg.to_dict() + self.assertDictEquals(case, generated_dict, ref_dict) + + +class TestSCFGIterator(SCFGComparator): + def test_scfg_iter(self): + name_generator = NameGenerator() + block_0 = BasicBlock(name_generator, Label()) + block_1 = BasicBlock(name_generator, Label()) + expected = [ + block_0.block_name, + block_1.block_name, + ] + scfg, ref_dict = SCFG.from_yaml( + """ + "0": + type: "basic" + out: ["1"] + "1": + type: "basic" + """ + ) + received = list(scfg) + self.assertEqual(expected, received) + + +class TestInsertBlock(SCFGComparator): + def test_linear(self): + original = """ + "0": + type: "basic" + out: ["1"] + "1": + type: "basic" + out: [] + """ + original_scfg, block_ref_orig = SCFG.from_yaml(original) + expected = """ + "0": + type: "basic" + out: ["2"] + "1": + type: "basic" + out: [] + "2": + type: "basic" + out: ["1"] + """ + expected_scfg, _ = SCFG.from_yaml(expected) + + preds = list((block_ref_orig["0"],)) + succs = list((block_ref_orig["1"],)) + new_block = original_scfg.add_block() + original_scfg.insert_block_between(new_block, preds, succs) + + self.assertSCFGEqual(expected_scfg, original_scfg) + + def test_dual_predecessor(self): + original = """ + "0": + type: "basic" + out: ["2"] + "1": + type: "basic" + out: ["2"] + "2": + type: "basic" + out: [] + """ + original_scfg, block_ref_orig = SCFG.from_yaml(original) + expected = """ + "0": + type: "basic" + out: ["3"] + "1": + type: "basic" + out: ["3"] + "2": + type: "basic" + out: [] + "3": + type: "basic" + out: ["2"] + """ + expected_scfg, block_ref_exp = SCFG.from_yaml(expected) + + preds = list((block_ref_orig["0"], block_ref_orig["1"])) + succs = list((block_ref_orig["2"],)) + new_block = original_scfg.add_block() + original_scfg.insert_block_between(new_block, preds, succs) + + self.assertSCFGEqual(expected_scfg, original_scfg) + + def test_dual_successor(self): + original = """ + "0": + type: "basic" + out: ["1", "2"] + "1": + type: "basic" + out: [] + "2": + type: "basic" + out: [] + """ + original_scfg, block_ref_orig = SCFG.from_yaml(original) + expected = """ + "0": + type: "basic" + out: ["3"] + "1": + type: "basic" + out: [] + "2": + type: "basic" + out: [] + "3": + type: "basic" + out: ["1", "2"] + """ + expected_scfg, block_ref_exp = SCFG.from_yaml(expected) + + preds = list((block_ref_orig["0"],)) + succs = list((block_ref_orig["1"], block_ref_orig["2"])) + new_block = original_scfg.add_block() + original_scfg.insert_block_between(new_block, preds, succs) + + self.assertSCFGEqual(expected_scfg, original_scfg) + + def test_dual_predecessor_and_dual_successor(self): + original = """ + "0": + type: "basic" + out: ["1", "2"] + "1": + type: "basic" + out: ["3"] + "2": + type: "basic" + out: ["4"] + "3": + type: "basic" + out: [] + "4": + type: "basic" + out: [] + """ + original_scfg, block_ref_orig = SCFG.from_yaml(original) + expected = """ + "0": + type: "basic" + out: ["1", "2"] + "1": + type: "basic" + out: ["5"] + "2": + type: "basic" + out: ["5"] + "3": + type: "basic" + out: [] + "4": + type: "basic" + out: [] + "5": + type: "basic" + out: ["3", "4"] + """ + expected_scfg, block_ref_exp = SCFG.from_yaml(expected) + + preds = list((block_ref_orig["1"], block_ref_orig["2"])) + succs = list((block_ref_orig["3"], block_ref_orig["4"])) + new_block = original_scfg.add_block() + original_scfg.insert_block_between(new_block, preds, succs) + + self.assertSCFGEqual(expected_scfg, original_scfg) + + def test_dual_predecessor_and_dual_successor_with_additional_arcs(self): + original = """ + "0": + type: "basic" + out: ["1", "2"] + "1": + type: "basic" + out: ["3"] + "2": + type: "basic" + out: ["1", "4"] + "3": + type: "basic" + out: ["0"] + "4": + type: "basic" + out: [] + """ + original_scfg, block_ref_orig = SCFG.from_yaml(original) + expected = """ + "0": + type: "basic" + out: ["1", "2"] + "1": + type: "basic" + out: ["5"] + "2": + type: "basic" + out: ["1", "5"] + "3": + type: "basic" + out: ["0"] + "4": + type: "basic" + out: [] + "5": + type: "basic" + out: ["3", "4"] + """ + expected_scfg, block_ref_exp = SCFG.from_yaml(expected) + + preds = list((block_ref_orig["1"], block_ref_orig["2"])) + succs = list((block_ref_orig["3"], block_ref_orig["4"])) + new_block = original_scfg.add_block() + original_scfg.insert_block_between(new_block, preds, succs) + + self.assertSCFGEqual(expected_scfg, original_scfg) + + +if __name__ == "__main__": + main() diff --git a/numba_rvsdg/tests/test_simulate.py b/numba_rvsdg/tests/test_simulate.py index f97a37b..b13219b 100644 --- a/numba_rvsdg/tests/test_simulate.py +++ b/numba_rvsdg/tests/test_simulate.py @@ -2,35 +2,24 @@ from numba_rvsdg.tests.simulator import Simulator import unittest -# flow = ByteFlow.from_bytecode(foo) -# #pprint(flow.bbmap) -# flow = flow.restructure() -# #pprint(flow.bbmap) -# # pprint(rtsflow.bbmap) -# ByteFlowRenderer().render_byteflow(flow).view() -# print(dis(foo)) -# -# sim = Simulator(flow, foo.__globals__) -# ret = sim.run(dict(x=1)) -# assert ret == foo(x=1) -# -# #sim = Simulator(flow, foo.__globals__) -# #ret = sim.run(dict(x=100)) -# #assert ret == foo(x=100) - -# You can use the following snipppet to visually debug the restructured -# byteflow: -# -# ByteFlowRenderer().render_byteflow(flow).view() -# -# - class SimulatorTest(unittest.TestCase): + + def setUp(self): + """Initialize simulator. """ + self.sim = None + def _run(self, func, flow, kwargs): + """Run function func. """ + # lazily initialize the simulator + if self.sim is None: + self.sim = Simulator(flow, func.__globals__) + with self.subTest(): - sim = Simulator(flow, func.__globals__) - self.assertEqual(sim.run(kwargs), func(**kwargs)) + self.assertEqual(self.sim.run(kwargs), func(**kwargs)) + + def _check_trace(self, flow): + self.assertEqual(self.sim.trace, set(flow.scfg.blocks.keys())) def test_simple_branch(self): def foo(x): @@ -42,13 +31,16 @@ def foo(x): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + # flow = flow.restructure() # if case self._run(foo, flow, {"x": 1}) # else case self._run(foo, flow, {"x": 0}) + # check the trace + self._check_trace(flow) + def test_simple_for_loop(self): def foo(x): c = 0 @@ -57,7 +49,7 @@ def foo(x): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + # flow = flow.restructure() # loop bypass case self._run(foo, flow, {"x": 0}) @@ -66,6 +58,9 @@ def foo(x): # extended loop case self._run(foo, flow, {"x": 100}) + # check the trace + self._check_trace(flow) + def test_simple_while_loop(self): def foo(x): c = 0 @@ -76,7 +71,7 @@ def foo(x): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + # flow = flow.restructure() # loop bypass case self._run(foo, flow, {"x": 0}) @@ -85,6 +80,9 @@ def foo(x): # extended loop case self._run(foo, flow, {"x": 100}) + # check the trace + self._check_trace(flow) + def test_for_loop_with_exit(self): def foo(x): c = 0 @@ -95,14 +93,17 @@ def foo(x): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + # flow = flow.restructure() # loop bypass case self._run(foo, flow, {"x": 0}) # loop case self._run(foo, flow, {"x": 2}) # break case - self._run(foo, flow, {"x": 15}) + self._run(foo, flow, {"x": 101}) + + # check the trace + self._check_trace(flow) def test_nested_for_loop_with_break_and_continue(self): def foo(x): @@ -119,7 +120,7 @@ def foo(x): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + # flow = flow.restructure() # no loop self._run(foo, flow, {"x": 0}) @@ -130,6 +131,9 @@ def foo(x): # will break self._run(foo, flow, {"x": 5}) + # check the trace + self._check_trace(flow) + def test_for_loop_with_multiple_backedges(self): def foo(x): c = 0 @@ -143,7 +147,7 @@ def foo(x): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + # flow = flow.restructure() # loop bypass self._run(foo, flow, {"x": 0}) @@ -154,12 +158,15 @@ def foo(x): # adding 1000, via the elif clause self._run(foo, flow, {"x": 7}) + # check the trace + self._check_trace(flow) + def test_andor(self): def foo(x, y): return (x > 0 and x < 10) or (y > 0 and y < 10) flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + # flow = flow.restructure() self._run(foo, flow, {"x": 5, "y": 5}) @@ -173,7 +180,7 @@ def foo(s, e): return c flow = ByteFlow.from_bytecode(foo) - flow = flow.restructure() + # flow = flow.restructure() # no looping self._run(foo, flow, {"s": 0, "e": 0}) @@ -189,6 +196,8 @@ def foo(s, e): # mutiple iterations self._run(foo, flow, {"s": 23, "e": 28}) + # check the trace + self._check_trace(flow) if __name__ == "__main__": unittest.main() diff --git a/numba_rvsdg/tests/test_transforms.py b/numba_rvsdg/tests/test_transforms.py index 7c1d621..859de1b 100644 --- a/numba_rvsdg/tests/test_transforms.py +++ b/numba_rvsdg/tests/test_transforms.py @@ -1,4 +1,3 @@ - from unittest import main from numba_rvsdg.core.datastructures.labels import ( @@ -6,467 +5,391 @@ SyntheticTail, SyntheticExit, ) -from numba_rvsdg.core.datastructures.block_map import BlockMap, wrap_id -from numba_rvsdg.core.transformations import loop_restructure_helper -from numba_rvsdg.tests.test_utils import MapComparator - - -class TestInsertBlock(MapComparator): - def test_linear(self): - original = """ - "0": - jt: ["1"] - "1": - jt: [] - """ - original_block_map = BlockMap.from_yaml(original) - expected = """ - "0": - jt: ["2"] - "1": - jt: [] - "2": - jt: ["1"] - """ - expected_block_map = BlockMap.from_yaml(expected) - original_block_map.insert_block( - ControlLabel("2"), wrap_id(("0",)), wrap_id(("1",)) - ) - self.assertMapEqual(expected_block_map, original_block_map) - - def test_dual_predecessor(self): - original = """ - "0": - jt: ["2"] - "1": - jt: ["2"] - "2": - jt: [] - """ - original_block_map = BlockMap.from_yaml(original) - expected = """ - "0": - jt: ["3"] - "1": - jt: ["3"] - "2": - jt: [] - "3": - jt: ["2"] - """ - expected_block_map = BlockMap.from_yaml(expected) - original_block_map.insert_block( - ControlLabel("3"), wrap_id(("0", "1")), wrap_id(("2",)) - ) - self.assertMapEqual(expected_block_map, original_block_map) - - def test_dual_successor(self): - original = """ - "0": - jt: ["1", "2"] - "1": - jt: [] - "2": - jt: [] - """ - original_block_map = BlockMap.from_yaml(original) - expected = """ - "0": - jt: ["3"] - "1": - jt: [] - "2": - jt: [] - "3": - jt: ["1", "2"] - """ - expected_block_map = BlockMap.from_yaml(expected) - original_block_map.insert_block( - ControlLabel("3"), - wrap_id(("0",)), - wrap_id(("1", "2")), - ) - self.assertMapEqual(expected_block_map, original_block_map) +from numba_rvsdg.core.datastructures.scfg import SCFG +from numba_rvsdg.core.transformations import ( + loop_restructure_helper, + join_returns, + join_tails_and_exits, +) +from numba_rvsdg.tests.test_utils import SCFGComparator - def test_dual_predecessor_and_dual_successor(self): - original = """ - "0": - jt: ["1", "2"] - "1": - jt: ["3"] - "2": - jt: ["4"] - "3": - jt: [] - "4": - jt: [] - """ - original_block_map = BlockMap.from_yaml(original) - expected = """ - "0": - jt: ["1", "2"] - "1": - jt: ["5"] - "2": - jt: ["5"] - "3": - jt: [] - "4": - jt: [] - "5": - jt: ["3", "4"] - """ - expected_block_map = BlockMap.from_yaml(expected) - original_block_map.insert_block( - ControlLabel("5"), - wrap_id(("1", "2")), - wrap_id(("3", "4")), - ) - self.assertMapEqual(expected_block_map, original_block_map) - def test_dual_predecessor_and_dual_successor_with_additional_arcs(self): +class TestJoinReturns(SCFGComparator): + def test_two_returns(self): original = """ "0": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "1": - jt: ["3"] + type: "basic" + out: [] "2": - jt: ["1", "4"] - "3": - jt: ["0"] - "4": - jt: [] + type: "basic" + out: [] """ - original_block_map = BlockMap.from_yaml(original) + original_scfg, block_ref_orig = SCFG.from_yaml(original) expected = """ "0": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "1": - jt: ["5"] + type: "basic" + out: ["3"] "2": - jt: ["1", "5"] + type: "basic" + out: ["3"] "3": - jt: ["0"] - "4": - jt: [] - "5": - jt: ["3", "4"] + type: "basic" + label_type: "synth_return" + out: [] """ - expected_block_map = BlockMap.from_yaml(expected) - original_block_map.insert_block( - ControlLabel("5"), - wrap_id(("1", "2")), - wrap_id(("3", "4")), - ) - self.assertMapEqual(expected_block_map, original_block_map) + expected_scfg, _ = SCFG.from_yaml(expected) + join_returns(original_scfg) -class TestJoinReturns(MapComparator): - def test_two_returns(self): - original = """ - "0": - jt: ["1", "2"] - "1": - jt: [] - "2": - jt: [] - """ - original_block_map = BlockMap.from_yaml(original) - expected = """ - "0": - jt: ["1", "2"] - "1": - jt: ["3"] - "2": - jt: ["3"] - "3": - jt: [] - """ - expected_block_map = BlockMap.from_yaml(expected) - original_block_map.join_returns() - self.assertMapEqual(expected_block_map, original_block_map) + self.assertSCFGEqual(expected_scfg, original_scfg) -class TestJoinTailsAndExits(MapComparator): +class TestJoinTailsAndExits(SCFGComparator): def test_join_tails_and_exits_case_00(self): original = """ "0": - jt: ["1"] + type: "basic" + out: ["1"] "1": - jt: [] + type: "basic" + out: [] """ - original_block_map = BlockMap.from_yaml(original) + original_scfg, block_ref_orig = SCFG.from_yaml(original) expected = """ "0": - jt: ["1"] + type: "basic" + out: ["1"] "1": - jt: [] + type: "basic" + out: [] """ - expected_block_map = BlockMap.from_yaml(expected) + expected_scfg, _ = SCFG.from_yaml(expected) - tails = wrap_id(("0",)) - exits = wrap_id(("1",)) - solo_tail_label, solo_exit_label = original_block_map.join_tails_and_exits( - tails, exits - ) + tails = list((block_ref_orig["0"],)) + exits = list((block_ref_orig["1"],)) + join_tails_and_exits(original_scfg, tails, exits) - self.assertMapEqual(expected_block_map, original_block_map) - self.assertEqual(ControlLabel("0"), solo_tail_label) - self.assertEqual(ControlLabel("1"), solo_exit_label) + self.assertSCFGEqual(expected_scfg, original_scfg) def test_join_tails_and_exits_case_01(self): original = """ "0": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "1": - jt: ["3"] + type: "basic" + out: ["3"] "2": - jt: ["3"] + type: "basic" + out: ["3"] "3": - jt: [] + type: "basic" + out: [] """ - original_block_map = BlockMap.from_yaml(original) + original_scfg, block_ref_orig = SCFG.from_yaml(original) expected = """ "0": - jt: ["4"] + type: "basic" + out: ["4"] "1": - jt: ["3"] + type: "basic" + out: ["3"] "2": - jt: ["3"] + type: "basic" + out: ["3"] "3": - jt: [] + type: "basic" + out: [] "4": - jt: ["1", "2"] + type: "basic" + label_type: "synth_exit" + out: ["1", "2"] """ - expected_block_map = BlockMap.from_yaml(expected) + expected_scfg, _ = SCFG.from_yaml(expected) - tails = wrap_id(("0",)) - exits = wrap_id(("1", "2")) - solo_tail_label, solo_exit_label = original_block_map.join_tails_and_exits( - tails, exits - ) + tails = list((block_ref_orig["0"],)) + exits = list((block_ref_orig["1"], block_ref_orig["2"])) + join_tails_and_exits(original_scfg, tails, exits) - self.assertMapEqual(expected_block_map, original_block_map) - self.assertEqual(ControlLabel("0"), solo_tail_label) - self.assertEqual(SyntheticExit("4"), solo_exit_label) + self.assertSCFGEqual(expected_scfg, original_scfg) def test_join_tails_and_exits_case_02_01(self): original = """ "0": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "1": - jt: ["3"] + type: "basic" + out: ["3"] "2": - jt: ["3"] + type: "basic" + out: ["3"] "3": - jt: [] + type: "basic" + out: [] """ - original_block_map = BlockMap.from_yaml(original) + original_scfg, block_ref_orig = SCFG.from_yaml(original) expected = """ "0": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "1": - jt: ["4"] + type: "basic" + out: ["4"] "2": - jt: ["4"] + type: "basic" + out: ["4"] "3": - jt: [] + type: "basic" + out: [] "4": - jt: ["3"] + type: "basic" + label_type: "synth_tail" + out: ["3"] """ - expected_block_map = BlockMap.from_yaml(expected) + expected_scfg, _ = SCFG.from_yaml(expected) - tails = wrap_id(("1", "2")) - exits = wrap_id(("3",)) - solo_tail_label, solo_exit_label = original_block_map.join_tails_and_exits( - tails, exits - ) + tails = list((block_ref_orig["1"], block_ref_orig["2"])) + exits = list((block_ref_orig["3"],)) + join_tails_and_exits(original_scfg, tails, exits) - self.assertMapEqual(expected_block_map, original_block_map) - self.assertEqual(SyntheticTail("4"), solo_tail_label) - self.assertEqual(ControlLabel("3"), solo_exit_label) + self.assertSCFGEqual(expected_scfg, original_scfg) def test_join_tails_and_exits_case_02_02(self): original = """ "0": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "1": - jt: ["3"] + type: "basic" + out: ["3"] "2": - jt: ["1", "3"] + type: "basic" + out: ["1", "3"] "3": - jt: [] + type: "basic" + out: [] """ - original_block_map = BlockMap.from_yaml(original) + original_scfg, block_ref_orig = SCFG.from_yaml(original) expected = """ "0": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "1": - jt: ["4"] + type: "basic" + out: ["4"] "2": - jt: ["1", "4"] + type: "basic" + out: ["1", "4"] "3": - jt: [] + type: "basic" + out: [] "4": - jt: ["3"] + type: "basic" + label_type: "synth_tail" + out: ["3"] """ - expected_block_map = BlockMap.from_yaml(expected) + expected_scfg, _ = SCFG.from_yaml(expected) - tails = wrap_id(("1", "2")) - exits = wrap_id(("3",)) + tails = list((block_ref_orig["1"], block_ref_orig["2"])) + exits = list((block_ref_orig["3"],)) - solo_tail_label, solo_exit_label = original_block_map.join_tails_and_exits( - tails, exits - ) - self.assertMapEqual(expected_block_map, original_block_map) - self.assertEqual(SyntheticTail("4"), solo_tail_label) - self.assertEqual(ControlLabel("3"), solo_exit_label) + join_tails_and_exits(original_scfg, tails, exits) + self.assertSCFGEqual(expected_scfg, original_scfg) def test_join_tails_and_exits_case_03_01(self): original = """ "0": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "1": - jt: ["3"] + type: "basic" + out: ["3"] "2": - jt: ["4"] + type: "basic" + out: ["4"] "3": - jt: ["5"] + type: "basic" + out: ["5"] "4": - jt: ["5"] + type: "basic" + out: ["5"] "5": - jt: [] + type: "basic" + out: [] """ - original_block_map = BlockMap.from_yaml(original) + original_scfg, block_ref_orig = SCFG.from_yaml(original) expected = """ "0": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "1": - jt: ["6"] + type: "basic" + out: ["6"] "2": - jt: ["6"] + type: "basic" + out: ["6"] "3": - jt: ["5"] + type: "basic" + out: ["5"] "4": - jt: ["5"] + type: "basic" + out: ["5"] "5": - jt: [] + type: "basic" + out: [] "6": - jt: ["7"] + type: "basic" + label_type: "synth_tail" + out: ["7"] "7": - jt: ["3", "4"] + type: "basic" + label_type: "synth_exit" + out: ["3", "4"] """ - expected_block_map = BlockMap.from_yaml(expected) + expected_scfg, _ = SCFG.from_yaml(expected) + + tails = list((block_ref_orig["1"], block_ref_orig["2"])) + exits = list((block_ref_orig["3"], block_ref_orig["4"])) + join_tails_and_exits(original_scfg, tails, exits) - tails = wrap_id(("1", "2")) - exits = wrap_id(("3", "4")) - solo_tail_label, solo_exit_label = original_block_map.join_tails_and_exits( - tails, exits - ) - self.assertMapEqual(expected_block_map, original_block_map) - self.assertEqual(SyntheticTail("6"), solo_tail_label) - self.assertEqual(SyntheticExit("7"), solo_exit_label) + self.assertSCFGEqual(expected_scfg, original_scfg) def test_join_tails_and_exits_case_03_02(self): original = """ "0": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "1": - jt: ["3"] + type: "basic" + out: ["3"] "2": - jt: ["1", "4"] + type: "basic" + out: ["1", "4"] "3": - jt: ["5"] + type: "basic" + out: ["5"] "4": - jt: ["5"] + type: "basic" + out: ["5"] "5": - jt: [] + type: "basic" + out: [] """ - original_block_map = BlockMap.from_yaml(original) + original_scfg, block_ref_orig = SCFG.from_yaml(original) expected = """ "0": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "1": - jt: ["6"] + type: "basic" + out: ["6"] "2": - jt: ["1", "6"] + type: "basic" + out: ["1", "6"] "3": - jt: ["5"] + type: "basic" + out: ["5"] "4": - jt: ["5"] + type: "basic" + out: ["5"] "5": - jt: [] + type: "basic" + out: [] "6": - jt: ["7"] + type: "basic" + label_type: "synth_tail" + out: ["7"] "7": - jt: ["3", "4"] + type: "basic" + label_type: "synth_exit" + out: ["3", "4"] """ - expected_block_map = BlockMap.from_yaml(expected) - tails = wrap_id(("1", "2")) - exits = wrap_id(("3", "4")) - solo_tail_label, solo_exit_label = original_block_map.join_tails_and_exits( - tails, exits - ) - self.assertMapEqual(expected_block_map, original_block_map) - self.assertEqual(SyntheticTail("6"), solo_tail_label) - self.assertEqual(SyntheticExit("7"), solo_exit_label) + expected_scfg, _ = SCFG.from_yaml(expected) + tails = list((block_ref_orig["1"], block_ref_orig["2"])) + exits = list((block_ref_orig["3"], block_ref_orig["4"])) + join_tails_and_exits(original_scfg, tails, exits) -class TestLoopRestructure(MapComparator): + self.assertSCFGEqual(expected_scfg, original_scfg) + +class TestLoopRestructure(SCFGComparator): def test_no_op_mono(self): """Loop consists of a single Block.""" original = """ "0": - jt: ["1"] + type: "basic" + out: ["1"] "1": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "2": - jt: [] + type: "basic" + out: [] """ expected = """ "0": - jt: ["1"] + type: "basic" + out: ["1"] "1": - jt: ["1", "2"] - be: ["1"] + type: "basic" + out: ["1", "2"] + back: ["1"] "2": - jt: [] + type: "basic" + out: [] """ - original_block_map = BlockMap.from_yaml(original) - expected_block_map = BlockMap.from_yaml(expected) - loop_restructure_helper(original_block_map, set(wrap_id({"1"}))) - self.assertMapEqual(expected_block_map, original_block_map) + original_scfg, block_ref_orig = SCFG.from_yaml(original) + expected_scfg, _ = SCFG.from_yaml(expected) + + loop_restructure_helper(original_scfg, set((block_ref_orig["1"],))) + + self.assertSCFGEqual(expected_scfg, original_scfg) def test_no_op(self): """Loop consists of two blocks, but it's in form.""" - original =""" + original = """ "0": - jt: ["1"] + type: "basic" + out: ["1"] "1": - jt: ["2"] + type: "basic" + out: ["2"] "2": - jt: ["1", "3"] + type: "basic" + out: ["1", "3"] "3": - jt: [] + type: "basic" + out: [] """ - expected =""" + expected = """ "0": - jt: ["1"] + type: "basic" + out: ["1"] "1": - jt: ["2"] + type: "basic" + out: ["2"] "2": - jt: ["1", "3"] - be: ["1"] + type: "basic" + out: ["1", "3"] + back: ["1"] "3": - jt: [] + type: "basic" + out: [] """ - original_block_map = BlockMap.from_yaml(original) - expected_block_map = BlockMap.from_yaml(expected) - loop_restructure_helper(original_block_map, set(wrap_id({"1", "2"}))) - self.assertMapEqual(expected_block_map, original_block_map) + original_scfg, block_ref_orig = SCFG.from_yaml(original) + expected_scfg, _ = SCFG.from_yaml(expected) + loop_restructure_helper(original_scfg, set((block_ref_orig["1"], block_ref_orig["2"]))) + self.assertSCFGEqual(expected_scfg, original_scfg) def test_backedge_not_exiting(self): """Loop has a backedge not coming from the exiting block. @@ -475,35 +398,49 @@ def test_backedge_not_exiting(self): """ original = """ "0": - jt: ["1"] + type: "basic" + out: ["1"] "1": - jt: ["2", "3"] + type: "basic" + out: ["2", "3"] "2": - jt: ["1"] + type: "basic" + out: ["1"] "3": - jt: [] + type: "basic" + out: [] """ expected = """ "0": - jt: ["1"] + type: "basic" + out: ["1"] "1": - jt: ["2", "5"] + type: "basic" + out: ["2", "5"] "2": - jt: ["6"] + type: "basic" + out: ["6"] "3": - jt: [] + type: "basic" + out: [] "4": - jt: ["1", "3"] - be: ["1"] + type: "basic" + label_type: "synth_exit_latch" + out: ["1", "3"] + back: ["1"] "5": - jt: ["4"] + type: "basic" + label_type: "synth_assign" + out: ["4"] "6": - jt: ["4"] + type: "basic" + label_type: "synth_assign" + out: ["4"] """ - original_block_map = BlockMap.from_yaml(original) - expected_block_map = BlockMap.from_yaml(expected) - loop_restructure_helper(original_block_map, set(wrap_id({"1", "2"}))) - self.assertMapEqual(expected_block_map, original_block_map) + original_scfg, block_ref_orig = SCFG.from_yaml(original) + expected_scfg, _ = SCFG.from_yaml(expected) + loop_restructure_helper(original_scfg, set((block_ref_orig["1"], block_ref_orig["2"]))) + self.assertSCFGEqual(expected_scfg, original_scfg) def test_double_exit(self): """Loop has two exiting blocks. @@ -513,159 +450,441 @@ def test_double_exit(self): """ original = """ "0": - jt: ["1"] + type: "basic" + out: ["1"] "1": - jt: ["2"] + type: "basic" + out: ["2"] "2": - jt: ["3", "4"] + type: "basic" + out: ["3", "4"] "3": - jt: ["1", "4"] + type: "basic" + out: ["1", "4"] "4": - jt: [] + type: "basic" + out: [] """ expected = """ "0": - jt: ["1"] + type: "basic" + out: ["1"] "1": - jt: ["2"] + type: "basic" + out: ["2"] "2": - jt: ["3", "6"] + type: "basic" + out: ["3", "6"] "3": - jt: ["7", "8"] + type: "basic" + out: ["7", "8"] "4": - jt: [] + type: "basic" + out: [] "5": - jt: ["1", "4"] - be: ["1"] + type: "basic" + label_type: "synth_exit_latch" + out: ["1", "4"] + back: ["1"] "6": - jt: ["5"] + type: "basic" + label_type: "synth_assign" + out: ["5"] "7": - jt: ["5"] + type: "basic" + label_type: "synth_assign" + out: ["5"] "8": - jt: ["5"] + type: "basic" + label_type: "synth_assign" + out: ["5"] """ - original_block_map = BlockMap.from_yaml(original) - expected_block_map = BlockMap.from_yaml(expected) - loop_restructure_helper(original_block_map, set(wrap_id({"1", "2", "3"}))) - self.assertMapEqual(expected_block_map, original_block_map) + original_scfg, block_ref_orig = SCFG.from_yaml(original) + expected_scfg, _ = SCFG.from_yaml(expected) + loop_restructure_helper(original_scfg, set((block_ref_orig["1"], block_ref_orig["2"], block_ref_orig["3"]))) + self.assertSCFGEqual(expected_scfg, original_scfg) def test_double_header(self): - """ This is like the example from Bahman2015 fig. 3 -- + """This is like the example from Bahman2015 fig. 3 -- but with one exiting block removed.""" original = """ "0": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "1": - jt: ["3"] + type: "basic" + out: ["3"] "2": - jt: ["4"] + type: "basic" + out: ["4"] "3": - jt: ["2", "5"] + type: "basic" + out: ["2", "5"] "4": - jt: ["1"] + type: "basic" + out: ["1"] "5": - jt: [] + type: "basic" + out: [] """ expected = """ "0": - jt: ["7", "8"] + type: "basic" + out: ["7", "8"] "1": - jt: ["3"] + type: "basic" + out: ["3"] "2": - jt: ["4"] + type: "basic" + out: ["4"] "3": - jt: ["10", "11"] + type: "basic" + out: ["10", "11"] "4": - jt: ["12"] + type: "basic" + out: ["12"] "5": - jt: [] + type: "basic" + out: [] "6": - jt: ["1", "2"] + type: "basic" + label_type: "synth_head" + out: ["1", "2"] "7": - jt: ["6"] + type: "basic" + label_type: "synth_assign" + out: ["6"] "8": - jt: ["6"] + type: "basic" + label_type: "synth_assign" + out: ["6"] "9": - jt: ["5", "6"] - be: ["6"] + type: "basic" + label_type: "synth_exit_latch" + out: ["6", "5"] + back: ["6"] "10": - jt: ["9"] + type: "basic" + label_type: "synth_assign" + out: ["9"] "11": - jt: ["9"] + type: "basic" + label_type: "synth_assign" + out: ["9"] "12": - jt: ["9"] + type: "basic" + label_type: "synth_assign" + out: ["9"] """ - original_block_map = BlockMap.from_yaml(original) - expected_block_map = BlockMap.from_yaml(expected) - loop_restructure_helper(original_block_map, set(wrap_id({"1", "2", "3", "4"}))) - self.assertMapEqual(expected_block_map, original_block_map) + original_scfg, block_ref_orig = SCFG.from_yaml(original) + expected_scfg, _ = SCFG.from_yaml(expected) + loop_restructure_helper(original_scfg, set((block_ref_orig["1"], block_ref_orig["2"], block_ref_orig["3"], block_ref_orig["4"]))) + self.assertSCFGEqual(expected_scfg, original_scfg) def test_double_header_double_exiting(self): - """ This is like the example from Bahman2015 fig. 3. + """This is like the example from Bahman2015 fig. 3. Two headers that need to be multiplexed to, on additional branch that becomes the exiting latch and one branch that becomes the exit. - + """ original = """ "0": - jt: ["1", "2"] + type: "basic" + out: ["1", "2"] "1": - jt: ["3"] + type: "basic" + out: ["3"] "2": - jt: ["4"] + type: "basic" + out: ["4"] "3": - jt: ["2", "5"] + type: "basic" + out: ["2", "5"] "4": - jt: ["1", "6"] + type: "basic" + out: ["1", "6"] "5": - jt: ["7"] + type: "basic" + out: ["7"] "6": - jt: ["7"] + type: "basic" + out: ["7"] "7": - jt: [] + type: "basic" + out: [] """ expected = """ "0": - jt: ["10", "9"] + type: "basic" + out: ["9", "10"] "1": - jt: ["3"] + type: "basic" + out: ["3"] "2": - jt: ["4"] + type: "basic" + out: ["4"] "3": - jt: ["13", "14"] + type: "basic" + out: ["13", "14"] "4": - jt: ["15", "16"] + type: "basic" + out: ["15", "16"] "5": - jt: ["7"] + type: "basic" + out: ["7"] "6": - jt: ["7"] + type: "basic" + out: ["7"] "7": - jt: [] + type: "basic" + out: [] "8": - jt: ["1", "2"] + type: "basic" + label_type: "synth_head" + out: ["1", "2"] "9": - jt: ["8"] + type: "basic" + label_type: "synth_assign" + out: ["8"] "10": - jt: ["8"] + type: "basic" + label_type: "synth_assign" + out: ["8"] "11": - jt: ["12", "8"] - be: ["8"] + type: "basic" + label_type: "synth_exit" + out: ["5", "6"] "12": - jt: ["5", "6"] + type: "basic" + label_type: "synth_exit_latch" + out: ["8", "11"] + back: ["8"] "13": - jt: ["11"] + type: "basic" + label_type: "synth_assign" + out: ["12"] "14": - jt: ["11"] + type: "basic" + label_type: "synth_assign" + out: ["12"] "15": - jt: ["11"] + type: "basic" + label_type: "synth_assign" + out: ["12"] "16": - jt: ["11"] + type: "basic" + label_type: "synth_assign" + out: ["12"] + """ + original_scfg, block_ref_orig = SCFG.from_yaml(original) + expected_scfg, _ = SCFG.from_yaml(expected) + loop_restructure_helper(original_scfg, set((block_ref_orig["1"], block_ref_orig["2"], block_ref_orig["3"], block_ref_orig["4"]))) + self.assertSCFGEqual(expected_scfg, original_scfg) + + +class TestLoops(SCFGComparator): + def test_basic_for_loop(self): + + original = """ + "0": + type: "basic" + out: ["1"] + "1": + type: "basic" + out: ["2", "3"] + "2": + type: "basic" + out: ["1"] + "3": + type: "basic" + out: [] """ - original_block_map = BlockMap.from_yaml(original) - expected_block_map = BlockMap.from_yaml(expected) - loop_restructure_helper(original_block_map, set(wrap_id({"1", "2", "3", "4"}))) - self.assertMapEqual(expected_block_map, original_block_map) + original_scfg, block_ref_orig = SCFG.from_yaml(original) + expected = """ + "0": + type: "basic" + out: ["1"] + "1": + type: "basic" + out: ["2", "5"] + "2": + type: "basic" + out: ["6"] + "3": + type: "basic" + out: [] + "4": + type: "basic" + label_type: "synth_exit_latch" + out: ["1", "3"] + back: ["1"] + "5": + type: "basic" + label_type: "synth_assign" + out: ["4"] + "6": + type: "basic" + label_type: "synth_assign" + out: ["4"] + """ + expected_scfg, _ = SCFG.from_yaml(expected) + + loop_restructure_helper(original_scfg, set((block_ref_orig["1"], block_ref_orig["2"]))) + print(original_scfg.compute_scc()) + self.assertSCFGEqual(expected_scfg, original_scfg) + + def test_basic_while_loop(self): + original = """ + "0": + type: "basic" + out: ["1", "2"] + "1": + type: "basic" + out: ["1", "2"] + "2": + type: "basic" + out: [] + """ + original_scfg, block_ref_orig = SCFG.from_yaml(original) + expected = """ + "0": + type: "basic" + out: ["1", "2"] + "1": + type: "basic" + out: ["1", "2"] + back: ["1"] + "2": + type: "basic" + out: [] + """ + expected_scfg, _ = SCFG.from_yaml(expected) + + loop_restructure_helper(original_scfg, set((block_ref_orig["1"],))) + print(original_scfg.compute_scc()) + self.assertSCFGEqual(expected_scfg, original_scfg) + + def test_mixed_for_while_loop_with_branch(self): + original = """ + "0": + type: "basic" + out: ["1"] + "1": + type: "basic" + out: ["2", "7"] + "2": + type: "basic" + out: ["3", "6"] + "3": + type: "basic" + out: ["4", "5"] + "4": + type: "basic" + out: ["5"] + "5": + type: "basic" + out: ["3", "6"] + "6": + type: "basic" + out: ["1"] + "7": + type: "basic" + out: [] + """ + original_scfg, block_ref_orig = SCFG.from_yaml(original) + # this has two loops, so we need to attempt to rotate twice, first for + # the header controlled loop, inserting an additional block + expected01 = """ + "0": + type: "basic" + out: ["1"] + "1": + type: "basic" + out: ["2", "9"] + "2": + type: "basic" + out: ["3", "6"] + "3": + type: "basic" + out: ["4", "5"] + "4": + type: "basic" + out: ["5"] + "5": + type: "basic" + out: ["3", "6"] + "6": + type: "basic" + out: ["10"] + "7": + type: "basic" + out: [] + "8": + type: "basic" + label_type: "synth_exit_latch" + out: ["1", "7"] + back: ["1"] + "9": + type: "basic" + label_type: "synth_assign" + out: ["8"] + "10": + type: "basic" + label_type: "synth_assign" + out: ["8"] + """ + expected01_block_map, _ = SCFG.from_yaml(expected01) + loop_restructure_helper(original_scfg, + set((block_ref_orig["1"], block_ref_orig["2"], block_ref_orig["3"], block_ref_orig["4"], block_ref_orig["5"], block_ref_orig["6"]))) + self.assertSCFGEqual(expected01_block_map, original_scfg) + # And then, we make sure that the inner-loop remains unchanged, and the + # loop rotation will only detect the aditional backedge, from 5 to 3 + expected02 = """ + "0": + type: "basic" + out: ["1"] + "1": + type: "basic" + out: ["2", "9"] + "2": + type: "basic" + out: ["3", "6"] + "3": + type: "basic" + out: ["4", "5"] + "4": + type: "basic" + out: ["5"] + "5": + type: "basic" + out: ["3", "6"] + back: ["3"] + "6": + type: "basic" + out: ["10"] + "7": + type: "basic" + out: [] + "8": + type: "basic" + label_type: "synth_exit_latch" + out: ["1", "7"] + back: ["1"] + "9": + type: "basic" + label_type: "synth_assign" + out: ["8"] + "10": + type: "basic" + label_type: "synth_assign" + out: ["8"] + """ + expected02_block_map, _ = SCFG.from_yaml(expected02) + loop_restructure_helper(original_scfg, + set((block_ref_orig["3"], block_ref_orig["4"], block_ref_orig["5"],))) + self.assertSCFGEqual(expected02_block_map, original_scfg) + if __name__ == "__main__": main() diff --git a/numba_rvsdg/tests/test_utils.py b/numba_rvsdg/tests/test_utils.py index c820e68..dc73e2f 100644 --- a/numba_rvsdg/tests/test_utils.py +++ b/numba_rvsdg/tests/test_utils.py @@ -1,21 +1,46 @@ from unittest import TestCase -class MapComparator(TestCase): - def assertMapEqual(self, first_map, second_map): +from typing import Dict, List +from numba_rvsdg.core.datastructures.scfg import SCFG + + +class SCFGComparator(TestCase): + def assertSCFGEqual(self, first_scfg: SCFG, second_scfg: SCFG): for key1, key2 in zip( - sorted(first_map.graph.keys(), key=lambda x: x.index), - sorted(second_map.graph.keys(), key=lambda x: x.index), + sorted(first_scfg.blocks.keys(), key=lambda x: x.name), + sorted(second_scfg.blocks.keys(), key=lambda x: x.name), ): - # compare indices of labels - self.assertEqual(key1.index, key2.index) - # compare indices of jump_targets - self.assertEqual( - sorted([j.index for j in first_map[key1]._jump_targets]), - sorted([j.index for j in second_map[key2]._jump_targets]), - ) - # compare indices of backedges - self.assertEqual( - sorted([j.index for j in first_map[key1].backedges]), - sorted([j.index for j in second_map[key2].backedges]), - ) + block_1 = first_scfg[key1] + block_2 = second_scfg[key2] + + # compare labels + self.assertEqual(type(block_1.label), type(block_2.label)) + # compare edges + self.assertEqual(first_scfg.out_edges[key1], second_scfg.out_edges[key2]) + self.assertEqual(first_scfg.back_edges, second_scfg.back_edges) + + def assertYAMLEquals(self, first_yaml: str, second_yaml: str, ref_dict: Dict): + for key, value in ref_dict.items(): + first_yaml = first_yaml.replace(key, value.name) + + self.assertEqual(first_yaml, second_yaml) + + def assertDictEquals(self, first_dict: dict, second_dict: dict, ref_dict: dict): + + def replace_with_refs(scfg_dict: dict): + new_dict = {} + for key, value in scfg_dict.items(): + key = str(ref_dict[key]) + _new_dict = {} + for _key, _value in value.items(): + if isinstance(_value, list): + for i in range(len(_value)): + _value[i] = str(ref_dict[_value[i]]) + _new_dict[_key] = _value + new_dict[key] = _new_dict + + return new_dict + + first_dict = replace_with_refs(first_dict) + self.assertEqual(first_dict, second_dict)