Skip to content

Commit

Permalink
fix third-factor plasticity buffering and add third factor plasticity…
Browse files Browse the repository at this point in the history
… tutorial
  • Loading branch information
C.A.P. Linssen committed May 6, 2024
1 parent 19e6682 commit 146c737
Show file tree
Hide file tree
Showing 6 changed files with 966 additions and 194 deletions.

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pynestml/codegeneration/nest_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,14 +604,14 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
namespace = self._get_model_namespace(neuron)

if "paired_synapse" in dir(neuron):
if "state_vars_that_need_post_spike_buffering" in dir(neuron):
namespace["state_vars_that_need_post_spike_buffering"] = neuron.state_vars_that_need_post_spike_buffering
if "state_vars_that_need_continuous_buffering" in dir(neuron):
namespace["state_vars_that_need_continuous_buffering"] = neuron.state_vars_that_need_continuous_buffering

codegen_and_builder_opts = FrontendConfiguration.get_codegen_opts()
xfrm = SynapsePostNeuronTransformer(codegen_and_builder_opts)
namespace["state_vars_that_need_post_spike_buffering"] = [xfrm.get_neuron_var_name_from_syn_port_name(port_name, removesuffix(neuron.unpaired_name, FrontendConfiguration.suffix), removesuffix(neuron.paired_synapse.get_name().split("__with_")[0], FrontendConfiguration.suffix)) for port_name in neuron.state_vars_that_need_post_spike_buffering]
namespace["state_vars_that_need_continuous_buffering_transformed"] = [xfrm.get_neuron_var_name_from_syn_port_name(port_name, removesuffix(neuron.unpaired_name, FrontendConfiguration.suffix), removesuffix(neuron.paired_synapse.get_name().split("__with_")[0], FrontendConfiguration.suffix)) for port_name in neuron.state_vars_that_need_continuous_buffering]
else:
namespace["state_vars_that_need_post_spike_buffering"] = []
namespace["state_vars_that_need_continuous_buffering"] = []
namespace["extra_on_emit_spike_stmts_from_synapse"] = neuron.extra_on_emit_spike_stmts_from_synapse
namespace["paired_synapse"] = neuron.paired_synapse.get_name()
namespace["post_spike_updates"] = neuron.post_spike_updates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,22 @@ along with NEST. If not, see <http://www.gnu.org/licenses/>.
#include "lockptrdatum.h"

#include "{{neuronName}}.h"
{%- if state_vars_that_need_continuous_buffering | length > 0 %}


continuous_variable_histentry_{{ neuronName }}::continuous_variable_histentry_{{ neuronName }}( double t,
{%- for state_var in state_vars_that_need_continuous_buffering %}
double {{ state_var }}{% if not loop.last %}, {% endif %}
{%- endfor %} )
: t_( t )
, access_counter_ ( 0 )
{%- for state_var in state_vars_that_need_continuous_buffering %}
, {{ state_var }}( {{ state_var }} )
{%- endfor %}
{
}
{%- endif %}

{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2")
or nest_version.startswith("v3.3") or nest_version.startswith("v3.4") or nest_version.startswith("v3.5") or nest_version.startswith("v3.6")) %}
void
Expand Down Expand Up @@ -965,6 +979,15 @@ const double {{ variable_name }}__tmp = {{ printer.print(var_ast) }};
// voltage logging
B_.logger_.record_data(origin.get_steps() + lag);
{%- endif %}

{%- if state_vars_that_need_continuous_buffering | length > 0 %}
write_continuous_variable_history( nest::Time::step( origin.get_steps() + lag + 1 ),
{%- for state_var_name in state_vars_that_need_continuous_buffering_transformed %}
{%- set state_var = utils.get_variable_by_name(astnode, state_var_name) %}
{{ printer.print(state_var) }}{% if not loop.last %}, {% endif %}
{%- endfor %}
);
{%- endif %}
}

{%- if use_gap_junctions %}
Expand Down Expand Up @@ -992,6 +1015,108 @@ const double {{ variable_name }}__tmp = {{ printer.print(var_ast) }};
{%- endif %}
}

{%- if state_vars_that_need_continuous_buffering | length > 0 %}

