Skip to content

Commit

Permalink
transform kernels and convolutions using a transformer before code ge…
Browse files Browse the repository at this point in the history
…neration
  • Loading branch information
C.A.P. Linssen committed May 11, 2024
1 parent abba1b4 commit d88aa15
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -717,28 +717,6 @@ void {{neuronName}}::update(nest::Time const & origin,const long from, const lon
update_delay_variables();
{%- endif %}

/**
* subthreshold updates of the convolution variables
*
* step 1: regardless of whether and how integrate_odes() will be called, update variables due to convolutions
**/

{%- if uses_analytic_solver %}
{%- for variable_name in analytic_state_variables: %}
{%- if "__X__" in variable_name %}
{%- set update_expr = update_expressions[variable_name] %}
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
{%- if use_gap_junctions %}
const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}__tmp_ = {{ printer.print(update_expr) | replace("B_." + gap_junction_port + "_grid_sum_", "(B_." + gap_junction_port + "_grid_sum_ + __I_gap)") }};
{%- else %}
const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}__tmp_ = {{ printer.print(update_expr) }};
{%- endif %}
{%- endif %}
{%- endfor %}
{%- endif %}


/**
* Begin NESTML generated code for the update block(s)
**/
Expand Down Expand Up @@ -768,22 +746,6 @@ const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}_
}
{%- endfor %}

/**
* subthreshold updates of the convolution variables
*
* step 2: regardless of whether and how integrate_odes() was called, update variables due to convolutions. Set to the updated values at the end of the timestep.
**/
{% if uses_analytic_solver %}
{%- for variable_name in analytic_state_variables: %}
{%- if "__X__" in variable_name %}
{%- set update_expr = update_expressions[variable_name] %}
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
{{ printer.print(var_ast) }} = {{variable_name}}__tmp_;
{%- endif %}
{%- endfor %}
{%- endif %}

/**
* Begin NESTML generated code for the onCondition block(s)
**/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class Neuron_{{neuronName}}(Neuron):
{%- endif %}
{%- endfor %}
{%- endfilter %}
pass
else:
# internals V_
{%- filter indent(6) %}
Expand Down Expand Up @@ -262,13 +263,6 @@ class Neuron_{{neuronName}}(Neuron):
{%- set analytic_state_variables_ = utils.filter_variables_list(analytic_state_variables_, ast.get_args()) %}
{%- endif %}

