Skip to content

Commit

Permalink
change third factor STDP plasticity unit test into a notebook tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Jan 11, 2024
1 parent 63ed5ce commit 067aa0a
Show file tree
Hide file tree
Showing 8 changed files with 4,756 additions and 3,681 deletions.

This file was deleted.

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions models/synapses/third_factor_stdp_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ model third_factor_stdp_synapse:
parameters:
d ms = 1 ms @nest::delay # Synaptic transmission delay
lambda real = .01
tau_tr_pre ms = 20 ms
tau_tr_post ms = 20 ms
tau_tr_pre ms = 10 ms
tau_tr_post ms = 10 ms
alpha real = 1.
mu_plus real = 1.
mu_minus real = 1.
Expand Down
11 changes: 9 additions & 2 deletions pynestml/codegeneration/nest_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,16 @@ def _get_synapse_model_namespace(self, synapse: ASTModel) -> Dict:

if "paired_neuron" in dir(synapse):
# synapse is being co-generated with neuron
namespace["paired_neuron"] = synapse.paired_neuron.get_name()
namespace["paired_neuron"] = synapse.paired_neuron
namespace["paired_neuron_name"] = synapse.paired_neuron.get_name()
namespace["post_ports"] = synapse.post_port_names
namespace["spiking_post_ports"] = synapse.spiking_post_port_names

namespace["continuous_post_ports"] = []
if "neuron_synapse_pairs" in FrontendConfiguration.get_codegen_opts().keys():
post_ports = ASTUtils.get_post_ports_of_neuron_synapse_pair(synapse.paired_neuron, synapse, FrontendConfiguration.get_codegen_opts()["neuron_synapse_pairs"])
namespace["continuous_post_ports"] = [v for v in post_ports if isinstance(v, tuple) or isinstance(v, list)]

namespace["vt_ports"] = synapse.vt_port_names
namespace["pre_ports"] = list(set(all_input_port_names)
- set(namespace["post_ports"]) - set(namespace["vt_ports"]))
Expand Down Expand Up @@ -587,7 +594,7 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:

codegen_and_builder_opts = FrontendConfiguration.get_codegen_opts()
xfrm = SynapsePostNeuronTransformer(codegen_and_builder_opts)
namespace["state_vars_that_need_continuous_buffering_transformed"] = [xfrm.get_neuron_var_name_from_syn_port_name(port_name, neuron.unpaired_name.removesuffix(FrontendConfiguration.suffix), neuron.paired_synapse.get_name().split("__with_")[0].removesuffix("_nestml")) for port_name in neuron.state_vars_that_need_continuous_buffering]
namespace["state_vars_that_need_continuous_buffering_transformed"] = [xfrm.get_neuron_var_name_from_syn_port_name(port_name, neuron.unpaired_name.removesuffix(FrontendConfiguration.suffix), neuron.paired_synapse.get_name().split("__with_")[0].removesuffix(FrontendConfiguration.suffix)) for port_name in neuron.state_vars_that_need_continuous_buffering]
else:
namespace["state_vars_that_need_continuous_buffering"] = []
namespace["extra_on_emit_spike_stmts_from_synapse"] = neuron.extra_on_emit_spike_stmts_from_synapse
Expand Down
3 changes: 2 additions & 1 deletion pynestml/codegeneration/printers/nest_variable_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def print_variable(self, variable: ASTVariable) -> str:
_name = str(variable)
if variable.get_alternate_name():
# the disadvantage of this approach is that the time the value is to be obtained is not explicitly specified, so we will actually get the value at the end of the min_delay timestep
return "((post_neuron_t*)(__target))->get_" + variable.get_alternate_name() + "()"
return "__" + variable.get_alternate_name()
# return "((post_neuron_t*)(__target))->get_" + variable.get_alternate_name() + "()"

return "((post_neuron_t*)(__target))->get_" + _name + "(_tr_t)"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ public:
template < typename targetidentifierT >
class {{synapseName}} : public Connection< targetidentifierT >
{
{%- if paired_neuron | length > 0 %}
typedef {{ paired_neuron }} post_neuron_t;
{%- if paired_neuron_name | length > 0 %}
typedef {{ paired_neuron_name }} post_neuron_t;

{% endif %}
{%- if vt_ports is defined and vt_ports|length > 0 %}
Expand Down Expand Up @@ -636,10 +636,10 @@ public:

{%- if paired_neuron is defined %}
try {
dynamic_cast<{{paired_neuron}}&>(t);
dynamic_cast<{{paired_neuron_name}}&>(t);
}
catch (std::bad_cast &exp) {
std::cout << "wrong type of neuron connected! Synapse '{{synapseName}}' will only work with neuron '{{paired_neuron}}'.\n";
std::cout << "wrong type of neuron connected! Synapse '{{synapseName}}' will only work with neuron '{{paired_neuron_name}}'.\n";
exit(1);
}
{%- endif %}
Expand Down Expand Up @@ -669,7 +669,7 @@ public:
return tid;
};

