Skip to content

Commit

Permalink
transform kernels and convolutions using a transformer before code ge…
Browse files Browse the repository at this point in the history
…neration [noci]
  • Loading branch information
C.A.P. Linssen committed May 12, 2024
1 parent d88aa15 commit de22d17
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 1 deletion.
78 changes: 77 additions & 1 deletion pynestml/transformers/convolutions_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pynestml.codegeneration.printers.unitless_sympy_simple_expression_printer import UnitlessSympySimpleExpressionPrinter
from pynestml.frontend.frontend_configuration import FrontendConfiguration
from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.meta_model.ast_block import ASTBlock
from pynestml.meta_model.ast_data_type import ASTDataType
from pynestml.meta_model.ast_declaration import ASTDeclaration
from pynestml.meta_model.ast_equations_block import ASTEquationsBlock
Expand All @@ -46,6 +47,7 @@
from pynestml.meta_model.ast_node import ASTNode
from pynestml.meta_model.ast_node_factory import ASTNodeFactory
from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression
from pynestml.meta_model.ast_small_stmt import ASTSmallStmt
from pynestml.meta_model.ast_variable import ASTVariable
from pynestml.symbols.predefined_functions import PredefinedFunctions
from pynestml.symbols.real_type_symbol import RealTypeSymbol
Expand Down Expand Up @@ -86,6 +88,53 @@ def __init__(self, options: Optional[Mapping[str, Any]] = None):
self._ode_toolbox_variable_printer._expression_printer = self._ode_toolbox_printer
self._ode_toolbox_function_call_printer._expression_printer = self._ode_toolbox_printer

def add_restore_kernel_variables_to_start_of_timestep(self, model, solvers_json):
r"""For each integrate_odes() call in the model, append statements restoring the kernel variables to the values at the start of the timestep"""

var_names = []
for solver_dict in solvers_json:
if solver_dict is None:
continue

for var_name, expr in solver_dict["initial_values"].items():
var_names.append(var_name)

class IntegrateODEsFunctionCallVisitor(ASTVisitor):
all_args = None

def __init__(self):
super().__init__()

def visit_small_stmt(self, node: ASTSmallStmt):
self._visit(node)

def visit_simple_expression(self, node: ASTSimpleExpression):
self._visit(node)

def _visit(self, node):
if node.is_function_call() and node.get_function_call().get_name() == PredefinedFunctions.INTEGRATE_ODES:
parent_stmt = node.get_parent()
parent_block = parent_stmt.get_parent()
assert isinstance(parent_block, ASTBlock)
idx = parent_block.stmts.index(parent_stmt)

for i, var_name in enumerate(var_names):
var = ASTNodeFactory.create_ast_variable(var_name + "__at_start_of_timestep", type_symbol=RealTypeSymbol)
var.update_scope(parent_block.get_scope())
expr = ASTNodeFactory.create_ast_simple_expression(variable=var)
ast_assignment = ASTNodeFactory.create_ast_assignment(lhs=ASTUtils.get_variable_by_name(model, var_name),
is_direct_assignment=True,
expression=expr, source_position=ASTSourceLocation.get_added_source_position())
ast_assignment.update_scope(parent_block.get_scope())
ast_small_stmt = ASTNodeFactory.create_ast_small_stmt(assignment=ast_assignment)
ast_small_stmt.update_scope(parent_block.get_scope())
ast_stmt = ASTNodeFactory.create_ast_stmt(small_stmt=ast_small_stmt)
ast_stmt.update_scope(parent_block.get_scope())

parent_block.stmts.insert(idx + i + 1, ast_stmt)

model.accept(IntegrateODEsFunctionCallVisitor())

def add_kernel_variables_to_integrate_odes_calls(self, model, solvers_json):
for solver_dict in solvers_json:
if solver_dict is None:
Expand Down Expand Up @@ -124,7 +173,7 @@ def add_temporary_kernel_variables_copy(self, model, solvers_json):
scope = model.get_update_blocks()[0].scope

for var_name in var_names:
var = ASTNodeFactory.create_ast_variable(var_name + "__tmp", type_symbol=RealTypeSymbol)
var = ASTNodeFactory.create_ast_variable(var_name + "__at_start_of_timestep", type_symbol=RealTypeSymbol)
var.scope = scope
expr = ASTNodeFactory.create_ast_simple_expression(variable=ASTUtils.get_variable_by_name(model, var_name))
ast_declaration = ASTNodeFactory.create_ast_declaration(variables=[var],
Expand Down Expand Up @@ -163,6 +212,7 @@ def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode,
self.replace_convolve_calls_with_buffers_(model)
self.remove_kernel_definitions_from_equations_blocks(model)
self.add_kernel_variables_to_integrate_odes_calls(model, solvers_json)
self.add_restore_kernel_variables_to_start_of_timestep(model, solvers_json)
self.add_temporary_kernel_variables_copy(model, solvers_json)
self.add_integrate_odes_call_for_kernel_variables(model, solvers_json)
self.add_kernel_equations(model, solvers_json)
Expand Down Expand Up @@ -438,6 +488,32 @@ def generate_kernel_buffers(self, model: ASTModel) -> Mapping[ASTKernel, ASTInpu

return kernel_buffers

def add_kernel_equations(self, model, solver_dicts):
if not model.get_equations_blocks():
ASTUtils.create_equations_block()

assert len(model.get_equations_blocks()) <= 1

equations_block = model.get_equations_blocks()[0]

for solver_dict in solver_dicts:
if solver_dict is None:
continue

for var_name, expr_str in solver_dict["update_expressions"].items():
expr = ModelParser.parse_expression(expr_str)
expr.update_scope(model.get_scope())
expr.accept(ASTSymbolTableVisitor())

var = ASTNodeFactory.create_ast_variable(var_name, differential_order=1, source_position=ASTSourceLocation.get_added_source_position())
var.update_scope(equations_block.get_scope())
ast_ode_equation = ASTNodeFactory.create_ast_ode_equation(lhs=var, rhs=expr, source_position=ASTSourceLocation.get_added_source_position())
ast_ode_equation.update_scope(equations_block.get_scope())
equations_block.declarations.append(ast_ode_equation)

model.accept(ASTParentVisitor())
model.accept(ASTSymbolTableVisitor())

def remove_kernel_definitions_from_equations_blocks(self, model: ASTModel) -> ASTDeclaration:
r"""
Removes all kernels in equations blocks.
Expand Down
3 changes: 3 additions & 0 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ def create_equations_block(cls, model: ASTModel) -> ASTModel:
block = ASTNodeFactory.create_ast_equations_block(list(),
ASTSourceLocation.get_added_source_position())
model.get_body().get_body_elements().append(block)

model.accept(ASTParentVisitor())

return model

@classmethod
Expand Down

0 comments on commit de22d17

Please sign in to comment.