From 5937addc9f32b7cae9253d79a0af3fa37e4e9504 Mon Sep 17 00:00:00 2001 From: "C.A.P. Linssen" Date: Wed, 8 May 2024 11:54:40 +0200 Subject: [PATCH] refactor ASTNode.get_parent() to improve runtime performance --- pynestml/meta_model/ast_model.py | 4 +-- pynestml/meta_model/ast_on_condition_block.py | 9 +++++- pynestml/meta_model/ast_simple_expression.py | 28 ++++++++++++------- .../synapse_post_neuron_transformer.py | 10 ++++--- pynestml/utils/ast_utils.py | 10 +++++++ pynestml/visitors/ast_parent_visitor.py | 9 ------ tests/test_symbol_table_builder.py | 3 ++ 7 files changed, 47 insertions(+), 26 deletions(-) diff --git a/pynestml/meta_model/ast_model.py b/pynestml/meta_model/ast_model.py index 207592ca9..834e56897 100644 --- a/pynestml/meta_model/ast_model.py +++ b/pynestml/meta_model/ast_model.py @@ -475,7 +475,7 @@ def add_to_internals_block(self, declaration: ASTDeclaration, index: int = -1) - symtable_vistor = ASTSymbolTableVisitor() symtable_vistor.block_type_stack.push(BlockType.INTERNALS) declaration.accept(symtable_vistor) - self.get_internals_blocks().accept(ASTParentVisitor()) + self.get_internals_blocks()[0].accept(ASTParentVisitor()) symtable_vistor.block_type_stack.pop() def add_to_state_block(self, declaration: ASTDeclaration) -> None: @@ -494,7 +494,7 @@ def add_to_state_block(self, declaration: ASTDeclaration) -> None: symtable_vistor = ASTSymbolTableVisitor() symtable_vistor.block_type_stack.push(BlockType.STATE) declaration.accept(symtable_vistor) - self.get_state_blocks().accept(ASTParentVisitor()) + self.get_state_blocks()[0].accept(ASTParentVisitor()) symtable_vistor.block_type_stack.pop() from pynestml.symbols.symbol import SymbolKind assert declaration.get_variables()[0].get_scope().resolve_to_symbol( diff --git a/pynestml/meta_model/ast_on_condition_block.py b/pynestml/meta_model/ast_on_condition_block.py index 3e1aaf745..d8e1ac4cd 100644 --- a/pynestml/meta_model/ast_on_condition_block.py +++ b/pynestml/meta_model/ast_on_condition_block.py @@ -87,7 +87,14 @@ def get_children(self) -> List[ASTNode]: Returns the children of this node, if any. :return: List of children of this node. """ - return [self.get_block()] + children = [] + if self.cond_expr: + children.append(self.cond_expr) + + if self.get_block(): + children.append(self.get_block()) + + return children def equals(self, other: ASTNode) -> bool: r""" diff --git a/pynestml/meta_model/ast_simple_expression.py b/pynestml/meta_model/ast_simple_expression.py index fcb12b86d..8514f76d2 100644 --- a/pynestml/meta_model/ast_simple_expression.py +++ b/pynestml/meta_model/ast_simple_expression.py @@ -286,9 +286,11 @@ def get_children(self) -> List[ASTNode]: return [] - def equals(self, other: ASTNode) -> bool: - r""" - The equality method. + def set_variable(self, variable): + """ + Updates the variable of this node. + :param variable: a single variable + :type variable: ASTVariable """ assert (variable is None or isinstance(variable, ASTVariable)), \ '(PyNestML.AST.SimpleExpression) No or wrong type of variable provided (%s)!' % type(variable) @@ -304,33 +306,39 @@ def set_function_call(self, function_call): '(PyNestML.AST.SimpleExpression) No or wrong type of function call provided (%s)!' % type(function_call) self.function_call = function_call - def equals(self, other): - """ - The equals method. - :param other: a different object. - :type other: object - :return:True if equal, otherwise False. - :rtype: bool + def equals(self, other: ASTNode) -> bool: + r""" + The equality method. """ if not isinstance(other, ASTSimpleExpression): return False + if self.is_function_call() + other.is_function_call() == 1: return False + if self.is_function_call() and other.is_function_call() and not self.get_function_call().equals( other.get_function_call()): return False + if self.get_numeric_literal() != other.get_numeric_literal(): return False + if self.is_boolean_false != other.is_boolean_false or self.is_boolean_true != other.is_boolean_true: return False + if self.is_variable() + other.is_variable() == 1: return False + if self.is_variable() and other.is_variable() and not self.get_variable().equals(other.get_variable()): return False + if self.is_inf_literal != other.is_inf_literal: return False + if self.is_string() + other.is_string() == 1: return False + if self.get_string() != other.get_string(): return False + return True diff --git a/pynestml/transformers/synapse_post_neuron_transformer.py b/pynestml/transformers/synapse_post_neuron_transformer.py index e97d30259..5dd4aa3e0 100644 --- a/pynestml/transformers/synapse_post_neuron_transformer.py +++ b/pynestml/transformers/synapse_post_neuron_transformer.py @@ -226,10 +226,12 @@ def transform_neuron_synapse_pair_(self, neuron, synapse): new_neuron = neuron.clone() new_synapse = synapse.clone() - new_neuron.accept(ASTSymbolTableVisitor()) + new_neuron.parent_ = None # set root element new_neuron.accept(ASTParentVisitor()) - new_synapse.accept(ASTSymbolTableVisitor()) + new_synapse.parent_ = None # set root element new_synapse.accept(ASTParentVisitor()) + new_neuron.accept(ASTSymbolTableVisitor()) + new_synapse.accept(ASTSymbolTableVisitor()) assert len(new_neuron.get_equations_blocks()) <= 1, "Only one equations block per neuron supported for now." assert len(new_synapse.get_equations_blocks()) <= 1, "Only one equations block per synapse supported for now." @@ -544,12 +546,12 @@ def mark_post_port(_expr=None): # add modified versions of neuron and synapse to list # + new_neuron.accept(ASTParentVisitor()) + new_synapse.accept(ASTParentVisitor()) ast_symbol_table_visitor = ASTSymbolTableVisitor() ast_symbol_table_visitor.after_ast_rewrite_ = True new_neuron.accept(ast_symbol_table_visitor) new_synapse.accept(ast_symbol_table_visitor) - new_neuron.accept(ASTParentVisitor()) - new_synapse.accept(ASTParentVisitor()) ASTUtils.update_blocktype_for_common_parameters(new_synapse) diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index 35f92287d..9a67054a1 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -418,6 +418,10 @@ def create_internal_block(cls, model: ASTModel): ASTSourceLocation.get_added_source_position()) internal.update_scope(model.get_scope()) model.get_body().get_body_elements().append(internal) + + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + model.accept(ASTParentVisitor()) + return model @classmethod @@ -433,6 +437,10 @@ def create_state_block(cls, model: ASTModel): state = ASTNodeFactory.create_ast_block_with_variables(True, False, False, list(), ASTSourceLocation.get_added_source_position()) model.get_body().get_body_elements().append(state) + + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + model.accept(ASTParentVisitor()) + return model @classmethod @@ -638,6 +646,8 @@ def replace_var(_expr=None): if alternate_name: ast_ext_var.set_alternate_name(alternate_name) + ast_ext_var.parent_ = _expr + ast_ext_var.update_alt_scope(new_scope) from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor ast_ext_var.accept(ASTSymbolTableVisitor()) diff --git a/pynestml/visitors/ast_parent_visitor.py b/pynestml/visitors/ast_parent_visitor.py index 8384fac3f..f67e5b866 100644 --- a/pynestml/visitors/ast_parent_visitor.py +++ b/pynestml/visitors/ast_parent_visitor.py @@ -55,12 +55,3 @@ def visit(self, node: ASTNode): children = node.get_children() for child in children: child.parent_ = node - # queue = [node] - # while queue: - # node = queue.pop(0) # pop from the front of the queue -- breadth first search - - # children = node.get_children() - # for child in children: - # child.parent_ = node - - # queue.extend(children) diff --git a/tests/test_symbol_table_builder.py b/tests/test_symbol_table_builder.py index ba8082212..718b09ec3 100644 --- a/tests/test_symbol_table_builder.py +++ b/tests/test_symbol_table_builder.py @@ -36,6 +36,7 @@ from pynestml.symbols.predefined_variables import PredefinedVariables from pynestml.utils.logger import Logger, LoggingLevel from pynestml.visitors.ast_builder_visitor import ASTBuilderVisitor +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor @@ -78,6 +79,8 @@ def test_symbol_table_builder(self): SymbolTable.initialize_symbol_table(ast.get_source_position()) symbol_table_visitor = ASTSymbolTableVisitor() for model in ast.get_model_list(): + model.parent_ = None # set root element + model.accept(ASTParentVisitor()) model.accept(symbol_table_visitor) SymbolTable.add_model_scope(name=model.get_name(), scope=model.get_scope())