{#- always integrate convolutions in time #}
{%- for var in analytic_state_variables %}
{%- if "__X__" in var %}
{%- set tmp = analytic_state_variables_.append(var) %}
{%- endif %}
{%- endfor %}

{%- include "directives_py/AnalyticIntegrationStep_begin.jinja2" %}

{%- if uses_numeric_solver %}
Expand All @@ -283,14 +277,6 @@ class Neuron_{{neuronName}}(Neuron):
def step(self, origin: float, timestep: float) -> None:
__resolution: float = timestep # do not remove, this is necessary for the resolution() function

# -------------------------------------------------------------------------
# integrate variables related to convolutions
# -------------------------------------------------------------------------

{%- with analytic_state_variables_ = analytic_state_variables_from_convolutions %}
{%- include "directives_py/AnalyticIntegrationStep_begin.jinja2" %}
{%- endwith %}

# -------------------------------------------------------------------------
# NESTML generated code for the update block
# -------------------------------------------------------------------------
Expand All @@ -304,14 +290,6 @@ class Neuron_{{neuronName}}(Neuron):
{%- endfilter %}
{%- endif %}

# -------------------------------------------------------------------------
# integrate variables related to convolutions
# -------------------------------------------------------------------------

{%- with analytic_state_variables_ = analytic_state_variables_from_convolutions %}
{%- include "directives_py/AnalyticIntegrationStep_end.jinja2" %}
{%- endwith %}

# -------------------------------------------------------------------------
# begin NESTML generated code for the onReceive block(s)
# -------------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions pynestml/meta_model/ast_node_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ def create_ast_update_block(cls, block, source_position):
return ASTUpdateBlock(block, source_position=source_position)

@classmethod
def create_ast_variable(cls, name: str, differential_order: int = 0, vector_parameter=None, is_homogeneous=False, source_position: Optional[ASTSourceLocation] = None, scope: Optional[Scope] = None) -> ASTVariable:
var = ASTVariable(name, differential_order, vector_parameter=vector_parameter, is_homogeneous=is_homogeneous, source_position=source_position)
def create_ast_variable(cls, name: str, differential_order: int = 0, vector_parameter=None, is_homogeneous=False, type_symbol: Optional[str] = None, source_position: Optional[ASTSourceLocation] = None, scope: Optional[Scope] = None) -> ASTVariable:
var = ASTVariable(name, differential_order, type_symbol=type_symbol, vector_parameter=vector_parameter, is_homogeneous=is_homogeneous, source_position=source_position)
if scope:
var.scope = scope

Expand Down
62 changes: 62 additions & 0 deletions pynestml/transformers/convolutions_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pynestml.codegeneration.printers.unitless_sympy_simple_expression_printer import UnitlessSympySimpleExpressionPrinter
from pynestml.frontend.frontend_configuration import FrontendConfiguration
from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.meta_model.ast_data_type import ASTDataType
from pynestml.meta_model.ast_declaration import ASTDeclaration
from pynestml.meta_model.ast_equations_block import ASTEquationsBlock
from pynestml.meta_model.ast_expression import ASTExpression
Expand All @@ -47,6 +48,7 @@
from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression
from pynestml.meta_model.ast_variable import ASTVariable
from pynestml.symbols.predefined_functions import PredefinedFunctions
from pynestml.symbols.real_type_symbol import RealTypeSymbol
from pynestml.symbols.symbol import SymbolKind
from pynestml.symbols.variable_symbol import BlockType
from pynestml.transformers.transformer import Transformer
Expand Down Expand Up @@ -84,6 +86,61 @@ def __init__(self, options: Optional[Mapping[str, Any]] = None):
self._ode_toolbox_variable_printer._expression_printer = self._ode_toolbox_printer
self._ode_toolbox_function_call_printer._expression_printer = self._ode_toolbox_printer

def add_kernel_variables_to_integrate_odes_calls(self, model, solvers_json):
for solver_dict in solvers_json:
if solver_dict is None:
continue

for var_name, expr in solver_dict["initial_values"].items():
var = ASTUtils.get_variable_by_name(model, var_name)
ASTUtils.add_state_var_to_integrate_odes_calls(model, var)

model.accept(ASTParentVisitor())


def add_integrate_odes_call_for_kernel_variables(self, model, solvers_json):
var_names = []
for solver_dict in solvers_json:
if solver_dict is None:
continue

for var_name, expr in solver_dict["initial_values"].items():
var_names.append(var_name)

args = ASTUtils.resolve_variables_to_simple_expressions(model, var_names)
ast_function_call = ASTNodeFactory.create_ast_function_call("integrate_odes", args)
ASTUtils.add_function_call_to_update_block(ast_function_call, model)
model.accept(ASTParentVisitor())

def add_temporary_kernel_variables_copy(self, model, solvers_json):
var_names = []
for solver_dict in solvers_json:
if solver_dict is None:
continue

for var_name, expr in solver_dict["initial_values"].items():
var_names.append(var_name)

scope = model.get_update_blocks()[0].scope

for var_name in var_names:
var = ASTNodeFactory.create_ast_variable(var_name + "__tmp", type_symbol=RealTypeSymbol)
var.scope = scope
expr = ASTNodeFactory.create_ast_simple_expression(variable=ASTUtils.get_variable_by_name(model, var_name))
ast_declaration = ASTNodeFactory.create_ast_declaration(variables=[var],
data_type=ASTDataType(is_real=True),
expression=expr, source_position=ASTSourceLocation.get_added_source_position())
ast_declaration.update_scope(scope)
ast_small_stmt = ASTNodeFactory.create_ast_small_stmt(declaration=ast_declaration)
ast_small_stmt.update_scope(scope)
ast_stmt = ASTNodeFactory.create_ast_stmt(small_stmt=ast_small_stmt)
ast_stmt.update_scope(scope)

model.get_update_blocks()[0].get_block().stmts.insert(0, ast_stmt)

model.accept(ASTParentVisitor())
model.accept(ASTSymbolTableVisitor())

def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]:
r"""Transform a model or a list of models. Return an updated model or list of models."""
for model in models:
Expand All @@ -105,6 +162,11 @@ def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode,
self.create_spike_update_event_handlers(model, solvers_json, kernel_buffers)
self.replace_convolve_calls_with_buffers_(model)
self.remove_kernel_definitions_from_equations_blocks(model)
self.add_kernel_variables_to_integrate_odes_calls(model, solvers_json)
self.add_temporary_kernel_variables_copy(model, solvers_json)
self.add_integrate_odes_call_for_kernel_variables(model, solvers_json)
self.add_kernel_equations(model, solvers_json)

