Skip to content

Commit

Permalink
transform kernels and convolutions using a transformer before code ge…
Browse files Browse the repository at this point in the history
…neration
  • Loading branch information
C.A.P. Linssen committed May 9, 2024
1 parent 026e34c commit abba1b4
Show file tree
Hide file tree
Showing 19 changed files with 622 additions and 847 deletions.
1 change: 0 additions & 1 deletion pynestml/codegeneration/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def _setup_template_env(self, template_files: List[str], templates_root_dir: str
# Environment for neuron templates
env = Environment(loader=FileSystemLoader(_template_dirs))
env.globals["raise"] = self.raise_helper
env.globals["is_delta_kernel"] = ASTUtils.is_delta_kernel

# Load all the templates
_templates = list()
Expand Down
162 changes: 14 additions & 148 deletions pynestml/codegeneration/nest_code_generator.py

Large diffs are not rendered by default.

122 changes: 13 additions & 109 deletions pynestml/codegeneration/nest_compartmental_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,22 +280,16 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None:

def create_ode_indict(self,
neuron: ASTModel,
parameters_block: ASTBlockWithVariables,
kernel_buffers: Mapping[ASTKernel,
ASTInputPort]):
odetoolbox_indict = self.transform_ode_and_kernels_to_json(
neuron, parameters_block, kernel_buffers)
parameters_block: ASTBlockWithVariables):
odetoolbox_indict = self.transform_ode_and_kernels_to_json(neuron, parameters_block)
odetoolbox_indict["options"] = {}
odetoolbox_indict["options"]["output_timestep_symbol"] = "__h"
return odetoolbox_indict

def ode_solve_analytically(self,
neuron: ASTModel,
parameters_block: ASTBlockWithVariables,
kernel_buffers: Mapping[ASTKernel,
ASTInputPort]):
odetoolbox_indict = self.create_ode_indict(
neuron, parameters_block, kernel_buffers)
parameters_block: ASTBlockWithVariables):
odetoolbox_indict = self.create_ode_indict(neuron, parameters_block)

