Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate the entire dynamics for higher-order ODEs #1139

Merged
merged 14 commits into from
Nov 27, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#}
{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %}
{%- if uses_analytic_solver %}
{%- for variable_name in analytic_state_variables_: %}
{%- for variable_name in analytic_state_variables_ %}
{%- set update_expr = update_expressions[variable_name] %}
{%- set var_ast = utils.get_variable_by_name(astnode, variable_name)%}
{%- set var_symbol = var_ast.get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE)%}
Expand Down
14 changes: 14 additions & 0 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,21 @@ def filter_variables_list(cls, variables_list, variables_to_filter_by):
for var in variables_list:
if var in variables_to_filter_by:
ret.append(var)
# Add higher order variables of var if not already in the filter list
ret.extend(cls.get_higher_order_variables(var, variables_list, variables_to_filter_by))
return ret

@classmethod
def get_higher_order_variables(cls, var, variables_list, variables_to_filter_by) -> List[str]:
"""
Returns a list of higher order state variables of ``var`` from the ``variables_list`` that are not already present in ``variables_to_filter_by``.
"""
ret = []
for v in variables_list:
order = v.count('__d')
if order > 0:
if v.split("__d")[0] == var and v not in variables_to_filter_by:
ret.append(v)
return ret

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
alpha_function_2nd_order_ode_neuron.nestml
##########################################

Tests that for a system of higher-oder ODEs of the form F(x'',x',x)=0, integrate_odes(x) includes the integration of all the higher-order variables involved of the system.

Copyright statement
+++++++++++++++++++

This file is part of NEST.

Copyright (C) 2004 The NEST Initiative

NEST is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 2 of the License, or
(at your option) any later version.

NEST is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with NEST. If not, see <http://www.gnu.org/licenses/>.
"""
model alpha_function_2nd_order_ode_neuron:
state:
x real = 0
x' ms**-1 = 0 * ms**-1
y real = 0

input:
fX <- spike

equations:
x'' = - 2 * x' / ms - x / ms**2
y' = (-y + 42) / s

update:
integrate_odes(x, y)

onReceive(fX):
x' += e*fX * s / ms
55 changes: 50 additions & 5 deletions tests/nest_tests/test_integrate_odes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def setUp(self):
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "integrate_odes_test.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "integrate_odes_nonlinear_test.nestml")))],
os.path.join("resources", "integrate_odes_nonlinear_test.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "alpha_function_2nd_order_ode_neuron.nestml"))),],
logging_level="INFO",
suffix="_nestml")

Expand Down Expand Up @@ -131,11 +133,9 @@ def test_integrate_odes(self):
pass

# create the network
spikedet = nest.Create("spike_recorder")
neuron = nest.Create("integrate_odes_test_nestml")
mm = nest.Create("multimeter", params={"record_from": ["test_1", "test_2"]})
nest.Connect(mm, neuron)
nest.Connect(neuron, spikedet)

# simulate
nest.Simulate(sim_time)
Expand Down Expand Up @@ -182,11 +182,9 @@ def test_integrate_odes_nonlinear(self):
pass

# create the network
spikedet = nest.Create("spike_recorder")
neuron = nest.Create("integrate_odes_nonlinear_test_nestml")
mm = nest.Create("multimeter", params={"record_from": ["test_1", "test_2"]})
nest.Connect(mm, neuron)
nest.Connect(neuron, spikedet)

# simulate
nest.Simulate(sim_time)
Expand Down Expand Up @@ -232,3 +230,50 @@ def test_integrate_odes_params2(self):
generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")

assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2

def test_integrate_odes_higher_order(self):
r"""
Tests for higher-order ODEs of the form F(x'',x',x)=0, integrate_odes(x) integrates the full dynamics of x.
"""
resolution = 0.1
simtime = 15.
nest.set_verbosity("M_ALL")
nest.ResetKernel()
nest.SetKernelStatus({"resolution": resolution})
try:
nest.Install("nestmlmodule")
except Exception:
# ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions
pass

n = nest.Create("alpha_function_2nd_order_ode_neuron_nestml")
sgX = nest.Create("spike_generator", params={"spike_times": [10.]})
nest.Connect(sgX, n, syn_spec={"weight": 1., "delay": resolution})

mm = nest.Create("multimeter", params={"interval": resolution, "record_from": ["x", "y"]})
nest.Connect(mm, n)

nest.Simulate(simtime)
x = mm.get()["events"]["x"]
y = mm.get()["events"]["y"]
times = mm.get()["events"]["times"]
print(x[-1], y[-1])

if TEST_PLOTS:
fig, ax = plt.subplots(nrows=2)
ax1, ax2 = ax

ax2.plot(times, x, label="x")
ax1.plot(times, y, label="y", alpha=.7, linestyle=":")

for _ax in ax:
_ax.grid(which="major", axis="both")
_ax.grid(which="minor", axis="x", linestyle=":", alpha=.4)
_ax.set_xlim(0., simtime)
_ax.legend()

fig.savefig("/tmp/test_integrate_odes_higher_order.png", dpi=300)

# verify
np.testing.assert_allclose(x[-1], 0.10737970490959549)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know these values are correct (are they from just running the NESTML model once, or from an independent implementation of the ODEs)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These values are from repeated simulations of the NESTML-generated code, equivalent to providing integrate_odes(x',x) in the NESTML model. I have changed the test now to include calls to both integrate_odes and compare their output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the latest changes, integrate_odes(x, x') is not supported and hence I changed the test again to compare it against numerical values. Do you have any other ideas to check this?

np.testing.assert_allclose(y[-1], 0.6211608596446752)