print("-------- MODEL AFTER TRANSFORM ------------")
print(model)
print("-------------------------------------------")
Expand Down
106 changes: 83 additions & 23 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,18 @@ def visit_function_call(self, node: ASTFunctionCall):
remove_state_var_from_integrate_odes_calls_visitor = RemoveStateVarFromIntegrateODEsCallsVisitor()
model.accept(remove_state_var_from_integrate_odes_calls_visitor)

@classmethod
def add_state_var_to_integrate_odes_calls(cls, model: ASTModel, var: ASTExpression):
r"""Add a state variable to the arguments to each integrate_odes() calls in the model."""

class AddStateVarToIntegrateODEsCallsVisitor(ASTVisitor):
def visit_function_call(self, node: ASTFunctionCall):
if node.get_name() == PredefinedFunctions.INTEGRATE_ODES:
expr = ASTNodeFactory.create_ast_simple_expression(variable=var.clone())
node.args.append(expr)

model.accept(AddStateVarToIntegrateODEsCallsVisitor())

@classmethod
def resolve_variables_to_expressions(cls, astnode, analytic_state_variables_moved):
"""receives a list of variable names (as strings) and returns a list of ASTExpressions containing each ASTVariable"""
Expand All @@ -564,7 +576,19 @@ def resolve_variables_to_expressions(cls, astnode, analytic_state_variables_move
for var_name in analytic_state_variables_moved:
node = ASTUtils.get_variable_by_name(astnode, var_name)
assert node is not None
expressions.append(ASTNodeFactory.create_ast_expression(False, None, False, ASTNodeFactory.create_ast_simple_expression(variable=node)))
expressions.append(ASTNodeFactory.create_ast_expression(expression=ASTNodeFactory.create_ast_simple_expression(variable=node)))

return expressions

@classmethod
def resolve_variables_to_simple_expressions(cls, model, vars):
"""receives a list of variable names (as strings) and returns a list of ASTSimpleExpressions containing each ASTVariable"""
expressions = []

for var_name in vars:
node = ASTUtils.get_variable_by_name(model, var_name)
assert node is not None
expressions.append(ASTNodeFactory.create_ast_simple_expression(variable=node))

return expressions

Expand Down Expand Up @@ -1113,45 +1137,80 @@ def declaration_in_state_block(cls, neuron: ASTModel, variable_name: str) -> boo
return False

@classmethod
def add_assignment_to_update_block(cls, assignment: ASTAssignment, neuron: ASTModel) -> ASTModel:
def add_assignment_to_update_block(cls, assignment: ASTAssignment, model: ASTModel) -> ASTModel:
"""
Adds a single assignment to the end of the update block of the handed over neuron. At most one update block should be present.
Adds a single assignment to the end of the update block of the handed over model. At most one update block should be present.
:param assignment: a single assignment
:param neuron: a single neuron instance
:return: the modified neuron
:param model: a single model instance
:return: the modified model
"""
assert len(neuron.get_update_blocks()) <= 1, "At most one update block should be present"
assert len(model.get_update_blocks()) <= 1, "At most one update block should be present"
small_stmt = ASTNodeFactory.create_ast_small_stmt(assignment=assignment,
source_position=ASTSourceLocation.get_added_source_position())
stmt = ASTNodeFactory.create_ast_stmt(small_stmt=small_stmt,
source_position=ASTSourceLocation.get_added_source_position())
if not neuron.get_update_blocks():
neuron.create_empty_update_block()
neuron.get_update_blocks()[0].get_block().get_stmts().append(stmt)
small_stmt.update_scope(neuron.get_update_blocks()[0].get_block().get_scope())
stmt.update_scope(neuron.get_update_blocks()[0].get_block().get_scope())
return neuron
if not model.get_update_blocks():
model.create_empty_update_block()
model.get_update_blocks()[0].get_block().get_stmts().append(stmt)
small_stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope())
stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope())