continuous_variable_histentry_{{ neuronName }} {{ neuronName }}::get_continuous_variable_history( double t )
{
std::deque< continuous_variable_histentry_{{ neuronName }} >::iterator runner;
if ( continuous_variable_history_.empty() or t < 0.0 )
{
return continuous_variable_histentry_{{ neuronName }}(0.,
{%- for state_var in state_vars_that_need_continuous_buffering %}
0.{% if not loop.last %},{% endif %}
{%- endfor %}); // XXX: TODO: return initial value
}
else
{
runner = continuous_variable_history_.begin();
while ( runner != continuous_variable_history_.end() )
{
if ( fabs( t - runner->t_ ) < nest::kernel().connection_manager.get_stdp_eps() )
{
return *runner;
}
( runner->access_counter_ )++;
++runner;
}
}

// if we get here, there is no entry at time t
std::cout << "\n\n\nXXX FIX ME: no entry at time t!\n\n\n";
return continuous_variable_histentry_{{ neuronName }}(0.,
{%- for state_var in state_vars_that_need_continuous_buffering %}
0.{% if not loop.last %}, {% endif %}
{%- endfor %}); // XXX: TODO: return initial value
}

void {{neuronName}}::get_continuous_variable_history( double t1,
double t2,
std::deque< continuous_variable_histentry_{{ neuronName }} >::iterator* start,
std::deque< continuous_variable_histentry_{{ neuronName }} >::iterator* finish )
{
#ifdef DEBUG
std::cout << "{{neuronName}}::get_continuous_variable_history()" << std::endl;
#endif
*finish = continuous_variable_history_.end();
if ( continuous_variable_history_.empty() )
{
*start = *finish;
return;
}
else
{
std::deque< continuous_variable_histentry_{{ neuronName }} >::iterator runner = continuous_variable_history_.begin();

// To have a well defined discretization of the integral, we make sure
// that we exclude the entry at t1 but include the one at t2 by subtracting
// a small number so that runner->t_ is never equal to t1 or t2.
while ( ( runner != continuous_variable_history_.end() ) and runner->t_ - 1.0e-6 < t1 )
{
++runner;
}
*start = runner;
while ( ( runner != continuous_variable_history_.end() ) and runner->t_ - 1.0e-6 < t2 )
{
( runner->access_counter_ )++;
++runner;
}
*finish = runner;
}
}

void {{neuronName}}::write_continuous_variable_history(nest::Time const &t,
{%- for state_var in state_vars_that_need_continuous_buffering %}
const double {{ state_var }}{% if not loop.last %}, {% endif %}
{%- endfor %})
{
#ifdef DEBUG
std::cout << "{{neuronName}}::write_continuous_variable_history()" << std::endl;
#endif
const double t_ms = t.get_ms();

// prune all entries from history which are no longer needed except the penultimate one. we might still need it.
while ( continuous_variable_history_.size() > 1 )
{
if ( continuous_variable_history_.front().access_counter_ >= n_incoming_ )
{
continuous_variable_history_.pop_front();
}
else
{
break;
}
}

continuous_variable_history_.push_back( continuous_variable_histentry_{{ neuronName }}( t_ms,
{%- for state_var in state_vars_that_need_continuous_buffering %}
{{ state_var }}{% if not loop.last %}, {% endif %}
{%- endfor %}) );
#ifdef DEBUG
std::cout << "\thistory size = " << continuous_variable_history_.size() << std::endl;
#endif
}
{%- endif %}

// Do not move this function as inline to h-file. It depends on
// universal_data_logger_impl.h being included here.
void {{neuronName}}::handle(nest::DataLoggingRequest& e)
Expand Down Expand Up @@ -1241,9 +1366,6 @@ void
{%- endfor %}
{%- for var in analytic_state_variables_moved|sort %}
, get_{{ var }}()
{%- endfor %}
{%- for var in state_vars_that_need_post_spike_buffering|sort %}
, get_{{ var }}()
{%- endfor %}
, 0 // access counter
) );
Expand Down Expand Up @@ -1271,7 +1393,7 @@ void
generate getter functions for the transferred variables
#}

{%- for var in transferred_variables + state_vars_that_need_post_spike_buffering %}
{%- for var in transferred_variables %}
{%- with variable_symbol = transferred_variables_syms[var] %}

