Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clopath synapse test (initial version) and some minor code fixes #4

Open
wants to merge 10 commits into
base: jit-third-factor
Choose a base branch
from
79 changes: 79 additions & 0 deletions models/clopath_synapse.nestml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Name: clopath - Synapse type for spike-timing dependent plasticity.

Description:

"""
synapse clopath:

state:
w real = 1.
u_bar_plus mV = E_L
end

parameters:
the_delay ms = 1 ms @nest::delay # !!! cannot have a variable called "delay"
lambda real = .01
tau_tr_pre ms = 20 ms
tau_tr_post ms = 20 ms
alpha real = 1.
mu_plus real = 1.
mu_minus real = 1.
Wmax real = 100.
Wmin real = 0.

# Clopath
g_L nS = 30.0 nS
C_m pF = 281.0 pF
E_L mV = -70.6 mV
Delta_T mV = 2.0 mV
tau_w ms = 144.0 ms
tau_z ms = 40.0 ms
tau_V_th ms = 50.0 ms
V_th_max mV = 30.4 mV
V_th_rest mV = -50.4 mV
tau_plus ms = 7.0 ms
tau_minus ms = 10.0 ms
tau_bar_bar ms = 100.0 ms
a nS = 4.0 nS
b pA = 80.5 pA
I_sp pA = 400.0 pA
I_e pA = 0.0 pA
end

equations:
kernel pre_trace_kernel = exp(-t / tau_tr_pre)
inline pre_trace real = convolve(pre_trace_kernel, pre_spikes)

# all-to-all trace of postsynaptic neuron
kernel post_trace_kernel = exp(-t / tau_tr_post)
inline post_trace real = convolve(post_trace_kernel, post_spikes)

u_bar_plus' = (-u_bar_plus + v_clamp) / tau_plus
end

input:
pre_spikes nS <- spike
post_spikes nS <- spike
v_clamp mV <- continuous
end

output: spike

onReceive(post_spikes):
# potentiate synapse
w_ real = Wmax * ( w / Wmax + (lambda * ( 1. - ( w / Wmax ) )**mu_plus * pre_trace ))
w = min(Wmax, w_)
end

onReceive(pre_spikes):
# depress synapse
w_ real = Wmax * ( w / Wmax - ( alpha * lambda * ( w / Wmax )**mu_minus * post_trace ))
w = max(Wmin, w_)

# deliver spike to postsynaptic partner
deliver_spike(w, the_delay)
end

end

4 changes: 2 additions & 2 deletions models/triplet_stdp_synapse.nestml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
"""
synapse stdp_triplet:
synapse stdp_triplet_nn:

state:
w nS = 1 nS
Expand Down Expand Up @@ -37,7 +37,7 @@ synapse stdp_triplet:

input:
pre_spikes nS <- spike
post_spikes nS <- post spike
post_spikes nS <- spike
end

