From 0596e8ea7ea4ea0dcaa6fe4761de8b25561feff2 Mon Sep 17 00:00:00 2001 From: Pooja Babu <75320801+pnbabu@users.noreply.github.com> Date: Mon, 16 Dec 2024 17:15:28 +0100 Subject: [PATCH] Fix bug with `integrate_odes()` for numeric solver (#1147) --- .../GSLDifferentiationFunction.jinja2 | 9 +- .../aeif_cond_alpha_alt_neuron.nestml | 120 ++++++++++++++++++ tests/nest_tests/test_integrate_odes.py | 85 +++++++++++-- 3 files changed, 199 insertions(+), 15 deletions(-) create mode 100644 tests/nest_tests/resources/aeif_cond_alpha_alt_neuron.nestml diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLDifferentiationFunction.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLDifferentiationFunction.jinja2 index c40f71cf6..e2495a676 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLDifferentiationFunction.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLDifferentiationFunction.jinja2 @@ -51,14 +51,17 @@ extern "C" inline int {{neuronName}}_dynamics{% if ast.get_args() | length > 0 % } {%- endif %} - +{% set numeric_state_variables_to_be_integrated = numeric_state_variables + purely_numeric_state_variables_moved %} +{%- if ast.get_args() | length > 0 %} +{%- set numeric_state_variables_to_be_integrated = utils.filter_variables_list(numeric_state_variables_to_be_integrated, ast.get_args()) %} +{%- endif %} {%- for variable_name in numeric_state_variables + numeric_state_variables_moved %} {%- set update_expr = numeric_update_expressions[variable_name] %} {%- set variable_symbol = variable_symbols[variable_name] %} {%- if use_gap_junctions %} - f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in utils.integrate_odes_args_strs_from_function_call(ast) + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer.print(update_expr)|replace("node.B_." + gap_junction_port + "_grid_sum_", "(node.B_." + gap_junction_port + "_grid_sum_ + __I_gap)") }}{% else %}0{% endif %}{% else %}{{ gsl_printer.print(update_expr) }}{% endif %}; + f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in numeric_state_variables_to_be_integrated + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer.print(update_expr)|replace("node.B_." + gap_junction_port + "_grid_sum_", "(node.B_." + gap_junction_port + "_grid_sum_ + __I_gap)") }}{% else %}0{% endif %}{% else %}{{ gsl_printer.print(update_expr) }}{% endif %}; {%- else %} - f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in utils.integrate_odes_args_strs_from_function_call(ast) + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer.print(update_expr) }}{% else %}0{% endif %}{% else %}{{ gsl_printer.print(update_expr) }}{% endif %}; + f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in numeric_state_variables_to_be_integrated + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer.print(update_expr) }}{% else %}0{% endif %}{% else %}{{ gsl_printer.print(update_expr) }}{% endif %}; {%- endif %} {%- endfor %} diff --git a/tests/nest_tests/resources/aeif_cond_alpha_alt_neuron.nestml b/tests/nest_tests/resources/aeif_cond_alpha_alt_neuron.nestml new file mode 100644 index 000000000..9121ef2b9 --- /dev/null +++ b/tests/nest_tests/resources/aeif_cond_alpha_alt_neuron.nestml @@ -0,0 +1,120 @@ +""" +aeif_cond_alpha - Conductance based exponential integrate-and-fire neuron model +############################################################################### + +Description ++++++++++++ + +aeif_psc_alpha is the adaptive exponential integrate and fire neuron according to Brette and Gerstner (2005), with post-synaptic conductances in the form of a bi-exponential ("alpha") function. + +The membrane potential is given by the following differential equation: + +.. math:: + + C_m \frac{dV_m}{dt} = + -g_L(V_m-E_L)+g_L\Delta_T\exp\left(\frac{V_m-V_{th}}{\Delta_T}\right) - + g_e(t)(V_m-E_e) \\ + -g_i(t)(V_m-E_i)-w + I_e + +and + +.. math:: + + \tau_w \frac{dw}{dt} = a(V_m-E_L) - w + +Note that the membrane potential can diverge to positive infinity due to the exponential term. To avoid numerical instabilities, instead of :math:`V_m`, the value :math:`\min(V_m,V_{peak})` is used in the dynamical equations. + + +References +++++++++++ + +.. [1] Brette R and Gerstner W (2005). Adaptive exponential + integrate-and-fire model as an effective description of neuronal + activity. Journal of Neurophysiology. 943637-3642 + DOI: https://doi.org/10.1152/jn.00686.2005 + + +See also +++++++++ + +iaf_psc_alpha, aeif_psc_exp +""" +model aeif_cond_alpha_alt_neuron: + + state: + V_m mV = E_L # Membrane potential + w pA = 0 pA # Spike-adaptation current + refr_t ms = 0 ms # Refractory period timer + g_exc nS = 0 nS # AHP conductance + g_exc' nS/ms = 0 nS/ms # AHP conductance + g_inh nS = 0 nS # AHP conductance + g_inh' nS/ms = 0 nS/ms # AHP conductance + + equations: + inline V_bounded mV = min(V_m, V_peak) # prevent exponential divergence + + g_exc'' = -2 * g_exc' / tau_syn_exc - g_exc / tau_syn_exc**2 + g_inh'' = -2 * g_inh' / tau_syn_inh - g_inh / tau_syn_inh**2 + + # Add inlines to simplify the equation definition of V_m + inline exp_arg real = (V_bounded - V_th) / Delta_T + inline I_spike pA = g_L * Delta_T * exp(exp_arg) + + V_m' = (-g_L * (V_bounded - E_L) + I_spike - g_exc * (V_bounded - E_exc) - g_inh * (V_bounded - E_inh) - w + I_e + I_stim) / C_m + w' = (a * (V_bounded - E_L) - w) / tau_w + + refr_t' = -1e3 * ms/s # refractoriness is implemented as an ODE, representing a timer counting back down to zero. XXX: TODO: This should simply read ``refr_t' = -1 / s`` (see https://github.com/nest/nestml/issues/984) + + parameters: + # membrane parameters + C_m pF = 281.0 pF # Membrane Capacitance + refr_T ms = 2 ms # Duration of refractory period + V_reset mV = -60.0 mV # Reset Potential + g_L nS = 30.0 nS # Leak Conductance + E_L mV = -70.6 mV # Leak reversal Potential (aka resting potential) + + # spike adaptation parameters + a nS = 4 nS # Subthreshold adaptation + b pA = 80.5 pA # Spike-triggered adaptation + Delta_T mV = 2.0 mV # Slope factor + tau_w ms = 144.0 ms # Adaptation time constant + V_th mV = -50.4 mV # Threshold Potential + V_peak mV = 0 mV # Spike detection threshold + + # synaptic parameters + tau_syn_exc ms = 0.2 ms # Synaptic Time Constant Excitatory Synapse + tau_syn_inh ms = 2.0 ms # Synaptic Time Constant for Inhibitory Synapse + E_exc mV = 0 mV # Excitatory reversal Potential + E_inh mV = -85.0 mV # Inhibitory reversal Potential + + # constant external input current + I_e pA = 0 pA + + input: + exc_spikes <- excitatory spike + inh_spikes <- inhibitory spike + I_stim pA <- continuous + + output: + spike + + update: + if refr_t > 0 ms: + # neuron is absolute refractory, do not evolve V_m + integrate_odes(g_exc, g_inh, w, refr_t) + else: + # neuron not refractory + integrate_odes(g_exc, g_inh, V_m, w) + + onReceive(exc_spikes): + g_exc' += exc_spikes * (e / tau_syn_exc) * nS * s + + onReceive(inh_spikes): + g_inh' += inh_spikes * (e / tau_syn_inh) * nS * s + + onCondition(refr_t <= 0 ms and V_m >= V_th): + # threshold crossing + refr_t = refr_T # start of the refractory period + V_m = V_reset + w += b + emit_spike() diff --git a/tests/nest_tests/test_integrate_odes.py b/tests/nest_tests/test_integrate_odes.py index ae4002062..dd880ad0e 100644 --- a/tests/nest_tests/test_integrate_odes.py +++ b/tests/nest_tests/test_integrate_odes.py @@ -33,9 +33,11 @@ try: import matplotlib + matplotlib.use("Agg") import matplotlib.ticker import matplotlib.pyplot as plt + TEST_PLOTS = True except Exception: TEST_PLOTS = False @@ -50,22 +52,24 @@ def setUp(self): r"""Generate the model code""" generate_nest_target(input_path=[os.path.realpath(os.path.join(os.path.dirname(__file__), - os.path.join(os.pardir, os.pardir, "models", "neurons", "iaf_psc_exp_neuron.nestml"))), + os.path.join(os.pardir, os.pardir, "models", "neurons", "iaf_psc_exp_neuron.nestml"))), os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test.nestml"))), os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_nonlinear_test.nestml"))), os.path.realpath(os.path.join(os.path.dirname(__file__), - os.path.join("resources", "alpha_function_2nd_order_ode_neuron.nestml"))),], + os.path.join("resources", "alpha_function_2nd_order_ode_neuron.nestml"))), + os.path.realpath(os.path.join(os.path.dirname(__file__), + os.path.join("resources", "aeif_cond_alpha_alt_neuron.nestml")))], logging_level="INFO", suffix="_nestml") def test_convolutions_always_integrated(self): r"""Test that synaptic integration continues for iaf_psc_exp, even when neuron is refractory.""" - sim_time: float = 100. # [ms] - resolution: float = .1 # [ms] - spike_interval = 5. # [ms] + sim_time: float = 100. # [ms] + resolution: float = .1 # [ms] + spike_interval = 5. # [ms] nest.set_verbosity("M_ALL") nest.ResetKernel() @@ -120,8 +124,8 @@ def test_convolutions_always_integrated(self): def test_integrate_odes(self): r"""Test the integrate_odes() function, in particular when not all the ODEs are being integrated.""" - sim_time: float = 100. # [ms] - resolution: float = .1 # [ms] + sim_time: float = 100. # [ms] + resolution: float = .1 # [ms] nest.set_verbosity("M_ALL") nest.ResetKernel() @@ -169,8 +173,8 @@ def test_integrate_odes(self): def test_integrate_odes_nonlinear(self): r"""Test the integrate_odes() function, in particular when not all the ODEs are being integrated, for nonlinear ODEs.""" - sim_time: float = 100. # [ms] - resolution: float = .1 # [ms] + sim_time: float = 100. # [ms] + resolution: float = .1 # [ms] nest.set_verbosity("M_ALL") nest.ResetKernel() @@ -205,7 +209,6 @@ def test_integrate_odes_nonlinear(self): for _ax in ax: _ax.grid(which="major", axis="both") _ax.grid(which="minor", axis="x", linestyle=":", alpha=.4) - # _ax.minorticks_on() _ax.set_xlim(0., sim_time) _ax.legend() @@ -218,7 +221,8 @@ def test_integrate_odes_nonlinear(self): def test_integrate_odes_params(self): r"""Test the integrate_odes() function, in particular with respect to the parameter types.""" - fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml"))) + fname = os.path.realpath( + os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml"))) generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2 @@ -226,7 +230,8 @@ def test_integrate_odes_params(self): def test_integrate_odes_params2(self): r"""Test the integrate_odes() function, in particular with respect to non-existent parameter variables.""" - fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params2.nestml"))) + fname = os.path.realpath( + os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params2.nestml"))) generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2 @@ -276,3 +281,59 @@ def test_integrate_odes_higher_order(self): # verify np.testing.assert_allclose(x_actual[-1], 0.10737970490959549) np.testing.assert_allclose(y_actual[-1], 0.6211608596446752) + + def test_integrate_odes_numeric_higher_order(self): + r""" + Tests for higher-order ODEs of the form F(x'',x',x)=0, integrate_odes(x) integrates the full dynamics of x with a numeric solver. + """ + resolution = 0.1 + simtime = 800. + params_nestml = {"V_peak": 0.0, "a": 4.0, "b": 80.5, "E_L": -70.6, + "g_L": 300.0, 'E_exc': 20.0, 'E_inh': -85.0, + 'tau_syn_exc': 40.0, 'tau_syn_inh': 20.0} + + params_nest = {"V_peak": 0.0, "a": 4.0, "b": 80.5, "E_L": -70.6, + "g_L": 300.0, 'E_ex': 20.0, 'E_in': -85.0, + 'tau_syn_ex': 40.0, 'tau_syn_in': 20.0} + + for model in ["aeif_cond_alpha_alt_neuron_nestml", "aeif_cond_alpha"]: + nest.set_verbosity("M_ALL") + nest.ResetKernel() + nest.SetKernelStatus({"resolution": resolution}) + try: + nest.Install("nestmlmodule") + except Exception: + # ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions + pass + n = nest.Create(model) + if "_nestml" in model: + nest.SetStatus(n, params_nestml) + else: + nest.SetStatus(n, params_nest) + + spike = nest.Create("spike_generator") + spike_times = [10.0, 400.0] + nest.SetStatus(spike, {"spike_times": spike_times}) + nest.Connect(spike, n, syn_spec={"weight": 0.1, "delay": 1.0}) + nest.Connect(spike, n, syn_spec={"weight": -0.2, "delay": 100.}) + + mm = nest.Create("multimeter", params={"record_from": ["V_m"]}) + nest.Connect(mm, n) + + nest.Simulate(simtime) + times = mm.get()["events"]["times"] + if "_nestml" in model: + v_m_nestml = mm.get()["events"]["V_m"] + else: + v_m_nest = mm.get()["events"]["V_m"] + + if TEST_PLOTS: + fig, ax = plt.subplots(nrows=1) + + ax.plot(times, v_m_nestml, label="NESTML") + ax.plot(times, v_m_nest, label="NEST") + ax.legend() + + fig.savefig("/tmp/test_integrate_odes_numeric_higher_order.png") + + np.testing.assert_allclose(v_m_nestml, v_m_nest)