Skip to content

Commit

Permalink
refactor ASTNode.get_parent() to improve runtime performance
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed May 8, 2024
1 parent 2c00b03 commit 5937add
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 26 deletions.
4 changes: 2 additions & 2 deletions pynestml/meta_model/ast_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def add_to_internals_block(self, declaration: ASTDeclaration, index: int = -1) -
symtable_vistor = ASTSymbolTableVisitor()
symtable_vistor.block_type_stack.push(BlockType.INTERNALS)
declaration.accept(symtable_vistor)
self.get_internals_blocks().accept(ASTParentVisitor())
self.get_internals_blocks()[0].accept(ASTParentVisitor())
symtable_vistor.block_type_stack.pop()

def add_to_state_block(self, declaration: ASTDeclaration) -> None:
Expand All @@ -494,7 +494,7 @@ def add_to_state_block(self, declaration: ASTDeclaration) -> None:
symtable_vistor = ASTSymbolTableVisitor()
symtable_vistor.block_type_stack.push(BlockType.STATE)
declaration.accept(symtable_vistor)
self.get_state_blocks().accept(ASTParentVisitor())
self.get_state_blocks()[0].accept(ASTParentVisitor())
symtable_vistor.block_type_stack.pop()
from pynestml.symbols.symbol import SymbolKind
assert declaration.get_variables()[0].get_scope().resolve_to_symbol(
Expand Down
9 changes: 8 additions & 1 deletion pynestml/meta_model/ast_on_condition_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,14 @@ def get_children(self) -> List[ASTNode]:
Returns the children of this node, if any.
:return: List of children of this node.
"""
return [self.get_block()]
children = []
if self.cond_expr:
children.append(self.cond_expr)

if self.get_block():
children.append(self.get_block())

return children

def equals(self, other: ASTNode) -> bool:
r"""
Expand Down
28 changes: 18 additions & 10 deletions pynestml/meta_model/ast_simple_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,11 @@ def get_children(self) -> List[ASTNode]:

return []

def equals(self, other: ASTNode) -> bool:
r"""
The equality method.
def set_variable(self, variable):
"""
Updates the variable of this node.
:param variable: a single variable
:type variable: ASTVariable
"""
assert (variable is None or isinstance(variable, ASTVariable)), \
'(PyNestML.AST.SimpleExpression) No or wrong type of variable provided (%s)!' % type(variable)
Expand All @@ -304,33 +306,39 @@ def set_function_call(self, function_call):
'(PyNestML.AST.SimpleExpression) No or wrong type of function call provided (%s)!' % type(function_call)
self.function_call = function_call

def equals(self, other):
"""
The equals method.
:param other: a different object.
:type other: object
:return:True if equal, otherwise False.
:rtype: bool
def equals(self, other: ASTNode) -> bool:
r"""
The equality method.
"""
if not isinstance(other, ASTSimpleExpression):
return False

if self.is_function_call() + other.is_function_call() == 1:
return False

if self.is_function_call() and other.is_function_call() and not self.get_function_call().equals(
other.get_function_call()):
return False

if self.get_numeric_literal() != other.get_numeric_literal():
return False

if self.is_boolean_false != other.is_boolean_false or self.is_boolean_true != other.is_boolean_true:
return False

if self.is_variable() + other.is_variable() == 1:
return False

if self.is_variable() and other.is_variable() and not self.get_variable().equals(other.get_variable()):
return False

if self.is_inf_literal != other.is_inf_literal:
return False

if self.is_string() + other.is_string() == 1:
return False

if self.get_string() != other.get_string():
return False

return True
10 changes: 6 additions & 4 deletions pynestml/transformers/synapse_post_neuron_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,12 @@ def transform_neuron_synapse_pair_(self, neuron, synapse):
new_neuron = neuron.clone()
new_synapse = synapse.clone()

new_neuron.accept(ASTSymbolTableVisitor())
new_neuron.parent_ = None # set root element
new_neuron.accept(ASTParentVisitor())
new_synapse.accept(ASTSymbolTableVisitor())
new_synapse.parent_ = None # set root element
new_synapse.accept(ASTParentVisitor())
new_neuron.accept(ASTSymbolTableVisitor())
new_synapse.accept(ASTSymbolTableVisitor())

assert len(new_neuron.get_equations_blocks()) <= 1, "Only one equations block per neuron supported for now."
assert len(new_synapse.get_equations_blocks()) <= 1, "Only one equations block per synapse supported for now."
Expand Down Expand Up @@ -544,12 +546,12 @@ def mark_post_port(_expr=None):
# add modified versions of neuron and synapse to list
#

new_neuron.accept(ASTParentVisitor())
new_synapse.accept(ASTParentVisitor())
ast_symbol_table_visitor = ASTSymbolTableVisitor()
ast_symbol_table_visitor.after_ast_rewrite_ = True
new_neuron.accept(ast_symbol_table_visitor)
new_synapse.accept(ast_symbol_table_visitor)
new_neuron.accept(ASTParentVisitor())
new_synapse.accept(ASTParentVisitor())

ASTUtils.update_blocktype_for_common_parameters(new_synapse)

Expand Down
10 changes: 10 additions & 0 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ def create_internal_block(cls, model: ASTModel):
ASTSourceLocation.get_added_source_position())
internal.update_scope(model.get_scope())
model.get_body().get_body_elements().append(internal)

from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
model.accept(ASTParentVisitor())

return model

@classmethod
Expand All @@ -433,6 +437,10 @@ def create_state_block(cls, model: ASTModel):
state = ASTNodeFactory.create_ast_block_with_variables(True, False, False, list(),
ASTSourceLocation.get_added_source_position())
model.get_body().get_body_elements().append(state)

from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
model.accept(ASTParentVisitor())

return model

@classmethod
Expand Down Expand Up @@ -638,6 +646,8 @@ def replace_var(_expr=None):
if alternate_name:
ast_ext_var.set_alternate_name(alternate_name)

ast_ext_var.parent_ = _expr

ast_ext_var.update_alt_scope(new_scope)
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
ast_ext_var.accept(ASTSymbolTableVisitor())
Expand Down
9 changes: 0 additions & 9 deletions pynestml/visitors/ast_parent_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,3 @@ def visit(self, node: ASTNode):
children = node.get_children()
for child in children:
child.parent_ = node
# queue = [node]
# while queue:
# node = queue.pop(0) # pop from the front of the queue -- breadth first search

# children = node.get_children()
# for child in children:
# child.parent_ = node

# queue.extend(children)
3 changes: 3 additions & 0 deletions tests/test_symbol_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pynestml.symbols.predefined_variables import PredefinedVariables
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.visitors.ast_builder_visitor import ASTBuilderVisitor
from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor


Expand Down Expand Up @@ -78,6 +79,8 @@ def test_symbol_table_builder(self):
SymbolTable.initialize_symbol_table(ast.get_source_position())
symbol_table_visitor = ASTSymbolTableVisitor()
for model in ast.get_model_list():
model.parent_ = None # set root element
model.accept(ASTParentVisitor())
model.accept(symbol_table_visitor)
SymbolTable.add_model_scope(name=model.get_name(), scope=model.get_scope())

Expand Down

0 comments on commit 5937add

Please sign in to comment.