const double __t_spike = e.get_stamp().get_ms();
const double __t_spike = e.get_stamp().get_ms(); // time of the presynaptic spike [ms]
#ifdef DEBUG
std::cout << "{{synapseName}}::send(): handling pre spike at t = " << __t_spike << std::endl;
#endif
Expand All @@ -680,8 +680,8 @@ public:

{%- endif %}
// use accessor functions (inherited from Connection< >) to obtain delay and target
{%- if paired_neuron is not none and paired_neuron|length > 0 %}
{{paired_neuron}}* __target = static_cast<{{paired_neuron}}*>(get_target(tid));
{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 %}
{{paired_neuron_name}}* __target = static_cast<{{paired_neuron_name}}*>(get_target(tid));
assert(__target);
{%- else %}
Node* __target = get_target( tid );
Expand All @@ -696,7 +696,7 @@ public:
t_lastspike_ = 0.;
}

{%- if paired_neuron is not none and paired_neuron|length > 0 %}
{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 %}
double timestep = 0;
{
/**
Expand All @@ -706,8 +706,8 @@ public:
*
* Note that this also increases the access counter for these entries which is used to prune the history.
**/
std::deque< histentry__{{paired_neuron}} >::iterator start;
std::deque< histentry__{{paired_neuron}} >::iterator finish;
std::deque< histentry__{{paired_neuron_name}} >::iterator start;
std::deque< histentry__{{paired_neuron_name}} >::iterator finish;
{%- if vt_ports is defined and vt_ports|length > 0 %}
double t0 = t_last_update_;
{%- endif %}
Expand All @@ -717,23 +717,25 @@ public:
&finish );
while ( start != finish )
{
{%- if paired_neuron is not none and paired_neuron|length > 0 %}

{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 %}
/**
* grab the postsynaptic continuous variable values from the postsynaptic neuron
**/
{%- if paired_neuron.state_vars_that_need_continuous_buffering | length > 0 %}
auto histentry = ((post_neuron_t*)(__target))->get_continuous_variable_history(start->t_ - __dendritic_delay);
auto histentry = ((post_neuron_t*)(__target))->get_continuous_variable_history(start->t_ - __dendritic_delay);
{%- endif %}

{%- for var_name in paired_neuron.state_vars_that_need_continuous_buffering %}
{%- set var = utils.get_parameter_variable_by_name(astnode, var_name) %}
const double {{ var_name }}_at_post_spike_time = histentry.{{ var_name }};
const double __{{ var_name }} = histentry.{{ var_name }};
{%- endfor %}
{%- endif %}

{%- if vt_ports is defined and vt_ports|length > 0 %}
{%- set vt_port = vt_ports[0] %}
{% if vt_ports is defined and vt_ports|length > 0 %}
{%- set vt_port = vt_ports[0] %}
process_{{vt_port}}_spikes_( vt_spikes, t0, start->t_ + __dendritic_delay, cp );
t0 = start->t_ + __dendritic_delay;
{%- endif %}
{%- endif %}
const double minus_dt = t_lastspike_ - ( start->t_ + __dendritic_delay );
// get_history() should make sure that
// start->t_ > t_lastspike_ - dendritic_delay, i.e. minus_dt < 0
Expand Down Expand Up @@ -798,7 +800,18 @@ public:
const double _tr_t = __t_spike - __dendritic_delay;

{
auto get_t = [__t_spike](){ return __t_spike; }; // do not remove, this is in case the predefined time variable ``t`` is used in the NESTML model
auto get_t = [__t_spike](){ return __t_spike; }; // do not remove, this is in case the predefined time variable ``t`` is used in the NESTML model // XXX: TODO: is this correct or should this be equal to ``_tr_t``, i.e. with dendritic delay subtracted?

{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 and paired_neuron. state_vars_that_need_continuous_buffering | length > 0 %}
/**
* grab the postsynaptic continuous variable values from the postsynaptic neuron
**/
{%- for var_name in paired_neuron.state_vars_that_need_continuous_buffering %}

{%- set var_name_post = utils.get_var_name_tuples_of_neuron_synapse_pair(continuous_post_ports, var_name) %}
const double __{{ var_name }} = ((post_neuron_t*)(__target))->get_{{ var_name_post }}();
{%- endfor %}
{%- endif %}

{%- filter indent(8, True) %}
{%- for pre_port in pre_ports %}
Expand Down Expand Up @@ -833,6 +846,17 @@ public:
{
auto get_t = [__t_spike](){ return __t_spike; }; // do not remove, this is in case the predefined time variable ``t`` is used in the NESTML model

{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 and paired_neuron. state_vars_that_need_continuous_buffering | length > 0 %}
/**
* grab the postsynaptic continuous variable values from the postsynaptic neuron
**/
{%- for var_name in paired_neuron.state_vars_that_need_continuous_buffering %}

{%- set var_name_post = utils.get_var_name_tuples_of_neuron_synapse_pair(continuous_post_ports, var_name) %}
const double __{{ var_name }} = ((post_neuron_t*)(__target))->get_{{ var_name_post }}();
{%- endfor %}
{%- endif %}

{%- filter indent(6, True) %}
{%- if post_ports is defined %}
{%- for post_port in spiking_post_ports %}
Expand Down Expand Up @@ -1249,9 +1273,9 @@ inline void
double dendritic_delay = get_delay();

// get spike history in relevant range (t_last_update, t_trig] from postsyn. neuron
std::deque< histentry__{{paired_neuron}} >::iterator start;
std::deque< histentry__{{paired_neuron}} >::iterator finish;
static_cast<{{paired_neuron}}*>(get_target(t))->get_history__( t_last_update_ - dendritic_delay, t_trig - dendritic_delay, &start, &finish );
std::deque< histentry__{{paired_neuron_name}} >::iterator start;
std::deque< histentry__{{paired_neuron_name}} >::iterator finish;
static_cast<{{paired_neuron_name}}*>(get_target(t))->get_history__( t_last_update_ - dendritic_delay, t_trig - dendritic_delay, &start, &finish );

// facilitation due to postsyn. spikes since last update
double t0 = t_last_update_;
Expand Down
2 changes: 1 addition & 1 deletion pynestml/transformers/synapse_post_neuron_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def mark_post_port(_expr=None):
for state_var, alternate_name in zip(post_connected_continuous_input_ports, post_variable_names):
Logger.log_message(None, -1, "\t• Replacing variable " + str(state_var), None, LoggingLevel.INFO)
ASTUtils.replace_with_external_variable(state_var, new_synapse, "",
new_synapse.get_equations_blocks()[0], alternate_name)
new_synapse.get_equations_blocks()[0], state_var)

#
# copy parameters
Expand Down
26 changes: 26 additions & 0 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pynestml.codegeneration.printers.ast_printer import ASTPrinter
from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter
from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter
from pynestml.frontend.frontend_configuration import FrontendConfiguration
from pynestml.generated.PyNestMLLexer import PyNestMLLexer
from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.meta_model.ast_block import ASTBlock
Expand Down Expand Up @@ -614,6 +615,31 @@ def get_kernel_by_name(cls, node, name: str) -> Optional[ASTKernel]:

return None

@classmethod
def print_alternate_var_name(cls, var_name, continuous_post_ports):
for pair in continuous_post_ports:
if pair[0] == var_name:
return pair[1]

assert False

@classmethod
def get_post_ports_of_neuron_synapse_pair(cls, neuron, synapse, codegen_opts_pairs):
for pair in codegen_opts_pairs:
print("Checking pair " + str(pair) + " for ne = " + str(neuron.get_name().split("__with_")[0].removesuffix(FrontendConfiguration.suffix)) + " syn = " + synapse.get_name().split("__with_")[0].removesuffix(FrontendConfiguration.suffix))
if pair["neuron"] == neuron.get_name().split("__with_")[0].removesuffix(FrontendConfiguration.suffix) and pair["synapse"] == synapse.get_name().split("__with_")[0].removesuffix(FrontendConfiguration.suffix):
return pair["post_ports"]
return None

@classmethod
def get_var_name_tuples_of_neuron_synapse_pair(cls, post_port_names, post_port):
print("post port names: " + str(post_port_names))
print("Searching for " + str(post_port))
for pair in post_port_names:
if pair[0] == post_port:
return pair[1]
return None

@classmethod
def replace_with_external_variable(cls, var_name, node: ASTNode, suffix, new_scope, alternate_name=None):
"""
Expand Down

0 comments on commit 067aa0a

Please sign in to comment.