Skip to content

Commit

Permalink
fix special handling of NEST delay and weight variables in synapse
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Apr 29, 2024
1 parent 267bbf1 commit abe6577
Show file tree
Hide file tree
Showing 29 changed files with 107 additions and 53 deletions.
2 changes: 1 addition & 1 deletion pynestml/cocos/co_co_all_variables_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False):
code, message = Messages.get_variable_not_defined(var.get_complete_name())
Logger.log_message(code=code, message=message, error_position=node.get_source_position(),
log_level=LoggingLevel.ERROR, node=node)
raise Exception("Error(s) occurred during code generation")
continue

# check if it is part of an invariant
# if it is the case, there is no "recursive" declaration
Expand Down
9 changes: 6 additions & 3 deletions pynestml/codegeneration/nest_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def analyse_synapse(self, synapse: ASTModel) -> Dict[str, ASTAssignment]:
# special case for NEST delay variable (state or parameter)

ASTUtils.update_blocktype_for_common_parameters(synapse)
assert synapse_name_stripped in self.get_option("delay_variable").keys(), "Please specify a delay variable for synapse '" + synapse_name_stripped + "' in the code generator options"
assert ASTUtils.get_variable_by_name(synapse, self.get_option("delay_variable")[synapse_name_stripped]), "Delay variable '" + self.get_option("delay_variable")[synapse_name_stripped] + "' not found in synapse '" + synapse_name_stripped + "'"
NESTCodeGeneratorUtils.set_nest_alternate_name(synapse, {ASTUtils.get_variable_by_name(synapse, self.get_option("delay_variable")[synapse_name_stripped]).get_name(): "get_delay()"})

Expand Down Expand Up @@ -577,9 +578,11 @@ def _get_synapse_model_namespace(self, synapse: ASTModel) -> Dict:
assert synapse_name_stripped in self.get_option("delay_variable").keys() and ASTUtils.get_variable_by_name(synapse, self.get_option("delay_variable")[synapse_name_stripped]), "For synapse '" + synapse_name_stripped + "', a delay variable or parameter has to be specified for the NEST target; see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay"
namespace["nest_codegen_opt_delay_variable"] = self.get_option("delay_variable")[synapse_name_stripped]

# special case for NEST weigth variable (state or parameter)
assert synapse_name_stripped in self.get_option("weight_variable").keys() and ASTUtils.get_variable_by_name(synapse, self.get_option("weight_variable")[synapse_name_stripped]), "For synapse '" + synapse_name_stripped + "', a weight variable or parameter has to be specified for the NEST target; see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight"
namespace["nest_codegen_opt_weight_variable"] = self.get_option("weight_variable")[synapse_name_stripped]
# special case for NEST weight variable (state or parameter)
if synapse_name_stripped in self.get_option("weight_variable").keys() and ASTUtils.get_variable_by_name(synapse, self.get_option("weight_variable")[synapse_name_stripped]):
namespace["nest_codegen_opt_weight_variable"] = self.get_option("weight_variable")[synapse_name_stripped]
else:
namespace["nest_codegen_opt_weight_variable"] = ""

return namespace

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,6 @@ public:
{%- for parameter in synapse.get_parameter_symbols() %}
{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in parameter.get_decorators() %}
{%- if isHomogeneous %}
{%- set namespaceName = parameter.get_namespace_decorator("nest") %}
{%- if namespaceName == '' %}
{{ raise('nest::names decorator is required for parameter "%s" when used in a common properties class' % printer.print(utils.get_parameter_variable_by_name(astnode, parameter.get_symbol_name()))) }}
{%- endif %}
{%- set variable_symbol = parameter %}
{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- include "directives_cpp/CommonPropertiesDictionaryWriter.jinja2" %}
Expand Down Expand Up @@ -372,7 +368,7 @@ private:
Parameters_ P_; //!< Free parameters.
State_ S_; //!< Dynamic state.
Variables_ V_; //!< Internal Variables
{%- if synapse.get_state_symbols()|length > 0 %}
{%- if synapse.get_state_symbols()|length > 0 or synapse.get_parameter_symbols()|length > 0 %}
// -------------------------------------------------------------------------
// Getters/setters for parameters and state variables
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -596,23 +592,34 @@ public:
{%- endif %} }
};
{%- if nest_codegen_opt_weight_variable | length > 0 and nest_codegen_opt_weight_variable != "weight" %}
{%- set variable = utils.get_variable_by_name(astnode, nest_codegen_opt_weight_variable) %}
{%- set variable_symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE) %}
{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %}