full_solver_result = analysis(
odetoolbox_indict,
Expand All @@ -314,8 +308,7 @@ def ode_solve_analytically(self,

return full_solver_result, analytic_solver

def ode_toolbox_analysis(self, neuron: ASTModel,
kernel_buffers: Mapping[ASTKernel, ASTInputPort]):
def ode_toolbox_analysis(self, neuron: ASTModel):
"""
Prepare data for ODE-toolbox input format, invoke ODE-toolbox analysis via its API, and return the output.
"""
Expand All @@ -324,15 +317,13 @@ def ode_toolbox_analysis(self, neuron: ASTModel,

equations_block = neuron.get_equations_blocks()[0]

if len(equations_block.get_kernels()) == 0 and len(
equations_block.get_ode_equations()) == 0:
if len(equations_block.get_ode_equations()) == 0:
# no equations defined -> no changes to the neuron
return None, None

parameters_block = neuron.get_parameters_blocks()[0]

solver_result, analytic_solver = self.ode_solve_analytically(
neuron, parameters_block, kernel_buffers)
solver_result, analytic_solver = self.ode_solve_analytically(neuron, parameters_block)

# if numeric solver is required, generate a stepping function that
# includes each state variable
Expand All @@ -341,8 +332,7 @@ def ode_toolbox_analysis(self, neuron: ASTModel,
x for x in solver_result if x["solver"].startswith("numeric")]

if numeric_solvers:
odetoolbox_indict = self.create_ode_indict(
neuron, parameters_block, kernel_buffers)
odetoolbox_indict = self.create_ode_indict(neuron, parameters_block)
solver_result = analysis(
odetoolbox_indict,
disable_stiffness_check=True,
Expand Down Expand Up @@ -417,24 +407,6 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:

return []

# goes through all convolve() inside ode's from equations block
# if they have delta kernels, use sympy to expand the expression, then
# find the convolve calls and replace them with constant value 1
# then return every subexpression that had that convolve() replaced
delta_factors = ASTUtils.get_delta_factors_(neuron, equations_block)

# goes through all convolve() inside equations block
# extracts what kernel is paired with what spike buffer
# returns pairs (kernel, spike_buffer)
kernel_buffers = ASTUtils.generate_kernel_buffers(
neuron, equations_block)

# replace convolve(g_E, spikes_exc) with g_E__X__spikes_exc[__d]
# done by searching for every ASTSimpleExpression inside equations_block
# which is a convolve call and substituting that call with
# newly created ASTVariable kernel__X__spike_buffer
ASTUtils.replace_convolve_calls_with_buffers_(neuron, equations_block)

# substitute inline expressions with each other
# such that no inline expression references another inline expression
ASTUtils.make_inline_expressions_self_contained(
Expand All @@ -450,14 +422,13 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
# "update_expressions" key in those solvers contains a mapping
# {expression1: update_expression1, expression2: update_expression2}

analytic_solver, numeric_solver = self.ode_toolbox_analysis(
neuron, kernel_buffers)
analytic_solver, numeric_solver = self.ode_toolbox_analysis(neuron)

"""
# separate analytic solutions by kernel
# this is is needed for the synaptic case
self.kernel_name_to_analytic_solver[neuron.get_name(
)] = self.ode_toolbox_anaysis_cm_syns(neuron, kernel_buffers)
)] = self.ode_toolbox_anaysis_cm_syns(neuron)
"""

self.analytic_solver[neuron.get_name()] = analytic_solver
Expand All @@ -472,12 +443,6 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
# by odetoolbox, higher order variables don't get deleted here
ASTUtils.remove_initial_values_for_kernels(neuron)

# delete all kernels as they are all converted into buffers
# and corresponding update formulas calculated by odetoolbox
# Remember them in a variable though
kernels = ASTUtils.remove_kernel_definitions_from_equations_block(
neuron)

# Every ODE variable (a variable of order > 0) is renamed according to ODE-toolbox conventions
# their initial values are replaced by expressions suggested by ODE-toolbox.
# Differential order can now be set to 0, becase they can directly represent the value of the derivative now.
Expand All @@ -491,22 +456,11 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
# corresponding updates
ASTUtils.remove_ode_definitions_from_equations_block(neuron)

# restore state variables that were referenced by kernels
# and set their initial values by those suggested by ODE-toolbox
ASTUtils.create_initial_values_for_kernels(
neuron, [analytic_solver, numeric_solver], kernels)

# Inside all remaining expressions, translate all remaining variable names
# according to the naming conventions of ODE-toolbox.
ASTUtils.replace_variable_names_in_expressions(
neuron, [analytic_solver, numeric_solver])

# find all inline kernels defined as ASTSimpleExpression
# that have a single kernel convolution aliasing variable ('__X__')
# translate all remaining variable names according to the naming
# conventions of ODE-toolbox
ASTUtils.replace_convolution_aliasing_inlines(neuron)

# add variable __h to internals block
ASTUtils.add_timestep_symbol(neuron)

Expand Down Expand Up @@ -677,13 +631,9 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
expr_ast.accept(ASTSymbolTableVisitor())
namespace["numeric_update_expressions"][sym] = expr_ast

namespace["spike_updates"] = neuron.spike_updates

namespace["recordable_state_variables"] = [
sym for sym in neuron.get_state_symbols() if namespace["declarations"].get_domain_from_type(
sym.get_type_symbol()) == "double" and sym.is_recordable and not ASTUtils.is_delta_kernel(
neuron.get_kernel_by_name(
sym.name))]
sym.get_type_symbol()) == "double" and sym.is_recordable]
namespace["recordable_inline_expressions"] = [
sym for sym in neuron.get_inline_expression_symbols() if namespace["declarations"].get_domain_from_type(
sym.get_type_symbol()) == "double" and sym.is_recordable]
Expand Down Expand Up @@ -807,7 +757,7 @@ def get_spike_update_expressions(
for var_order in range(
ASTUtils.get_kernel_var_order_from_ode_toolbox_result(
kernel_var.get_name(), solver_dicts)):
kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name(
kernel_spike_buf_name = ASTUtils.construct_kernel_spike_buf_name(
kernel_var.get_name(), spike_input_port, var_order)
expr = ASTUtils.get_initial_value_from_ode_toolbox_result(
kernel_spike_buf_name, solver_dicts)
Expand Down Expand Up @@ -849,18 +799,9 @@ def get_spike_update_expressions(
def transform_ode_and_kernels_to_json(
self,
neuron: ASTModel,
parameters_block,
kernel_buffers):
parameters_block):
"""
Converts AST node to a JSON representation suitable for passing to ode-toolbox.
Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements
convolve(G, ex_spikes)
convolve(G, in_spikes)
then `kernel_buffers` will contain the pairs `(G, ex_spikes)` and `(G, in_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__ex_spikes` and `G__X__in_spikes`.
:param parameters_block: ASTBlockWithVariables
:return: Dict
"""
Expand Down Expand Up @@ -890,43 +831,6 @@ def transform_ode_and_kernels_to_json(
iv_symbol_name)] = expr
odetoolbox_indict["dynamics"].append(entry)

# write a copy for each (kernel, spike buffer) combination
for kernel, spike_input_port in kernel_buffers:

if ASTUtils.is_delta_kernel(kernel):
# delta function -- skip passing this to ode-toolbox
continue

for kernel_var in kernel.get_variables():
expr = ASTUtils.get_expr_from_kernel_var(
kernel, kernel_var.get_complete_name())
kernel_order = kernel_var.get_differential_order()
kernel_X_spike_buf_name_ticks = ASTUtils.construct_kernel_X_spike_buf_name(
kernel_var.get_name(), spike_input_port, kernel_order, diff_order_symbol="'")

ASTUtils.replace_rhs_variables(expr, kernel_buffers)

entry = {}
entry["expression"] = kernel_X_spike_buf_name_ticks + " = " + str(expr)

# initial values need to be declared for order 1 up to kernel
# order (e.g. none for kernel function f(t) = ...; 1 for kernel
# ODE f'(t) = ...; 2 for f''(t) = ... and so on)
entry["initial_values"] = {}
for order in range(kernel_order):
iv_sym_name_ode_toolbox = ASTUtils.construct_kernel_X_spike_buf_name(
kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'")
symbol_name_ = kernel_var.get_name() + "'" * order
symbol = equations_block.get_scope().resolve_to_symbol(
symbol_name_, SymbolKind.VARIABLE)
assert symbol is not None, "Could not find initial value for variable " + symbol_name_
initial_value_expr = symbol.get_declaring_expression()
assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_
entry["initial_values"][iv_sym_name_ode_toolbox] = self._ode_toolbox_printer.print(
initial_value_expr)

odetoolbox_indict["dynamics"].append(entry)

odetoolbox_indict["parameters"] = {}
if parameters_block is not None:
for decl in parameters_block.get_declarations():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,8 @@ std::vector< std::tuple< int, int > > {{neuronName}}::rport_to_nestml_buffer_idx

// copy state struct S_
{%- for init in neuron.get_state_symbols() %}
{%- if not is_delta_kernel(neuron.get_kernel_by_name(init.name)) %}
{%- set node = utils.get_state_variable_by_name(astnode, init.get_symbol_name()) %}
{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }} = __n.{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }};
{%- endif %}
{%- endfor %}

// copy internals V_
Expand Down Expand Up @@ -786,14 +784,6 @@ const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}_
{%- endfor %}
{%- endif %}


/**
* spike updates due to convolutions
**/
{% filter indent(4) %}
{%- include "directives_cpp/ApplySpikesFromBuffers.jinja2" %}
{%- endfilter %}

/**
* Begin NESTML generated code for the onCondition block(s)
**/
Expand Down Expand Up @@ -1149,13 +1139,9 @@ void
{%- endfor %}

/**
* print updates due to convolutions
* push back spike history
**/

{%- for _, spike_update in post_spike_updates.items() %}
{{ printer.print(utils.get_variable_by_name(astnode, spike_update.get_variable().get_complete_name())) }} += 1.;
{%- endfor %}

last_spike_ = t_sp_ms;
history_.push_back( histentry__{{neuronName}}( last_spike_
{%- for var in purely_numeric_state_variables_moved|sort %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,11 @@ public:
// Getters/setters for state block
// -------------------------------------------------------------------------

{% filter indent(2, True) -%}
{% filter indent(2, True) -%}
{%- for variable_symbol in neuron.get_state_symbols() %}
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
{% endif %}
{% endfor %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %}
{% endfor %}
{%- endfilter %}
{%- endif %}

Expand Down Expand Up @@ -962,22 +960,20 @@ inline nest_port_t {{neuronName}}::handles_test_event(nest::DataLoggingRequest&
inline void {{neuronName}}::get_status(DictionaryDatum &__d) const
{
// parameters
{%- for variable_symbol in neuron.get_parameter_symbols() %}
{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- filter indent(2) %}
{%- filter indent(2) %}
{%- for variable_symbol in neuron.get_parameter_symbols() %}
{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- include "directives_cpp/WriteInDictionary.jinja2" %}
{%- endfilter %}
{%- endfor %}
{%- endfor %}
{%- endfilter %}

// initial values for state variables in ODE or kernel
{%- for variable_symbol in neuron.get_state_symbols() %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
{%- filter indent(2) %}
{%- include "directives_cpp/WriteInDictionary.jinja2" %}
{%- endfilter %}
{%- endif -%}
{%- endfor %}
{%- filter indent(2) %}
{%- for variable_symbol in neuron.get_state_symbols() %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- include "directives_cpp/WriteInDictionary.jinja2" %}
{%- endfor %}
{%- endfilter %}

{{neuron_parent_class}}::get_status( __d );

Expand Down Expand Up @@ -1023,11 +1019,9 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d)
// initial values for state variables in ODE or kernel
{%- for variable_symbol in neuron.get_state_symbols() %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
{%- filter indent(2) %}
{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %}
{%- endfilter %}
{%- endif %}
{%- filter indent(2) %}
{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %}
{%- endfilter %}
{%- endfor %}

// We now know that (ptmp, stmp) are consistent. We do not
Expand All @@ -1046,11 +1040,9 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d)

{%- for variable_symbol in neuron.get_state_symbols() -%}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %}
{%- filter indent(2) %}
{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %}
{%- endfilter %}
{%- endif %}
{%- filter indent(2) %}
{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %}
{%- endfilter %}
{%- endfor %}

{% for invariant in neuron.get_parameter_invariants() %}
Expand Down
Loading

0 comments on commit abba1b4

Please sign in to comment.