from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
model.accept(ASTParentVisitor())

return model

@classmethod
def add_function_call_to_update_block(cls, function_call: ASTFunctionCall, model: ASTModel) -> ASTModel:
"""
Adds a single assignment to the end of the update block of the handed over model.
:param function_call: a single function call
:param neuron: a single model instance
:return: the modified model
"""
assert len(model.get_update_blocks()) <= 1, "At most one update block should be present"

if not model.get_update_blocks():
model.create_empty_update_block()

small_stmt = ASTNodeFactory.create_ast_small_stmt(function_call=function_call,
source_position=ASTSourceLocation.get_added_source_position())
stmt = ASTNodeFactory.create_ast_stmt(small_stmt=small_stmt,
source_position=ASTSourceLocation.get_added_source_position())
model.get_update_blocks()[0].get_block().get_stmts().append(stmt)
small_stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope())
stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope())

from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
model.accept(ASTParentVisitor())

return model

@classmethod
def add_declaration_to_update_block(cls, declaration: ASTDeclaration, neuron: ASTModel) -> ASTModel:
def add_declaration_to_update_block(cls, declaration: ASTDeclaration, model: ASTModel) -> ASTModel:
"""
Adds a single declaration to the end of the update block of the handed over neuron.
Adds a single declaration to the end of the update block of the handed over model.
:param declaration: ASTDeclaration node to add
:param neuron: a single neuron instance
:return: a modified neuron
:param model: a single model instance
:return: a modified model
"""
assert len(neuron.get_update_blocks()) <= 1, "At most one update block should be present"
assert len(model.get_update_blocks()) <= 1, "At most one update block should be present"
small_stmt = ASTNodeFactory.create_ast_small_stmt(declaration=declaration,
source_position=ASTSourceLocation.get_added_source_position())
stmt = ASTNodeFactory.create_ast_stmt(small_stmt=small_stmt,
source_position=ASTSourceLocation.get_added_source_position())
if not neuron.get_update_blocks():
neuron.create_empty_update_block()
neuron.get_update_blocks()[0].get_block().get_stmts().append(stmt)
small_stmt.update_scope(neuron.get_update_blocks()[0].get_block().get_scope())
stmt.update_scope(neuron.get_update_blocks()[0].get_block().get_scope())
return neuron
if not model.get_update_blocks():
model.create_empty_update_block()
model.get_update_blocks()[0].get_block().get_stmts().append(stmt)
small_stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope())
stmt.update_scope(model.get_update_blocks()[0].get_block().get_scope())

from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
model.accept(ASTParentVisitor())

return model

@classmethod
def add_state_updates(cls, neuron: ASTModel, update_expressions: Mapping[str, str]) -> ASTModel:
Expand All @@ -1165,6 +1224,7 @@ def add_state_updates(cls, neuron: ASTModel, update_expressions: Mapping[str, st
for variable, update_expression in update_expressions.items():
declaration_statement = variable + '__tmp real = ' + update_expression
cls.add_declaration_to_update_block(ModelParser.parse_declaration(declaration_statement), neuron)

for variable, update_expression in update_expressions.items():
cls.add_assignment_to_update_block(ModelParser.parse_assignment(variable + ' = ' + variable + '__tmp'),
neuron)
Expand Down

0 comments on commit d88aa15

Please sign in to comment.