{%- if not variable_symbol %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,6 @@ public:
{%- endfor %}
{%- for var in analytic_state_variables_moved|sort %}
double {{ var }},
{%- endfor %}
{%- for var in state_vars_that_need_post_spike_buffering|sort %}
double {{ var }},
{%- endfor %}
size_t access_counter )
: t_( t )
Expand All @@ -183,9 +180,6 @@ public:
{%- endfor %}
{%- for var in analytic_state_variables_moved|sort %}
, {{ var }}_( {{ var }} )
{%- endfor %}
{%- for var in state_vars_that_need_post_spike_buffering|sort %}
, {{ var }}_( {{ var }} )
{%- endfor %}
, access_counter_( access_counter )
{
Expand All @@ -197,13 +191,27 @@ public:
{%- endfor %}
{%- for var in analytic_state_variables_moved|sort %}
double {{ var }}_;
{%- endfor %}
{%- for var in state_vars_that_need_post_spike_buffering|sort %}
double {{ var }}_;
{%- endfor %}
size_t access_counter_; //!< access counter to enable removal of the entry, once all neurons read it
};
{%- if state_vars_that_need_continuous_buffering | length > 0 %}

class continuous_variable_histentry_{{ neuronName }}
{
public:
continuous_variable_histentry_{{ neuronName }}( double t,
{%- for state_var in state_vars_that_need_continuous_buffering %}
double {{ state_var }}{% if not loop.last %}, {% endif %}
{%- endfor %} );

{%- for state_var in state_vars_that_need_continuous_buffering %}
double {{ state_var }};
{%- endfor %}

double t_; //!< point in time for history entry
size_t access_counter_;
};
{%- endif %}
{%- endif %}

/* BeginDocumentation
Expand Down Expand Up @@ -346,6 +354,25 @@ public:
* with t > t_first_read.
*/
void register_stdp_connection( double t_first_read, double delay );
{%- if state_vars_that_need_continuous_buffering | length > 0 %}

/**
* write_continuous_variable_history
*/
void write_continuous_variable_history(nest::Time const &t,
{%- for state_var in state_vars_that_need_continuous_buffering %}
const double {{ state_var }}{% if not loop.last %}, {% endif %}
{%- endfor %});

void get_continuous_variable_history( double t1,
double t2,
std::deque< continuous_variable_histentry_{{ neuronName }} >::iterator* start,
std::deque< continuous_variable_histentry_{{ neuronName }} >::iterator* finish );

continuous_variable_histentry_{{ neuronName }} get_continuous_variable_history( double t );

std::deque< continuous_variable_histentry_{{ neuronName }} > continuous_variable_history_;
{%- endif %}
{%- endif %}
{%- if neuron.get_state_symbols()|length > 0 %}
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -394,7 +421,7 @@ public:

/* getters/setters for variables transferred from synapse */

{%- for var in transferred_variables + state_vars_that_need_post_spike_buffering %}
{%- for var in transferred_variables + state_vars_that_need_continuous_buffering %}
double get_{{var}}( double t, const bool before_increment = true );
{%- endfor %}
{%- endif %}
Expand Down Expand Up @@ -457,7 +484,7 @@ private:
std::deque< histentry__{{neuronName}} > history_;