/**
* special case for weights in NEST: only in case a NESTML state variable was specified in code generation options as ``weight_variable``
**/
inline void set_weight(double w)
{
{%- set variable = utils.get_variable_by_name(astnode, nest_codegen_opt_weight_variable) %}
{%- if isHomogeneous %}
throw BadProperty(
"Setting of individual weights is not possible! The common weights can "
"be changed via "
"CopyModel()." );
{%- else %}
{{ printer.print(variable) }} = w;
{%- endif %}
}

{%- if not isHomogeneous %}
/**
* special case for weights in NEST: only in case a NESTML state variable was specified in code generation options as ``weight_variable``
**/
inline double get_weight() const
{
return {{ printer.print(variable) }};
}
{%- endif %}
{%- endif %}

{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") or nest_version.startswith("v3.4")) %}
Expand Down Expand Up @@ -990,7 +997,7 @@ void
{{synapseName}}< targetidentifierT >::set_status( const DictionaryDatum& __d,
ConnectorModel& cm )
{
{%- if nest_codegen_opt_weight_variable != "weight" %}
{%- if nest_codegen_opt_weight_variable|length > 0 and nest_codegen_opt_weight_variable != "weight" %}
if (__d->known(nest::{{ synapseName }}_names::_{{ nest_codegen_opt_weight_variable }}) and __d->known(nest::names::weight))
{
throw BadProperty( "To prevent inconsistencies, please set either 'weight' or '{{ nest_codegen_opt_weight_variable }}' variable; not both at the same time." );
Expand Down Expand Up @@ -1047,7 +1054,6 @@ if (__d->known(nest::names::weight))
{%- filter indent(2,True) %}
{%- for variable_symbol in synapse.get_state_symbols() + synapse.get_parameter_symbols() %}
{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- set namespaceName = variable_symbol.get_namespace_decorator('nest') %}
{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %}
{%- if not isHomogeneous and not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.name)) %}
{%- if variable.get_name() == nest_codegen_opt_delay_variable %}
Expand Down Expand Up @@ -1112,12 +1118,13 @@ template < typename targetidentifierT >
{
const double __resolution = nest::Time::get_resolution().get_ms(); // do not remove, this is necessary for the resolution() function

// initial values for parameters
{%- filter indent(2, True) %}
{%- for variable_symbol in synapse.get_parameter_symbols() %}
{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %}
{%- if not isHomogeneous %}
{%- if variable.get_name() != nest_codegen_opt_weight_variable and variable.get_name() != nest_codegen_opt_delay_variable %}
{%- if variable.get_name() != nest_codegen_opt_delay_variable %}
{%- include "directives_cpp/MemberInitialization.jinja2" %}
{%- endif %}
{%- endif %}
Expand Down Expand Up @@ -1175,9 +1182,12 @@ template < typename targetidentifierT >
// special treatment of NEST delay
set_delay(rhs.get_delay());
{%- if nest_codegen_opt_weight_variable | length > 0 %}

{%- set variable_symbol = synapse.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE) %}
{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %}
{%- if not isHomogeneous %}
// special treatment of NEST weight
set_weight(rhs.get_weight());
{%- endif %}
{%- endif %}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
{%- if variable_symbol.has_vector_parameter() %}
{{ raise('Vector parameters not supported in common properties dictionary.') }}
{%- endif %}
updateValue< {{ declarations.print_variable_type(variable_symbol) }} >(d, names::{{namespaceName}}, this->{{ printer_no_origin.print(variable) }} );
updateValue< {{ declarations.print_variable_type(variable_symbol) }} >(d, nest::{{ synapseName }}_names::_{{ printer_no_origin.print(variable) }}, this->{{ printer_no_origin.print(variable) }} );

Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
{%- if variable_symbol.has_vector_parameter() %}
{{ raise('Vector parameters not supported in common properties dictionary.') }}
{%- endif %}
def< {{ declarations.print_variable_type(variable_symbol) }} >(d, names::{{ namespaceName }}, this->{{ printer_no_origin.print(variable) }} );
def< {{ declarations.print_variable_type(variable_symbol) }} >(d, nest::{{ synapseName }}_names::_{{ printer_no_origin.print(variable) }}, this->{{ printer_no_origin.print(variable) }} );
8 changes: 0 additions & 8 deletions pynestml/symbols/variable_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,6 @@ def get_decorators(self):
"""
return self.decorators

def get_namespace_decorators(self):
return self.namespace_decorators

def get_namespace_decorator(self, namespace):
if namespace in self.namespace_decorators.keys():
return self.namespace_decorators[namespace]
return ''

def has_vector_parameter(self):
"""
Returns whether this variable symbol has a vector parameter.
Expand Down
4 changes: 2 additions & 2 deletions pynestml/transformers/synapse_post_neuron_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,10 @@ def transform_neuron_synapse_pair_(self, neuron, synapse):
# exclude certain variables from being moved:
# exclude any variable assigned to in any block that is not connected to a postsynaptic port
strictly_synaptic_vars = ["t"] # "seed" this with the predefined variable t
if self.option_exists("delay_variable") and self.get_option("delay_variable")[removesuffix(synapse.get_name(), FrontendConfiguration.suffix)]:
if self.option_exists("delay_variable") and removesuffix(synapse.get_name(), FrontendConfiguration.suffix) in self.get_option("delay_variable").keys() and self.get_option("delay_variable")[removesuffix(synapse.get_name(), FrontendConfiguration.suffix)]:
strictly_synaptic_vars.append(self.get_option("delay_variable")[removesuffix(synapse.get_name(), FrontendConfiguration.suffix)])

if self.option_exists("weight_variable") and self.get_option("weight_variable")[removesuffix(synapse.get_name(), FrontendConfiguration.suffix)]:
if self.option_exists("weight_variable") and removesuffix(synapse.get_name(), FrontendConfiguration.suffix) in self.get_option("weight_variable").keys() and self.get_option("weight_variable")[removesuffix(synapse.get_name(), FrontendConfiguration.suffix)]:
strictly_synaptic_vars.append(self.get_option("weight_variable")[removesuffix(synapse.get_name(), FrontendConfiguration.suffix)])

for input_block in new_synapse.get_input_blocks():
Expand Down
2 changes: 1 addition & 1 deletion tests/cocos_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_invalid_element_not_defined_in_scope(self):
os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
'CoCoVariableNotDefined.nestml'))
self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
LoggingLevel.ERROR)), 4)
LoggingLevel.ERROR)), 5)

