diff --git a/pynestml/codegeneration/code_generator.py b/pynestml/codegeneration/code_generator.py
index abd289601..a0ed82c60 100644
--- a/pynestml/codegeneration/code_generator.py
+++ b/pynestml/codegeneration/code_generator.py
@@ -117,7 +117,6 @@ def _setup_template_env(self, template_files: List[str], templates_root_dir: str
# Environment for neuron templates
env = Environment(loader=FileSystemLoader(_template_dirs))
env.globals["raise"] = self.raise_helper
- env.globals["is_delta_kernel"] = ASTUtils.is_delta_kernel
# Load all the templates
_templates = list()
diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py
index 8fdb04566..333a39fe6 100644
--- a/pynestml/codegeneration/nest_code_generator.py
+++ b/pynestml/codegeneration/nest_code_generator.py
@@ -53,7 +53,6 @@
from pynestml.frontend.frontend_configuration import FrontendConfiguration
from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.meta_model.ast_input_port import ASTInputPort
-from pynestml.meta_model.ast_kernel import ASTKernel
from pynestml.meta_model.ast_model import ASTModel
from pynestml.meta_model.ast_node_factory import ASTNodeFactory
from pynestml.meta_model.ast_ode_equation import ASTOdeEquation
@@ -228,6 +227,10 @@ def run_nest_target_specific_cocos(self, neurons: Sequence[ASTModel], synapses:
raise Exception("Error(s) occurred during code generation")
def generate_code(self, models: Sequence[ASTModel]) -> None:
+ for model in models:
+ for equations_block in model.get_equations_blocks():
+ assert len(equations_block.get_kernels()) == 0, "Kernels and convolutions should have been removed by ConvolutionsTransformer"
+
neurons, synapses = CodeGeneratorUtils.get_model_types_from_names(models, neuron_models=self.get_option("neuron_models"), synapse_models=self.get_option("synapse_models"))
self.run_nest_target_specific_cocos(neurons, synapses)
@@ -264,9 +267,7 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None:
for neuron in neurons:
code, message = Messages.get_analysing_transforming_model(neuron.get_name())
Logger.log_message(None, code, message, None, LoggingLevel.INFO)
- spike_updates, post_spike_updates, equations_with_delay_vars, equations_with_vector_vars = self.analyse_neuron(neuron)
- neuron.spike_updates = spike_updates
- neuron.post_spike_updates = post_spike_updates
+ equations_with_delay_vars, equations_with_vector_vars = self.analyse_neuron(neuron)
neuron.equations_with_delay_vars = equations_with_delay_vars
neuron.equations_with_vector_vars = equations_with_vector_vars
@@ -277,14 +278,13 @@ def analyse_transform_synapses(self, synapses: List[ASTModel]) -> None:
"""
for synapse in synapses:
Logger.log_message(None, None, "Analysing/transforming synapse {}.".format(synapse.get_name()), None, LoggingLevel.INFO)
- synapse.spike_updates = self.analyse_synapse(synapse)
+ self.analyse_synapse(synapse)
def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment], List[ASTOdeEquation], List[ASTOdeEquation]]:
"""
Analyse and transform a single neuron.
:param neuron: a single neuron.
:return: see documentation for get_spike_update_expressions() for more information.
- :return: post_spike_updates: list of post-synaptic spike update expressions
:return: equations_with_delay_vars: list of equations containing delay variables
:return: equations_with_vector_vars: list of equations containing delay variables
"""
@@ -298,18 +298,15 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di
ASTUtils.all_variables_defined_in_block(neuron.get_state_blocks()))
ASTUtils.add_timestep_symbol(neuron)
- return {}, {}, [], []
+ return [], []
if len(neuron.get_equations_blocks()) > 1:
raise Exception("Only one equations block per model supported for now")
equations_block = neuron.get_equations_blocks()[0]
- kernel_buffers = ASTUtils.generate_kernel_buffers(neuron, equations_block)
ASTUtils.make_inline_expressions_self_contained(equations_block.get_inline_expressions())
ASTUtils.replace_inline_expressions_through_defining_expressions(equations_block.get_ode_equations(), equations_block.get_inline_expressions())
- delta_factors = ASTUtils.get_delta_factors_(neuron, equations_block)
- ASTUtils.replace_convolve_calls_with_buffers_(neuron, equations_block)
# Collect all equations with delay variables and replace ASTFunctionCall to ASTVariable wherever necessary
equations_with_delay_vars_visitor = ASTEquationsWithDelayVarsVisitor()
@@ -321,7 +318,7 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di
neuron.accept(eqns_with_vector_vars_visitor)
equations_with_vector_vars = eqns_with_vector_vars_visitor.equations
- analytic_solver, numeric_solver = self.ode_toolbox_analysis(neuron, kernel_buffers)
+ analytic_solver, numeric_solver = self.ode_toolbox_analysis(neuron)
self.analytic_solver[neuron.get_name()] = analytic_solver
self.numeric_solver[neuron.get_name()] = numeric_solver
@@ -335,23 +332,14 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di
if ode_eq.get_lhs().get_name() == var.get_name():
used_in_eq = True
break
- for kern in equations_block.get_kernels():
- for kern_var in kern.get_variables():
- if kern_var.get_name() == var.get_name():
- used_in_eq = True
- break
if not used_in_eq:
self.non_equations_state_variables[neuron.get_name()].append(var)
- ASTUtils.remove_initial_values_for_kernels(neuron)
- kernels = ASTUtils.remove_kernel_definitions_from_equations_block(neuron)
ASTUtils.update_initial_values_for_odes(neuron, [analytic_solver, numeric_solver])
ASTUtils.remove_ode_definitions_from_equations_block(neuron)
- ASTUtils.create_initial_values_for_kernels(neuron, [analytic_solver, numeric_solver], kernels)
ASTUtils.create_integrate_odes_combinations(neuron)
ASTUtils.replace_variable_names_in_expressions(neuron, [analytic_solver, numeric_solver])
- ASTUtils.replace_convolution_aliasing_inlines(neuron)
ASTUtils.add_timestep_symbol(neuron)
if self.analytic_solver[neuron.get_name()] is not None:
@@ -364,9 +352,7 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di
# Update the delay parameter parameters after symbol table update
ASTUtils.update_delay_parameter_in_state_vars(neuron, state_vars_before_update)
- spike_updates, post_spike_updates = self.get_spike_update_expressions(neuron, kernel_buffers, [analytic_solver, numeric_solver], delta_factors)
-
- return spike_updates, post_spike_updates, equations_with_delay_vars, equations_with_vector_vars
+ return equations_with_delay_vars, equations_with_vector_vars
def analyse_synapse(self, synapse: ASTModel) -> Dict[str, ASTAssignment]:
"""
@@ -376,34 +362,26 @@ def analyse_synapse(self, synapse: ASTModel) -> Dict[str, ASTAssignment]:
code, message = Messages.get_start_processing_model(synapse.get_name())
Logger.log_message(synapse, code, message, synapse.get_source_position(), LoggingLevel.INFO)
- spike_updates = {}
if synapse.get_equations_blocks():
if len(synapse.get_equations_blocks()) > 1:
raise Exception("Only one equations block per model supported for now")
equations_block = synapse.get_equations_blocks()[0]
- kernel_buffers = ASTUtils.generate_kernel_buffers(synapse, equations_block)
ASTUtils.make_inline_expressions_self_contained(equations_block.get_inline_expressions())
ASTUtils.replace_inline_expressions_through_defining_expressions(
equations_block.get_ode_equations(), equations_block.get_inline_expressions())
- delta_factors = ASTUtils.get_delta_factors_(synapse, equations_block)
- ASTUtils.replace_convolve_calls_with_buffers_(synapse, equations_block)
- analytic_solver, numeric_solver = self.ode_toolbox_analysis(synapse, kernel_buffers)
+ analytic_solver, numeric_solver = self.ode_toolbox_analysis(synapse)
self.analytic_solver[synapse.get_name()] = analytic_solver
self.numeric_solver[synapse.get_name()] = numeric_solver
- ASTUtils.remove_initial_values_for_kernels(synapse)
- kernels = ASTUtils.remove_kernel_definitions_from_equations_block(synapse)
ASTUtils.update_initial_values_for_odes(synapse, [analytic_solver, numeric_solver])
ASTUtils.remove_ode_definitions_from_equations_block(synapse)
- ASTUtils.create_initial_values_for_kernels(synapse, [analytic_solver, numeric_solver], kernels)
ASTUtils.create_integrate_odes_combinations(synapse)
ASTUtils.replace_variable_names_in_expressions(synapse, [analytic_solver, numeric_solver])
ASTUtils.add_timestep_symbol(synapse)
self.update_symbol_table(synapse)
- spike_updates, _ = self.get_spike_update_expressions(synapse, kernel_buffers, [analytic_solver, numeric_solver], delta_factors)
if not self.analytic_solver[synapse.get_name()] is None:
synapse = ASTUtils.add_declarations_to_internals(
@@ -416,8 +394,6 @@ def analyse_synapse(self, synapse: ASTModel) -> Dict[str, ASTAssignment]:
ASTUtils.update_blocktype_for_common_parameters(synapse)
- return spike_updates
-
def _get_model_namespace(self, astnode: ASTModel) -> Dict:
namespace = {}
@@ -567,8 +543,6 @@ def _get_synapse_model_namespace(self, synapse: ASTModel) -> Dict:
expr_ast.accept(ASTSymbolTableVisitor())
namespace["numeric_update_expressions"][sym] = expr_ast
- namespace["spike_updates"] = synapse.spike_updates
-
return namespace
def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
@@ -582,7 +556,6 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
if "paired_synapse" in dir(neuron):
namespace["extra_on_emit_spike_stmts_from_synapse"] = neuron.extra_on_emit_spike_stmts_from_synapse
namespace["paired_synapse"] = neuron.paired_synapse.get_name()
- namespace["post_spike_updates"] = neuron.post_spike_updates
namespace["transferred_variables"] = neuron._transferred_variables
namespace["transferred_variables_syms"] = {var_name: neuron.scope.resolve_to_symbol(
var_name, SymbolKind.VARIABLE) for var_name in namespace["transferred_variables"]}
@@ -736,8 +709,6 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
namespace["numerical_state_symbols"] = numeric_state_variable_names
ASTUtils.assign_numeric_non_numeric_state_variables(neuron, numeric_state_variable_names, namespace["numeric_update_expressions"] if "numeric_update_expressions" in namespace.keys() else None, namespace["update_expressions"] if "update_expressions" in namespace.keys() else None)
- namespace["spike_updates"] = neuron.spike_updates
-
namespace["recordable_state_variables"] = []
for state_block in neuron.get_state_blocks():
for decl in state_block.get_declarations():
@@ -745,7 +716,6 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
sym = var.get_scope().resolve_to_symbol(var.get_complete_name(), SymbolKind.VARIABLE)
if isinstance(sym.get_type_symbol(), (UnitTypeSymbol, RealTypeSymbol)) \
- and not ASTUtils.is_delta_kernel(neuron.get_kernel_by_name(sym.name)) \
and sym.is_recordable:
namespace["recordable_state_variables"].append(var)
@@ -755,7 +725,7 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
for var in decl.get_variables():
sym = var.get_scope().resolve_to_symbol(var.get_complete_name(), SymbolKind.VARIABLE)
- if sym.has_declaring_expression() and (not neuron.get_kernel_by_name(sym.name)):
+ if sym.has_declaring_expression():
namespace["parameter_vars_with_iv"].append(var)
namespace["recordable_inline_expressions"] = [sym for sym in neuron.get_inline_expression_symbols()
@@ -774,7 +744,7 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
return namespace
- def ode_toolbox_analysis(self, neuron: ASTModel, kernel_buffers: Mapping[ASTKernel, ASTInputPort]):
+ def ode_toolbox_analysis(self, neuron: ASTModel):
"""
Prepare data for ODE-toolbox input format, invoke ODE-toolbox analysis via its API, and return the output.
"""
@@ -783,11 +753,11 @@ def ode_toolbox_analysis(self, neuron: ASTModel, kernel_buffers: Mapping[ASTKern
equations_block = neuron.get_equations_blocks()[0]
- if len(equations_block.get_kernels()) == 0 and len(equations_block.get_ode_equations()) == 0:
+ if len(equations_block.get_ode_equations()) == 0:
# no equations defined -> no changes to the neuron
return None, None
- odetoolbox_indict = ASTUtils.transform_ode_and_kernels_to_json(neuron, neuron.get_parameters_blocks(), kernel_buffers, printer=self._ode_toolbox_printer)
+ odetoolbox_indict = ASTUtils.transform_odes_to_json(neuron, neuron.get_parameters_blocks(), printer=self._ode_toolbox_printer)
odetoolbox_indict["options"] = {}
odetoolbox_indict["options"]["output_timestep_symbol"] = "__h"
disable_analytic_solver = self.get_option("solver") != "analytic"
@@ -831,107 +801,3 @@ def update_symbol_table(self, neuron) -> None:
symbol_table_visitor.after_ast_rewrite_ = True
neuron.accept(symbol_table_visitor)
SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope())
-
- def get_spike_update_expressions(self, neuron: ASTModel, kernel_buffers, solver_dicts, delta_factors) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment]]:
- r"""
- Generate the equations that update the dynamical variables when incoming spikes arrive. To be invoked after
- ode-toolbox.
-
- For example, a resulting `assignment_str` could be "I_kernel_in += (inh_spikes/nS) * 1". The values are taken from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from user specification in the model.
- from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from
- user specification in the model.
-
- Note that for kernels, `initial_values` actually contains the increment upon spike arrival, rather than the
- initial value of the corresponding ODE dimension.
- ``spike_updates`` is a mapping from input port name (as a string) to update expressions.
-
- ``post_spike_updates`` is a mapping from kernel name (as a string) to update expressions.
- """
- spike_updates = {}
- post_spike_updates = {}
-
- for kernel, spike_input_port in kernel_buffers:
- if ASTUtils.is_delta_kernel(kernel):
- continue
-
- spike_input_port_name = spike_input_port.get_variable().get_name()
-
- if not spike_input_port_name in spike_updates.keys():
- spike_updates[str(spike_input_port)] = []
-
- if "_is_post_port" in dir(spike_input_port.get_variable()) \
- and spike_input_port.get_variable()._is_post_port:
- # it's a port in the neuron ??? that receives post spikes ???
- orig_port_name = spike_input_port_name[:spike_input_port_name.index("__for_")]
- buffer_type = neuron.paired_synapse.get_scope().resolve_to_symbol(orig_port_name, SymbolKind.VARIABLE).get_type_symbol()
- else:
- buffer_type = neuron.get_scope().resolve_to_symbol(spike_input_port_name, SymbolKind.VARIABLE).get_type_symbol()
-
- assert not buffer_type is None
-
- for kernel_var in kernel.get_variables():
- for var_order in range(ASTUtils.get_kernel_var_order_from_ode_toolbox_result(kernel_var.get_name(), solver_dicts)):
- kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name(kernel_var.get_name(), spike_input_port, var_order)
- expr = ASTUtils.get_initial_value_from_ode_toolbox_result(kernel_spike_buf_name, solver_dicts)
- assert expr is not None, "Initial value not found for kernel " + kernel_var
- expr = str(expr)
- if expr in ["0", "0.", "0.0"]:
- continue # skip adding the statement if we are only adding zero
-
- assignment_str = kernel_spike_buf_name + " += "
- if "_is_post_port" in dir(spike_input_port.get_variable()) \
- and spike_input_port.get_variable()._is_post_port:
- assignment_str += "1."
- else:
- assignment_str += "(" + str(spike_input_port) + ")"
- if not expr in ["1.", "1.0", "1"]:
- assignment_str += " * (" + expr + ")"
-
- if not buffer_type.print_nestml_type() in ["1.", "1.0", "1", "real", "integer"]:
- assignment_str += " / (" + buffer_type.print_nestml_type() + ")"
-
- ast_assignment = ModelParser.parse_assignment(assignment_str)
- ast_assignment.update_scope(neuron.get_scope())
- ast_assignment.accept(ASTSymbolTableVisitor())
-
- if neuron.get_scope().resolve_to_symbol(spike_input_port_name, SymbolKind.VARIABLE) is None:
- # this case covers variables that were moved from synapse to the neuron
- post_spike_updates[kernel_var.get_name()] = ast_assignment
- elif "_is_post_port" in dir(spike_input_port.get_variable()) and spike_input_port.get_variable()._is_post_port:
- Logger.log_message(None, None, "Adding post assignment string: " + str(ast_assignment), None, LoggingLevel.INFO)
- spike_updates[str(spike_input_port)].append(ast_assignment)
- else:
- spike_updates[str(spike_input_port)].append(ast_assignment)
-
- for k, factor in delta_factors.items():
- var = k[0]
- inport = k[1]
- assignment_str = var.get_name() + "'" * (var.get_differential_order() - 1) + " += "
- if not factor in ["1.", "1.0", "1"]:
- factor_expr = ModelParser.parse_expression(factor)
- factor_expr.update_scope(neuron.get_scope())
- factor_expr.accept(ASTSymbolTableVisitor())
- assignment_str += "(" + self._printer_no_origin.print(factor_expr) + ") * "
-
- if "_is_post_port" in dir(inport) and inport._is_post_port:
- orig_port_name = inport[:inport.index("__for_")]
- buffer_type = neuron.paired_synapse.get_scope().resolve_to_symbol(orig_port_name, SymbolKind.VARIABLE).get_type_symbol()
- else:
- buffer_type = neuron.get_scope().resolve_to_symbol(inport.get_name(), SymbolKind.VARIABLE).get_type_symbol()
-
- assignment_str += str(inport)
- if not buffer_type.print_nestml_type() in ["1.", "1.0", "1"]:
- assignment_str += " / (" + buffer_type.print_nestml_type() + ")"
- ast_assignment = ModelParser.parse_assignment(assignment_str)
- ast_assignment.update_scope(neuron.get_scope())
- ast_assignment.accept(ASTSymbolTableVisitor())
-
- inport_name = inport.get_name()
- if inport.has_vector_parameter():
- inport_name += "_" + str(ASTUtils.get_numeric_vector_size(inport))
- if not inport_name in spike_updates.keys():
- spike_updates[inport_name] = []
-
- spike_updates[inport_name].append(ast_assignment)
-
- return spike_updates, post_spike_updates
diff --git a/pynestml/codegeneration/nest_compartmental_code_generator.py b/pynestml/codegeneration/nest_compartmental_code_generator.py
index 8dc48958d..cf3757481 100644
--- a/pynestml/codegeneration/nest_compartmental_code_generator.py
+++ b/pynestml/codegeneration/nest_compartmental_code_generator.py
@@ -280,22 +280,16 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None:
def create_ode_indict(self,
neuron: ASTModel,
- parameters_block: ASTBlockWithVariables,
- kernel_buffers: Mapping[ASTKernel,
- ASTInputPort]):
- odetoolbox_indict = self.transform_ode_and_kernels_to_json(
- neuron, parameters_block, kernel_buffers)
+ parameters_block: ASTBlockWithVariables):
+ odetoolbox_indict = self.transform_ode_and_kernels_to_json(neuron, parameters_block)
odetoolbox_indict["options"] = {}
odetoolbox_indict["options"]["output_timestep_symbol"] = "__h"
return odetoolbox_indict
def ode_solve_analytically(self,
neuron: ASTModel,
- parameters_block: ASTBlockWithVariables,
- kernel_buffers: Mapping[ASTKernel,
- ASTInputPort]):
- odetoolbox_indict = self.create_ode_indict(
- neuron, parameters_block, kernel_buffers)
+ parameters_block: ASTBlockWithVariables):
+ odetoolbox_indict = self.create_ode_indict(neuron, parameters_block)
full_solver_result = analysis(
odetoolbox_indict,
@@ -314,8 +308,7 @@ def ode_solve_analytically(self,
return full_solver_result, analytic_solver
- def ode_toolbox_analysis(self, neuron: ASTModel,
- kernel_buffers: Mapping[ASTKernel, ASTInputPort]):
+ def ode_toolbox_analysis(self, neuron: ASTModel):
"""
Prepare data for ODE-toolbox input format, invoke ODE-toolbox analysis via its API, and return the output.
"""
@@ -324,15 +317,13 @@ def ode_toolbox_analysis(self, neuron: ASTModel,
equations_block = neuron.get_equations_blocks()[0]
- if len(equations_block.get_kernels()) == 0 and len(
- equations_block.get_ode_equations()) == 0:
+ if len(equations_block.get_ode_equations()) == 0:
# no equations defined -> no changes to the neuron
return None, None
parameters_block = neuron.get_parameters_blocks()[0]
- solver_result, analytic_solver = self.ode_solve_analytically(
- neuron, parameters_block, kernel_buffers)
+ solver_result, analytic_solver = self.ode_solve_analytically(neuron, parameters_block)
# if numeric solver is required, generate a stepping function that
# includes each state variable
@@ -341,8 +332,7 @@ def ode_toolbox_analysis(self, neuron: ASTModel,
x for x in solver_result if x["solver"].startswith("numeric")]
if numeric_solvers:
- odetoolbox_indict = self.create_ode_indict(
- neuron, parameters_block, kernel_buffers)
+ odetoolbox_indict = self.create_ode_indict(neuron, parameters_block)
solver_result = analysis(
odetoolbox_indict,
disable_stiffness_check=True,
@@ -417,24 +407,6 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
return []
- # goes through all convolve() inside ode's from equations block
- # if they have delta kernels, use sympy to expand the expression, then
- # find the convolve calls and replace them with constant value 1
- # then return every subexpression that had that convolve() replaced
- delta_factors = ASTUtils.get_delta_factors_(neuron, equations_block)
-
- # goes through all convolve() inside equations block
- # extracts what kernel is paired with what spike buffer
- # returns pairs (kernel, spike_buffer)
- kernel_buffers = ASTUtils.generate_kernel_buffers(
- neuron, equations_block)
-
- # replace convolve(g_E, spikes_exc) with g_E__X__spikes_exc[__d]
- # done by searching for every ASTSimpleExpression inside equations_block
- # which is a convolve call and substituting that call with
- # newly created ASTVariable kernel__X__spike_buffer
- ASTUtils.replace_convolve_calls_with_buffers_(neuron, equations_block)
-
# substitute inline expressions with each other
# such that no inline expression references another inline expression
ASTUtils.make_inline_expressions_self_contained(
@@ -450,14 +422,13 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
# "update_expressions" key in those solvers contains a mapping
# {expression1: update_expression1, expression2: update_expression2}
- analytic_solver, numeric_solver = self.ode_toolbox_analysis(
- neuron, kernel_buffers)
+ analytic_solver, numeric_solver = self.ode_toolbox_analysis(neuron)
"""
# separate analytic solutions by kernel
# this is is needed for the synaptic case
self.kernel_name_to_analytic_solver[neuron.get_name(
- )] = self.ode_toolbox_anaysis_cm_syns(neuron, kernel_buffers)
+ )] = self.ode_toolbox_anaysis_cm_syns(neuron)
"""
self.analytic_solver[neuron.get_name()] = analytic_solver
@@ -472,12 +443,6 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
# by odetoolbox, higher order variables don't get deleted here
ASTUtils.remove_initial_values_for_kernels(neuron)
- # delete all kernels as they are all converted into buffers
- # and corresponding update formulas calculated by odetoolbox
- # Remember them in a variable though
- kernels = ASTUtils.remove_kernel_definitions_from_equations_block(
- neuron)
-
# Every ODE variable (a variable of order > 0) is renamed according to ODE-toolbox conventions
# their initial values are replaced by expressions suggested by ODE-toolbox.
# Differential order can now be set to 0, becase they can directly represent the value of the derivative now.
@@ -491,22 +456,11 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
# corresponding updates
ASTUtils.remove_ode_definitions_from_equations_block(neuron)
- # restore state variables that were referenced by kernels
- # and set their initial values by those suggested by ODE-toolbox
- ASTUtils.create_initial_values_for_kernels(
- neuron, [analytic_solver, numeric_solver], kernels)
-
# Inside all remaining expressions, translate all remaining variable names
# according to the naming conventions of ODE-toolbox.
ASTUtils.replace_variable_names_in_expressions(
neuron, [analytic_solver, numeric_solver])
- # find all inline kernels defined as ASTSimpleExpression
- # that have a single kernel convolution aliasing variable ('__X__')
- # translate all remaining variable names according to the naming
- # conventions of ODE-toolbox
- ASTUtils.replace_convolution_aliasing_inlines(neuron)
-
# add variable __h to internals block
ASTUtils.add_timestep_symbol(neuron)
@@ -677,13 +631,9 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
expr_ast.accept(ASTSymbolTableVisitor())
namespace["numeric_update_expressions"][sym] = expr_ast
- namespace["spike_updates"] = neuron.spike_updates
-
namespace["recordable_state_variables"] = [
sym for sym in neuron.get_state_symbols() if namespace["declarations"].get_domain_from_type(
- sym.get_type_symbol()) == "double" and sym.is_recordable and not ASTUtils.is_delta_kernel(
- neuron.get_kernel_by_name(
- sym.name))]
+ sym.get_type_symbol()) == "double" and sym.is_recordable]
namespace["recordable_inline_expressions"] = [
sym for sym in neuron.get_inline_expression_symbols() if namespace["declarations"].get_domain_from_type(
sym.get_type_symbol()) == "double" and sym.is_recordable]
@@ -807,7 +757,7 @@ def get_spike_update_expressions(
for var_order in range(
ASTUtils.get_kernel_var_order_from_ode_toolbox_result(
kernel_var.get_name(), solver_dicts)):
- kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name(
+ kernel_spike_buf_name = ASTUtils.construct_kernel_spike_buf_name(
kernel_var.get_name(), spike_input_port, var_order)
expr = ASTUtils.get_initial_value_from_ode_toolbox_result(
kernel_spike_buf_name, solver_dicts)
@@ -849,18 +799,9 @@ def get_spike_update_expressions(
def transform_ode_and_kernels_to_json(
self,
neuron: ASTModel,
- parameters_block,
- kernel_buffers):
+ parameters_block):
"""
Converts AST node to a JSON representation suitable for passing to ode-toolbox.
-
- Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements
-
- convolve(G, ex_spikes)
- convolve(G, in_spikes)
-
- then `kernel_buffers` will contain the pairs `(G, ex_spikes)` and `(G, in_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__ex_spikes` and `G__X__in_spikes`.
-
:param parameters_block: ASTBlockWithVariables
:return: Dict
"""
@@ -890,43 +831,6 @@ def transform_ode_and_kernels_to_json(
iv_symbol_name)] = expr
odetoolbox_indict["dynamics"].append(entry)
- # write a copy for each (kernel, spike buffer) combination
- for kernel, spike_input_port in kernel_buffers:
-
- if ASTUtils.is_delta_kernel(kernel):
- # delta function -- skip passing this to ode-toolbox
- continue
-
- for kernel_var in kernel.get_variables():
- expr = ASTUtils.get_expr_from_kernel_var(
- kernel, kernel_var.get_complete_name())
- kernel_order = kernel_var.get_differential_order()
- kernel_X_spike_buf_name_ticks = ASTUtils.construct_kernel_X_spike_buf_name(
- kernel_var.get_name(), spike_input_port, kernel_order, diff_order_symbol="'")
-
- ASTUtils.replace_rhs_variables(expr, kernel_buffers)
-
- entry = {}
- entry["expression"] = kernel_X_spike_buf_name_ticks + " = " + str(expr)
-
- # initial values need to be declared for order 1 up to kernel
- # order (e.g. none for kernel function f(t) = ...; 1 for kernel
- # ODE f'(t) = ...; 2 for f''(t) = ... and so on)
- entry["initial_values"] = {}
- for order in range(kernel_order):
- iv_sym_name_ode_toolbox = ASTUtils.construct_kernel_X_spike_buf_name(
- kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'")
- symbol_name_ = kernel_var.get_name() + "'" * order
- symbol = equations_block.get_scope().resolve_to_symbol(
- symbol_name_, SymbolKind.VARIABLE)
- assert symbol is not None, "Could not find initial value for variable " + symbol_name_
- initial_value_expr = symbol.get_declaring_expression()
- assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_
- entry["initial_values"][iv_sym_name_ode_toolbox] = self._ode_toolbox_printer.print(
- initial_value_expr)
-
- odetoolbox_indict["dynamics"].append(entry)
-
odetoolbox_indict["parameters"] = {}
if parameters_block is not None:
for decl in parameters_block.get_declarations():
diff --git a/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2
index 5832d90f3..fedb8a96e 100644
--- a/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2
+++ b/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2
@@ -262,10 +262,8 @@ std::vector< std::tuple< int, int > > {{neuronName}}::rport_to_nestml_buffer_idx
// copy state struct S_
{%- for init in neuron.get_state_symbols() %}
-{%- if not is_delta_kernel(neuron.get_kernel_by_name(init.name)) %}
{%- set node = utils.get_state_variable_by_name(astnode, init.get_symbol_name()) %}
{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }} = __n.{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }};
-{%- endif %}
{%- endfor %}
// copy internals V_
@@ -786,14 +784,6 @@ const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}_
{%- endfor %}
{%- endif %}
-
- /**
- * spike updates due to convolutions
- **/
-{% filter indent(4) %}
-{%- include "directives_cpp/ApplySpikesFromBuffers.jinja2" %}
-{%- endfilter %}
-
/**
* Begin NESTML generated code for the onCondition block(s)
**/
@@ -1149,13 +1139,9 @@ void
{%- endfor %}
/**
- * print updates due to convolutions
+ * push back spike history
**/
-{%- for _, spike_update in post_spike_updates.items() %}
- {{ printer.print(utils.get_variable_by_name(astnode, spike_update.get_variable().get_complete_name())) }} += 1.;
-{%- endfor %}
-
last_spike_ = t_sp_ms;
history_.push_back( histentry__{{neuronName}}( last_spike_
{%- for var in purely_numeric_state_variables_moved|sort %}
diff --git a/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2
index b8aa79c41..345372e0a 100644
--- a/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2
+++ b/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2
@@ -346,13 +346,11 @@ public:
// Getters/setters for state block
// -------------------------------------------------------------------------
-{% filter indent(2, True) -%}
+{% filter indent(2, True) -%}
{%- for variable_symbol in neuron.get_state_symbols() %}
-{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
-{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
-{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
-{% endif %}
-{% endfor %}
+{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
+{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
+{% endfor %}
{%- endfilter %}
{%- endif %}
@@ -962,22 +960,20 @@ inline nest_port_t {{neuronName}}::handles_test_event(nest::DataLoggingRequest&
inline void {{neuronName}}::get_status(DictionaryDatum &__d) const
{
// parameters
-{%- for variable_symbol in neuron.get_parameter_symbols() %}
-{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
-{%- filter indent(2) %}
+{%- filter indent(2) %}
+{%- for variable_symbol in neuron.get_parameter_symbols() %}
+{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- include "directives_cpp/WriteInDictionary.jinja2" %}
-{%- endfilter %}
-{%- endfor %}
+{%- endfor %}
+{%- endfilter %}
// initial values for state variables in ODE or kernel
-{%- for variable_symbol in neuron.get_state_symbols() %}
-{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
-{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
-{%- filter indent(2) %}
-{%- include "directives_cpp/WriteInDictionary.jinja2" %}
-{%- endfilter %}
-{%- endif -%}
-{%- endfor %}
+{%- filter indent(2) %}
+{%- for variable_symbol in neuron.get_state_symbols() %}
+{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
+{%- include "directives_cpp/WriteInDictionary.jinja2" %}
+{%- endfor %}
+{%- endfilter %}
{{neuron_parent_class}}::get_status( __d );
@@ -1023,11 +1019,9 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d)
// initial values for state variables in ODE or kernel
{%- for variable_symbol in neuron.get_state_symbols() %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
-{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
-{%- filter indent(2) %}
-{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %}
-{%- endfilter %}
-{%- endif %}
+{%- filter indent(2) %}
+{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %}
+{%- endfilter %}
{%- endfor %}
// We now know that (ptmp, stmp) are consistent. We do not
@@ -1046,11 +1040,9 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d)
{%- for variable_symbol in neuron.get_state_symbols() -%}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
-{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
-{%- filter indent(2) %}
-{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %}
-{%- endfilter %}
-{%- endif %}
+{%- filter indent(2) %}
+{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %}
+{%- endfilter %}
{%- endfor %}
{% for invariant in neuron.get_parameter_invariants() %}
diff --git a/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2
index 3d727ea23..38b6d6ce4 100644
--- a/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2
+++ b/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2
@@ -814,17 +814,6 @@ public:
{%- endfilter %}
}
- /**
- * update all convolutions with pre spikes
- **/
-
-{%- for spike_updates_for_port in spike_updates.values() %}
-{%- for spike_update in spike_updates_for_port %}
- {{ printer.print(spike_update.get_variable()) }} += 1.; // XXX: TODO: increment with initial value instead of 1
-{%- endfor %}
-{%- endfor %}
-
-
/**
* in case pre and post spike time coincide and pre update takes priority
**/
@@ -989,14 +978,12 @@ void
{%- filter indent(2,True) %}
{%- for variable_symbol in synapse.get_state_symbols() %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
-{%- if not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.get_symbol_name())) %}
-{%- include "directives_cpp/WriteInDictionary.jinja2" %}
-{%- if variable_symbol.get_namespace_decorator("nest")|length > 0 %}
+{%- include "directives_cpp/WriteInDictionary.jinja2" %}
+{%- if variable_symbol.get_namespace_decorator("nest")|length > 0 %}
// special treatment for variable marked with @nest::name decorator
-{%- set nest_namespace_name = variable_symbol.get_namespace_decorator("nest") %}
-{%- if not variable_symbol.is_internals() %}
+{%- set nest_namespace_name = variable_symbol.get_namespace_decorator("nest") %}
+{%- if not variable_symbol.is_internals() %}
def<{{declarations.print_variable_type(variable_symbol)}}>(__d, names::{{nest_namespace_name}}, get_{{printer_no_origin.print(variable)}}());
-{%- endif %}
{%- endif %}
{%- endif %}
{%- endfor %}
@@ -1024,24 +1011,22 @@ void
{%- filter indent(2,True) %}
{%- for variable_symbol in synapse.get_state_symbols() %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
-{%- if not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.name)) %}
-{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %}
-{%- if variable_symbol.get_namespace_decorator("nest")|length > 0 %}
+{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %}
+{%- if variable_symbol.get_namespace_decorator("nest")|length > 0 %}
// special treatment for variables marked with @nest::name decorator
-{%- set nest_namespace_name = variable_symbol.get_namespace_decorator("nest") %}
+{%- set nest_namespace_name = variable_symbol.get_namespace_decorator("nest") %}
{#- -------- XXX: TODO: this is almost the content of directives_cpp/ReadFromDictionaryToTmp.jinja2 verbatim, refactor this ---------- #}
-{%- if not variable_symbol.is_inline_expression and not variable_symbol.is_state() %}
+{%- if not variable_symbol.is_inline_expression and not variable_symbol.is_state() %}
tmp_{{ printer_no_origin.print(variable) }} = get_{{ printer_no_origin.print(variable) }}();
updateValue<{{ declarations.print_variable_type(variable_symbol) }}>(__d, "{{ nest_namespace_name }}", tmp_{{ printer_no_origin.print(variable) }});
-{%- elif not variable_symbol.is_inline_expression and variable_symbol.is_state() %}
+{%- elif not variable_symbol.is_inline_expression and variable_symbol.is_state() %}
tmp_{{ printer_no_origin.print(variable) }} = get_{{ printer_no_origin.print(variable) }}();
updateValue<{{ declarations.print_variable_type(variable_symbol) }}>(__d, "{{ nest_namespace_name }}", tmp_{{ printer_no_origin.print(variable) }});
-{%- else %}
+{%- else %}
// ignores '{{ printer_no_origin.print(variable) }}' {{ declarations.print_variable_type(variable_symbol) }}' since it is an function and setter isn't defined
-{%- endif %}
-{#- -------------------------------------------------------------------------------------------------------------------------------- #}
{%- endif %}
-{%- endif %}
+{#- -------------------------------------------------------------------------------------------------------------------------------- #}
+{%- endif %}
{%- endfor %}
{%- endfilter %}
@@ -1069,11 +1054,9 @@ updateValue<{{ declarations.print_variable_type(variable_symbol) }}>(__d, "{{ ne
// set state
{%- for variable_symbol in synapse.get_state_symbols() %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
-{%- if not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.name)) %}
-{%- filter indent(2,True) %}
-{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %}
-{%- endfilter %}
-{%- endif %}
+{%- filter indent(2,True) %}
+{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %}
+{%- endfilter %}
{%- endfor %}
// check invariants
diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/ApplySpikesFromBuffers.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/ApplySpikesFromBuffers.jinja2
deleted file mode 100644
index 881257451..000000000
--- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/ApplySpikesFromBuffers.jinja2
+++ /dev/null
@@ -1,6 +0,0 @@
-{% if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %}
-{%- for spike_updates_for_port in spike_updates.values() %}
-{%- for ast in spike_updates_for_port -%}
-{%- include "directives_cpp/Assignment.jinja2" %}
-{%- endfor %}
-{%- endfor %}
diff --git a/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 b/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2
index 5fc3ae589..7022e06ff 100644
--- a/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2
+++ b/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2
@@ -220,10 +220,8 @@ class Neuron_{{neuronName}}(Neuron):
# -------------------------------------------------------------------------
{% filter indent(2, True) -%}
{%- for variable_symbol in neuron.get_state_symbols() %}
-{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.get_symbol_name())) %}
{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- include "directives_py/MemberVariableGetterSetter.jinja2" %}
-{%- endif %}
{%- endfor %}
{%- endfilter %}
@@ -314,13 +312,6 @@ class Neuron_{{neuronName}}(Neuron):
{%- include "directives_py/AnalyticIntegrationStep_end.jinja2" %}
{%- endwith %}
- # -------------------------------------------------------------------------
- # process spikes from buffers
- # -------------------------------------------------------------------------
-{%- filter indent(4, True) -%}
-{%- include "directives_py/ApplySpikesFromBuffers.jinja2" %}
-{%- endfilter %}
-
# -------------------------------------------------------------------------
# begin NESTML generated code for the onReceive block(s)
# -------------------------------------------------------------------------
diff --git a/pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/ApplySpikesFromBuffers.jinja2 b/pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/ApplySpikesFromBuffers.jinja2
deleted file mode 100644
index c0952b2f5..000000000
--- a/pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/ApplySpikesFromBuffers.jinja2
+++ /dev/null
@@ -1,6 +0,0 @@
-{%- if tracing %}# generated by {{self._TemplateReference__context.name}}{% endif %}
-{%- for spike_updates_for_port in spike_updates.values() %}
-{%- for ast in spike_updates_for_port -%}
-{%- include "directives_py/Assignment.jinja2" %}
-{%- endfor %}
-{%- endfor %}
diff --git a/pynestml/frontend/pynestml_frontend.py b/pynestml/frontend/pynestml_frontend.py
index 53a8cbdd6..5aef28052 100644
--- a/pynestml/frontend/pynestml_frontend.py
+++ b/pynestml/frontend/pynestml_frontend.py
@@ -37,6 +37,7 @@
from pynestml.symbols.predefined_types import PredefinedTypes
from pynestml.symbols.predefined_units import PredefinedUnits
from pynestml.symbols.predefined_variables import PredefinedVariables
+from pynestml.transformers.convolutions_transformer import ConvolutionsTransformer
from pynestml.transformers.transformer import Transformer
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.utils.messages import Messages
@@ -59,6 +60,9 @@ def transformers_from_target_name(target_name: str, options: Optional[Mapping[st
if options is None:
options = {}
+ # for all targets, add the convolutions transformer
+ transformers.append(ConvolutionsTransformer())
+
if target_name.upper() in ["NEST", "SPINNAKER"]:
from pynestml.transformers.illegal_variable_name_transformer import IllegalVariableNameTransformer
diff --git a/pynestml/meta_model/ast_model.py b/pynestml/meta_model/ast_model.py
index 834e56897..b38629275 100644
--- a/pynestml/meta_model/ast_model.py
+++ b/pynestml/meta_model/ast_model.py
@@ -670,8 +670,10 @@ def get_on_receive_block(self, port_name: str) -> Optional[ASTOnReceiveBlock]:
return self.get_body().get_on_receive_block(port_name)
def get_on_condition_blocks(self) -> List[ASTOnConditionBlock]:
+ r"""See ASTModelBody.get_on_condition_blocks() for the documentation for this function."""
if not self.get_body():
return []
+
return self.get_body().get_on_condition_blocks()
def get_on_condition_block(self, port_name: str) -> Optional[ASTOnConditionBlock]:
diff --git a/pynestml/meta_model/ast_model_body.py b/pynestml/meta_model/ast_model_body.py
index 6e32561ce..1908bb88d 100644
--- a/pynestml/meta_model/ast_model_body.py
+++ b/pynestml/meta_model/ast_model_body.py
@@ -182,10 +182,14 @@ def get_on_condition_block(self, port_name) -> Optional[ASTOnConditionBlock]:
return None
def get_on_condition_blocks(self) -> List[ASTOnConditionBlock]:
+ r"""
+ XXX: TODO: sorting based on priority
+ """
on_condition_blocks = []
for elem in self.get_body_elements():
if isinstance(elem, ASTOnConditionBlock):
on_condition_blocks.append(elem)
+
return on_condition_blocks
def get_equations_blocks(self) -> List[ASTEquationsBlock]:
diff --git a/pynestml/transformers/convolutions_transformer.py b/pynestml/transformers/convolutions_transformer.py
new file mode 100644
index 000000000..cb38f8f04
--- /dev/null
+++ b/pynestml/transformers/convolutions_transformer.py
@@ -0,0 +1,501 @@
+# -*- coding: utf-8 -*-
+#
+# convolutions_transformer.py
+#
+# This file is part of NEST.
+#
+# Copyright (C) 2004 The NEST Initiative
+#
+# NEST is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# NEST is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+
+from __future__ import annotations
+
+from typing import Any, Dict, List, Sequence, Mapping, Optional, Tuple, Union
+
+import re
+
+import odetoolbox
+
+from pynestml.codegeneration.printers.ast_printer import ASTPrinter
+from pynestml.codegeneration.printers.constant_printer import ConstantPrinter
+from pynestml.codegeneration.printers.ode_toolbox_expression_printer import ODEToolboxExpressionPrinter
+from pynestml.codegeneration.printers.ode_toolbox_function_call_printer import ODEToolboxFunctionCallPrinter
+from pynestml.codegeneration.printers.ode_toolbox_variable_printer import ODEToolboxVariablePrinter
+from pynestml.codegeneration.printers.unitless_sympy_simple_expression_printer import UnitlessSympySimpleExpressionPrinter
+from pynestml.frontend.frontend_configuration import FrontendConfiguration
+from pynestml.meta_model.ast_assignment import ASTAssignment
+from pynestml.meta_model.ast_declaration import ASTDeclaration
+from pynestml.meta_model.ast_equations_block import ASTEquationsBlock
+from pynestml.meta_model.ast_expression import ASTExpression
+from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
+from pynestml.meta_model.ast_input_port import ASTInputPort
+from pynestml.meta_model.ast_kernel import ASTKernel
+from pynestml.meta_model.ast_model import ASTModel
+from pynestml.meta_model.ast_node import ASTNode
+from pynestml.meta_model.ast_node_factory import ASTNodeFactory
+from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression
+from pynestml.meta_model.ast_variable import ASTVariable
+from pynestml.symbols.predefined_functions import PredefinedFunctions
+from pynestml.symbols.symbol import SymbolKind
+from pynestml.symbols.variable_symbol import BlockType
+from pynestml.transformers.transformer import Transformer
+from pynestml.utils.ast_source_location import ASTSourceLocation
+from pynestml.utils.ast_utils import ASTUtils
+from pynestml.utils.logger import Logger
+from pynestml.utils.logger import LoggingLevel
+from pynestml.utils.model_parser import ModelParser
+from pynestml.utils.string_utils import removesuffix
+from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
+from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
+from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor
+from pynestml.visitors.ast_visitor import ASTVisitor
+
+
+class ConvolutionsTransformer(Transformer):
+ r"""For each convolution that occurs in the model, allocate one or more needed state variables and replace the convolution() calls by these variable names."""
+
+ _default_options = {
+ "convolution_separator": "__conv__",
+ "diff_order_symbol": "__d",
+ "simplify_expression": "sympy.logcombine(sympy.powsimp(sympy.expand(expr)))"
+ }
+
+ def __init__(self, options: Optional[Mapping[str, Any]] = None):
+ super(Transformer, self).__init__(options)
+
+ # ODE-toolbox printers
+ self._constant_printer = ConstantPrinter()
+ self._ode_toolbox_variable_printer = ODEToolboxVariablePrinter(None)
+ self._ode_toolbox_function_call_printer = ODEToolboxFunctionCallPrinter(None)
+ self._ode_toolbox_printer = ODEToolboxExpressionPrinter(simple_expression_printer=UnitlessSympySimpleExpressionPrinter(variable_printer=self._ode_toolbox_variable_printer,
+ constant_printer=self._constant_printer,
+ function_call_printer=self._ode_toolbox_function_call_printer))
+ self._ode_toolbox_variable_printer._expression_printer = self._ode_toolbox_printer
+ self._ode_toolbox_function_call_printer._expression_printer = self._ode_toolbox_printer
+
+ def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]:
+ r"""Transform a model or a list of models. Return an updated model or list of models."""
+ for model in models:
+ print("-------- MODEL BEFORE TRANSFORM ------------")
+ print(model)
+ kernel_buffers = self.generate_kernel_buffers(model)
+ odetoolbox_indict = self.transform_kernels_to_json(model, kernel_buffers)
+ print("odetoolbox indict: " + str(odetoolbox_indict))
+ solvers_json, shape_sys, shapes = odetoolbox._analysis(odetoolbox_indict,
+ disable_stiffness_check=True,
+ disable_analytic_solver=True,
+ preserve_expressions=True,
+ simplify_expression=self.get_option("simplify_expression"),
+ log_level=FrontendConfiguration.logging_level)
+ print("odetoolbox outdict: " + str(solvers_json))
+
+ self.remove_initial_values_for_kernels(model)
+ self.create_initial_values_for_kernels(model, solvers_json, kernel_buffers)
+ self.create_spike_update_event_handlers(model, solvers_json, kernel_buffers)
+ self.replace_convolve_calls_with_buffers_(model)
+ self.remove_kernel_definitions_from_equations_blocks(model)
+ print("-------- MODEL AFTER TRANSFORM ------------")
+ print(model)
+ print("-------------------------------------------")
+
+ return models
+
+ def construct_kernel_spike_buf_name(self, kernel_var_name: str, spike_input_port: ASTInputPort, order: int, diff_order_symbol: Optional[str] = None):
+ """
+ Construct a kernel-buffer name as ``KERNEL_NAME__conv__INPUT_PORT_NAME``
+
+ For example, if the kernel is
+ .. code-block::
+ kernel I_kernel = exp(-t / tau_x)
+
+ and the input port is
+ .. code-block::
+ pre_spikes nS <- spike
+
+ then the constructed variable will be ``I_kernel__conv__pre_pikes``
+ """
+ assert type(kernel_var_name) is str
+ assert type(order) is int
+
+ if isinstance(spike_input_port, ASTSimpleExpression):
+ spike_input_port = spike_input_port.get_variable()
+
+ if not isinstance(spike_input_port, str):
+ spike_input_port_name = spike_input_port.get_name()
+ else:
+ spike_input_port_name = spike_input_port
+
+ if isinstance(spike_input_port, ASTVariable):
+ if spike_input_port.has_vector_parameter():
+ spike_input_port_name += "_" + str(self.get_numeric_vector_size(spike_input_port))
+
+ if not diff_order_symbol:
+ diff_order_symbol = self.get_option("diff_order_symbol")
+
+ return kernel_var_name.replace("$", "__DOLLAR") + self.get_option("convolution_separator") + spike_input_port_name + diff_order_symbol * order
+
+ def replace_rhs_variable(self, expr: ASTExpression, variable_name_to_replace: str, kernel_var: ASTVariable,
+ spike_buf: ASTInputPort):
+ """
+ Replace variable names in definitions of kernel dynamics
+ :param expr: expression in which to replace the variables
+ :param variable_name_to_replace: variable name to replace in the expression
+ :param kernel_var: kernel variable instance
+ :param spike_buf: input port instance
+ :return:
+ """
+ def replace_kernel_var(node):
+ if type(node) is ASTSimpleExpression \
+ and node.is_variable() \
+ and node.get_variable().get_name() == variable_name_to_replace:
+ var_order = node.get_variable().get_differential_order()
+ new_variable_name = cls.construct_kernel_X_spike_buf_name(
+ kernel_var.get_name(), spike_buf, var_order - 1, diff_order_symbol="'")
+ new_variable = ASTVariable(new_variable_name, var_order)
+ new_variable.set_source_position(node.get_variable().get_source_position())
+ node.set_variable(new_variable)
+
+ expr.accept(ASTHigherOrderVisitor(visit_funcs=replace_kernel_var))
+
+ def replace_rhs_variables(self, expr: ASTExpression, kernel_buffers: Mapping[ASTKernel, ASTInputPort]):
+ """
+ Replace variable names in definitions of kernel dynamics.
+
+ Say that the kernel is
+
+ .. code-block::
+
+ G = -G / tau
+
+ Its variable symbol might be replaced by "G__conv__spikesEx":
+
+ .. code-block::
+
+ G__conv__spikesEx = -G / tau
+
+ This function updates the right-hand side of `expr` so that it would also read (in this example):
+
+ .. code-block::
+
+ G__conv__spikesEx = -G__conv__spikesEx / tau
+
+ These equations will later on be fed to ode-toolbox, so we use the symbol "'" to indicate differential order.
+
+ Note that for kernels/systems of ODE of dimension > 1, all variable orders and all variables for this kernel will already be present in `kernel_buffers`.
+ """
+ for kernel, spike_buf in kernel_buffers:
+ for kernel_var in kernel.get_variables():
+ variable_name_to_replace = kernel_var.get_name()
+ self.replace_rhs_variable(expr, variable_name_to_replace=variable_name_to_replace,
+ kernel_var=kernel_var, spike_buf=spike_buf)
+
+ @classmethod
+ def remove_initial_values_for_kernels(cls, model: ASTModel) -> None:
+ r"""
+ Remove initial values for original declarations (e.g. g_in, g_in', V_m); these will be replaced with the initial value expressions returned from ODE-toolbox.
+ """
+ symbols_to_remove = set()
+ for equations_block in model.get_equations_blocks():
+ for kernel in equations_block.get_kernels():
+ for kernel_var in kernel.get_variables():
+ kernel_var_order = kernel_var.get_differential_order()
+ for order in range(kernel_var_order):
+ symbol_name = kernel_var.get_name() + "'" * order
+ symbols_to_remove.add(symbol_name)
+
+ decl_to_remove = set()
+ for symbol_name in symbols_to_remove:
+ for state_block in model.get_state_blocks():
+ for decl in state_block.get_declarations():
+ if len(decl.get_variables()) == 1:
+ if decl.get_variables()[0].get_name() == symbol_name:
+ decl_to_remove.add(decl)
+ else:
+ for var in decl.get_variables():
+ if var.get_name() == symbol_name:
+ decl.variables.remove(var)
+
+ for decl in decl_to_remove:
+ for state_block in model.get_state_blocks():
+ if decl in state_block.get_declarations():
+ state_block.get_declarations().remove(decl)
+
+ def create_initial_values_for_kernels(self, model: ASTModel, solver_dicts: List[Dict], kernels: List[ASTKernel]) -> None:
+ r"""
+ Add the variables used in kernels from the ode-toolbox result dictionary as ODEs in NESTML AST
+ """
+ for solver_dict in solver_dicts:
+ if solver_dict is None:
+ continue
+
+ for var_name, expr in solver_dict["initial_values"].items():
+ spike_in_port_name = var_name.split(self.get_option("convolution_separator"))[1]
+ spike_in_port_name = spike_in_port_name.split("__d")[0]
+ spike_in_port = ASTUtils.get_input_port_by_name(model.get_input_blocks(), spike_in_port_name)
+ type_str = "real"
+ if spike_in_port:
+ differential_order: int = len(re.findall("__d", var_name))
+ if differential_order:
+ type_str = "(s**-" + str(differential_order) + ")"
+
+ expr = "0 " + type_str # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is 0 (property of the convolution)
+ if not ASTUtils.declaration_in_state_block(model, var_name):
+ ASTUtils.add_declaration_to_state_block(model, var_name, expr, type_str)
+
+ def is_delta_kernel(self, kernel: ASTKernel) -> bool:
+ """
+ Catches definition of kernel, or reference (function call or variable name) of a delta kernel function.
+ """
+ if not isinstance(kernel, ASTKernel):
+ return False
+
+ if len(kernel.get_variables()) != 1:
+ # delta kernel not allowed if more than one variable is defined in this kernel
+ return False
+
+ expr = kernel.get_expressions()[0]
+
+ rhs_is_delta_kernel = type(expr) is ASTSimpleExpression \
+ and expr.is_function_call() \
+ and expr.get_function_call().get_scope().resolve_to_symbol(expr.get_function_call().get_name(), SymbolKind.FUNCTION).equals(PredefinedFunctions.name2function["delta"])
+
+ rhs_is_multiplied_delta_kernel = type(expr) is ASTExpression \
+ and type(expr.get_rhs()) is ASTSimpleExpression \
+ and expr.get_rhs().is_function_call() \
+ and expr.get_rhs().get_function_call().get_scope().resolve_to_symbol(expr.get_rhs().get_function_call().get_name(), SymbolKind.FUNCTION).equals(PredefinedFunctions.name2function["delta"])
+
+ return rhs_is_delta_kernel or rhs_is_multiplied_delta_kernel
+
+ def replace_convolve_calls_with_buffers_(self, model: ASTModel) -> None:
+ r"""
+ Replace all occurrences of `convolve(kernel[']^n, spike_input_port)` with the corresponding buffer variable, e.g. `g_E__X__spikes_exc[__d]^n` for a kernel named `g_E` and a spike input port named `spikes_exc`.
+ """
+
+ def replace_function_call_through_var(_expr=None):
+ if _expr.is_function_call() and _expr.get_function_call().get_name() == "convolve":
+ convolve = _expr.get_function_call()
+ el = (convolve.get_args()[0], convolve.get_args()[1])
+ sym = convolve.get_args()[0].get_scope().resolve_to_symbol(
+ convolve.get_args()[0].get_variable().name, SymbolKind.VARIABLE)
+ if sym.block_type == BlockType.INPUT:
+ # swap elements
+ el = (el[1], el[0])
+ var = el[0].get_variable()
+ spike_input_port = el[1].get_variable()
+ kernel = model.get_kernel_by_name(var.get_name())
+
+ _expr.set_function_call(None)
+ buffer_var = self.construct_kernel_spike_buf_name(
+ var.get_name(), spike_input_port, var.get_differential_order() - 1)
+ if self.is_delta_kernel(kernel):
+ # delta kernels are treated separately, and should be kept out of the dynamics (computing derivates etc.) --> set to zero
+ _expr.set_variable(None)
+ _expr.set_numeric_literal(0)
+ else:
+ ast_variable = ASTVariable(buffer_var)
+ ast_variable.set_source_position(_expr.get_source_position())
+ _expr.set_variable(ast_variable)
+
+ def func(x):
+ return replace_function_call_through_var(x) if isinstance(x, ASTSimpleExpression) else True
+
+ for equations_block in model.get_equations_blocks():
+ equations_block.accept(ASTHigherOrderVisitor(func))
+
+ @classmethod
+ def replace_convolution_aliasing_inlines(cls, neuron: ASTModel) -> None:
+ """
+ Replace all occurrences of kernel names (e.g. ``I_dend`` and ``I_dend'`` for a definition involving a second-order kernel ``inline kernel I_dend = convolve(kern_name, spike_buf)``) with the ODE-toolbox generated variable ``kern_name__X__spike_buf``.
+ """
+ def replace_var(_expr, replace_var_name: str, replace_with_var_name: str):
+ if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable():
+ var = _expr.get_variable()
+ if var.get_name() == replace_var_name:
+ ast_variable = ASTVariable(replace_with_var_name + '__d' * var.get_differential_order(),
+ differential_order=0)
+ ast_variable.set_source_position(var.get_source_position())
+ _expr.set_variable(ast_variable)
+
+ elif isinstance(_expr, ASTVariable):
+ var = _expr
+ if var.get_name() == replace_var_name:
+ var.set_name(replace_with_var_name + '__d' * var.get_differential_order())
+ var.set_differential_order(0)
+
+ for equation_block in neuron.get_equations_blocks():
+ for decl in equation_block.get_declarations():
+ if isinstance(decl, ASTInlineExpression):
+ expr = decl.get_expression()
+ if isinstance(expr, ASTExpression):
+ expr = expr.get_lhs()
+
+ if isinstance(expr, ASTSimpleExpression) \
+ and '__X__' in str(expr) \
+ and expr.get_variable():
+ replace_with_var_name = expr.get_variable().get_name()
+ neuron.accept(ASTHigherOrderVisitor(lambda x: replace_var(
+ x, decl.get_variable_name(), replace_with_var_name)))
+
+ def generate_kernel_buffers(self, model: ASTModel) -> Mapping[ASTKernel, ASTInputPort]:
+ r"""
+ For every occurrence of a convolution of the form `convolve(var, spike_buf)`: add the element `(kernel, spike_buf)` to the set, with `kernel` being the kernel that contains variable `var`.
+ """
+ kernel_buffers = set()
+ for equations_block in model.get_equations_blocks():
+ convolve_calls = ASTUtils.get_convolve_function_calls(equations_block)
+ for convolve in convolve_calls:
+ el = (convolve.get_args()[0], convolve.get_args()[1])
+ sym = convolve.get_args()[0].get_scope().resolve_to_symbol(convolve.get_args()[0].get_variable().name, SymbolKind.VARIABLE)
+ if sym is None:
+ raise Exception("No initial value(s) defined for kernel with variable \""
+ + convolve.get_args()[0].get_variable().get_complete_name() + "\"")
+ if sym.block_type == BlockType.INPUT:
+ # swap the order
+ el = (el[1], el[0])
+
+ # find the corresponding kernel object
+ var = el[0].get_variable()
+ assert var is not None
+ kernel = model.get_kernel_by_name(var.get_name())
+ assert kernel is not None, "In convolution \"convolve(" + str(var.name) + ", " + str(
+ el[1]) + ")\": no kernel by name \"" + var.get_name() + "\" found in model."
+
+ el = (kernel, el[1])
+ kernel_buffers.add(el)
+
+ return kernel_buffers
+
+ def remove_kernel_definitions_from_equations_blocks(self, model: ASTModel) -> ASTDeclaration:
+ r"""
+ Removes all kernels in equations blocks.
+ """
+ for equations_block in model.get_equations_blocks():
+ decl_to_remove = set()
+ for decl in equations_block.get_declarations():
+ if type(decl) is ASTKernel:
+ decl_to_remove.add(decl)
+
+ for decl in decl_to_remove:
+ equations_block.get_declarations().remove(decl)
+
+ def transform_kernels_to_json(self, model: ASTModel, kernel_buffers: List[Tuple[ASTKernel, ASTInputPort]]) -> Dict:
+ """
+ Converts AST node to a JSON representation suitable for passing to ode-toolbox.
+
+ Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements
+
+ .. code-block::
+
+ convolve(G, exc_spikes)
+ convolve(G, inh_spikes)
+
+ then `kernel_buffers` will contain the pairs `(G, exc_spikes)` and `(G, inh_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__exc_spikes` and `G__X__inh_spikes`.
+ """
+ odetoolbox_indict = {}
+ odetoolbox_indict["dynamics"] = []
+
+ for kernel, spike_input_port in kernel_buffers:
+
+ if self.is_delta_kernel(kernel):
+ # delta function -- skip passing this to ode-toolbox
+ continue
+
+ for kernel_var in kernel.get_variables():
+ expr = ASTUtils.get_expr_from_kernel_var(kernel, kernel_var.get_complete_name())
+ kernel_order = kernel_var.get_differential_order()
+ kernel_X_spike_buf_name_ticks = self.construct_kernel_spike_buf_name(kernel_var.get_name(), spike_input_port, kernel_order, diff_order_symbol="'")
+
+ self.replace_rhs_variables(expr, kernel_buffers)
+
+ entry = {"expression": kernel_X_spike_buf_name_ticks + " = " + str(expr), "initial_values": {}}
+
+ # initial values need to be declared for order 1 up to kernel order (e.g. none for kernel function
+ # f(t) = ...; 1 for kernel ODE f'(t) = ...; 2 for f''(t) = ... and so on)
+ for order in range(kernel_order):
+ iv_sym_name_ode_toolbox = self.construct_kernel_spike_buf_name(kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'")
+ symbol_name_ = kernel_var.get_name() + "'" * order
+ symbol = model.get_scope().resolve_to_symbol(symbol_name_, SymbolKind.VARIABLE)
+ assert symbol is not None, "Could not find initial value for variable " + symbol_name_
+ initial_value_expr = symbol.get_declaring_expression()
+ assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_
+ entry["initial_values"][iv_sym_name_ode_toolbox] = self._ode_toolbox_printer.print(initial_value_expr)
+
+ odetoolbox_indict["dynamics"].append(entry)
+
+ odetoolbox_indict["parameters"] = {}
+ for parameters_block in model.get_parameters_blocks():
+ for decl in parameters_block.get_declarations():
+ for var in decl.variables:
+ odetoolbox_indict["parameters"][var.get_complete_name()] = self._ode_toolbox_printer.print(decl.get_expression())
+
+ return odetoolbox_indict
+
+ def create_spike_update_event_handlers(self, model: ASTModel, solver_dicts, kernel_buffers: List[Tuple[ASTKernel, ASTInputPort]]) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment]]:
+ r"""
+ Generate the equations that update the dynamical variables when incoming spikes arrive. To be invoked after
+ ode-toolbox.
+
+ For example, a resulting `assignment_str` could be "I_kernel_in += (inh_spikes/nS) * 1". The values are taken from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from user specification in the model.
+ from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from
+ user specification in the model.
+
+ Note that for kernels, `initial_values` actually contains the increment upon spike arrival, rather than the
+ initial value of the corresponding ODE dimension.
+ """
+
+ spike_in_port_to_stmts = {}
+ for solver_dict in solver_dicts:
+ for var, expr in solver_dict["initial_values"].items():
+ expr = str(expr)
+ if expr in ["0", "0.", "0.0"]:
+ continue # skip adding the statement if we are only adding zero
+
+ spike_in_port_name = var.split(self.get_option("convolution_separator"))[1]
+ spike_in_port_name = spike_in_port_name.split("__d")[0]
+ spike_in_port = ASTUtils.get_input_port_by_name(model.get_input_blocks(), spike_in_port_name)
+ type_str = "real"
+
+ assert spike_in_port
+ differential_order: int = len(re.findall("__d", var))
+ if differential_order:
+ type_str = "(s**-" + str(differential_order) + ")"
+
+ assignment_str = var + " += "
+ assignment_str += "(" + str(spike_in_port_name) + ")"
+ if not expr in ["1.", "1.0", "1"]:
+ assignment_str += " * (" + expr + ")"
+
+ ast_assignment = ModelParser.parse_assignment(assignment_str)
+ ast_assignment.update_scope(model.get_scope())
+ ast_assignment.accept(ASTSymbolTableVisitor())
+
+ ast_small_stmt = ASTNodeFactory.create_ast_small_stmt(assignment=ast_assignment)
+ ast_stmt = ASTNodeFactory.create_ast_stmt(small_stmt=ast_small_stmt)
+
+ if not spike_in_port_name in spike_in_port_to_stmts.keys():
+ spike_in_port_to_stmts[spike_in_port_name] = []
+
+ spike_in_port_to_stmts[spike_in_port_name].append(ast_stmt)
+
+ # for every input port, add an onreceive block with its update statements
+ for in_port, stmts in spike_in_port_to_stmts.items():
+ stmts_block = ASTNodeFactory.create_ast_block(stmts, ASTSourceLocation.get_added_source_position())
+ on_receive_block = ASTNodeFactory.create_ast_on_receive_block(stmts_block,
+ in_port,
+ const_parameters=None, # XXX: TODO: add priority here!
+ source_position=ASTSourceLocation.get_added_source_position())
+
+ model.get_body().get_body_elements().append(on_receive_block)
+
+ model.accept(ASTParentVisitor())
diff --git a/pynestml/transformers/synapse_post_neuron_transformer.py b/pynestml/transformers/synapse_post_neuron_transformer.py
index 5dd4aa3e0..ccaefd743 100644
--- a/pynestml/transformers/synapse_post_neuron_transformer.py
+++ b/pynestml/transformers/synapse_post_neuron_transformer.py
@@ -24,13 +24,8 @@
from typing import Any, Sequence, Mapping, Optional, Union
from pynestml.frontend.frontend_configuration import FrontendConfiguration
-from pynestml.meta_model.ast_assignment import ASTAssignment
-from pynestml.meta_model.ast_equations_block import ASTEquationsBlock
-from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
from pynestml.meta_model.ast_model import ASTModel
from pynestml.meta_model.ast_node import ASTNode
-from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression
-from pynestml.meta_model.ast_variable import ASTVariable
from pynestml.symbols.symbol import SymbolKind
from pynestml.symbols.variable_symbol import BlockType
from pynestml.transformers.transformer import Transformer
@@ -40,7 +35,6 @@
from pynestml.utils.string_utils import removesuffix
from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
-from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor
from pynestml.visitors.ast_visitor import ASTVisitor
@@ -149,51 +143,6 @@ def get_neuron_var_name_from_syn_port_name(self, port_name: str, neuron_name: st
return None
- def get_convolve_with_not_post_vars(self, nodes: Union[ASTEquationsBlock, Sequence[ASTEquationsBlock]], neuron_name: str, synapse_name: str, parent_node: ASTNode):
- class ASTVariablesUsedInConvolutionVisitor(ASTVisitor):
- _variables = []
-
- def __init__(self, node: ASTNode, parent_node: ASTNode, codegen_class):
- super(ASTVariablesUsedInConvolutionVisitor, self).__init__()
- self.node = node
- self.parent_node = parent_node
- self.codegen_class = codegen_class
-
- def visit_function_call(self, node):
- func_name = node.get_name()
- if func_name == "convolve":
- symbol_buffer = node.get_scope().resolve_to_symbol(str(node.get_args()[1]),
- SymbolKind.VARIABLE)
- input_port = ASTUtils.get_input_port_by_name(
- self.parent_node.get_input_blocks(), symbol_buffer.name)
- if input_port and not self.codegen_class.is_post_port(input_port.name, neuron_name, synapse_name):
- kernel_name = node.get_args()[0].get_variable().name
- self._variables.append(kernel_name)
-
- found_parent_assignment = False
- node_ = node
- while not found_parent_assignment:
- node_ = node_.get_parent()
- # XXX TODO also needs to accept normal ASTExpression, ASTAssignment?
- if isinstance(node_, ASTInlineExpression):
- found_parent_assignment = True
- var_name = node_.get_variable_name()
- self._variables.append(var_name)
-
- if not nodes:
- return []
-
- if isinstance(nodes, ASTNode):
- nodes = [nodes]
-
- variables = []
- for node in nodes:
- visitor = ASTVariablesUsedInConvolutionVisitor(node, parent_node, self)
- node.accept(visitor)
- variables.extend(visitor._variables)
-
- return variables
-
def get_all_variables_assigned_to(self, node):
r"""Return a list of all variables that are assigned to in ``node``."""
class ASTAssignedToVariablesFinderVisitor(ASTVisitor):
@@ -257,13 +206,6 @@ def transform_neuron_synapse_pair_(self, neuron, synapse):
all_state_vars = [var.get_complete_name() for var in all_state_vars]
- # add names of convolutions
- all_state_vars += ASTUtils.get_all_variables_used_in_convolutions(synapse.get_equations_blocks(), synapse)
-
- # add names of kernels
- kernel_buffers = ASTUtils.generate_kernel_buffers(synapse, synapse.get_equations_blocks())
- all_state_vars += [var.name for k in kernel_buffers for var in k[0].variables]
-
# exclude certain variables from being moved:
# exclude any variable assigned to in any block that is not connected to a postsynaptic port
strictly_synaptic_vars = ["t"] # "seed" this with the predefined variable t
@@ -276,28 +218,25 @@ def transform_neuron_synapse_pair_(self, neuron, synapse):
for update_block in synapse.get_update_blocks():
strictly_synaptic_vars += self.get_all_variables_assigned_to(update_block)
- # exclude convolutions if they are not with a postsynaptic variable
- convolve_with_not_post_vars = self.get_convolve_with_not_post_vars(synapse.get_equations_blocks(), neuron.name, synapse.name, synapse)
-
# exclude all variables that depend on the ones that are not to be moved
strictly_synaptic_vars_dependent = ASTUtils.recursive_dependent_variables_search(strictly_synaptic_vars, synapse)
# do set subtraction
- syn_to_neuron_state_vars = list(set(all_state_vars) - (set(strictly_synaptic_vars) | set(convolve_with_not_post_vars) | set(strictly_synaptic_vars_dependent)))
+ syn_to_neuron_state_vars = list(set(all_state_vars) - (set(strictly_synaptic_vars) | set(strictly_synaptic_vars_dependent)))
#
- # collect all the variable/parameter/kernel/function/etc. names used in defining expressions of `syn_to_neuron_state_vars`
+ # collect all the variable/parameter/function/etc. names used in defining expressions of `syn_to_neuron_state_vars`
#
recursive_vars_used = ASTUtils.recursive_necessary_variables_search(syn_to_neuron_state_vars, synapse)
new_neuron.recursive_vars_used = recursive_vars_used
new_neuron._transferred_variables = [neuron_state_var + var_name_suffix
- for neuron_state_var in syn_to_neuron_state_vars if new_synapse.get_kernel_by_name(neuron_state_var) is None]
+ for neuron_state_var in syn_to_neuron_state_vars]
# all state variables that will be moved from synapse to neuron
syn_to_neuron_state_vars = []
for var_name in recursive_vars_used:
- if ASTUtils.get_state_variable_by_name(synapse, var_name) or ASTUtils.get_inline_expression_by_name(synapse, var_name) or ASTUtils.get_kernel_by_name(synapse, var_name):
+ if ASTUtils.get_state_variable_by_name(synapse, var_name) or ASTUtils.get_inline_expression_by_name(synapse, var_name):
syn_to_neuron_state_vars.append(var_name)
Logger.log_message(None, -1, "State variables that will be moved from synapse to neuron: " + str(syn_to_neuron_state_vars),
@@ -393,33 +332,6 @@ def transform_neuron_synapse_pair_(self, neuron, synapse):
block_type=BlockType.STATE,
mode="move")
- #
- # mark variables in the neuron pertaining to synapse postsynaptic ports
- #
- # convolutions with them ultimately yield variable updates when post neuron calls emit_spike()
- #
-
- def mark_post_ports(neuron, synapse, mark_node):
- post_ports = []
-
- def mark_post_port(_expr=None):
- var = None
- if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable():
- var = _expr.get_variable()
- elif isinstance(_expr, ASTVariable):
- var = _expr
-
- if var:
- var_base_name = var.name[:-len(var_name_suffix)] # prune the suffix
- if self.is_post_port(var_base_name, neuron.name, synapse.name):
- post_ports.append(var)
- var._is_post_port = True
-
- mark_node.accept(ASTHigherOrderVisitor(lambda x: mark_post_port(x)))
- return post_ports
-
- mark_post_ports(new_neuron, new_synapse, new_neuron)
-
#
# move statements in post receive block from synapse to new_neuron
#
@@ -561,6 +473,13 @@ def mark_post_port(_expr=None):
return new_neuron, new_synapse
def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]:
+
+ # check that there are no convolutions or kernels in the model (these should have been transformed out by the ConvolutionsTransformer)
+ for model in models:
+ for equations_block in model.get_equations_blocks():
+ assert len(equations_block.get_kernels()) == 0, "Kernels and convolutions should have been removed by ConvolutionsTransformer"
+
+ # transform each (neuron, synapse) pair
for neuron_synapse_pair in self.get_option("neuron_synapse_pairs"):
neuron_name = neuron_synapse_pair["neuron"]
synapse_name = neuron_synapse_pair["synapse"]
diff --git a/pynestml/transformers/transformer.py b/pynestml/transformers/transformer.py
index 7144a9bec..06dbb37c5 100644
--- a/pynestml/transformers/transformer.py
+++ b/pynestml/transformers/transformer.py
@@ -39,4 +39,5 @@ def __init__(self, options: Optional[Mapping[str, Any]]=None):
@abstractmethod
def transform(self, model: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]:
+ r"""Transform a model or a list of models. Return an updated model or list of models."""
assert False
diff --git a/pynestml/utils/ast_synapse_information_collector.py b/pynestml/utils/ast_synapse_information_collector.py
index f5a6763bc..e44a539a1 100644
--- a/pynestml/utils/ast_synapse_information_collector.py
+++ b/pynestml/utils/ast_synapse_information_collector.py
@@ -207,7 +207,7 @@ def get_basic_kernel_variable_names(self, synapse_inline):
kernel_name = kernel_var.get_name()
spike_input_port = self.input_port_name_to_input_port[spike_var.get_name(
)]
- kernel_variable_name = self.construct_kernel_X_spike_buf_name(
+ kernel_variable_name = self.construct_kernel_spike_buf_name(
kernel_name, spike_input_port, order)
results.append(kernel_variable_name)
@@ -338,12 +338,3 @@ def visit_expression(self, node):
def endvisit_expression(self, node):
self.inside_expression = False
self.current_expression = None
-
- # this method was copied over from ast_transformer
- # in order to avoid a circular dependency
- @staticmethod
- def construct_kernel_X_spike_buf_name(kernel_var_name: str, spike_input_port, order: int, diff_order_symbol="__d"):
- assert type(kernel_var_name) is str
- assert type(order) is int
- assert type(diff_order_symbol) is str
- return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + str(spike_input_port) + diff_order_symbol * order
diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py
index 9a67054a1..a8755bb89 100644
--- a/pynestml/utils/ast_utils.py
+++ b/pynestml/utils/ast_utils.py
@@ -458,6 +458,17 @@ def create_equations_block(cls, model: ASTModel) -> ASTModel:
model.get_body().get_body_elements().append(block)
return model
+ @classmethod
+ def create_on_receive_block(cls, model: ASTModel, block: ASTBlock, input_port_name: str) -> ASTModel:
+ """
+ Creates a single onReceive block in the handed over model.
+ :param model: a single model
+ :return: the modified model
+ """
+ # local import since otherwise circular dependency
+
+ return model
+
@classmethod
def contains_convolve_call(cls, variable: VariableSymbol) -> bool:
"""
@@ -1262,7 +1273,7 @@ def to_ode_toolbox_name(cls, name: str) -> str:
@classmethod
def get_expr_from_kernel_var(cls, kernel: ASTKernel, var_name: str) -> Union[ASTExpression, ASTSimpleExpression]:
- """
+ r"""
Get the expression using the kernel variable
"""
assert isinstance(var_name, str)
@@ -1277,122 +1288,9 @@ def all_convolution_variable_names(cls, model: ASTModel) -> List[str]:
var_names = [var.get_complete_name() for var in vars if "__X__" in var.get_complete_name()]
return var_names
- @classmethod
- def construct_kernel_X_spike_buf_name(cls, kernel_var_name: str, spike_input_port: ASTInputPort, order: int,
- diff_order_symbol="__d"):
- """
- Construct a kernel-buffer name as
-
- For example, if the kernel is
- .. code-block::
- kernel I_kernel = exp(-t / tau_x)
-
- and the input port is
- .. code-block::
- pre_spikes nS <- spike
-
- then the constructed variable will be 'I_kernel__X__pre_pikes'
- """
- assert type(kernel_var_name) is str
- assert type(order) is int
- assert type(diff_order_symbol) is str
-
- if isinstance(spike_input_port, ASTSimpleExpression):
- spike_input_port = spike_input_port.get_variable()
-
- if not isinstance(spike_input_port, str):
- spike_input_port_name = spike_input_port.get_name()
- else:
- spike_input_port_name = spike_input_port
-
- if isinstance(spike_input_port, ASTVariable):
- if spike_input_port.has_vector_parameter():
- spike_input_port_name += "_" + str(cls.get_numeric_vector_size(spike_input_port))
-
- return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + spike_input_port_name + diff_order_symbol * order
-
- @classmethod
- def replace_rhs_variable(cls, expr: ASTExpression, variable_name_to_replace: str, kernel_var: ASTVariable,
- spike_buf: ASTInputPort):
- """
- Replace variable names in definitions of kernel dynamics
- :param expr: expression in which to replace the variables
- :param variable_name_to_replace: variable name to replace in the expression
- :param kernel_var: kernel variable instance
- :param spike_buf: input port instance
- :return:
- """
- def replace_kernel_var(node):
- if type(node) is ASTSimpleExpression \
- and node.is_variable() \
- and node.get_variable().get_name() == variable_name_to_replace:
- var_order = node.get_variable().get_differential_order()
- new_variable_name = cls.construct_kernel_X_spike_buf_name(
- kernel_var.get_name(), spike_buf, var_order - 1, diff_order_symbol="'")
- new_variable = ASTVariable(new_variable_name, var_order)
- new_variable.set_source_position(node.get_variable().get_source_position())
- node.set_variable(new_variable)
-
- expr.accept(ASTHigherOrderVisitor(visit_funcs=replace_kernel_var))
-
- @classmethod
- def replace_rhs_variables(cls, expr: ASTExpression, kernel_buffers: Mapping[ASTKernel, ASTInputPort]):
- """
- Replace variable names in definitions of kernel dynamics.
-
- Say that the kernel is
-
- .. code-block::
-
- G = -G / tau
-
- Its variable symbol might be replaced by "G__X__spikesEx":
-
- .. code-block::
-
- G__X__spikesEx = -G / tau
-
- This function updates the right-hand side of `expr` so that it would also read (in this example):
-
- .. code-block::
-
- G__X__spikesEx = -G__X__spikesEx / tau
-
- These equations will later on be fed to ode-toolbox, so we use the symbol "'" to indicate differential order.
-
- Note that for kernels/systems of ODE of dimension > 1, all variable orders and all variables for this kernel will already be present in `kernel_buffers`.
- """
- for kernel, spike_buf in kernel_buffers:
- for kernel_var in kernel.get_variables():
- variable_name_to_replace = kernel_var.get_name()
- cls.replace_rhs_variable(expr, variable_name_to_replace=variable_name_to_replace,
- kernel_var=kernel_var, spike_buf=spike_buf)
-
- @classmethod
- def is_delta_kernel(cls, kernel: ASTKernel) -> bool:
- """
- Catches definition of kernel, or reference (function call or variable name) of a delta kernel function.
- """
- if type(kernel) is ASTKernel:
- if not len(kernel.get_variables()) == 1:
- # delta kernel not allowed if more than one variable is defined in this kernel
- return False
- expr = kernel.get_expressions()[0]
- else:
- expr = kernel
-
- rhs_is_delta_kernel = type(expr) is ASTSimpleExpression \
- and expr.is_function_call() \
- and expr.get_function_call().get_scope().resolve_to_symbol(expr.get_function_call().get_name(), SymbolKind.FUNCTION).equals(PredefinedFunctions.name2function["delta"])
- rhs_is_multiplied_delta_kernel = type(expr) is ASTExpression \
- and type(expr.get_rhs()) is ASTSimpleExpression \
- and expr.get_rhs().is_function_call() \
- and expr.get_rhs().get_function_call().get_scope().resolve_to_symbol(expr.get_rhs().get_function_call().get_name(), SymbolKind.FUNCTION).equals(PredefinedFunctions.name2function["delta"])
- return rhs_is_delta_kernel or rhs_is_multiplied_delta_kernel
-
@classmethod
def get_input_port_by_name(cls, input_blocks: List[ASTInputBlock], port_name: str) -> ASTInputPort:
- """
+ r"""
Get the input port given the port name
:param input_block: block to be searched
:param port_name: name of the input port
@@ -1407,8 +1305,10 @@ def get_input_port_by_name(cls, input_blocks: List[ASTInputBlock], port_name: st
port_name, port_index = port_name.split("_")
assert int(port_index) > 0
assert int(port_index) <= size_parameter
+
if input_port.name == port_name:
return input_port
+
return None
@classmethod
@@ -1424,6 +1324,7 @@ def get_parameter_by_name(cls, node: ASTModel, var_name: str) -> ASTDeclaration:
for var in decl.get_variables():
if var.get_name() == var_name:
return decl
+
return None
@classmethod
@@ -1629,37 +1530,6 @@ def recursive_necessary_variables_search(cls, vars: List[str], model: ASTModel)
return list(set(vars_used))
- @classmethod
- def remove_initial_values_for_kernels(cls, model: ASTModel) -> None:
- """
- Remove initial values for original declarations (e.g. g_in, g_in', V_m); these might conflict with the initial value expressions returned from ODE-toolbox.
- """
- symbols_to_remove = set()
- for equations_block in model.get_equations_blocks():
- for kernel in equations_block.get_kernels():
- for kernel_var in kernel.get_variables():
- kernel_var_order = kernel_var.get_differential_order()
- for order in range(kernel_var_order):
- symbol_name = kernel_var.get_name() + "'" * order
- symbols_to_remove.add(symbol_name)
-
- decl_to_remove = set()
- for symbol_name in symbols_to_remove:
- for state_block in model.get_state_blocks():
- for decl in state_block.get_declarations():
- if len(decl.get_variables()) == 1:
- if decl.get_variables()[0].get_name() == symbol_name:
- decl_to_remove.add(decl)
- else:
- for var in decl.get_variables():
- if var.get_name() == symbol_name:
- decl.variables.remove(var)
-
- for decl in decl_to_remove:
- for state_block in model.get_state_blocks():
- if decl in state_block.get_declarations():
- state_block.get_declarations().remove(decl)
-
@classmethod
def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict]) -> None:
"""
@@ -1888,53 +1758,12 @@ def _visit(self, node):
return visitor.calls
@classmethod
- def create_initial_values_for_kernels(cls, model: ASTModel, solver_dicts: List[Dict], kernels: List[ASTKernel]) -> None:
- r"""
- Add the variables used in kernels from the ode-toolbox result dictionary as ODEs in NESTML AST
- """
- for solver_dict in solver_dicts:
- if solver_dict is None:
- continue
-
- for var_name in solver_dict["initial_values"].keys():
- if cls.variable_in_kernels(var_name, kernels):
- # original initial value expressions should have been removed to make place for ode-toolbox results
- assert not cls.declaration_in_state_block(model, var_name)
-
- for solver_dict in solver_dicts:
- if solver_dict is None:
- continue
-
- for var_name, expr in solver_dict["initial_values"].items():
- # overwrite is allowed because initial values might be repeated between numeric and analytic solver
- if cls.variable_in_kernels(var_name, kernels):
- spike_in_port_name = var_name.split("__X__")[1]
- spike_in_port_name = spike_in_port_name.split("__d")[0]
- spike_in_port = ASTUtils.get_input_port_by_name(model.get_input_blocks(), spike_in_port_name)
- type_str = "real"
- if spike_in_port:
- differential_order: int = len(re.findall("__d", var_name))
- if differential_order:
- type_str = "(s**-" + str(differential_order) + ")"
-
- expr = "0 " + type_str # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is 0 (property of the convolution)
- if not cls.declaration_in_state_block(model, var_name):
- cls.add_declaration_to_state_block(model, var_name, expr, type_str)
-
- @classmethod
- def transform_ode_and_kernels_to_json(cls, model: ASTModel, parameters_blocks: Sequence[ASTBlockWithVariables],
- kernel_buffers: Mapping[ASTKernel, ASTInputPort], printer: ASTPrinter) -> Dict:
+ def transform_odes_to_json(cls,
+ model: ASTModel,
+ parameters_blocks: Sequence[ASTBlockWithVariables],
+ printer: ASTPrinter) -> Dict:
"""
- Converts AST node to a JSON representation suitable for passing to ode-toolbox.
-
- Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements
-
- .. code-block::
-
- convolve(G, exc_spikes)
- convolve(G, inh_spikes)
-
- then `kernel_buffers` will contain the pairs `(G, exc_spikes)` and `(G, inh_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__exc_spikes` and `G__X__inh_spikes`.
+ Converts to a JSON representation suitable for passing to ode-toolbox.
"""
odetoolbox_indict = {}
@@ -1959,37 +1788,6 @@ def transform_ode_and_kernels_to_json(cls, model: ASTModel, parameters_blocks: S
odetoolbox_indict["dynamics"].append(entry)
- # write a copy for each (kernel, spike buffer) combination
- for kernel, spike_input_port in kernel_buffers:
-
- if cls.is_delta_kernel(kernel):
- # delta function -- skip passing this to ode-toolbox
- continue
-
- for kernel_var in kernel.get_variables():
- expr = cls.get_expr_from_kernel_var(kernel, kernel_var.get_complete_name())
- kernel_order = kernel_var.get_differential_order()
- kernel_X_spike_buf_name_ticks = cls.construct_kernel_X_spike_buf_name(
- kernel_var.get_name(), spike_input_port, kernel_order, diff_order_symbol="'")
-
- cls.replace_rhs_variables(expr, kernel_buffers)
-
- entry = {"expression": kernel_X_spike_buf_name_ticks + " = " + str(expr), "initial_values": {}}
-
- # initial values need to be declared for order 1 up to kernel order (e.g. none for kernel function
- # f(t) = ...; 1 for kernel ODE f'(t) = ...; 2 for f''(t) = ... and so on)
- for order in range(kernel_order):
- iv_sym_name_ode_toolbox = cls.construct_kernel_X_spike_buf_name(
- kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'")
- symbol_name_ = kernel_var.get_name() + "'" * order
- symbol = equations_block.get_scope().resolve_to_symbol(symbol_name_, SymbolKind.VARIABLE)
- assert symbol is not None, "Could not find initial value for variable " + symbol_name_
- initial_value_expr = symbol.get_declaring_expression()
- assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_
- entry["initial_values"][iv_sym_name_ode_toolbox] = printer.print(initial_value_expr)
-
- odetoolbox_indict["dynamics"].append(entry)
-
odetoolbox_indict["parameters"] = {}
for parameters_block in parameters_blocks:
for decl in parameters_block.get_declarations():
@@ -2079,52 +1877,6 @@ def log_set_source_position(node):
return definitions
- @classmethod
- def get_delta_factors_(cls, neuron: ASTModel, equations_block: ASTEquationsBlock) -> dict:
- r"""
- For every occurrence of a convolution of the form `x^(n) = a * convolve(kernel, inport) + ...` where `kernel` is a delta function, add the element `(x^(n), inport) --> a` to the set.
- """
- delta_factors = {}
-
- for ode_eq in equations_block.get_ode_equations():
- var = ode_eq.get_lhs()
- expr = ode_eq.get_rhs()
- conv_calls = ASTUtils.get_convolve_function_calls(expr)
- for conv_call in conv_calls:
- assert len(
- conv_call.args) == 2, "convolve() function call should have precisely two arguments: kernel and spike input port"
- kernel = conv_call.args[0]
- if cls.is_delta_kernel(neuron.get_kernel_by_name(kernel.get_variable().get_name())):
- inport = conv_call.args[1].get_variable()
- expr_str = str(expr)
- sympy_expr = sympy.parsing.sympy_parser.parse_expr(expr_str, global_dict=odetoolbox.Shape._sympy_globals)
- sympy_expr = sympy.expand(sympy_expr)
- sympy_conv_expr = sympy.parsing.sympy_parser.parse_expr(str(conv_call), global_dict=odetoolbox.Shape._sympy_globals)
- factor_str = []
- for term in sympy.Add.make_args(sympy_expr):
- if term.find(sympy_conv_expr):
- factor_str.append(str(term.replace(sympy_conv_expr, 1)))
- factor_str = " + ".join(factor_str)
- delta_factors[(var, inport)] = factor_str
-
- return delta_factors
-
- @classmethod
- def remove_kernel_definitions_from_equations_block(cls, model: ASTModel) -> ASTDeclaration:
- r"""
- Removes all kernels in equations blocks.
- """
- for equations_block in model.get_equations_blocks():
- decl_to_remove = set()
- for decl in equations_block.get_declarations():
- if type(decl) is ASTKernel:
- decl_to_remove.add(decl)
-
- for decl in decl_to_remove:
- equations_block.get_declarations().remove(decl)
-
- return decl_to_remove
-
@classmethod
def add_timestep_symbol(cls, model: ASTModel) -> None:
"""
@@ -2137,70 +1889,6 @@ def add_timestep_symbol(cls, model: ASTModel) -> None:
)], "\"__h\" is a reserved name, please do not use variables by this name in your NESTML file"
model.add_to_internals_block(ModelParser.parse_declaration('__h ms = resolution()'), index=0)
- @classmethod
- def generate_kernel_buffers(cls, model: ASTModel, equations_block: Union[ASTEquationsBlock, List[ASTEquationsBlock]]) -> Mapping[ASTKernel, ASTInputPort]:
- """
- For every occurrence of a convolution of the form `convolve(var, spike_buf)`: add the element `(kernel, spike_buf)` to the set, with `kernel` being the kernel that contains variable `var`.
- """
-
- kernel_buffers = set()
- convolve_calls = ASTUtils.get_convolve_function_calls(equations_block)
- for convolve in convolve_calls:
- el = (convolve.get_args()[0], convolve.get_args()[1])
- sym = convolve.get_args()[0].get_scope().resolve_to_symbol(convolve.get_args()[0].get_variable().name, SymbolKind.VARIABLE)
- if sym is None:
- raise Exception("No initial value(s) defined for kernel with variable \""
- + convolve.get_args()[0].get_variable().get_complete_name() + "\"")
- if sym.block_type == BlockType.INPUT:
- # swap the order
- el = (el[1], el[0])
-
- # find the corresponding kernel object
- var = el[0].get_variable()
- assert var is not None
- kernel = model.get_kernel_by_name(var.get_name())
- assert kernel is not None, "In convolution \"convolve(" + str(var.name) + ", " + str(
- el[1]) + ")\": no kernel by name \"" + var.get_name() + "\" found in model."
-
- el = (kernel, el[1])
- kernel_buffers.add(el)
-
- return kernel_buffers
-
- @classmethod
- def replace_convolution_aliasing_inlines(cls, neuron: ASTModel) -> None:
- """
- Replace all occurrences of kernel names (e.g. ``I_dend`` and ``I_dend'`` for a definition involving a second-order kernel ``inline kernel I_dend = convolve(kern_name, spike_buf)``) with the ODE-toolbox generated variable ``kern_name__X__spike_buf``.
- """
- def replace_var(_expr, replace_var_name: str, replace_with_var_name: str):
- if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable():
- var = _expr.get_variable()
- if var.get_name() == replace_var_name:
- ast_variable = ASTVariable(replace_with_var_name + '__d' * var.get_differential_order(),
- differential_order=0)
- ast_variable.set_source_position(var.get_source_position())
- _expr.set_variable(ast_variable)
-
- elif isinstance(_expr, ASTVariable):
- var = _expr
- if var.get_name() == replace_var_name:
- var.set_name(replace_with_var_name + '__d' * var.get_differential_order())
- var.set_differential_order(0)
-
- for equation_block in neuron.get_equations_blocks():
- for decl in equation_block.get_declarations():
- if isinstance(decl, ASTInlineExpression):
- expr = decl.get_expression()
- if isinstance(expr, ASTExpression):
- expr = expr.get_lhs()
-
- if isinstance(expr, ASTSimpleExpression) \
- and '__X__' in str(expr) \
- and expr.get_variable():
- replace_with_var_name = expr.get_variable().get_name()
- neuron.accept(ASTHigherOrderVisitor(lambda x: replace_var(
- x, decl.get_variable_name(), replace_with_var_name)))
-
@classmethod
def replace_variable_names_in_expressions(cls, model: ASTModel, solver_dicts: List[dict]) -> None:
"""
@@ -2229,42 +1917,6 @@ def func(x):
model.accept(ASTHigherOrderVisitor(func))
- @classmethod
- def replace_convolve_calls_with_buffers_(cls, model: ASTModel, equations_block: ASTEquationsBlock) -> None:
- r"""
- Replace all occurrences of `convolve(kernel[']^n, spike_input_port)` with the corresponding buffer variable, e.g. `g_E__X__spikes_exc[__d]^n` for a kernel named `g_E` and a spike input port named `spikes_exc`.
- """
-
- def replace_function_call_through_var(_expr=None):
- if _expr.is_function_call() and _expr.get_function_call().get_name() == "convolve":
- convolve = _expr.get_function_call()
- el = (convolve.get_args()[0], convolve.get_args()[1])
- sym = convolve.get_args()[0].get_scope().resolve_to_symbol(
- convolve.get_args()[0].get_variable().name, SymbolKind.VARIABLE)
- if sym.block_type == BlockType.INPUT:
- # swap elements
- el = (el[1], el[0])
- var = el[0].get_variable()
- spike_input_port = el[1].get_variable()
- kernel = model.get_kernel_by_name(var.get_name())
-
- _expr.set_function_call(None)
- buffer_var = cls.construct_kernel_X_spike_buf_name(
- var.get_name(), spike_input_port, var.get_differential_order() - 1)
- if cls.is_delta_kernel(kernel):
- # delta kernels are treated separately, and should be kept out of the dynamics (computing derivates etc.) --> set to zero
- _expr.set_variable(None)
- _expr.set_numeric_literal(0)
- else:
- ast_variable = ASTVariable(buffer_var)
- ast_variable.set_source_position(_expr.get_source_position())
- _expr.set_variable(ast_variable)
-
- def func(x):
- return replace_function_call_through_var(x) if isinstance(x, ASTSimpleExpression) else True
-
- equations_block.accept(ASTHigherOrderVisitor(func))
-
@classmethod
def update_blocktype_for_common_parameters(cls, node):
r"""Change the BlockType for all homogeneous parameters to BlockType.COMMON_PARAMETER"""
@@ -2313,7 +1965,7 @@ def find_model_by_name(cls, model_name: str, models: Iterable[ASTModel]) -> Opti
@classmethod
def get_convolve_function_calls(cls, nodes: Union[ASTNode, List[ASTNode]]):
"""
- Returns all sum function calls in the handed over meta_model node or one of its children.
+ Returns all convolve function calls in the handed over node.
:param nodes: a single or list of AST nodes.
"""
if isinstance(nodes, ASTNode):
@@ -2326,13 +1978,12 @@ def get_convolve_function_calls(cls, nodes: Union[ASTNode, List[ASTNode]]):
return function_calls
@classmethod
- def contains_convolve_function_call(cls, ast: ASTNode) -> bool:
+ def contains_convolve_function_call(cls, node: ASTNode) -> bool:
"""
- Indicates whether _ast or one of its child nodes contains a sum call.
- :param ast: a single meta_model
+ Indicates whether the node contains any convolve function call.
:return: True if sum is contained, otherwise False.
"""
- return len(cls.get_function_calls(ast, PredefinedFunctions.CONVOLVE)) > 0
+ return len(cls.get_function_calls(node, PredefinedFunctions.CONVOLVE)) > 0
@classmethod
def get_function_calls(cls, ast_node: ASTNode, function_list: List[str]) -> List[ASTFunctionCall]:
@@ -2457,13 +2108,6 @@ def visit_variable(self, node):
for expr in numeric_update_expressions.values():
expr.accept(visitor)
- for update_expr_list in neuron.spike_updates.values():
- for update_expr in update_expr_list:
- update_expr.accept(visitor)
-
- for update_expr in neuron.post_spike_updates.values():
- update_expr.accept(visitor)
-
for node in neuron.equations_with_delay_vars + neuron.equations_with_vector_vars:
node.accept(visitor)
diff --git a/pynestml/utils/synapse_processing.py b/pynestml/utils/synapse_processing.py
index 464abd269..298286dd9 100644
--- a/pynestml/utils/synapse_processing.py
+++ b/pynestml/utils/synapse_processing.py
@@ -73,7 +73,7 @@ def collect_additional_base_infos(cls, neuron, syns_info):
for kernel_var, spikes_var in kernel_arg_pairs:
kernel_name = kernel_var.get_name()
spikes_name = spikes_var.get_name()
- convolution_name = info_collector.construct_kernel_X_spike_buf_name(
+ convolution_name = info_collector.construct_kernel_spike_buf_name(
kernel_name, spikes_name, 0)
syns_info[synapse_name]["convolutions"][convolution_name] = {
"kernel": {
@@ -196,7 +196,7 @@ def transform_ode_and_kernels_to_json(
expr = ASTUtils.get_expr_from_kernel_var(
kernel, kernel_var.get_complete_name())
kernel_order = kernel_var.get_differential_order()
- kernel_X_spike_buf_name_ticks = ASTUtils.construct_kernel_X_spike_buf_name(
+ kernel_X_spike_buf_name_ticks = ASTUtils.construct_kernel_spike_buf_name(
kernel_var.get_name(), spike_input_port.get_name(), kernel_order, diff_order_symbol="'")
ASTUtils.replace_rhs_variables(expr, kernel_buffers)
@@ -207,7 +207,7 @@ def transform_ode_and_kernels_to_json(
# order (e.g. none for kernel function f(t) = ...; 1 for kernel
# ODE f'(t) = ...; 2 for f''(t) = ... and so on)
for order in range(kernel_order):
- iv_sym_name_ode_toolbox = ASTUtils.construct_kernel_X_spike_buf_name(
+ iv_sym_name_ode_toolbox = ASTUtils.construct_kernel_spike_buf_name(
kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'")
symbol_name_ = kernel_var.get_name() + "'" * order
symbol = equations_block.get_scope().resolve_to_symbol(
diff --git a/tests/nest_tests/stdp_nn_restr_symm_test.py b/tests/nest_tests/stdp_nn_restr_symm_test.py
index a89ad4c3e..b1d13ce10 100644
--- a/tests/nest_tests/stdp_nn_restr_symm_test.py
+++ b/tests/nest_tests/stdp_nn_restr_symm_test.py
@@ -44,17 +44,17 @@
class NestSTDPNNRestrSymmSynapseTest(unittest.TestCase):
- neuron_model_name = "iaf_psc_exp_neuron_nestml__with_stdp_nn_restr_symm_synapse_nestml"
- ref_neuron_model_name = "iaf_psc_exp_neuron_nestml_non_jit"
+ neuron_model_name = "iaf_psc_alpha_neuron_nestml__with_stdp_nn_restr_symm_synapse_nestml"
+ ref_neuron_model_name = "iaf_psc_alpha_neuron_nestml_non_jit"
- synapse_model_name = "stdp_nn_restr_symm_synapse_nestml__with_iaf_psc_exp_neuron_nestml"
+ synapse_model_name = "stdp_nn_restr_symm_synapse_nestml__with_iaf_psc_alpha_neuron_nestml"
ref_synapse_model_name = "stdp_nn_restr_synapse"
def setUp(self):
r"""Generate the neuron model code"""
# generate the "jit" model (co-generated neuron and synapse), that does not rely on ArchivingNode
- files = [os.path.join("models", "neurons", "iaf_psc_exp_neuron.nestml"),
+ files = [os.path.join("models", "neurons", "iaf_psc_alpha_neuron.nestml"),
os.path.join("models", "synapses", "stdp_nn_restr_symm_synapse.nestml")]
input_path = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join(
os.pardir, os.pardir, s))) for s in files]
@@ -65,13 +65,13 @@ def setUp(self):
suffix="_nestml",
codegen_opts={"neuron_parent_class": "StructuralPlasticityNode",
"neuron_parent_class_include": "structural_plasticity_node.h",
- "neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron",
+ "neuron_synapse_pairs": [{"neuron": "iaf_psc_alpha_neuron",
"synapse": "stdp_nn_restr_symm_synapse",
"post_ports": ["post_spikes"]}]})
# generate the "non-jit" model, that relies on ArchivingNode
generate_nest_target(input_path=os.path.realpath(os.path.join(os.path.dirname(__file__),
- os.path.join(os.pardir, os.pardir, "models", "neurons", "iaf_psc_exp_neuron.nestml"))),
+ os.path.join(os.pardir, os.pardir, "models", "neurons", "iaf_psc_alpha_neuron.nestml"))),
target_path="/tmp/nestml-non-jit",
logging_level="INFO",
module_name="nestml_non_jit_module",