// cache for initial values
{%- for var in transferred_variables + state_vars_that_need_post_spike_buffering %}
{%- for var in transferred_variables + state_vars_that_need_continuous_buffering %}
double {{var}}__iv;
{%- endfor %}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#define DEBUG
{#-
SynapseHeader.h.jinja2
Expand Down Expand Up @@ -78,8 +79,6 @@ along with NEST. If not, see <http://www.gnu.org/licenses/>.
{{ synapse.print_comment() }}
**/

//#define DEBUG

namespace nest
{
{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2")
Expand Down Expand Up @@ -727,18 +726,17 @@ public:
&finish );
while ( start != finish )
{
{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 %}
{%- if paired_neuron.state_vars_that_need_post_spike_buffering | length > 0 %}
{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 and paired_neuron.state_vars_that_need_continuous_buffering | length > 0 %}
/**
* grab state variables from the postsynaptic neuron
* grab state variables from the postsynaptic neuron at the time of the post spike
**/

{%- for var_name in paired_neuron.state_vars_that_need_post_spike_buffering %}
auto histentry = ((post_neuron_t*)(__target))->get_continuous_variable_history(start->t_ + __dendritic_delay);

{%- for var_name in paired_neuron.state_vars_that_need_continuous_buffering %}
{%- set var = utils.get_parameter_variable_by_name(astnode, var_name) %}
{%- set var_name_post = utils.get_var_name_tuples_of_neuron_synapse_pair(continuous_post_ports, var_name) %}
const double __{{ var_name }} = start->{{ var_name_post }}_;
const double __{{ var_name }} = histentry.{{ var_name }};
{%- endfor %}
{%- endif %}
{%- endif %}

{% if vt_ports is defined and vt_ports|length > 0 %}
Expand Down Expand Up @@ -812,17 +810,15 @@ public:
{
auto get_t = [__t_spike](){ return __t_spike; }; // do not remove, this is in case the predefined time variable ``t`` is used in the NESTML model

{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 %}
{%- if paired_neuron.state_vars_that_need_post_spike_buffering | length > 0 %}
{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 and paired_neuron.state_vars_that_need_continuous_buffering | length > 0 %}
/**
* grab state variables from the postsynaptic neuron
**/

{%- for var_name in paired_neuron.state_vars_that_need_post_spike_buffering %}
{%- for var_name in paired_neuron.state_vars_that_need_continuous_buffering %}
{%- set var_name_post = utils.get_var_name_tuples_of_neuron_synapse_pair(continuous_post_ports, var_name) %}
const double __{{ var_name }} = ((post_neuron_t*)(__target))->get_{{ var_name_post }}(_tr_t);
const double __{{ var_name }} = ((post_neuron_t*)(__target))->get_{{ var_name_post }}();
{%- endfor %}
{%- endif %}
{%- endif %}

{%- filter indent(8, True) %}
Expand Down Expand Up @@ -859,14 +855,14 @@ public:
auto get_t = [__t_spike](){ return __t_spike; }; // do not remove, this is in case the predefined time variable ``t`` is used in the NESTML model

{%- if paired_neuron_name is not none and paired_neuron_name|length > 0 %}
{%- if paired_neuron.state_vars_that_need_post_spike_buffering | length > 0 %}
{%- if paired_neuron.state_vars_that_need_continuous_buffering | length > 0 %}
/**
* grab state variables from the postsynaptic neuron
**/

{%- for var_name in paired_neuron.state_vars_that_need_post_spike_buffering %}
{%- for var_name in paired_neuron.state_vars_that_need_continuous_buffering %}
{%- set var_name_post = utils.get_var_name_tuples_of_neuron_synapse_pair(continuous_post_ports, var_name) %}
const double __{{ var_name }} = ((post_neuron_t*)(__target))->get_{{ var_name_post }}(_tr_t);
const double __{{ var_name }} = ((post_neuron_t*)(__target))->get_{{ var_name_post }}();
{%- endfor %}
{%- endif %}
{%- endif %}
Expand Down
10 changes: 5 additions & 5 deletions pynestml/transformers/synapse_post_neuron_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,20 +377,20 @@ def transform_neuron_synapse_pair_(self, neuron: ASTModel, synapse: ASTModel):
# collect all ``continuous`` type input ports, the value of which is used in event handlers -- these have to be buffered in the hist_entry for each post spike in the postsynaptic history
#

state_vars_that_need_post_spike_buffering = []
state_vars_that_need_continuous_buffering = []
for input_block in new_synapse.get_input_blocks():
for port in input_block.get_input_ports():
if self.is_continuous_port(port.name, new_synapse):
state_vars_that_need_post_spike_buffering.append(port.name)
state_vars_that_need_continuous_buffering.append(port.name)

# check that they are not used in the update block
update_block_var_names = []
for update_block in synapse.get_update_blocks():
update_block_var_names.extend([var.get_complete_name() for var in ASTUtils.collect_variable_names_in_expression(update_block)])

assert all([var not in update_block_var_names for var in state_vars_that_need_post_spike_buffering])
assert all([var not in update_block_var_names for var in state_vars_that_need_continuous_buffering])

Logger.log_message(None, -1, "Synaptic state variables moved to neuron that will need buffering: " + str(state_vars_that_need_post_spike_buffering), None, LoggingLevel.INFO)
Logger.log_message(None, -1, "Synaptic state variables moved to neuron that will need buffering: " + str(state_vars_that_need_continuous_buffering), None, LoggingLevel.INFO)

#
# move state variable declarations from synapse to neuron
Expand Down Expand Up @@ -580,7 +580,7 @@ def mark_post_port(_expr=None):
new_neuron.unpaired_name = neuron.get_name()
new_neuron.set_name(new_neuron_name)
new_neuron.paired_synapse = new_synapse
new_neuron.state_vars_that_need_post_spike_buffering = state_vars_that_need_post_spike_buffering
new_neuron.state_vars_that_need_continuous_buffering = state_vars_that_need_continuous_buffering

#
# rename synapse
Expand Down

0 comments on commit 146c737

Please sign in to comment.