def test_valid_element_not_defined_in_scope(self):
Logger.set_logging_level(LoggingLevel.INFO)
Expand Down
6 changes: 4 additions & 2 deletions tests/invalid/CoCoResolutionLegallyUsed.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@ model resolution_legally_used_synapse:
x' = -x / tau + (resolution() / ms**2)

function test(tau ms) real:
w ms = resolution()
return w
z ms = resolution()
return z

parameters:
tau ms = 10 ms
w real = 1 # dummy weight variable
d ms = 1 ms # dummy delay variable

update:
integrate_odes()
2 changes: 2 additions & 0 deletions tests/nest_tests/nest_custom_templates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def test_custom_templates_with_synapse(self):
"neuron_synapse_pairs": [{"neuron": "iaf_psc_delta_neuron",
"synapse": "stdp_triplet_synapse",
"post_ports": ["post_spikes"]}],
"delay_variable": {"stdp_triplet_synapse": "d"},
"weight_variable": {"stdp_triplet_synapse": "w"},
"templates": {
"path": "resources_nest/point_neuron",
"model_templates": {
Expand Down
4 changes: 3 additions & 1 deletion tests/nest_tests/nest_multithreading_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def nestml_generate_target(self) -> None:
"neuron_parent_class_include": "structural_plasticity_node.h",
"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron",
"synapse": "stdp_synapse",
"post_ports": ["post_spikes"]}]})
"post_ports": ["post_spikes"]}],
"delay_variable": {"stdp_synapse": "d"},
"weight_variable": {"stdp_synapse": "w"}})

# Neuron model
generate_nest_target(input_path=neuron_path,
Expand Down
4 changes: 3 additions & 1 deletion tests/nest_tests/nest_resolution_builtin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def setUp(self):
codegen_opts={"neuron_parent_class": "StructuralPlasticityNode",
"neuron_parent_class_include": "structural_plasticity_node.h",
"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_resolution_test_neuron",
"synapse": "resolution_legally_used_synapse"}]})
"synapse": "resolution_legally_used_synapse"}],
"delay_variable": {"resolution_legally_used_synapse": "d"},
"weight_variable": {"resolution_legally_used_synapse": "w"}})