output: spike
Expand Down
40 changes: 39 additions & 1 deletion pynestml/codegeneration/nest_codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,6 @@ def move_updates_syn_neuron(state_var, synapse_block, neuron_block, var_name_suf
Logger.log_message(None, -1, "Moving onPost updates for " + str(state_var), None, LoggingLevel.INFO)
post_port_names = self.get_post_port_names(synapse, neuron.name, synapse.name)

assert len(post_port_names) <= 1, "Can only handle one \"post\" port"
if len(post_port_names) == 0:
continue
post_port_name = post_port_names[0]
Expand Down Expand Up @@ -670,6 +669,44 @@ def move_updates_syn_neuron(state_var, synapse_block, neuron_block, var_name_suf
Logger.log_message(None, -1, "\t• Copying variable " + str(state_var), None, LoggingLevel.INFO)
equations_from_syn_to_neuron(state_var, new_synapse.get_equations_block(), new_neuron.get_equations_block(), var_name_suffix, mode="move")

# Replace the post_port in the neuron to its equivalent post port
def replace_post_port_in_neurons(var_name, neuron, new_var_name, new_scope):
"""
Replace occurrences of moved post_port in the neuron with the neuron equivalent of post_port
"""
def replace_post_port(_expr=None):
if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable():
var = _expr.get_variable()
elif isinstance(_expr, ASTVariable):
var = _expr
else:
return

if var.get_name() != var_name:
return

new_var = ASTVariable(new_var_name, differential_order=var.get_differential_order(),
source_position=var.get_source_position())
new_var.update_scope(new_scope)
new_var.accept(ASTSymbolTableVisitor())

if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable():
_expr.set_variable(new_var)
elif isinstance(_expr, ASTVariable):
if isinstance(neuron.get_parent(_expr), ASTAssignment):
neuron.get_parent(_expr).lhs = new_var
Logger.log_message(None, -1, "ASTVariable replacement made in expression: " + str(
neuron.get_parent(_expr)), None, LoggingLevel.INFO)
elif isinstance(neuron.get_parent(_expr), ASTSimpleExpression) and neuron.get_parent(
_expr).is_variable():
neuron.get_parent(_expr).set_variable(new_var)
else:
Logger.log_message(None, -1,
"Error: instance of the variable not supported",
None, LoggingLevel.ERROR)
raise

neuron.accept(ASTHigherOrderVisitor(lambda x: replace_post_port(x)))

#
# replace occurrences of the variables in expressions in the original synapse with calls to the corresponding neuron getters
Expand Down Expand Up @@ -736,6 +773,7 @@ def replace_var(_expr=None):
for state_var, alternate_name in zip(post_connected_continuous_input_ports, post_variable_names):
Logger.log_message(None, -1, "\t• Replacing variable " + str(state_var), None, LoggingLevel.INFO)
replace_variable_name_in_expressions(state_var, new_synapse, "", new_synapse.get_equations_blocks().get_scope(), alternate_name)
replace_post_port_in_neurons(state_var, new_neuron, alternate_name, new_neuron.get_equations_blocks().get_scope())

# -------------- add dummy variable to state variable (and initial value) declarations so that type of the ASTExternalVariable can be resolved
"""
Expand Down
4 changes: 2 additions & 2 deletions pynestml/codegeneration/resources_nest/NeuronClass.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -746,9 +746,9 @@ double
{%- endif %}

{%- if purely_numeric_state_variables_moved|length > 0 %}
double ode_state_tmp[STATE_VEC_SIZE];
double ode_state_tmp[State_::STATE_VEC_SIZE];

for (int i = 0; i < STATE_VEC_SIZE; ++i) {
for (int i = 0; i < State_::STATE_VEC_SIZE; ++i) {
ode_state_tmp[i] = S_.ode_state[i];
}

Expand Down
16 changes: 10 additions & 6 deletions pynestml/codegeneration/resources_nest/SynapseHeader.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,11 @@ inline void set_weight(double w) {
**/

{%- set dynamics = synapse.get_on_receive_block(post_port) %}
{%- with ast = dynamics.get_block() %}
{%- include "directives/Block.jinja2" %}
{%- endwith %}
{%- if dynamics is not none %}
{%- with ast = dynamics.get_block() %}
{%- include "directives/Block.jinja2" %}
{%- endwith %}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- endfilter %}
Expand Down Expand Up @@ -722,9 +724,11 @@ inline void set_weight(double w) {
**/

{%- set dynamics = synapse.get_on_receive_block(post_port) %}
{%- with ast = dynamics.get_block() %}
{%- include "directives/Block.jinja2" %}
{%- endwith %}
{%- if dynamics is not none %}
{%- with ast = dynamics.get_block() %}
{%- include "directives/Block.jinja2" %}
{%- endwith %}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- endfilter %}
Expand Down
71 changes: 71 additions & 0 deletions tests/nest_tests/clopath_synapse_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import nest
import numpy as np

from pynestml.frontend.pynestml_frontend import to_nest
from pynestml.utils.model_installer import install_nest


def run_simulation(neuron_model_name, synapse_model_name, module_name, _pre_spike_times, _post_spike_times, delay=1.0):
nest.set_verbosity("M_ALL")
nest.ResetKernel()
try:
print('Installing module: ', module_name)
nest.Install(module_name)
except BaseException:
pass # pass if the module is already loaded

sim_time = max(np.amax(pre_spike_times), np.amax(post_spike_times)) + 5 * delay

neurons = nest.Create(neuron_model_name, 2)
pre_sg = nest.Create('spike_generator', params={'spike_times': _pre_spike_times})
post_sg = nest.Create('spike_generator', params={'spike_times': _post_spike_times})

spikes = nest.Create('spike_recorder')
wr = nest.Create('weight_recorder')
wr_ref = nest.Create('weight_recorder')
nest.CopyModel(synapse_model_name, synapse_model_name + "_rec",
{"weight_recorder": wr[0], "w": 1., "the_delay": 1., "receptor_type": 0, "lambda": .001})

multimeter_pre = nest.Create('multimeter')
multimeter_post = nest.Create('multimeter',
params={"record_from": "u_bar_plus__for_clopath_nestml"})

nest.Connect(neurons[0], neurons[1], syn_spec={'synapse_model': synapse_model_name + "_rec"})
nest.Connect(pre_sg, neurons[0], "one_to_one", syn_spec={"delay": 1.})
nest.Connect(post_sg, neurons[1], "one_to_one", syn_spec={"delay": 1., "weight": 9999.})
nest.Connect(multimeter_pre, neurons[0])
nest.Connect(multimeter_post, neurons[1])
nest.Connect(neurons, spikes)

# Simulate
nest.Simulate(sim_time)

# Record u_bar_plus
events_post = nest.GetStatus(multimeter_post, 'events')[0]
times_post = events_post['times']
u_bar_plus = events_post['u_bar_plus__for_clopath_nestml']

print(u_bar_plus)


nest_install_path = nest.ll_api.sli_func("statusdict/prefix ::")
module_name = "nestml_clopath_synapse_module"
to_nest(input_path=["models/iaf_psc_delta.nestml", "models/clopath_synapse.nestml"],
target_path="/tmp/nestml-clopath",
logging_level="INFO",
module_name=module_name,
suffix="_nestml",
codegen_opts={"neuron_parent_class": "StructuralPlasticityNode",
"neuron_parent_class_include": "structural_plasticity_node.h",
"neuron_synapse_pairs": [{"neuron": "iaf_psc_delta",
"synapse": "clopath",
"post_ports": ["post_spikes", ["v_clamp", "V_abs"]]}]})
install_nest("/tmp/nestml-clopath", nest_install_path)

neuron_model_name = "iaf_psc_delta_nestml__with_clopath_nestml"
synapse_model_name = "clopath_nestml__with_iaf_psc_delta_nestml"

post_spike_times = np.sort(np.unique(1 + np.round(10 * np.sort(np.abs(np.random.randn(10))))))
pre_spike_times = np.sort(np.unique(1 + np.round(10 * np.sort(np.abs(np.random.randn(10))))))

run_simulation(neuron_model_name, synapse_model_name, module_name, pre_spike_times, post_spike_times)
6 changes: 4 additions & 2 deletions tests/nest_tests/third_factor_stdp_synapse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def setUp(self):
nest_path = nest.ll_api.sli_func("statusdict/prefix ::")

# generate the "jit" model (co-generated neuron and synapse), that does not rely on ArchivingNode
to_nest(input_path=["models/iaf_psc_exp_dend.nestml", "models/third_factor_stdp_synapse.nestml"],
to_nest(input_path=["../../models/iaf_psc_exp_dend.nestml", "../../models/third_factor_stdp_synapse.nestml"],
target_path="/tmp/nestml-jit",
logging_level="INFO",
module_name="nestml_jit_module",
Expand Down Expand Up @@ -112,7 +112,9 @@ def run_synapse_test(self, neuron_model_name,
nest.set_verbosity("M_ALL")
nest.ResetKernel()
nest.Install("nestml_jit_module")
nest.Install("nestml_non_jit_module")

if sim_ref:
nest.Install("nestml_non_jit_module")

print("Pre spike times: " + str(pre_spike_times))
print("Post spike times: " + str(post_spike_times))
Expand Down