Skip to content

Commit

Permalink
allow neuron models without state and parameters block; fix dependent…
Browse files Browse the repository at this point in the history
… variables search for co-generation transformer
  • Loading branch information
C.A.P. Linssen committed Jun 4, 2024
1 parent 1edef95 commit c5de9c1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
12 changes: 4 additions & 8 deletions pynestml/transformers/synapse_post_neuron_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,9 @@ def transform_neuron_synapse_pair_(self, neuron, synapse):

if not new_synapse.get_equations_blocks():
ASTUtils.create_equations_block(new_synapse)
new_synapse.accept(ASTSymbolTableVisitor())

if not new_neuron.get_equations_blocks():
ASTUtils.create_equations_block(new_neuron)
new_synapse.accept(ASTSymbolTableVisitor())

post_port_names = []
for input_block in new_synapse.get_input_blocks():
Expand All @@ -375,7 +373,6 @@ def transform_neuron_synapse_pair_(self, neuron, synapse):

if syn_to_neuron_state_vars and not new_neuron.get_state_blocks():
ASTUtils.create_state_block(new_neuron)
new_neuron.accept(ASTSymbolTableVisitor())

for state_var in syn_to_neuron_state_vars:
Logger.log_message(None, -1, "Moving state variables for equation(s) " + str(state_var),
Expand Down Expand Up @@ -438,10 +435,10 @@ def mark_post_port(_expr=None):

collected_on_post_stmts.append(stmt)

stmt.scope = new_neuron.get_update_blocks()[0].scope
stmt.small_stmt.scope = new_neuron.get_update_blocks()[0].scope
stmt.small_stmt.get_assignment().scope = new_neuron.get_update_blocks()[0].scope
stmt.small_stmt.get_assignment().get_variable().scope = new_neuron.get_update_blocks()[0].scope
stmt.scope = new_neuron.scope
stmt.small_stmt.scope = new_neuron.scope
stmt.small_stmt.get_assignment().scope = new_neuron.scope
stmt.small_stmt.get_assignment().get_variable().scope = new_neuron.scope

for stmt in collected_on_post_stmts:
stmts.pop(stmts.index(stmt))
Expand Down Expand Up @@ -480,7 +477,6 @@ def mark_post_port(_expr=None):

if not new_neuron.get_parameters_blocks():
ASTUtils.create_parameters_block(new_neuron)
new_neuron.accept(ASTSymbolTableVisitor())

Logger.log_message(None, -1, "Copying parameters from synapse to neuron...", None, LoggingLevel.INFO)
for param_var in syn_to_neuron_params:
Expand Down
21 changes: 12 additions & 9 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,10 @@ def create_internal_block(cls, model: ASTModel):
"""
from pynestml.meta_model.ast_node_factory import ASTNodeFactory
if not model.get_internals_blocks():
internal = ASTNodeFactory.create_ast_block_with_variables(False, False, True, list(),
ASTSourceLocation.get_added_source_position())
internal.update_scope(model.get_scope())
model.get_body().get_body_elements().append(internal)
block = ASTNodeFactory.create_ast_block_with_variables(False, False, True, list(),
ASTSourceLocation.get_added_source_position())
block.update_scope(model.get_scope())
model.get_body().get_body_elements().append(block)

from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
model.accept(ASTParentVisitor())
Expand All @@ -451,10 +451,11 @@ def create_state_block(cls, model: ASTModel):
"""
# local import since otherwise circular dependency
from pynestml.meta_model.ast_node_factory import ASTNodeFactory
if not model.get_internals_blocks():
state = ASTNodeFactory.create_ast_block_with_variables(True, False, False, list(),
if not model.get_state_blocks():
block = ASTNodeFactory.create_ast_block_with_variables(True, False, False, list(),
ASTSourceLocation.get_added_source_position())
model.get_body().get_body_elements().append(state)
block.update_scope(model.get_scope())
model.get_body().get_body_elements().append(block)

from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
model.accept(ASTParentVisitor())
Expand All @@ -471,9 +472,10 @@ def create_parameters_block(cls, model: ASTModel):
# local import since otherwise circular dependency
from pynestml.meta_model.ast_node_factory import ASTNodeFactory
if not model.get_parameters_blocks():
state = ASTNodeFactory.create_ast_block_with_variables(False, True, False, list(),
block = ASTNodeFactory.create_ast_block_with_variables(False, True, False, list(),
ASTSourceLocation.get_added_source_position())
model.get_body().get_body_elements().append(state)
block.update_scope(model.get_scope())
model.get_body().get_body_elements().append(block)

from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
model.accept(ASTParentVisitor())
Expand All @@ -492,6 +494,7 @@ def create_equations_block(cls, model: ASTModel) -> ASTModel:
if not model.get_equations_blocks():
block = ASTNodeFactory.create_ast_equations_block(list(),
ASTSourceLocation.get_added_source_position())
block.update_scope(model.get_scope())
model.get_body().get_body_elements().append(block)

from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
Expand Down

0 comments on commit c5de9c1

Please sign in to comment.