Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate the entire dynamics for higher-order ODEs #1139

Merged
merged 14 commits into from
Nov 27, 2024
16 changes: 16 additions & 0 deletions doc/nestml_language/nestml_language_concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,22 @@ Integrating the ODEs needs to be triggered explicitly inside the ``update`` bloc

The ``integrate_odes()`` function numerically integrates the differential equations defined in the ``equations`` block. Integrating the ODEs from one timestep to the next has to be explicitly carried out in the model by calling the ``integrate_odes()`` function. If no parameters are given, all ODEs in the model are integrated. Integration can be limited to a given set of ODEs by giving their left-hand side state variables as parameters to the function, for example ``integrate_odes(V_m, I_ahp)`` if ODEs exist for the variables ``V_m`` and ``I_ahp``. In this example, these variables are integrated simultaneously (as one single system of equations). This is different from calling ``integrate_odes(V_m)`` and then ``integrate_odes(I_ahp)`` in that the second call would use the already-updated values from the first call. Variables not included in the call to ``integrate_odes()`` are assumed to remain constant (both inside the numeric solver stepping function as well as from before to after the call).

In case of higher-order ODEs of the form ``F(x'', x', x) = 0``, the solution ``x(t)`` is obtained by just providing the variable ``x`` to the ``integrate_odes`` function. For example,

.. code-block:: nestml

state:
x real = 0
x' ms**-1 = 0 * ms**-1

equations:
x'' = - 2 * x' / ms - x / ms**2

update:
integrate_odes(x)

Here, ``integrate_odes(x)`` integrates the entire dynamics of ``x(t)``, in this case, ``x`` and ``x'``.

Note that the dynamical equations that correspond to convolutions are always updated, regardless of whether ``integrate_odes()`` is called. The state variables affected by incoming events are updated at the end of each timestep, that is, within one timestep, the state as observed by statements in the ``update`` block will be those at :math:`t^-`, i.e. "just before" it has been updated due to the events. See also :ref:`Integrating spiking input` and :ref:`Integration order`.

ODEs that can be solved analytically are integrated to machine precision from one timestep to the next using the propagators obtained from `ODE-toolbox <https://ode-toolbox.readthedocs.io/>`_. In case a numerical solver is used (such as Runge-Kutta or forward Euler), the same ODEs are also evaluated numerically by the numerical solver to allow more precise values for analytically solvable ODEs *within* a timestep. In this way, the long-term dynamics obeys the analytic (more exact) equations, while the short-term (within one timestep) dynamics is evaluated to the precision of the numerical integrator.
Expand Down
3 changes: 3 additions & 0 deletions pynestml/cocos/co_co_integrate_odes_params_correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ def visit_function_call(self, node):
if symbol_var is None or not symbol_var.is_state():
code, message = Messages.get_integrate_odes_wrong_arg(str(arg))
Logger.log_message(code=code, message=message, error_position=node.get_source_position(), log_level=LoggingLevel.ERROR)
elif symbol_var.is_state() and arg.get_variable().get_differential_order() > 0:
code, message = Messages.get_integrate_odes_arg_higher_order(str(arg))
Logger.log_message(code=code, message=message, error_position=node.get_source_position(), log_level=LoggingLevel.ERROR)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#}
{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %}
{%- if uses_analytic_solver %}
{%- for variable_name in analytic_state_variables_: %}
{%- for variable_name in analytic_state_variables_ %}
{%- set update_expr = update_expressions[variable_name] %}
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
Expand Down
14 changes: 14 additions & 0 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,21 @@ def filter_variables_list(cls, variables_list, variables_to_filter_by):
for var in variables_list:
if var in variables_to_filter_by:
ret.append(var)
# Add higher order variables of var if not already in the filter list
ret.extend(cls.get_higher_order_variables(var, variables_list, variables_to_filter_by))
return ret

@classmethod
def get_higher_order_variables(cls, var, variables_list, variables_to_filter_by) -> List[str]:
"""
Returns a list of higher order state variables of ``var`` from the ``variables_list`` that are not already present in ``variables_to_filter_by``.
"""
ret = []
for v in variables_list:
order = v.count('__d')
if order > 0:
if v.split("__d")[0] == var and v not in variables_to_filter_by:
ret.append(v)
return ret

@classmethod
Expand Down
6 changes: 6 additions & 0 deletions pynestml/utils/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class MessageCode(Enum):
EXPONENT_MUST_BE_INTEGER = 114
EMIT_SPIKE_OUTPUT_PORT_TYPE_DIFFERS = 115
CONTINUOUS_OUTPUT_PORT_MAY_NOT_HAVE_ATTRIBUTES = 116
INTEGRATE_ODES_ARG_HIGHER_ORDER = 117


class Messages:
Expand Down Expand Up @@ -1300,6 +1301,11 @@ def get_integrate_odes_wrong_arg(cls, arg: str) -> Tuple[MessageCode, str]:
message = "Parameter provided to integrate_odes() function is not a state variable: '" + arg + "'"
return MessageCode.INTEGRATE_ODES_WRONG_ARG, message

@classmethod
def get_integrate_odes_arg_higher_order(cls, arg: str) -> Tuple[MessageCode, str]:
message = "Parameter provided to integrate_odes() function is a state variable of higher order: '" + arg + "'"
return MessageCode.INTEGRATE_ODES_ARG_HIGHER_ORDER, message

@classmethod
def get_mechs_dictionary_info(cls, chan_info, syns_info, conc_info, con_in_info) -> Tuple[MessageCode, str]:
message = ""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
alpha_function_2nd_order_ode_neuron.nestml
##########################################

Tests that for a system of higher-oder ODEs of the form F(x'',x',x)=0, integrate_odes(x) includes the integration of all the higher-order variables involved of the system.

Copyright statement
+++++++++++++++++++

This file is part of NEST.

Copyright (C) 2004 The NEST Initiative

NEST is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 2 of the License, or
(at your option) any later version.

NEST is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with NEST. If not, see <http://www.gnu.org/licenses/>.
"""
model alpha_function_2nd_order_ode_neuron:
state:
x real = 0
x' ms**-1 = 0 * ms**-1
y real = 0

input:
fX <- spike

equations:
x'' = - 2 * x' / ms - x / ms**2
y' = (-y + 42) / s

update:
integrate_odes(x, y)

onReceive(fX):
x' += e*fX * s / ms
54 changes: 49 additions & 5 deletions tests/nest_tests/test_integrate_odes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def setUp(self):
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "integrate_odes_test.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "integrate_odes_nonlinear_test.nestml")))],
os.path.join("resources", "integrate_odes_nonlinear_test.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "alpha_function_2nd_order_ode_neuron.nestml"))),],
logging_level="INFO",
suffix="_nestml")

Expand Down Expand Up @@ -131,11 +133,9 @@ def test_integrate_odes(self):
pass

# create the network
spikedet = nest.Create("spike_recorder")
neuron = nest.Create("integrate_odes_test_nestml")
mm = nest.Create("multimeter", params={"record_from": ["test_1", "test_2"]})
nest.Connect(mm, neuron)
nest.Connect(neuron, spikedet)

# simulate
nest.Simulate(sim_time)
Expand Down Expand Up @@ -182,11 +182,9 @@ def test_integrate_odes_nonlinear(self):
pass

# create the network
spikedet = nest.Create("spike_recorder")
neuron = nest.Create("integrate_odes_nonlinear_test_nestml")
mm = nest.Create("multimeter", params={"record_from": ["test_1", "test_2"]})
nest.Connect(mm, neuron)
nest.Connect(neuron, spikedet)

# simulate
nest.Simulate(sim_time)
Expand Down Expand Up @@ -232,3 +230,49 @@ def test_integrate_odes_params2(self):
generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")

assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2

def test_integrate_odes_higher_order(self):
r"""
Tests for higher-order ODEs of the form F(x'',x',x)=0, integrate_odes(x) integrates the full dynamics of x.
"""
resolution = 0.1
simtime = 15.
nest.set_verbosity("M_ALL")
nest.ResetKernel()
nest.SetKernelStatus({"resolution": resolution})
try:
nest.Install("nestmlmodule")
except Exception:
# ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions
pass

n = nest.Create("alpha_function_2nd_order_ode_neuron_nestml")
sgX = nest.Create("spike_generator", params={"spike_times": [10.]})
nest.Connect(sgX, n, syn_spec={"weight": 1., "delay": resolution})

mm = nest.Create("multimeter", params={"interval": resolution, "record_from": ["x", "y"]})
nest.Connect(mm, n)

nest.Simulate(simtime)
times = mm.get()["events"]["times"]
x_actual = mm.get()["events"]["x"]
y_actual = mm.get()["events"]["y"]

if TEST_PLOTS:
fig, ax = plt.subplots(nrows=2)
ax1, ax2 = ax

ax2.plot(times, x_actual, label="x")
ax1.plot(times, y_actual, label="y")

for _ax in ax:
_ax.grid(which="major", axis="both")
_ax.grid(which="minor", axis="x", linestyle=":", alpha=.4)
_ax.set_xlim(0., simtime)
_ax.legend()

fig.savefig("/tmp/test_integrate_odes_higher_order.png", dpi=300)

# verify
np.testing.assert_allclose(x_actual[-1], 0.10737970490959549)
np.testing.assert_allclose(y_actual[-1], 0.6211608596446752)