@pytest.mark.skipif(NESTTools.detect_nest_version().startswith("v2"),
reason="This test does not support NEST 2")
Expand Down
4 changes: 3 additions & 1 deletion tests/nest_tests/nest_set_with_distribution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def setUp(self):

codegen_opts = {"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron",
"synapse": "stdp_synapse",
"post_ports": ["post_spikes"]}]}
"post_ports": ["post_spikes"]}],
"delay_variable": {"stdp_synapse": "d"},
"weight_variable": {"stdp_synapse": "w"}}

# generate the "jit" model (co-generated neuron and synapse), that does not rely on ArchivingNode
files = [os.path.join("models", "neurons", "iaf_psc_exp_neuron.nestml"),
Expand Down
4 changes: 3 additions & 1 deletion tests/nest_tests/noisy_synapse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def setUp(self):
target_path="/tmp/nestml-noisy-synapse",
logging_level="INFO",
module_name="nestml_noisy_synapse_module",
suffix="_nestml")
suffix="_nestml",
codegen_opts={"delay_variable": {"noisy_synapse": "d"},
"weight_variable": {"noisy_synapse": "w"}})

@pytest.mark.skipif(NESTTools.detect_nest_version().startswith("v2"),
reason="This test does not support NEST 2")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ model dopa_second_order_synapse:

parameters:
tau_dopa ms = 100 ms
w real = 1
d ms = 1 ms

equations:
Expand Down
4 changes: 3 additions & 1 deletion tests/nest_tests/stdp_neuromod_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def setUp(self):
"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron",
"synapse": "neuromodulated_stdp_synapse",
"post_ports": ["post_spikes"],
"vt_ports": ["mod_spikes"]}]})
"vt_ports": ["mod_spikes"]}],
"delay_variable": {"neuromodulated_stdp_synapse": "d"},
"weight_variable": {"neuromodulated_stdp_synapse": "w"}})

generate_nest_target(input_path=os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join(os.pardir, os.pardir, "models", "neurons", "iaf_psc_exp_neuron.nestml"))),
Expand Down
4 changes: 3 additions & 1 deletion tests/nest_tests/stdp_nn_pre_centered_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def setUp(self):
"neuron_parent_class_include": "structural_plasticity_node.h",
"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron",
"synapse": "stdp_nn_pre_centered_synapse",
"post_ports": ["post_spikes"]}]})
"post_ports": ["post_spikes"]}],
"delay_variable": {"stdp_nn_pre_centered_synapse": "d"},
"weight_variable": {"stdp_nn_pre_centered_synapse": "w"}})

# generate the "non-jit" model, that relies on ArchivingNode
generate_nest_target(input_path=os.path.realpath(os.path.join(os.path.dirname(__file__),
Expand Down
4 changes: 3 additions & 1 deletion tests/nest_tests/stdp_nn_restr_symm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def setUp(self):
"neuron_parent_class_include": "structural_plasticity_node.h",
"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron",
"synapse": "stdp_nn_restr_symm_synapse",
"post_ports": ["post_spikes"]}]})
"post_ports": ["post_spikes"]}],
"delay_variable": {"stdp_nn_restr_symm_synapse": "d"},
"weight_variable": {"stdp_nn_restr_symm_synapse": "w"}})

# generate the "non-jit" model, that relies on ArchivingNode
generate_nest_target(input_path=os.path.realpath(os.path.join(os.path.dirname(__file__),
Expand Down
4 changes: 3 additions & 1 deletion tests/nest_tests/stdp_nn_synapse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def setUp(self):
"neuron_parent_class_include": "structural_plasticity_node.h",
"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron",
"synapse": "stdp_nn_symm_synapse",
"post_ports": ["post_spikes"]}]})
"post_ports": ["post_spikes"]}],
"delay_variable": {"stdp_nn_symm_synapse": "d"},
"weight_variable": {"stdp_nn_symm_synapse": "w"}})

# generate the "non-jit" model, that relies on ArchivingNode
generate_nest_target(input_path=os.path.realpath(os.path.join(os.path.dirname(__file__),
Expand Down
Loading

0 comments on commit abe6577

Please sign in to comment.