diff --git a/doc/fig/performance_ratio_nonVec_vs_vec_compartmental.png b/doc/fig/performance_ratio_nonVec_vs_vec_compartmental.png new file mode 100644 index 000000000..7d53db935 Binary files /dev/null and b/doc/fig/performance_ratio_nonVec_vs_vec_compartmental.png differ diff --git a/doc/running/running_nest_compartmental.rst b/doc/running/running_nest_compartmental.rst index 712fe263e..ddddc8603 100644 --- a/doc/running/running_nest_compartmental.rst +++ b/doc/running/running_nest_compartmental.rst @@ -79,7 +79,7 @@ As an example for a HH-type channel: All of the currents within a compartment (marked by ``@mechanism::channel``) are added up within a compartment. -For a complete example, please see `cm_default.nestml `_ and its associated unit test, `compartmental_model_test.py `_. +For a complete example, please see `cm_default.nestml `_ and its associated unit test, `test__compartmental_model.py `_. Concentration description @@ -111,12 +111,14 @@ As an example a description of a calcium concentration model where we pretend th The only difference here is that the equation that is marked with the ``@mechanism::concentration`` descriptor is not an inline equation but an ODE. This is because in case of the ion-channel what we want to simulate is the current which relies on the evolution of some state variables like gating variables in case of the HH-models, and the compartment voltage. The concentration though can be more simply described by an evolving state directly. -For a complete example, please see `concmech.nestml `_ and its associated unit test, `compartmental_model_test.py `_. +For a complete example, please see `concmech.nestml `_ and its associated unit test, `test__concmech_model.py `_. Synapse description ------------------- -Here synapse models are based on convolutions over a buffer of incoming spikes. This means that the equation for the current-contribution must contain a convolve() call and a description of the kernel used for that convolution is needed. The descriptor for synapses is ``@mechanism::receptor``. +Here synapse models are based on convolutions over a buffer of incoming spikes. This means that the equation for the +current-contribution must contain a convolve() call and a description of the kernel used for that convolution is needed. +The descriptor for synapses is ``@mechanism::receptor``. .. code-block:: nestml @@ -133,13 +135,64 @@ Here synapse models are based on convolutions over a buffer of incoming spikes. input: <- spike -For a complete example, please see `concmech.nestml `_ and its associated unit test, `compartmental_model_test.py `_. +For a complete example, please see `concmech.nestml `_ and its associated unit test, `test__concmech_model.py `_. + +Continuous input description +---------------------------- + +The continuous inputs are defined by an inline with the descriptor @mechanism::continuous_input. This inline needs to +include one input of type continuous and may include any states, parameters and functions. + +.. code-block:: nestml + + model : + equations: + inline real = \ + > \ + @mechanism::continuous_input + + input: + real <- continuous + +For a complete example, please see `continuous_test.nestml `_ and its associated unit test, `test__continuous_input.py `_. Mechanism interdependence ------------------------- Above examples of explicit interdependence inbetween concentration and channel models where already described. Note that it is not necessary to describe the basic interaction inherent through the contribution to the overall current of the compartment. During a simulation step all currents of channels and synapses are added up and contribute to the change of the membrane potential (v_comp) in the next timestep. Thereby one must only express a dependence explicitly if the mechanism depends on the activity of a specific channel- or synapse-type amongst multiple in a given compartment or some concentration. +General compartment scripting +----------------------------- +Update block: +~~~~~~~~~~~~~ +Even though the intended focus of the compartmental feature is the usage of the above mechanisms whose influence on the compartment and overall neurons is implicit, it is still possible to use the update block. The update block does not control the overall neurons behaviour but its computation may support the behaviour of mechanisms within a compartment. All states occuring inside of the update block or within the called functions and inlines etc. become part of a generall computation block that is always executed once for each compartment. These states may be used by the mechanisms. The ODEs owned by the update block are still integrated automatically and integrate_odes() shall not be used in this context. + +OnReceive(self_spikes): +~~~~~~~~~~~~~~~~~~~~~~~ +We also introduce a new special case of the OnReceive block with the CM feature specific variable self_spikes which is just a boolean that is true if and only if the neuron has spiked in the last timestep. The code within this block is only executed if the neuron spikes (somatic). Otherwise, the same rules as for the update block apply. + +Application: +~~~~~~~~~~~~ +This feature has been implemented with the implementation of IAF behaviour or backpropagation in mind. For examples see these model files: +`cm_iaf_psc_exp_dend_neuron.nestml `_ + +Technical Notes +--------------- + +We have put an emphasis on delivering good performance for neurons with high spatial complexity. We utilize vectorization, therefore, you should compile NEST with the OpenMP flag enabled. This, of course, can only be utilized if your hardware supports SIMD instructions. In that case, you can expect a performance improvement of about 3/4th of the theoretical maximum. + +Let's say you have an AVX2 SIMD instruction set available, which can fit 4 doubles (4*64-bit) into its vector register. In this case you can expect about a 3x performance improvement as long as your neuron has enough compartments. We vectorize the simulation steps of all instances of the same mechanism you have defined in your NESTML model, meaning that you will get a better complexity/performance ratio the more instances of the same mechanism are used. + +Here is a small benchmark example that shows the performance ratio (y-axis) as the number of compartments per neuron (x-axis) increases. + +.. figure:: https://raw.githubusercontent.com/nest/nestml/master/doc/fig/performance_ratio_nonVec_vs_vec_compartmental.png + :width: 326px + :height: 203px + :align: left + :target: # + +Be aware that we are using the -ffast-math flag when compiling the model by default. This can potentially lead to precision problems and inconsistencies across different systems. If you encounter unexpected results or want to be on the safe side, you can disable this by removing the flag from the CMakeLists.txt, which is part of the generated code. Note, however, that this may inhibit the compiler's ability to vectorize parts of the code in some cases. See also -------- diff --git a/extras/convert_cm_default_to_template.py b/extras/convert_cm_default_to_template.py index 92873d628..821955c0d 100644 --- a/extras/convert_cm_default_to_template.py +++ b/extras/convert_cm_default_to_template.py @@ -37,7 +37,7 @@ def get_replacement_patterns(): # file names 'cm_default' : '{{neuronSpecificFileNamesCmSyns[\"main\"]}}', 'cm_tree' : '{{neuronSpecificFileNamesCmSyns[\"tree\"]}}', - 'cm_compartmentcurrents': '{{neuronSpecificFileNamesCmSyns[\"compartmentcurrents\"]}}', + 'cm_neuroncurrents': '{{neuronSpecificFileNamesCmSyns[\"neuroncurrents\"]}}', # class names 'CompTree' : 'CompTree{{cm_unique_suffix}}', 'Compartment' : 'Compartment{{cm_unique_suffix}}', diff --git a/pynestml/cocos/co_co_cm_channel_model.py b/pynestml/cocos/co_co_cm_channel_model.py index bc556d9b2..968c09e64 100644 --- a/pynestml/cocos/co_co_cm_channel_model.py +++ b/pynestml/cocos/co_co_cm_channel_model.py @@ -26,10 +26,10 @@ class CoCoCmChannelModel(CoCo): @classmethod - def check_co_co(cls, model: ASTModel): + def check_co_co(cls, model: ASTModel, global_info): """ Checks if this compartmental condition applies to the handed over neuron. If yes, it checks the presence of expected functions and declarations. :param model: a single neuron instance. """ - return ChannelProcessing.check_co_co(model) + return ChannelProcessing.check_co_co(model, global_info) diff --git a/pynestml/cocos/co_co_cm_concentration_model.py b/pynestml/cocos/co_co_cm_concentration_model.py index 88eeea042..fce1cf9dd 100644 --- a/pynestml/cocos/co_co_cm_concentration_model.py +++ b/pynestml/cocos/co_co_cm_concentration_model.py @@ -27,10 +27,10 @@ class CoCoCmConcentrationModel(CoCo): @classmethod - def check_co_co(cls, model: ASTModel): + def check_co_co(cls, model: ASTModel, global_info): """ Check if this compartmental condition applies to the handed over neuron. If yes, it checks the presence of expected functions and declarations. :param model: a single neuron instance. """ - return ConcentrationProcessing.check_co_co(model) + return ConcentrationProcessing.check_co_co(model, global_info) diff --git a/pynestml/cocos/co_co_cm_continuous_input_model.py b/pynestml/cocos/co_co_cm_continuous_input_model.py new file mode 100644 index 000000000..38cd47d6e --- /dev/null +++ b/pynestml/cocos/co_co_cm_continuous_input_model.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# +# co_co_cm_continuous_input_model.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from pynestml.cocos.co_co import CoCo +from pynestml.meta_model.ast_model import ASTModel +from pynestml.utils.continuous_input_processing import ContinuousInputProcessing + + +class CoCoCmContinuousInputModel(CoCo): + @classmethod + def check_co_co(cls, neuron: ASTModel, global_info): + """ + Checks if this compartmental condition applies to the handed over neuron. + If yes, it checks the presence of expected functions and declarations. + :param neuron: a single neuron instance. + :type neuron: ast_neuron + """ + return ContinuousInputProcessing.check_co_co(neuron, global_info) diff --git a/pynestml/cocos/co_co_cm_global.py b/pynestml/cocos/co_co_cm_global.py new file mode 100644 index 000000000..a48e15364 --- /dev/null +++ b/pynestml/cocos/co_co_cm_global.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# +# co_co_cm_global.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from pynestml.cocos.co_co import CoCo +from pynestml.meta_model.ast_model import ASTModel +from pynestml.utils.global_processing import GlobalProcessing + + +class CoCoCmGlobal(CoCo): + @classmethod + def check_co_co(cls, neuron: ASTModel): + """ + Checks if this compartmental condition applies to the handed over neuron. + If yes, it checks the presence of expected functions and declarations. + :param neuron: a single neuron instance. + :type neuron: ast_neuron + """ + return GlobalProcessing.check_co_co(neuron) diff --git a/pynestml/cocos/co_co_cm_receptor_model.py b/pynestml/cocos/co_co_cm_receptor_model.py new file mode 100644 index 000000000..d7a0afc56 --- /dev/null +++ b/pynestml/cocos/co_co_cm_receptor_model.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# +# co_co_cm_receptor_model.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from pynestml.cocos.co_co import CoCo +from pynestml.meta_model.ast_model import ASTModel +from pynestml.utils.receptor_processing import ReceptorProcessing + + +class CoCoCmReceptorModel(CoCo): + + @classmethod + def check_co_co(cls, model: ASTModel, global_info): + """ + Checks if this compartmental condition applies to the handed over neuron. + If yes, it checks the presence of expected functions and declarations. + :param model: a single neuron instance. + """ + return ReceptorProcessing.check_co_co(model, global_info) diff --git a/pynestml/cocos/co_co_cm_synapse_model.py b/pynestml/cocos/co_co_cm_synapse_model.py index 5359e15cf..ff6da19d6 100644 --- a/pynestml/cocos/co_co_cm_synapse_model.py +++ b/pynestml/cocos/co_co_cm_synapse_model.py @@ -31,6 +31,6 @@ def check_co_co(cls, model: ASTModel): """ Checks if this compartmental condition applies to the handed over neuron. If yes, it checks the presence of expected functions and declarations. - :param model: a single neuron instance. + :param model: a single synapse instance. """ - return SynapseProcessing.check_co_co(model) + return SynapseProcessing.check_co_co(model) \ No newline at end of file diff --git a/pynestml/cocos/co_co_v_comp_exists.py b/pynestml/cocos/co_co_v_comp_exists.py index 4ef08c0ec..efbf4e364 100644 --- a/pynestml/cocos/co_co_v_comp_exists.py +++ b/pynestml/cocos/co_co_v_comp_exists.py @@ -26,6 +26,8 @@ from pynestml.utils.messages import Messages from pynestml.utils.logger import Logger, LoggingLevel +import traceback + class CoCoVCompDefined(CoCo): """ @@ -48,10 +50,14 @@ def check_co_co(cls, neuron: ASTModel): If True, checks are not as rigorous. Use False where possible. """ from pynestml.codegeneration.nest_compartmental_code_generator import NESTCompartmentalCodeGenerator + from pynestml.codegeneration.code_generator_utils import CodeGeneratorUtils if not FrontendConfiguration.get_target_platform().upper() == 'NEST_COMPARTMENTAL': return + if not isinstance(neuron, ASTModel): + return + enforced_variable_name = NESTCompartmentalCodeGenerator._default_options["compartmental_variable_name"] state_blocks = neuron.get_state_blocks() @@ -62,6 +68,18 @@ def check_co_co(cls, neuron: ASTModel): if isinstance(state_blocks, ASTBlockWithVariables): state_blocks = [state_blocks] + stack = traceback.extract_stack() + formatted = ''.join(traceback.format_list(stack)) + print("Traceback: \n" + formatted) + print("Neuron name: " + neuron.name) + print("State names: ") + for state_block in state_blocks: + declarations = state_block.get_declarations() + for declaration in declarations: + variables = declaration.get_variables() + for variable in variables: + variable_name = variable.get_name().lower().strip() + print(variable_name) for state_block in state_blocks: declarations = state_block.get_declarations() for declaration in declarations: @@ -77,4 +95,5 @@ def check_co_co(cls, neuron: ASTModel): @classmethod def log_error(cls, neuron: ASTModel, error_position, missing_variable_name): code, message = Messages.get_v_comp_variable_value_missing(neuron.get_name(), missing_variable_name) - Logger.log_message(error_position=error_position, node=neuron, log_level=LoggingLevel.ERROR, code=code, message=message) + Logger.log_message(error_position=error_position, node=neuron, log_level=LoggingLevel.ERROR, code=code, + message=message) diff --git a/pynestml/cocos/co_cos_manager.py b/pynestml/cocos/co_cos_manager.py index 83900539b..18250d7d0 100644 --- a/pynestml/cocos/co_cos_manager.py +++ b/pynestml/cocos/co_cos_manager.py @@ -22,9 +22,12 @@ from typing import Union from pynestml.cocos.co_co_all_variables_defined import CoCoAllVariablesDefined +from pynestml.cocos.co_co_cm_global import CoCoCmGlobal +from pynestml.cocos.co_co_cm_synapse_model import CoCoCmSynapseModel from pynestml.cocos.co_co_inline_expression_not_assigned_to import CoCoInlineExpressionNotAssignedTo from pynestml.cocos.co_co_input_port_not_assigned_to import CoCoInputPortNotAssignedTo from pynestml.cocos.co_co_cm_channel_model import CoCoCmChannelModel +from pynestml.cocos.co_co_cm_continuous_input_model import CoCoCmContinuousInputModel from pynestml.cocos.co_co_convolve_cond_correctly_built import CoCoConvolveCondCorrectlyBuilt from pynestml.cocos.co_co_convolve_has_correct_parameter import CoCoConvolveHasCorrectParameter from pynestml.cocos.co_co_input_port_not_assigned_to import CoCoInputPortNotAssignedTo @@ -54,7 +57,7 @@ from pynestml.cocos.co_co_simple_delta_function import CoCoSimpleDeltaFunction from pynestml.cocos.co_co_state_variables_initialized import CoCoStateVariablesInitialized from pynestml.cocos.co_co_convolve_has_correct_parameter import CoCoConvolveHasCorrectParameter -from pynestml.cocos.co_co_cm_synapse_model import CoCoCmSynapseModel +from pynestml.cocos.co_co_cm_receptor_model import CoCoCmReceptorModel from pynestml.cocos.co_co_cm_concentration_model import CoCoCmConcentrationModel from pynestml.cocos.co_co_input_port_qualifier_unique import CoCoInputPortQualifierUnique from pynestml.cocos.co_co_user_defined_function_correctly_defined import CoCoUserDefinedFunctionCorrectlyDefined @@ -68,6 +71,7 @@ from pynestml.cocos.co_co_priorities_correctly_specified import CoCoPrioritiesCorrectlySpecified from pynestml.meta_model.ast_model import ASTModel from pynestml.frontend.frontend_configuration import FrontendConfiguration +from pynestml.utils.global_processing import GlobalProcessing class CoCosManager: @@ -138,16 +142,23 @@ def check_v_comp_requirement(cls, neuron: ASTModel): CoCoVCompDefined.check_co_co(neuron) @classmethod - def check_compartmental_model(cls, neuron: ASTModel) -> None: + def check_compartmental_neuron_model(cls, neuron: ASTModel) -> None: """ collects all relevant information for the different compartmental mechanism classes for later code-generation searches for inlines or odes with decorator @mechanism:: and performs a base and, depending on type, specific information collection process. See nestml documentation on compartmental code generation. """ - CoCoCmChannelModel.check_co_co(neuron) - CoCoCmConcentrationModel.check_co_co(neuron) - CoCoCmSynapseModel.check_co_co(neuron) + CoCoCmGlobal.check_co_co(neuron) + global_info = GlobalProcessing.get_global_info(neuron) + CoCoCmChannelModel.check_co_co(neuron, global_info) + CoCoCmConcentrationModel.check_co_co(neuron, global_info) + CoCoCmReceptorModel.check_co_co(neuron, global_info) + CoCoCmContinuousInputModel.check_co_co(neuron, global_info) + + @classmethod + def check_compartmental_synapse_model(cls, synapse: ASTModel) -> None: + CoCoCmSynapseModel.check_co_co(synapse) @classmethod def check_inline_expressions_have_rhs(cls, model: ASTModel): @@ -400,7 +411,7 @@ def check_input_port_size_type(cls, model: ASTModel): CoCoVectorInputPortsCorrectSizeType.check_co_co(model) @classmethod - def post_symbol_table_builder_checks(cls, model: ASTModel, after_ast_rewrite: bool = False): + def post_symbol_table_builder_checks(cls, model: ASTModel, after_ast_rewrite: bool = False, syn_model: bool = False): """ Checks all context conditions. :param model: a single model object. @@ -413,8 +424,11 @@ def post_symbol_table_builder_checks(cls, model: ASTModel, after_ast_rewrite: bo cls.check_variables_defined_before_usage(model, after_ast_rewrite) if FrontendConfiguration.get_target_platform().upper() == 'NEST_COMPARTMENTAL': # XXX: TODO: refactor this out; define a ``cocos_from_target_name()`` in the frontend instead. - cls.check_v_comp_requirement(model) - cls.check_compartmental_model(model) + # cls.check_v_comp_requirement(model) + if syn_model: + cls.check_compartmental_synapse_model(model) + else: + cls.check_compartmental_neuron_model(model) cls.check_inline_expressions_have_rhs(model) cls.check_inline_has_max_one_lhs(model) cls.check_input_ports_not_assigned_to(model) diff --git a/pynestml/codegeneration/builder.py b/pynestml/codegeneration/builder.py index a5e3c566c..864e9a6bb 100644 --- a/pynestml/codegeneration/builder.py +++ b/pynestml/codegeneration/builder.py @@ -91,4 +91,5 @@ def set_options(self, options: Mapping[str, Any]) -> Mapping[str, Any]: ret = super().set_options(options) ret.pop("redirect_build_output", None) ret.pop("build_output_dir", None) + ret.pop("fastexp", None) return ret diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py index 0471d0041..ea06b698f 100644 --- a/pynestml/codegeneration/nest_code_generator.py +++ b/pynestml/codegeneration/nest_code_generator.py @@ -612,7 +612,6 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: :return: a map from name to functionality. """ namespace = self._get_model_namespace(neuron) - if "paired_synapse" in dir(neuron): namespace["extra_on_emit_spike_stmts_from_synapse"] = neuron.extra_on_emit_spike_stmts_from_synapse namespace["paired_synapse"] = neuron.paired_synapse.get_name() diff --git a/pynestml/codegeneration/nest_compartmental_code_generator.py b/pynestml/codegeneration/nest_compartmental_code_generator.py index 2e0fc37b6..a3ecef4af 100644 --- a/pynestml/codegeneration/nest_compartmental_code_generator.py +++ b/pynestml/codegeneration/nest_compartmental_code_generator.py @@ -18,8 +18,9 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +import shutil -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import datetime import os @@ -27,6 +28,8 @@ from jinja2 import TemplateRuntimeError import pynestml from pynestml.codegeneration.code_generator import CodeGenerator +from pynestml.codegeneration.code_generator_utils import CodeGeneratorUtils +from pynestml.codegeneration.nest_code_generator import NESTCodeGenerator from pynestml.codegeneration.nest_assignments_helper import NestAssignmentsHelper from pynestml.codegeneration.nest_declarations_helper import NestDeclarationsHelper from pynestml.codegeneration.printers.constant_printer import ConstantPrinter @@ -53,10 +56,16 @@ from pynestml.meta_model.ast_variable import ASTVariable from pynestml.symbol_table.symbol_table import SymbolTable from pynestml.symbols.symbol import SymbolKind +from pynestml.utils.ast_vector_parameter_setter_and_printer import ASTVectorParameterSetterAndPrinter +from pynestml.utils.global_info_enricher import GlobalInfoEnricher +from pynestml.utils.global_processing import GlobalProcessing +from pynestml.utils.ast_vector_parameter_setter_and_printer_factory import ASTVectorParameterSetterAndPrinterFactory from pynestml.transformers.inline_expression_expansion_transformer import InlineExpressionExpansionTransformer from pynestml.utils.mechanism_processing import MechanismProcessing from pynestml.utils.channel_processing import ChannelProcessing from pynestml.utils.concentration_processing import ConcentrationProcessing +from pynestml.utils.continuous_input_processing import ContinuousInputProcessing +from pynestml.utils.con_in_info_enricher import ConInInfoEnricher from pynestml.utils.conc_info_enricher import ConcInfoEnricher from pynestml.utils.ast_utils import ASTUtils from pynestml.utils.chan_info_enricher import ChanInfoEnricher @@ -64,12 +73,18 @@ from pynestml.utils.logger import LoggingLevel from pynestml.utils.messages import Messages from pynestml.utils.model_parser import ModelParser -from pynestml.utils.syns_info_enricher import SynsInfoEnricher +from pynestml.utils.string_utils import removesuffix from pynestml.utils.synapse_processing import SynapseProcessing +from pynestml.utils.syns_info_enricher import SynsInfoEnricher +from pynestml.utils.recs_info_enricher import RecsInfoEnricher +from pynestml.utils.receptor_processing import ReceptorProcessing from pynestml.visitors.ast_random_number_generator_visitor import ASTRandomNumberGeneratorVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor from odetoolbox import analysis +#DEBUGGING +from pynestml.cocos.co_co_v_comp_exists import CoCoVCompDefined + class NESTCompartmentalCodeGenerator(CodeGenerator): r""" @@ -88,6 +103,9 @@ class NESTCompartmentalCodeGenerator(CodeGenerator): """ _default_options = { + "neuron_synapse_pairs": [], + "neuron_models": [], + "synapse_models": [], "neuron_parent_class": "ArchivingNode", "neuron_parent_class_include": "archiving_node.h", "preserve_expressions": True, @@ -96,15 +114,18 @@ class NESTCompartmentalCodeGenerator(CodeGenerator): "path": "resources_nest_compartmental/cm_neuron", "model_templates": { "neuron": [ - "cm_compartmentcurrents_@NEURON_NAME@.cpp.jinja2", - "cm_compartmentcurrents_@NEURON_NAME@.h.jinja2", + "cm_neuroncurrents_@NEURON_NAME@.cpp.jinja2", + "cm_neuroncurrents_@NEURON_NAME@.h.jinja2", "@NEURON_NAME@.cpp.jinja2", "@NEURON_NAME@.h.jinja2", "cm_tree_@NEURON_NAME@.cpp.jinja2", "cm_tree_@NEURON_NAME@.h.jinja2"]}, "module_templates": ["setup"]}, "nest_version": "", - "compartmental_variable_name": "v_comp"} + "compartmental_variable_name": "v_comp", + "delay_variable": {}, + "weight_variable": {} + } _variable_matching_template = r"(\b)({})(\b)" _model_templates = dict() @@ -113,6 +134,8 @@ class NESTCompartmentalCodeGenerator(CodeGenerator): def __init__(self, options: Optional[Mapping[str, Any]] = None): super().__init__("NEST_COMPARTMENTAL", options) + self._nest_code_generator = NESTCodeGenerator(options) + # auto-detect NEST Simulator installed version if not self.option_exists("nest_version") or not self.get_option("nest_version"): from pynestml.codegeneration.nest_tools import NESTTools @@ -151,7 +174,7 @@ def setup_printers(self): self._nest_printer = CppPrinter(expression_printer=self._printer) self._nest_variable_printer_no_origin = NESTVariablePrinter(None, with_origin=False, - with_vector_parameter=False, + with_vector_parameter=True, enforce_getter=False) self._printer_no_origin = CppExpressionPrinter( simple_expression_printer=CppSimpleExpressionPrinter(variable_printer=self._nest_variable_printer_no_origin, @@ -185,15 +208,20 @@ def raise_helper(self, msg): raise TemplateRuntimeError(msg) def set_options(self, options: Mapping[str, Any]) -> Mapping[str, Any]: + self._nest_code_generator.set_options(options) ret = super().set_options(options) self.setup_template_env() return ret def generate_code(self, models: List[ASTModel]) -> None: - self.analyse_transform_neurons(models) - self.generate_neurons(models) - self.generate_module_code(models) + neurons, synapses = CodeGeneratorUtils.get_model_types_from_names(models, neuron_models=self.get_option( + "neuron_models"), synapse_models=self.get_option("synapse_models")) + synapses_per_neuron = self.arrange_synapses_per_neuron(neurons, synapses) + self.analyse_transform_neurons(neurons) + self.analyse_transform_synapses(synapses) + self.generate_compartmental_neurons(neurons, synapses_per_neuron) + self.generate_module_code(neurons) def generate_module_code(self, neurons: List[ASTModel]) -> None: """t @@ -245,7 +273,7 @@ def _get_module_namespace(self, neurons: List[ASTModel]) -> Dict: neuron_name_to_filename = dict() for neuron in neurons: neuron_name_to_filename[neuron.get_name()] = { - "compartmentcurrents": self.get_cm_syns_compartmentcurrents_file_prefix(neuron), + "neuroncurrents": self.get_cm_syns_neuroncurrents_file_prefix(neuron), "main": self.get_cm_syns_main_file_prefix(neuron), "tree": self.get_cm_syns_tree_file_prefix(neuron) } @@ -261,12 +289,18 @@ def _get_module_namespace(self, neurons: List[ASTModel]) -> Dict: def get_cm_syns_compartmentcurrents_file_prefix(self, neuron): return "cm_compartmentcurrents_" + neuron.get_name() + def get_cm_syns_neuroncurrents_file_prefix(self, neuron): + return "cm_neuroncurrents_" + neuron.get_name() + def get_cm_syns_main_file_prefix(self, neuron): return neuron.get_name() def get_cm_syns_tree_file_prefix(self, neuron): return "cm_tree_" + neuron.get_name() + def get_stdp_synapse_main_file_prefix(self, synapse): + return synapse.get_name() + def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None: """ Analyse and transform a list of neurons. @@ -279,6 +313,97 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None: spike_updates = self.analyse_neuron(neuron) neuron.spike_updates = spike_updates + equations_block = neuron.get_equations_blocks()[0] + kernel_buffers = ASTUtils.generate_kernel_buffers(neuron, equations_block) + + analytic_solver, numeric_solver = self._nest_code_generator.ode_toolbox_analysis(neuron, kernel_buffers) + + delta_factors = ASTUtils.get_delta_factors_(neuron, equations_block) + + spike_updates, post_spike_updates = self._nest_code_generator.get_spike_update_expressions(neuron, kernel_buffers, + [analytic_solver, numeric_solver], + delta_factors) + + neuron.spike_updates = spike_updates + neuron.post_spike_updates = post_spike_updates + + def analyse_transform_synapses(self, synapses: List[ASTModel]) -> None: + """ + Analyse and transform a list of synapses. + :param synapses: a list of synapses. + """ + for synapse in synapses: + Logger.log_message(None, None, "Analysing/transforming synapse {}.".format(synapse.get_name()), None, LoggingLevel.INFO) + SynapseProcessing.process(synapse, self.get_option("neuron_synapse_pairs")) + self.analyse_synapse(synapse) + + def analyse_synapse(self, synapse: ASTModel):# -> Dict[str, ASTAssignment]: + """ + Analyse and transform a single synapse. + :param synapse: a single synapse. + """ + """ + equations_block = synapse.get_equations_blocks()[0] + ASTUtils.replace_convolve_calls_with_buffers_(synapse, equations_block) + ASTUtils.add_timestep_symbol(synapse) + self.update_symbol_table(synapse) + """ + + + code, message = Messages.get_start_processing_model(synapse.get_name()) + Logger.log_message(synapse, code, message, synapse.get_source_position(), LoggingLevel.INFO) + + spike_updates = {} + if synapse.get_equations_blocks(): + if len(synapse.get_equations_blocks()) > 1: + raise Exception("Only one equations block per model supported for now") + + equations_block = synapse.get_equations_blocks()[0] + + kernel_buffers = ASTUtils.generate_kernel_buffers(synapse, equations_block) + + # substitute inline expressions with each other + # such that no inline expression references another inline expression; + # deference inline_expressions inside ode_equations + InlineExpressionExpansionTransformer().transform(synapse) + + delta_factors = ASTUtils.get_delta_factors_(synapse, equations_block) + ASTUtils.replace_convolve_calls_with_buffers_(synapse, equations_block) + + analytic_solver, numeric_solver = self.ode_toolbox_analysis(synapse, kernel_buffers) + self.analytic_solver[synapse.get_name()] = analytic_solver + self.numeric_solver[synapse.get_name()] = numeric_solver + + ASTUtils.remove_initial_values_for_kernels(synapse) + kernels = ASTUtils.remove_kernel_definitions_from_equations_block(synapse) + ASTUtils.update_initial_values_for_odes(synapse, [analytic_solver, numeric_solver]) + ASTUtils.remove_ode_definitions_from_equations_block(synapse) + ASTUtils.create_initial_values_for_kernels(synapse, [analytic_solver, numeric_solver], kernels) + ASTUtils.create_integrate_odes_combinations(synapse) + ASTUtils.replace_variable_names_in_expressions(synapse, [analytic_solver, numeric_solver]) + ASTUtils.add_timestep_symbol(synapse) + self.update_symbol_table(synapse) + spike_updates, _ = self.get_spike_update_expressions(synapse, kernel_buffers, [analytic_solver, numeric_solver], delta_factors) + + if not self.analytic_solver[synapse.get_name()] is None: + synapse = ASTUtils.add_declarations_to_internals( + synapse, self.analytic_solver[synapse.get_name()]["propagators"]) + + self.update_symbol_table(synapse) + else: + ASTUtils.add_timestep_symbol(synapse) + self.update_symbol_table(synapse) + + synapse_name_stripped = removesuffix(removesuffix(synapse.name.split("_with_")[0], "_"), FrontendConfiguration.suffix) + # special case for NEST delay variable (state or parameter) + + ASTUtils.update_blocktype_for_common_parameters(synapse) + #assert synapse_name_stripped in self.get_option("delay_variable").keys(), "Please specify a delay variable for synapse '" + synapse_name_stripped + "' in the code generator options" + #assert ASTUtils.get_variable_by_name(synapse, self.get_option("delay_variable")[synapse_name_stripped]), "Delay variable '" + self.get_option("delay_variable")[synapse_name_stripped] + "' not found in synapse '" + synapse_name_stripped + "'" + + return spike_updates + + def create_ode_indict(self, neuron: ASTModel, parameters_block: ASTBlockWithVariables, @@ -513,7 +638,7 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]: neuron, self.analytic_solver[neuron.get_name()]["propagators"]) # generate how to calculate the next spike update - self.update_symbol_table(neuron, kernel_buffers) + self.update_symbol_table(neuron) # find any spike update expressions defined by the user spike_updates = self.get_spike_update_expressions( neuron, kernel_buffers, [analytic_solver, numeric_solver], delta_factors) @@ -525,7 +650,7 @@ def compute_name_of_generated_file(self, jinja_file_name, neuron): jinja_file_name).split(".")[0] file_name_calculators = { - "CompartmentCurrents": self.get_cm_syns_compartmentcurrents_file_prefix, + "NeuronCurrents": self.get_cm_syns_neuroncurrents_file_prefix, "Tree": self.get_cm_syns_tree_file_prefix, "Main": self.get_cm_syns_main_file_prefix, } @@ -557,7 +682,7 @@ def getUniqueSuffix(self, neuron: ASTModel) -> str: underscore_pos = ret.find("_") return ret - def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: + def _get_neuron_model_namespace(self, neuron: ASTModel, paired_synapse: ASTModel = None) -> Dict: """ Returns a standard namespace for generating neuron code for NEST :param neuron: a single neuron instance @@ -589,6 +714,20 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: namespace["nestml_printer"] = NESTMLPrinter() namespace["type_symbol_printer"] = self._type_symbol_printer + class VectorPrinter(): + def __init__(self, neuron, printer): + self.printer = ASTVectorParameterSetterAndPrinterFactory(neuron, printer) + self.std_vector_parameter = None + + def print(self, expression, index = "i"): + self.std_vector_parameter = index + index_printer = self.printer.create_ast_vector_parameter_setter_and_printer(index) + return index_printer.print(expression) + + vector_printer = VectorPrinter(neuron, self._printer_no_origin) + + namespace["vector_printer"] = vector_printer + # NESTML syntax keywords namespace["PyNestMLLexer"] = {} from pynestml.generated.PyNestMLLexer import PyNestMLLexer @@ -600,6 +739,7 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: namespace["neuronName"] = neuron.get_name() namespace["neuron"] = neuron + #namespace["synapse"] = synapse namespace["moduleName"] = FrontendConfiguration.get_module_name() namespace["has_spike_input"] = ASTUtils.has_spike_input( neuron.get_body()) @@ -612,6 +752,8 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: "neuron_parent_class_include") namespace["PredefinedUnits"] = pynestml.symbols.predefined_units.PredefinedUnits + namespace["PredefinedFunctions"] = pynestml.symbols.predefined_functions.PredefinedFunctions + namespace["UnitTypeSymbol"] = pynestml.symbols.unit_type_symbol.UnitTypeSymbol namespace["SymbolKind"] = pynestml.symbols.symbol.SymbolKind @@ -694,20 +836,37 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: namespace["chan_info"] = ChannelProcessing.get_mechs_info(neuron) namespace["chan_info"] = ChanInfoEnricher.enrich_with_additional_info(neuron, namespace["chan_info"]) - namespace["syns_info"] = SynapseProcessing.get_mechs_info(neuron) - namespace["syns_info"] = SynsInfoEnricher.enrich_with_additional_info(neuron, namespace["syns_info"]) + namespace["recs_info"] = ReceptorProcessing.get_mechs_info(neuron) + namespace["recs_info"] = RecsInfoEnricher.enrich_with_additional_info(neuron, namespace["recs_info"]) namespace["conc_info"] = ConcentrationProcessing.get_mechs_info(neuron) namespace["conc_info"] = ConcInfoEnricher.enrich_with_additional_info(neuron, namespace["conc_info"]) + namespace["con_in_info"] = ContinuousInputProcessing.get_mechs_info(neuron) + namespace["con_in_info"] = ConInInfoEnricher.enrich_with_additional_info(neuron, namespace["con_in_info"]) + + if paired_synapse: + namespace["syns_info"] = SynapseProcessing.get_syn_info(paired_synapse) + namespace["syns_info"] = SynsInfoEnricher.enrich_with_additional_info(paired_synapse, namespace["syns_info"], namespace["chan_info"], namespace["recs_info"], namespace["conc_info"], namespace["con_in_info"]) + else: + namespace["syns_info"] = dict() + + namespace["global_info"] = GlobalProcessing.get_global_info(neuron) + namespace["global_info"] = GlobalInfoEnricher.enrich_with_additional_info(neuron, namespace["global_info"]) + chan_info_string = MechanismProcessing.print_dictionary(namespace["chan_info"], 0) - syns_info_string = MechanismProcessing.print_dictionary(namespace["syns_info"], 0) + recs_info_string = MechanismProcessing.print_dictionary(namespace["recs_info"], 0) conc_info_string = MechanismProcessing.print_dictionary(namespace["conc_info"], 0) - code, message = Messages.get_mechs_dictionary_info(chan_info_string, syns_info_string, conc_info_string) + con_in_info_string = MechanismProcessing.print_dictionary(namespace["con_in_info"], 0) + if paired_synapse: + syns_info_string = MechanismProcessing.print_dictionary(namespace["syns_info"], 0) + else: + syns_info_string = "" + global_info_string = MechanismProcessing.print_dictionary(namespace["global_info"], 0) + code, message = Messages.get_mechs_dictionary_info(chan_info_string, recs_info_string, conc_info_string, con_in_info_string, syns_info_string, global_info_string) Logger.log_message(None, code, message, None, LoggingLevel.DEBUG) - neuron_specific_filenames = { - "compartmentcurrents": self.get_cm_syns_compartmentcurrents_file_prefix(neuron), + "neuroncurrents": self.get_cm_syns_neuroncurrents_file_prefix(neuron), "main": self.get_cm_syns_main_file_prefix(neuron), "tree": self.get_cm_syns_tree_file_prefix(neuron)} @@ -719,9 +878,12 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: namespace["types_printer"] = self._type_symbol_printer + # python utils + namespace["set"] = set + return namespace - def update_symbol_table(self, neuron, kernel_buffers): + def update_symbol_table(self, neuron): """ Update symbol table and scope. """ @@ -932,3 +1094,47 @@ def transform_ode_and_kernels_to_json( )] = self._ode_toolbox_printer.print(decl.get_expression()) return odetoolbox_indict + + def generate_compartmental_neuron_code(self, neuron: ASTModel, paired_synapse = None) -> None: + self.generate_model_code(neuron.get_name(), + model_templates=self._model_templates["neuron"], + template_namespace=self._get_neuron_model_namespace(neuron, paired_synapse), + model_name_escape_string="@NEURON_NAME@") + + def generate_compartmental_neurons(self, neurons: Sequence[ASTModel], paired_synapses: dict) -> None: + """ + Generate code for the given neurons. + + :param neurons: a list of neurons. + """ + from pynestml.frontend.frontend_configuration import FrontendConfiguration + neuron_index = 0 + for neuron in neurons: + paired_syn_exists = False + for synapse in paired_synapses[neuron.get_name()]: + paired_syn_exists = True + self.generate_compartmental_neuron_code(neuron, synapse) + if not Logger.has_errors(neuron): + code, message = Messages.get_code_generated(neuron.get_name(), FrontendConfiguration.get_target_path()) + Logger.log_message(neuron, code, message, neuron.get_source_position(), LoggingLevel.INFO) + if not paired_syn_exists: + self.generate_compartmental_neuron_code(neuron) + if not Logger.has_errors(neuron): + code, message = Messages.get_code_generated(neuron.get_name(), + FrontendConfiguration.get_target_path()) + Logger.log_message(neuron, code, message, neuron.get_source_position(), LoggingLevel.INFO) + neuron_index += 1 + + def arrange_synapses_per_neuron(self, neurons: Sequence[ASTModel], synapses: Sequence[ASTModel]): + paired_synapses = dict() + for neuron in neurons: + paired_synapses[neuron.get_name()] = list() + + neuron_synapse_pairs = self.get_option("neuron_synapse_pairs") + for pair in neuron_synapse_pairs: + for synapse in synapses: + if synapse.get_name() == (pair["synapse"]+"_nestml"): + paired_synapses[pair["neuron"]+"_nestml"].append(synapse) + + return paired_synapses + diff --git a/pynestml/codegeneration/printers/cpp_function_call_printer.py b/pynestml/codegeneration/printers/cpp_function_call_printer.py index 11beba1bd..1edc411b8 100644 --- a/pynestml/codegeneration/printers/cpp_function_call_printer.py +++ b/pynestml/codegeneration/printers/cpp_function_call_printer.py @@ -32,6 +32,7 @@ from pynestml.utils.ast_utils import ASTUtils from pynestml.meta_model.ast_node import ASTNode from pynestml.meta_model.ast_variable import ASTVariable +from pynestml.frontend.frontend_configuration import FrontendConfiguration class CppFunctionCallPrinter(FunctionCallPrinter): @@ -83,6 +84,9 @@ def _print_function_call_format_string(self, function_call: ASTFunctionCall) -> """ function_name = function_call.get_name() + if function_name == PredefinedFunctions.HEAVISIDE: + return '({!s} > 0)' + if function_name == PredefinedFunctions.CLIP: # the arguments of this function must be swapped and are therefore [v_max, v_min, v] return 'std::min({2!s}, std::max({1!s}, {0!s}))' @@ -93,6 +97,9 @@ def _print_function_call_format_string(self, function_call: ASTFunctionCall) -> if function_name == PredefinedFunctions.MIN: return 'std::min({!s}, {!s})' + if function_name == 'Min': + return 'std::min({!s}, {!s})' + if function_name == PredefinedFunctions.ABS: return 'std::abs({!s})' diff --git a/pynestml/codegeneration/printers/nest_variable_printer.py b/pynestml/codegeneration/printers/nest_variable_printer.py index 48bbef59e..df2b6438b 100644 --- a/pynestml/codegeneration/printers/nest_variable_printer.py +++ b/pynestml/codegeneration/printers/nest_variable_printer.py @@ -20,6 +20,7 @@ # along with NEST. If not, see . from __future__ import annotations +from typing import Dict, Optional from pynestml.utils.ast_utils import ASTUtils @@ -49,6 +50,7 @@ def __init__(self, expression_printer: ExpressionPrinter, with_origin: bool = Tr self.with_vector_parameter = with_vector_parameter self.enforce_getter = enforce_getter self.variables_special_cases = variables_special_cases + self.cpp_variable_suffix = "" def print_variable(self, variable: ASTVariable) -> str: """ @@ -104,7 +106,9 @@ def print_variable(self, variable: ASTVariable) -> str: s = "" if not units_conversion_factor == 1: s += "(" + str(units_conversion_factor) + " * " - s += "B_." + self._print_buffer_value(variable) + if self.cpp_variable_suffix == "": + s += "B_." + s += self._print_buffer_value(variable) if not units_conversion_factor == 1: s += ")" return s @@ -112,17 +116,17 @@ def print_variable(self, variable: ASTVariable) -> str: if symbol.is_inline_expression: # there might not be a corresponding defined state variable; insist on calling the getter function if self.enforce_getter: - return "get_" + self._print(variable, symbol, with_origin=False) + vector_param + "()" + return "get_" + self._print(variable, symbol, with_origin=False) + vector_param + "()" + self.cpp_variable_suffix # modification to not enforce getter function: else: - return self._print(variable, symbol, with_origin=False) + return self._print(variable, symbol, with_origin=False) + self.cpp_variable_suffix assert not symbol.is_kernel(), "Cannot print kernel; kernel should have been converted during code generation" if symbol.is_state() or symbol.is_inline_expression: - return self._print(variable, symbol, with_origin=self.with_origin) + vector_param + return self._print(variable, symbol, with_origin=self.with_origin) + vector_param + self.cpp_variable_suffix - return self._print(variable, symbol, with_origin=self.with_origin) + vector_param + return self._print(variable, symbol, with_origin=self.with_origin) + vector_param + self.cpp_variable_suffix def _print_delay_variable(self, variable: ASTVariable) -> str: """ @@ -153,6 +157,9 @@ def _print_buffer_value(self, variable: ASTVariable) -> str: var_name += "_" + str(variable.get_vector_parameter()) return "spike_inputs_grid_sum_[" + var_name + " - MIN_SPIKE_RECEPTOR]" + if self.cpp_variable_suffix: + return variable_symbol.get_symbol_name() + self.cpp_variable_suffix + return variable_symbol.get_symbol_name() + '_grid_sum_' def _print(self, variable: ASTVariable, symbol, with_origin: bool = True) -> str: diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/FunctionDeclaration.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/FunctionDeclaration.jinja2 index 13c8c5ab1..cf9cf2d4d 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/FunctionDeclaration.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/FunctionDeclaration.jinja2 @@ -1,4 +1,4 @@ -{%- macro FunctionDeclaration(ast_function, namespace_prefix) -%} +{%- macro FunctionDeclaration(ast_function, namespace_prefix, pass_by_reference = false) -%} {%- with function_symbol = ast_function.get_scope().resolve_to_symbol(ast_function.get_name(), SymbolKind.FUNCTION) -%} {%- if function_symbol is none -%} {{ raise('Cannot resolve the method ' + ast_function.get_name()) }} @@ -8,7 +8,7 @@ {%- for param in ast_function.get_parameters() %} {%- with typeSym = param.get_data_type().get_type_symbol() -%} {%- filter indent(1, True) -%} -{{ type_symbol_printer.print(typeSym) }} {{ param.get_name() }} +{{ type_symbol_printer.print(typeSym) }}{% if pass_by_reference %}&{% endif %} {{ param.get_name() }} {%- if not loop.last -%} , {%- endif -%} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.cpp.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.cpp.jinja2 index 83e82f2d1..04cbca661 100644 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.cpp.jinja2 +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.cpp.jinja2 @@ -76,7 +76,7 @@ nest::{{neuronSpecificFileNamesCmSyns["main"]}}::{{neuronSpecificFileNamesCmSyns * ---------------------------------------------------------------- */ void -{{neuronSpecificFileNamesCmSyns["main"]}}::get_status( DictionaryDatum& statusdict ) const +{{neuronSpecificFileNamesCmSyns["main"]}}::get_status( DictionaryDatum& statusdict ) { def< double >( statusdict, names::V_th, V_th_ ); ArchivingNode::get_status( statusdict ); @@ -99,7 +99,7 @@ void compartment_ad.push_back( dd ); // add receptor info - compartment->compartment_currents.add_receptor_info( receptor_ad, compartment->comp_index ); + c_tree_.neuron_currents.add_receptor_info( receptor_ad, compartment->comp_index ); } // add compartment info and receptor info to the status dictionary def< ArrayDatum >( statusdict, names::compartments, compartment_ad ); @@ -233,15 +233,14 @@ nest::{{neuronSpecificFileNamesCmSyns["main"]}}::add_receptor_( DictionaryDatum& syn_buffers_.push_back( buffer ); // add the receptor to the compartment - Compartment{{cm_unique_suffix}}* compartment = c_tree_.get_compartment( compartment_idx ); if ( dd->known( names::params ) ) { - compartment->compartment_currents.add_synapse( - receptor_type, syn_idx, getValue< DictionaryDatum >( dd, names::params ) ); + c_tree_.neuron_currents.add_mechanism( + receptor_type, compartment_idx, getValue< DictionaryDatum >( dd, names::params ), syn_idx ); } else { - compartment->compartment_currents.add_synapse( receptor_type, syn_idx ); + c_tree_.neuron_currents.add_mechanism( receptor_type, compartment_idx, syn_idx ); } } @@ -312,18 +311,20 @@ nest::{{neuronSpecificFileNamesCmSyns["main"]}}::update( Time const& origin, con for ( long lag = from; lag < to; ++lag ) { - const double v_0_prev = c_tree_.get_root()->v_comp; + const double v_0_prev = *(c_tree_.get_root()->v_comp); c_tree_.construct_matrix( lag ); c_tree_.solve_matrix(); // threshold crossing - if ( c_tree_.get_root()->v_comp >= V_th_ && v_0_prev < V_th_ ) + if ( *(c_tree_.get_root()->v_comp) >= V_th_ && v_0_prev < V_th_ ) { set_spiketime( Time::step( origin.get_steps() + lag + 1 ) ); SpikeEvent se; kernel().event_delivery_manager.send( *this, se, lag ); + + c_tree_.neuron_currents.postsynaptic_synaptic_processing(); } logger_.record_data( origin.get_steps() + lag ); @@ -353,8 +354,11 @@ nest::{{neuronSpecificFileNamesCmSyns["main"]}}::handle( CurrentEvent& e ) const double c = e.get_current(); const double w = e.get_weight(); - Compartment{{cm_unique_suffix}}* compartment = c_tree_.get_compartment_opt( e.get_rport() ); - compartment->currents.add_value( e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), w * c ); + assert( e.get_delay_steps() > 0 ); + assert( ( e.get_rport() >= 0 ) && ( ( size_t ) e.get_rport() < syn_buffers_.size() ) ); + + syn_buffers_[ e.get_rport() ].add_value( + e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ), c*w ); } void @@ -363,4 +367,341 @@ nest::{{neuronSpecificFileNamesCmSyns["main"]}}::handle( DataLoggingRequest& e ) logger_.handle( e ); } +{%- if paired_synapse is defined %} +// ------------------------------------------------------------------------- +// Methods for neuron/synapse co-generation +// ------------------------------------------------------------------------- + +inline double +{{neuronName}}::get_spiketime_ms() const +{ + return last_spike_; +} + +void +{{neuronName}}::register_stdp_connection( double t_first_read, double delay ) +{ + // Mark all entries in the deque, which we will not read in future as read by + // this input input, so that we safely increment the incoming number of + // connections afterwards without leaving spikes in the history. + // For details see bug #218. MH 08-04-22 + + for ( std::deque< histentry__{{neuronName}} >::iterator runner = history_.begin(); + runner != history_.end() and ( t_first_read - runner->t_ > -1.0 * nest::kernel().connection_manager.get_stdp_eps() ); + ++runner ) + { + ( runner->access_counter_ )++; + } + + n_incoming_++; + + max_delay_ = std::max( delay, max_delay_ ); +} + + +void +{{neuronName}}::get_history__( double t1, + double t2, + std::deque< histentry__{{neuronName}} >::iterator* start, + std::deque< histentry__{{neuronName}} >::iterator* finish ) +{ + *finish = history_.end(); + if ( history_.empty() ) + { + *start = *finish; + return; + } + std::deque< histentry__{{neuronName}} >::reverse_iterator runner = history_.rbegin(); + const double t2_lim = t2 + nest::kernel().connection_manager.get_stdp_eps(); + const double t1_lim = t1 + nest::kernel().connection_manager.get_stdp_eps(); + while ( runner != history_.rend() and runner->t_ >= t2_lim ) + { + ++runner; + } + *finish = runner.base(); + while ( runner != history_.rend() and runner->t_ >= t1_lim ) + { + runner->access_counter_++; + ++runner; + } + *start = runner.base(); +} + +void +{{neuronName}}::set_spiketime( nest::Time const& t_sp, double offset ) +{ + {{neuron_parent_class}}::set_spiketime( t_sp, offset ); + + unsigned int num_transferred_variables = 0; +{%- for var in transferred_variables %} + ++num_transferred_variables; {# XXX: TODO: make this into a const member variable #} +{%- endfor %} + + const double t_sp_ms = t_sp.get_ms() - offset; + + if ( n_incoming_ ) + { + // prune all spikes from history which are no longer needed + // only remove a spike if: + // - its access counter indicates it has been read out by all connected + // STDP synapses, and + // - there is another, later spike, that is strictly more than + // (min_global_delay + max_delay_ + eps) away from the new spike (at t_sp_ms) + while ( history_.size() > 1 ) + { + const double next_t_sp = history_[ 1 ].t_; + // Note that ``access_counter`` now has an extra multiplicative factor equal (``n_incoming_``) to the number of trace values that exist, so that spikes are removed from the history only after they have been read out for the sake of computing each trace. + // see https://www.frontiersin.org/files/Articles/1382/fncom-04-00141-r1/image_m/fncom-04-00141-g003.jpg (Potjans et al. 2010) + + if ( history_.front().access_counter_ >= n_incoming_ * num_transferred_variables + and t_sp_ms - next_t_sp > max_delay_ + nest::Time::delay_steps_to_ms(nest::kernel().connection_manager.get_min_delay()) + nest::kernel().connection_manager.get_stdp_eps() ) + { + history_.pop_front(); + } + else + { + break; + } + } + + if (history_.size() > 0) + { + assert(history_.back().t_ == last_spike_); +{# +{%- for var in purely_numeric_state_variables_moved|sort %} + {{ printer.print(utils.get_state_variable_by_name(astnode, var)) }} = history_.back().{{var}}_; +{%- endfor %} +{%- for var in analytic_state_variables_moved|sort %} + {{ printer.print(utils.get_state_variable_by_name(astnode, var)) }} = history_.back().{{var}}_; +{%- endfor %} + } + else + { +{%- for var in purely_numeric_state_variables_moved|sort %} + {{ printer.print(utils.get_state_variable_by_name(astnode, var)) }} = {{ utils.initial_value_or_zero(astnode, var) }}; // initial value for convolution is always 0 +{%- endfor %} +{%- for var in analytic_state_variables_moved|sort %} + {{ printer.print(utils.get_state_variable_by_name(astnode, var)) }} = {{ utils.initial_value_or_zero(astnode, var) }}; // initial value for convolution is always 0 +{%- endfor %}#} + } +{# + /** + * update state variables transferred from synapse from `last_spike_` to `t_sp_ms` + * + * variables that will be integrated: {{ purely_numeric_state_variables_moved + analytic_state_variables_moved }} + **/ + + const double old___h = V_.__h; + V_.__h = t_sp_ms - last_spike_; + if (V_.__h > 1E-12) + { + recompute_internal_variables(true); + +{%- filter indent(6, True) -%} +{# emulate a call to ``integrate_odes(purely_numeric_state_variables_moved + analytic_state_variables_moved)`` +{%- set args = utils.resolve_variables_to_expressions(astnode, purely_numeric_state_variables_moved + analytic_state_variables_moved) %} +{%- set ast = ASTNodeFactory.create_ast_function_call("integrate_odes", args) %} +{%- include "directives_cpp/PredefinedFunction_integrate_odes.jinja2" %} +{%- endfilter %} + + V_.__h = old___h; + recompute_internal_variables(true); + } + #} + + /** + * print extra on-emit statements transferred from synapse + **/ + +{%- filter indent(4, True) %} +{%- for stmt in extra_on_emit_spike_stmts_from_synapse %} +{%- include "directives_cpp/Statement.jinja2" %} +{%- endfor %} +{%- endfilter %} + + /** + * print updates due to convolutions + **/ + +{%- for _, spike_update in post_spike_updates.items() %} + {{ printer.print(utils.get_variable_by_name(astnode, spike_update.get_variable().get_complete_name())) }} += 1.; +{%- endfor %} + + last_spike_ = t_sp_ms; + history_.push_back( histentry__{{neuronName}}( last_spike_ +{# +{%- for var in purely_numeric_state_variables_moved|sort %} + , get_{{var}}() +{%- endfor %} +{%- for var in analytic_state_variables_moved|sort %} + , get_{{var}}() +{%- endfor %} +#} +, 0 + ) ); + } + else + { + last_spike_ = t_sp_ms; + } +} + + +void +{{neuronName}}::clear_history() +{ + last_spike_ = -1.0; + history_.clear(); +} + + +{# + generate getter functions for the transferred variables +#} + +{%- for var in transferred_variables %} +{%- with variable_symbol = transferred_variables_syms[var] %} + +{%- if not var == variable_symbol.get_symbol_name() %} +{{ raise('Error in resolving variable to symbol') }} +{%- endif %} + +double +{{neuronName}}::get_{{var}}( double t, const bool before_increment ) +{ +#ifdef DEBUG + std::cout << "{{neuronName}}::get_{{var}}: getting value at t = " << t << std::endl; +#endif + + // case when the neuron has not yet spiked + if ( history_.empty() ) + { +#ifdef DEBUG + std::cout << "{{neuronName}}::get_{{var}}: \thistory empty, returning initial value = " << {{var}}__iv << std::endl; +#endif + // return initial value + return {{var}}__iv; + } + + // search for the latest post spike in the history buffer that came strictly before `t` + int i = history_.size() - 1; + double eps = 0.; + if ( before_increment ) + { + eps = nest::kernel().connection_manager.get_stdp_eps(); + } + while ( i >= 0 ) + { + if ( t - history_[ i ].t_ >= eps ) + { +#ifdef DEBUG + std::cout<<"{{neuronName}}::get_{{var}}: \tspike occurred at history[i].t_ = " << history_[i].t_ << std::endl; +#endif +{# +{%- for var_ in purely_numeric_state_variables_moved %} + {{ printer.print(utils.get_variable_by_name(astnode, var_)) }} = history_[ i ].{{var_}}_; +{%- endfor %} +{%- for var_ in analytic_state_variables_moved %} + {{ printer.print(utils.get_variable_by_name(astnode, var_)) }} = history_[ i ].{{var_}}_; +{%- endfor %} +#} + + /** + * update state variables transferred from synapse from `history[i].t_` to `t` + * + * variables that will be integrated: {{ purely_numeric_state_variables_moved + analytic_state_variables_moved }} + **/ + + if ( t - history_[ i ].t_ >= nest::kernel().connection_manager.get_stdp_eps() ) + { + const double old___h = V_.__h; + V_.__h = t - history_[i].t_; + assert(V_.__h > 0); + recompute_internal_variables(true); + +{# emulate a call to ``integrate_odes(purely_numeric_state_variables_moved + analytic_state_variables_moved)`` #} + {# +{%- set args = utils.resolve_variables_to_expressions(astnode, purely_numeric_state_variables_moved + analytic_state_variables_moved) %} +{%- set ast = ASTNodeFactory.create_ast_function_call("integrate_odes", args) %} +{%- include "directives_cpp/PredefinedFunction_integrate_odes.jinja2" %} +#} + + V_.__h = old___h; + recompute_internal_variables(true); + } + +#ifdef DEBUG + std::cout << "{{neuronName}}::get_{{var}}: \treturning " << {{ printer.print(utils.get_variable_by_name(astnode, var)) }} << std::endl; +#endif + return {{ printer.print(utils.get_variable_by_name(astnode, var)) }}; // type: {{declarations.print_variable_type(variable_symbol)}} + } + --i; + } + + // this case occurs when the trace was requested at a time precisely at that of the first spike in the history + if ( (!before_increment) and t == history_[ 0 ].t_) + { + {# +{%- for var_ in purely_numeric_state_variables_moved %} + {{ printer.print(utils.get_state_variable_by_name(astnode, var_)) }} = history_[ 0 ].{{var_}}_; +{%- endfor %} +{%- for var_ in analytic_state_variables_moved %} + {{ printer.print(utils.get_state_variable_by_name(astnode, var_)) }} = history_[ 0 ].{{var_}}_; +{%- endfor %} +#} + +#ifdef DEBUG + std::cout << "{{neuronName}}::get_{{var}}: \ttrace requested at exact time of history entry 0, returning " << {{ printer.print(utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name())) }} << std::endl; +#endif + return {{ printer.print(utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name())) }}; + } + + // this case occurs when the trace was requested at a time before the first spike in the history + // return initial value propagated in time +#ifdef DEBUG + std::cout << "{{neuronName}}::get_{{var}}: \tfall-through, returning initial value = " << {{var}}__iv << std::endl; +#endif + + if (t == 0.) + { + return 0.; // initial value for convolution is always 0 + } + + // set to initial value + {# +{%- for var_ in purely_numeric_state_variables_moved %} + {{ printer.print(utils.get_state_variable_by_name(astnode, var_)) }} = 0.; // initial value for convolution is always 0 +{%- endfor %} +{%- for var_ in analytic_state_variables_moved %} + {{ printer.print(utils.get_state_variable_by_name(astnode, var_)) }} = 0.; // initial value for convolution is always 0 +{%- endfor %} +#} + + /** + * update state variables transferred from synapse from initial condition to `t` + * + * variables that will be integrated: {{ purely_numeric_state_variables_moved + analytic_state_variables_moved }} + **/ + + const double old___h = V_.__h; + V_.__h = t; // from time 0 to the requested time + assert(V_.__h > 0); + recompute_internal_variables(true); + +{# emulate a call to ``integrate_odes(purely_numeric_state_variables_moved + analytic_state_variables_moved)`` #} +{# +{%- set args = utils.resolve_variables_to_expressions(astnode, purely_numeric_state_variables_moved + analytic_state_variables_moved) %} +{%- set ast = ASTNodeFactory.create_ast_function_call("integrate_odes", args) %} +{%- include "directives_cpp/PredefinedFunction_integrate_odes.jinja2" %} +#} + V_.__h = old___h; + recompute_internal_variables(true); + + return {{ printer.print(utils.get_variable_by_name(astnode, var)) }}; +} +{%- endwith -%} +{%- endfor %} + +{%- endif %} + } // namespace diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.h.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.h.jinja2 index 642a724a7..bfd3c92bf 100644 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.h.jinja2 +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/@NEURON_NAME@.h.jinja2 @@ -20,8 +20,8 @@ * */ -#ifndef CM_DEFAULT_H -#define CM_DEFAULT_H +#ifndef CM_{{neuronSpecificFileNamesCmSyns["main"].upper()}} +#define CM_{{neuronSpecificFileNamesCmSyns["main"].upper()}} // Includes from nestkernel: #include "archiving_node.h" @@ -29,7 +29,7 @@ #include "nest_types.h" #include "universal_data_logger.h" -#include "{{neuronSpecificFileNamesCmSyns["compartmentcurrents"]}}.h" +#include "{{neuronSpecificFileNamesCmSyns["neuroncurrents"]}}.h" #include "{{neuronSpecificFileNamesCmSyns["tree"]}}.h" namespace nest @@ -225,6 +225,43 @@ NEURON simulator ;-D EndUserDocs */ +{%- if paired_synapse is defined %} + +// entry in the spiking history +class histentry__{{neuronName}} +{ +public: + histentry__{{neuronName}}( double t, +{%- for var in purely_numeric_state_variables_moved|sort%} +double {{var}}, +{%- endfor %} +{%- for var in analytic_state_variables_moved|sort%} +double {{var}}, +{%- endfor %} +size_t access_counter ) + : t_( t ) +{%- for var in purely_numeric_state_variables_moved|sort %} + , {{var}}_( {{var}} ) +{%- endfor %} +{%- for var in analytic_state_variables_moved|sort %} + , {{var}}_( {{var}} ) +{%- endfor %} + , access_counter_( access_counter ) + { + } + + double t_; //!< point in time when spike occurred (in ms) +{%- for var in purely_numeric_state_variables_moved|sort %} + double {{var}}_; +{%- endfor %} +{%- for var in analytic_state_variables_moved|sort %} + double {{var}}_; +{%- endfor %} + size_t access_counter_; //!< access counter to enable removal of the entry, once all neurons read it +}; + +{%- endif %} + // Register the neuron model {%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") or nest_version.startswith("v3.4") or nest_version.startswith("v3.5") or nest_version.startswith("v3.6")) %} @@ -251,9 +288,54 @@ public: size_t handles_test_event( CurrentEvent&, size_t ); size_t handles_test_event( DataLoggingRequest&, size_t ); - void get_status( DictionaryDatum& ) const; + void get_status( DictionaryDatum& ); void set_status( const DictionaryDatum& ); + {% if paired_synapse is defined %} + // support for spike archiving + + /** + * \fn void get_history(long t1, long t2, + * std::deque::iterator* start, + * std::deque::iterator* finish) + * return the spike times (in steps) of spikes which occurred in the range + * (t1,t2]. + * XXX: two underscores to differentiate it from nest::Node::get_history() + */ + void get_history__( double t1, + double t2, + std::deque< histentry__{{neuronName}} >::iterator* start, + std::deque< histentry__{{neuronName}} >::iterator* finish ); + + /** + * Register a new incoming STDP connection. + * + * t_first_read: The newly registered synapse will read the history entries + * with t > t_first_read. + */ + void register_stdp_connection( double t_first_read, double delay ); +{%- endif %} + +protected: +{%- if paired_synapse is defined %} + // support for spike archiving + + /** + * record spike history + */ + void set_spiketime( nest::Time const& t_sp, double offset = 0.0 ); + + /** + * return most recent spike time in ms + */ + inline double get_spiketime_ms() const; + + /** + * clear spike history + */ + void clear_history(); +{%- endif %} + private: void add_compartment_( DictionaryDatum& dd ); void add_receptor_( DictionaryDatum& dd ); @@ -296,6 +378,29 @@ private: DynamicUniversalDataLogger< {{neuronSpecificFileNamesCmSyns["main"]}} > logger_; double V_th_; + +{%- if paired_synapse is defined %} + // support for spike archiving + + // number of incoming connections from stdp connectors. + // needed to determine, if every incoming connection has + // read the spikehistory for a given point in time + size_t n_incoming_; + + double max_delay_; + + double last_spike_; + + // spiking history needed by stdp synapses + std::deque< histentry__{{neuronName}} > history_; + + // cache for initial values +{%- for var in transferred_variables %} + double {{var}}__iv; +{%- endfor %} + + +{%- endif %} }; diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_compartmentcurrents_@NEURON_NAME@.cpp.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_compartmentcurrents_@NEURON_NAME@.cpp.jinja2 deleted file mode 100644 index 98626fc96..000000000 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_compartmentcurrents_@NEURON_NAME@.cpp.jinja2 +++ /dev/null @@ -1,417 +0,0 @@ -{#- -cm_compartmentcurrents_@NEURON_NAME@.cpp.jinja2 - -This file is part of NEST. - -Copyright (C) 2004 The NEST Initiative - -NEST is free software: you can redistribute it and/or modify -it under the terms of the GNU General Public License as published by -the Free Software Foundation, either version 2 of the License, or -(at your option) any later version. - -NEST is distributed in the hope that it will be useful, -but WITHOUT ANY WARRANTY; without even the implied warranty of -MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -GNU General Public License for more details. - -You should have received a copy of the GNU General Public License -along with NEST. If not, see . -#} -{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} -{%- import 'directives_cpp/FunctionDeclaration.jinja2' as function_declaration with context %} -#include "{{neuronSpecificFileNamesCmSyns["compartmentcurrents"]}}.h" - -{%- set current_conductance_name_prefix = "g" %} -{%- set current_equilibrium_name_prefix = "e" %} -{% macro render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} - {%- if variable_type == "gbar" %} - {{ current_conductance_name_prefix~"_"~ion_channel_name }} - {%- elif variable_type == "e" %} - {{ current_equilibrium_name_prefix~"_"~ion_channel_name }} - {%- endif %} -{%- endmacro %} - -{%- macro render_state_variable_name(pure_variable_name, ion_channel_name) %} - {{ pure_variable_name~"_"~ion_channel_name }} -{%- endmacro %} - -{% macro render_time_resolution_variable(synapse_info) %} -{# we assume here that there is only one such variable ! #} -{%- for analytic_helper_name, analytic_helper_info in synapse_info["analytic_helpers"].items() %} -{%- if analytic_helper_info["is_time_resolution"] %} - {{ analytic_helper_name }} -{%- endif %} -{%- endfor %} -{%- endmacro %} - -{% macro render_function_return_type(function) %} -{%- with %} - {%- set symbol = function.get_scope().resolve_to_symbol(function.get_name(), SymbolKind.FUNCTION) %} - {{ types_printer.print(symbol.get_return_type()) }} -{%- endwith %} -{%- endmacro %} - -{% macro render_inline_expression_type(inline_expression) %} -{%- with %} - {%- set symbol = inline_expression.get_scope().resolve_to_symbol(inline_expression.variable_name, SymbolKind.VARIABLE) %} - {{ types_printer.print(symbol.get_type_symbol()) }} -{%- endwith %} -{%- endmacro %} - -{% macro render_static_channel_variable_name(variable_type, ion_channel_name) %} - -{%- for ion_channel_nm, channel_info in chan_info.items() %} - {%- if ion_channel_nm == ion_channel_name %} - {%- for variable_tp, variable_info in channel_info["channel_parameters"].items() %} - {%- if variable_tp == variable_type %} - {%- set variable = variable_info["parameter_block_variable"] %} - {{ variable.name }} - {%- endif %} - {%- endfor %} - {%- endif %} -{%- endfor %} - -{%- endmacro %} - -{% macro render_channel_function(function, ion_channel_name) %} -{{ function_declaration.FunctionDeclaration(function, "nest::"~ion_channel_name~cm_unique_suffix~"::") }} -{ -{%- filter indent(2,True) %} -{%- with ast = function.get_block() %} -{%- include "directives_cpp/Block.jinja2" %} -{%- endwith %} -{%- endfilter %} -} -{%- endmacro %} - - -{%- for ion_channel_name, channel_info in chan_info.items() %} - -// {{ion_channel_name}} channel ////////////////////////////////////////////////////////////////// -nest::{{ion_channel_name}}{{cm_unique_suffix}}::{{ion_channel_name}}{{cm_unique_suffix}}() - -{%- for pure_variable_name, variable_info in channel_info["States"].items() %} -// state variable {{pure_variable_name -}} -{%- set variable = variable_info["ASTVariable"] %} -{%- set rhs_expression = variable_info["rhs_expression"] %} -{% if loop.first %}: {% else %}, {% endif %} -{{- variable.name}}({{ printer_no_origin.print(rhs_expression) -}}) -{%- endfor %} - -{% for variable_type, variable_info in channel_info["Parameters"].items() %} -// channel parameter {{variable_type -}} -{%- set variable = variable_info["ASTVariable"] %} -{%- set rhs_expression = variable_info["rhs_expression"] %} -,{{- variable.name }}({{ printer_no_origin.print(rhs_expression) -}}) -{%- endfor %} -{} - -nest::{{ion_channel_name}}{{cm_unique_suffix}}::{{ion_channel_name}}{{cm_unique_suffix}}(const DictionaryDatum& channel_params) - -{%- for pure_variable_name, variable_info in channel_info["States"].items() %} -// state variable {{pure_variable_name -}} -{%- set variable = variable_info["ASTVariable"] %} -{%- set rhs_expression = variable_info["rhs_expression"] %} -{% if loop.first %}: {% else %}, {% endif %} -{{- variable.name}}({{ printer_no_origin.print(rhs_expression) -}}) -{%- endfor %} - -{% for variable_type, variable_info in channel_info["Parameters"].items() %} -// channel parameter {{variable_type -}} -{%- set variable = variable_info["ASTVariable"] %} -{%- set rhs_expression = variable_info["rhs_expression"] %} -,{{- variable.name }}({{ printer_no_origin.print(rhs_expression) -}}) -{%- endfor %} -// update {{ion_channel_name}} channel parameters -{ - {%- for variable_type, variable_info in channel_info["Parameters"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} //have to remove??????????? - // {{ion_channel_name}} channel parameter {{dynamic_variable }} - if( channel_params->known( "{{variable.name}}" ) ) - {{variable.name}} = getValue< double >( channel_params, "{{variable.name}}" ); - {%- endfor %} -} - -void -nest::{{ion_channel_name}}{{cm_unique_suffix}}::append_recordables(std::map< Name, double* >* recordables, - const long compartment_idx) -{ - // add state variables to recordables map - {%- for pure_variable_name, variable_info in channel_info["States"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - ( *recordables )[ Name( "{{variable.name}}" + std::to_string(compartment_idx) )] = &{{variable.name}}; - {%- endfor %} - ( *recordables )[ Name( "i_tot_{{ion_channel_name}}" + std::to_string(compartment_idx) )] = &i_tot_{{ion_channel_name}}; -} - -std::pair< double, double > nest::{{ion_channel_name}}{{cm_unique_suffix}}::f_numstep(const double v_comp{% for ode in channel_info["Dependencies"]["concentrations"] %}, double {{ode.lhs.name}}{% endfor %} - {% for inline in channel_info["Dependencies"]["receptors"] %}, double {{inline.variable_name}}{% endfor %} - {% for inline in channel_info["Dependencies"]["channels"] %}, double {{inline.variable_name}}{% endfor %}) -{ - double g_val = 0., i_val = 0.; - - if({%- for key_zero_param in channel_info["RootInlineKeyZeros"] %} {{ key_zero_param }} > 1e-9 && {%- endfor %} true ){ - {% if channel_info["ODEs"].items()|length %} double {{ printer_no_origin.print(channel_info["time_resolution_var"]) }} = Time::get_resolution().get_ms(); {% endif %} - - {%- for ode_variable, ode_info in channel_info["ODEs"].items() %} - {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} - double {{ propagator }} = {{ printer_no_origin.print(propagator_info["init_expression"]) }}; - {%- endfor %} - {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} - {{state}} = {{ printer_no_origin.print(state_solution_info["update_expression"]) }}; - {%- endfor %} - {%- endfor %} - - {%- set inline_expression = channel_info["root_expression"] %} - {%- set inline_expression_d = channel_info["inline_derivative"] %} - // compute the conductance of the {{ion_channel_name}} channel - this->i_tot_{{ion_channel_name}} = {{ printer_no_origin.print(inline_expression.get_expression()) }}; - // derivative - double d_i_tot_dv = {{ printer_no_origin.print(inline_expression_d) }}; - - g_val = - d_i_tot_dv / 2.; - i_val = this->i_tot_{{ion_channel_name}} - d_i_tot_dv * v_comp / 2.; - } - return std::make_pair(g_val, i_val); - -} - -{%- for function in channel_info["Functions"] %} -{{render_channel_function(function, ion_channel_name)}} -{%- endfor %} - -double nest::{{ion_channel_name}}{{cm_unique_suffix}}::get_current_{{ion_channel_name}}(){ - return this->i_tot_{{ion_channel_name}}; -} - -// {{ion_channel_name}} channel end /////////////////////////////////////////////////////////// -{% endfor %} -//////////////////////////////////////////////////////////////////////////////// - -{%- for synapse_name, synapse_info in syns_info.items() %} -// {{synapse_name}} synapse //////////////////////////////////////////////////////////////// -nest::{{synapse_name}}{{cm_unique_suffix}}::{{synapse_name}}{{cm_unique_suffix}}( const long syn_index ) - {%- for param_name, param_declaration in synapse_info["Parameters"].items() %} - {% if loop.first %}: {% else %}, {% endif %} - {{ param_name }}({{ printer_no_origin.print(param_declaration["rhs_expression"]) }}) - {%- endfor %} -{ - syn_idx = syn_index; -} - -nest::{{synapse_name}}{{cm_unique_suffix}}::{{synapse_name}}{{cm_unique_suffix}}( const long syn_index, const DictionaryDatum& receptor_params ) - {%- for param_name, param_declaration in synapse_info["Parameters"].items() %} - {% if loop.first %}: {% else %}, {% endif %} - {{ param_name }}({{ printer_no_origin.print(param_declaration["rhs_expression"]) }}) - {%- endfor %} -{ - syn_idx = syn_index; - - // update parameters - {%- for param_name, param_declaration in synapse_info["Parameters"].items() %} - if( receptor_params->known( "{{param_name}}" ) ) - {{param_name}} = getValue< double >( receptor_params, "{{param_name}}" ); - {%- endfor %} -} - -void -nest::{{synapse_name}}{{cm_unique_suffix}}::append_recordables(std::map< Name, double* >* recordables) -{ - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} - ( *recordables )[ Name( "{{convolution_info["kernel"]["name"]}}" + std::to_string(syn_idx) )] = &{{convolution}}; - {%- endfor %} - ( *recordables )[ Name( "i_tot_{{synapse_name}}" + std::to_string(syn_idx) )] = &i_tot_{{synapse_name}}; -} - -{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} -void nest::{{synapse_name}}{{cm_unique_suffix}}::calibrate() -{%- else %} -void nest::{{synapse_name}}{{cm_unique_suffix}}::pre_run_hook() -{%- endif %} -{ - - const double {{render_time_resolution_variable(synapse_info)}} = Time::get_resolution().get_ms(); - - // set propagators to ode toolbox returned value - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} - {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} - {{state_variable_name}} = {{ printer_no_origin.print(state_variable_info["init_expression"]) }}; - {%- endfor %} - {%- endfor %} - - // initial values for user defined states - // warning: this shadows class variables - {%- for state_name, state_declaration in synapse_info["States"].items() %} - double {{state_name}} = {{ printer_no_origin.print(state_declaration["rhs_expression"]) }}; - {%- endfor %} - - // initial values for kernel state variables, set to zero - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} - {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} - {{state_variable_name}} = 0; - {%- endfor %} - {%- endfor %} - - // user declared internals in order they were declared - {%- for internal_name, internal_declaration in synapse_info["internals_used_declared"] %} - {{internal_name}} = {{ printer_no_origin.print(internal_declaration.get_expression()) }}; - {%- endfor %} - - {{synapse_info["buffer_name"]}}_->clear(); -} - -std::pair< double, double > nest::{{synapse_name}}{{cm_unique_suffix}}::f_numstep( const double v_comp, const long lag {% for ode in synapse_info["Dependencies"]["concentrations"] %}, double {{ode.lhs.name}}{% endfor %} - {% for inline in synapse_info["Dependencies"]["receptors"] %}, double {{inline.variable_name}}{% endfor %} - {% for inline in synapse_info["Dependencies"]["channels"] %}, double {{inline.variable_name}}{% endfor %}) -{ - // get spikes - double s_val = {{synapse_info["buffer_name"]}}_->get_value( lag ); // * g_norm_; - - //update ODE state variable - {% if synapse_info["ODEs"].items()|length %} double {{ printer_no_origin.print(synapse_info["time_resolution_var"]) }} = Time::get_resolution().get_ms(); {% endif %} - {%- for ode_variable, ode_info in synapse_info["ODEs"].items() %} - {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} - double {{ propagator }} = {{ printer_no_origin.print(propagator_info["init_expression"]) }}; - {%- endfor %} - {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} - {{state}} = {{ printer_no_origin.print(state_solution_info["update_expression"]) }}; - {%- endfor %} - {%- endfor %} - - // update kernel state variable / compute synaptic conductance - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} - {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items() %} - {{state_variable_name}} = {{ printer_no_origin.print(state_variable_info["update_expression"]) }}; - {{state_variable_name}} += s_val * {{ printer_no_origin.print(state_variable_info["init_expression"]) }}; - - {%- endfor %} - {%- endfor %} - - // total current - // this expression should be the transformed inline expression - this->i_tot_{{synapse_name}} = {{ printer_no_origin.print(synapse_info["root_expression"].get_expression()) }}; - - // derivative of that expression - // voltage derivative of total current - // compute derivative with respect to current with sympy - double d_i_tot_dv = {{ printer_no_origin.print(synapse_info["inline_expression_d"]) }}; - - // for numerical integration - double g_val = - d_i_tot_dv / 2.; - double i_val = this->i_tot_{{synapse_name}} - d_i_tot_dv * v_comp / 2.; - - return std::make_pair(g_val, i_val); - -} - -{%- for function in synapse_info["functions_used"] %} -{{ function_declaration.FunctionDeclaration(function, "nest::"~synapse_name~cm_unique_suffix~"::") }} -{ -{%- filter indent(2,True) %} -{%- with ast = function.get_block() %} -{%- include "directives_cpp/Block.jinja2" %} -{%- endwith %} -{%- endfilter %} -} -{%- endfor %} - - double nest::{{synapse_name}}{{cm_unique_suffix}}::get_current_{{synapse_name}}(){ - return this->i_tot_{{synapse_name}}; - } - -// {{synapse_name}} synapse end /////////////////////////////////////////////////////////// -{%- endfor %} - -//////////////////////////////// concentrations -{%- for concentration_name, concentration_info in conc_info.items() %} - -// {{ concentration_name }} concentration ////////////////////////////////////////////////////////////////// -nest::{{ concentration_name }}{{cm_unique_suffix}}::{{ concentration_name }}{{cm_unique_suffix}}(): -{%- set states_written = False %} -{%- for pure_variable_name, variable_info in concentration_info["States"].items() %} -// state variable {{pure_variable_name -}} -{%- set variable = variable_info["ASTVariable"] %} -{%- set rhs_expression = variable_info["rhs_expression"] %} -{% if loop.first %} {%- set states_written = True %} {% else %}, {% endif %} -{{- variable.name}}({{ printer_no_origin.print(rhs_expression) -}}) -{%- endfor %} - -{% for variable_type, variable_info in concentration_info["Parameters"].items() %} -// channel parameter {{variable_type -}} -{%- set variable = variable_info["ASTVariable"] %} -{%- set rhs_expression = variable_info["rhs_expression"] %} -{% if loop.first %} {% if states_written %}, {% endif %} {% else %}, {% endif %} -{{- variable.name }}({{ printer_no_origin.print(rhs_expression) -}}) -{%- endfor %} -{} - -nest::{{ concentration_name }}{{cm_unique_suffix}}::{{ concentration_name }}{{cm_unique_suffix}}(const DictionaryDatum& concentration_params): -{%- set states_written = False %} -{%- for pure_variable_name, variable_info in concentration_info["States"].items() %} -// state variable {{pure_variable_name -}} -{%- set variable = variable_info["ASTVariable"] %} -{%- set rhs_expression = variable_info["rhs_expression"] %} -{% if loop.first %} {%- set states_written = True %} {% else %}, {% endif %} -{{- variable.name}}({{ printer_no_origin.print(rhs_expression) -}}) -{%- endfor %} - -{% for variable_type, variable_info in concentration_info["Parameters"].items() %} -// channel parameter {{variable_type -}} -{%- set variable = variable_info["ASTVariable"] %} -{%- set rhs_expression = variable_info["rhs_expression"] %} -{% if loop.first %} {% if states_written %}, {% endif %} {% else %}, {% endif %} -{{- variable.name }}({{ printer_no_origin.print(rhs_expression) -}}) -{%- endfor %} -// update {{ concentration_name }} concentration parameters -{ - {%- for variable_type, variable_info in concentration_info["Parameters"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, concentration_name) %} //have to remove??????????? - // {{ concentration_name }} concentration parameter {{dynamic_variable }} - if( concentration_params->known( "{{variable.name}}" ) ) - {{variable.name}} = getValue< double >( concentration_params, "{{variable.name}}" ); - {%- endfor %} -} - -void -nest::{{ concentration_name }}{{cm_unique_suffix}}::append_recordables(std::map< Name, double* >* recordables, - const long compartment_idx) -{ - // add state variables to recordables map - {%- for pure_variable_name, variable_info in concentration_info["States"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - ( *recordables )[ Name( "{{variable.name}}" + std::to_string(compartment_idx) )] = &{{variable.name}}; - {%- endfor %} - ( *recordables )[ Name( "{{concentration_name}}" + std::to_string(compartment_idx) )] = &{{concentration_name}}; -} - -void nest::{{ concentration_name }}{{cm_unique_suffix}}::f_numstep(const double v_comp{% for ode in concentration_info["Dependencies"]["concentrations"] %}, double {{ode.lhs.name}}{% endfor %} - {% for inline in concentration_info["Dependencies"]["receptors"] %}, double {{inline.variable_name}}{% endfor %} - {% for inline in concentration_info["Dependencies"]["channels"] %}, double {{inline.variable_name}}{% endfor %}) -{ - if({%- for key_zero_param in concentration_info["RootInlineKeyZeros"] %} {{ key_zero_param }} > 1e-9 && {%- endfor %} true ){ - double {{ printer_no_origin.print(concentration_info["time_resolution_var"]) }} = Time::get_resolution().get_ms(); - - {%- for ode_variable, ode_info in concentration_info["ODEs"].items() %} - {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} - double {{ propagator }} = {{ printer_no_origin.print(propagator_info["init_expression"]) }}; - {%- endfor %} - {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} - {{state}} = {{ printer_no_origin.print(state_solution_info["update_expression"]) }}; - {%- endfor %} - {%- endfor %} - } -} - -{%- for function in concentration_info["Functions"] %} -{{render_channel_function(function, concentration_name)}} -{%- endfor %} - -double nest::{{concentration_name}}{{cm_unique_suffix}}::get_concentration_{{concentration_name}}(){ - return this->{{concentration_name}}; -} - -// {{concentration_name}} concentration end /////////////////////////////////////////////////////////// -{% endfor %} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_compartmentcurrents_@NEURON_NAME@.h.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_compartmentcurrents_@NEURON_NAME@.h.jinja2 deleted file mode 100644 index 508d3331b..000000000 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_compartmentcurrents_@NEURON_NAME@.h.jinja2 +++ /dev/null @@ -1,470 +0,0 @@ -{#- -cm_compartmentcurrents_@NEURON_NAME@.h.jinja2 - -This file is part of NEST. - -Copyright (C) 2004 The NEST Initiative - -NEST is free software: you can redistribute it and/or modify -it under the terms of the GNU General Public License as published by -the Free Software Foundation, either version 2 of the License, or -(at your option) any later version. - -NEST is distributed in the hope that it will be useful, -but WITHOUT ANY WARRANTY; without even the implied warranty of -MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -GNU General Public License for more details. - -You should have received a copy of the GNU General Public License -along with NEST. If not, see . -#} -{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} -{%- import 'directives_cpp/FunctionDeclaration.jinja2' as function_declaration with context %} -#ifndef SYNAPSES_NEAT_H_{{cm_unique_suffix | upper }} -#define SYNAPSES_NEAT_H_{{cm_unique_suffix | upper }} - -#include - -#include "ring_buffer.h" - -{% macro render_variable_type(variable) %} -{%- with %} - {%- set symbol = variable.get_scope().resolve_to_symbol(variable.name, SymbolKind.VARIABLE) %} - {{ types_printer.print(symbol.type_symbol) }} -{%- endwith %} -{%- endmacro %} - -namespace nest -{ - -{%- for ion_channel_name, channel_info in chan_info.items() %} - -class {{ion_channel_name}}{{cm_unique_suffix}}{ -private: - // states - {%- for pure_variable_name, variable_info in channel_info["States"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ render_variable_type(variable) }} {{ variable.name }} = {{ printer_no_origin.print(rhs_expression) }}; - {%- endfor %} - - // parameters - {%- for pure_variable_name, variable_info in channel_info["Parameters"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ render_variable_type(variable) }} {{ variable.name }} = {{ printer_no_origin.print(rhs_expression) }}; - {%- endfor %} - - // ion-channel root-inline value - double i_tot_{{ion_channel_name}} = 0; - -public: - // constructor, destructor - {{ion_channel_name}}{{cm_unique_suffix}}(); - {{ion_channel_name}}{{cm_unique_suffix}}(const DictionaryDatum& channel_params); - ~{{ion_channel_name}}{{cm_unique_suffix}}(){}; - - // initialization channel -{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} - void calibrate() { -{%- else %} - void pre_run_hook() { -{%- endif %} - // states - {%- for pure_variable_name, variable_info in channel_info["States"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name }} = {{ printer_no_origin.print(rhs_expression) }}; - {%- endfor %} - }; - void append_recordables(std::map< Name, double* >* recordables, - const long compartment_idx); - - // numerical integration step - std::pair< double, double > f_numstep( const double v_comp{% for ode in channel_info["Dependencies"]["concentrations"] %}, double {{ode.lhs.name}}{% endfor %}{% if channel_info["Dependencies"]["receptors"]|length %} - {% endif %}{% for inline in channel_info["Dependencies"]["receptors"] %}, double {{inline.variable_name}}{% endfor %}{% if channel_info["Dependencies"]["channels"]|length %} - {% endif %}{% for inline in channel_info["Dependencies"]["channels"] %}, double {{inline.variable_name}}{% endfor %}); - - // function declarations - -{%- for function in channel_info["Functions"] %} - {{ function_declaration.FunctionDeclaration(function) }}; -{%- endfor %} - - // root_inline getter - double get_current_{{ion_channel_name}}(); - -}; -{% endfor %} - - -////////////////////////////////////////////////// synapses - -{% macro render_time_resolution_variable(synapse_info) %} -{# we assume here that there is only one such variable ! #} -{%- for analytic_helper_name, analytic_helper_info in synapse_info["analytic_helpers"].items() %} -{%- if analytic_helper_info["is_time_resolution"] %} - {{ analytic_helper_name }} -{%- endif %} -{%- endfor %} -{%- endmacro %} - -{%- for synapse_name, synapse_info in syns_info.items() %} - -class {{synapse_name}}{{cm_unique_suffix}}{ -private: - // global synapse index - long syn_idx = 0; - - // propagators, initialized via pre_run_hook() or calibrate() - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} - {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} - double {{state_variable_name}}; - {%- endfor %} - {%- endfor %} - - // kernel state variables, initialized via pre_run_hook() or calibrate() - {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} - {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} - double {{state_variable_name}}; - {%- endfor %} - {%- endfor %} - - // user defined parameters, initialized via pre_run_hook() or calibrate() - {%- for param_name, param_declaration in synapse_info["Parameters"].items() %} - double {{param_name}}; - {%- endfor %} - - // states - {%- for pure_variable_name, variable_info in synapse_info["States"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ render_variable_type(variable) }} {{ variable.name }} = {{ printer_no_origin.print(rhs_expression) }}; - {%- endfor %} - double i_tot_{{synapse_name}} = 0; - - // user declared internals in order they were declared, initialized via pre_run_hook() or calibrate() - {%- for internal_name, internal_declaration in synapse_info["internals_used_declared"] %} - double {{internal_name}}; - {%- endfor %} - - - - // spike buffer - RingBuffer* {{synapse_info["buffer_name"]}}_; - -public: - // constructor, destructor - {{synapse_name}}{{cm_unique_suffix}}( const long syn_index); - {{synapse_name}}{{cm_unique_suffix}}( const long syn_index, const DictionaryDatum& receptor_params); - ~{{synapse_name}}{{cm_unique_suffix}}(){}; - - long - get_syn_idx() - { - return syn_idx; - }; - - // numerical integration step - std::pair< double, double > f_numstep( const double v_comp, const long lag {% for ode in synapse_info["Dependencies"]["concentrations"] %}, double {{ode.lhs.name}}{% endfor %}{% if synapse_info["Dependencies"]["receptors"]|length %} - {% endif %}{% for inline in synapse_info["Dependencies"]["receptors"] %}, double {{inline.variable_name}}{% endfor %}{% if synapse_info["Dependencies"]["channels"]|length %} - {% endif %}{% for inline in synapse_info["Dependencies"]["channels"] %}, double {{inline.variable_name}}{% endfor %}); - - // calibration -{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} - void calibrate(); -{%- else %} - void pre_run_hook(); -{%- endif %} - void append_recordables(std::map< Name, double* >* recordables); - void set_buffer_ptr( std::vector< RingBuffer >& syn_buffers ) - { - {{synapse_info["buffer_name"]}}_ = &syn_buffers[ syn_idx ]; - }; - - // function declarations - {%- for function in synapse_info["Functions"] %} - {{ function_declaration.FunctionDeclaration(function, "") -}}; - - {% endfor %} - - // root_inline getter - double get_current_{{synapse_name}}(); -}; - - -{% endfor %} - -///////////////////////////////////////////// concentrations - -{%- for concentration_name, concentration_info in conc_info.items() %} - -class {{ concentration_name }}{{cm_unique_suffix}}{ -private: - // parameters - {%- for pure_variable_name, variable_info in concentration_info["Parameters"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ render_variable_type(variable) }} {{ variable.name }} = {{ printer_no_origin.print(rhs_expression) }}; - {%- endfor %} - - // states - {%- for pure_variable_name, variable_info in concentration_info["States"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ render_variable_type(variable) }} {{ variable.name }} = {{ printer_no_origin.print(rhs_expression) }}; - {%- endfor %} - - // concentration value (root-ode state) - double {{concentration_name}} = 0; - -public: - // constructor, destructor - {{ concentration_name }}{{cm_unique_suffix}}(); - {{ concentration_name }}{{cm_unique_suffix}}(const DictionaryDatum& concentration_params); - ~{{ concentration_name }}{{cm_unique_suffix}}(){}; - - // initialization channel -{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} - void calibrate() { -{%- else %} - void pre_run_hook() { -{%- endif %} - // states - {%- for pure_variable_name, variable_info in concentration_info["States"].items() %} - {%- set variable = variable_info["ASTVariable"] %} - {%- set rhs_expression = variable_info["rhs_expression"] %} - {{ variable.name }} = {{ printer_no_origin.print(rhs_expression) }}; - {%- endfor %} - }; - void append_recordables(std::map< Name, double* >* recordables, - const long compartment_idx); - - // numerical integration step - void f_numstep( const double v_comp{% for ode in concentration_info["Dependencies"]["concentrations"] %}, double {{ode.lhs.name}}{% endfor %}{% if concentration_info["Dependencies"]["receptors"]|length %} - {% endif %}{% for inline in concentration_info["Dependencies"]["receptors"] %}, double {{inline.variable_name}}{% endfor %}{% if concentration_info["Dependencies"]["channels"]|length %} - {% endif %}{% for inline in concentration_info["Dependencies"]["channels"] %}, double {{inline.variable_name}}{% endfor %}); - - // function declarations -{%- for function in concentration_info["Functions"] %} - {{ function_declaration.FunctionDeclaration(function) }}; -{%- endfor %} - - // root_ode getter - double get_concentration_{{concentration_name}}(); - -}; -{% endfor %} - -///////////////////////////////////////////// currents - -{%- set channel_suffix = "_chan_" %} -{%- set concentration_suffix = "_conc_" %} - -class CompartmentCurrents{{cm_unique_suffix}} { -private: - // ion channels -{% with %} - {%- for ion_channel_name, channel_info in chan_info.items() %} - {{ion_channel_name}}{{cm_unique_suffix}} {{ion_channel_name}}{{channel_suffix}}; - {% endfor %} -{% endwith %} - - // synapses - {%- for synapse_name, synapse_info in syns_info.items() %} - std::vector < {{synapse_name}}{{cm_unique_suffix}} > {{synapse_name}}_syns_; - {% endfor %} - - //concentrations -{% with %} - {%- for concentration_name, concentration_info in conc_info.items() %} - {{concentration_name}}{{cm_unique_suffix}} {{concentration_name}}{{concentration_suffix}}; - {% endfor %} -{% endwith %} - -public: - CompartmentCurrents{{cm_unique_suffix}}(){}; - explicit CompartmentCurrents{{cm_unique_suffix}}(const DictionaryDatum& compartment_params) - { - {%- for ion_channel_name, channel_info in chan_info.items() %} - {{ion_channel_name}}{{channel_suffix}} = {{ion_channel_name}}{{cm_unique_suffix}}( compartment_params ); - {% endfor %} - - {%- for concentration_name, concentration_info in conc_info.items() %} - {{ concentration_name }}{{concentration_suffix}} = {{ concentration_name }}{{cm_unique_suffix}}( compartment_params ); - {% endfor %} - }; - ~CompartmentCurrents{{cm_unique_suffix}}(){}; - -{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} - void calibrate() { -{%- else %} - void pre_run_hook() { -{%- endif %} - // initialization of ion channels - {%- for ion_channel_name, channel_info in chan_info.items() %} -{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} - {{ion_channel_name}}{{channel_suffix}}.calibrate(); -{%- else %} - {{ion_channel_name}}{{channel_suffix}}.pre_run_hook(); -{%- endif %} - {% endfor %} - - // initialization of concentrations - {%- for concentration_name, concentration_info in conc_info.items() %} -{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} - {{ concentration_name }}{{concentration_suffix}}.calibrate(); -{%- else %} - {{ concentration_name }}{{concentration_suffix}}.pre_run_hook(); -{%- endif %} - {% endfor %} - - // initialization of synapses - {%- for synapse_name, synapse_info in syns_info.items() %} - // initialization of {{synapse_name}} synapses - for( auto syn_it = {{synapse_name}}_syns_.begin(); - syn_it != {{synapse_name}}_syns_.end(); - ++syn_it ) - { -{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} - syn_it->calibrate(); -{%- else %} - syn_it->pre_run_hook(); -{%- endif %} - } - {% endfor %} - }; - - void add_synapse( const std::string& type, const long syn_idx ) - { - {%- for synapse_name, synapse_info in syns_info.items() %} - {% if not loop.first %}else{% endif %} if ( type == "{{synapse_name}}" ) - { - {{synapse_name}}_syns_.push_back( {{synapse_name}}{{cm_unique_suffix}}( syn_idx ) ); - } - {% endfor %} - else - { - assert( false ); - } - }; - void add_synapse( const std::string& type, const long syn_idx, const DictionaryDatum& receptor_params ) - { - {%- for synapse_name, synapse_info in syns_info.items() %} - {% if not loop.first %}else{% endif %} if ( type == "{{synapse_name}}" ) - { - {{synapse_name}}_syns_.push_back( {{synapse_name}}{{cm_unique_suffix}}( syn_idx, receptor_params ) ); - } - {% endfor %} - else - { - assert( false ); - } - }; - - void - add_receptor_info( ArrayDatum& ad, const long compartment_index ) - { - {%- for synapse_name, synapse_info in syns_info.items() %} - for( auto syn_it = {{synapse_name}}_syns_.begin(); syn_it != {{synapse_name}}_syns_.end(); syn_it++) - { - DictionaryDatum dd = DictionaryDatum( new Dictionary ); - def< long >( dd, names::receptor_idx, syn_it->get_syn_idx() ); - def< long >( dd, names::comp_idx, compartment_index ); - def< std::string >( dd, names::receptor_type, "{{synapse_name}}" ); - ad.push_back( dd ); - } - {% endfor %} - }; - - void - set_syn_buffers( std::vector< RingBuffer >& syn_buffers ) - { - // spike buffers for synapses - {%- for synapse_name, synapse_info in syns_info.items() %} - for( auto syn_it = {{synapse_name}}_syns_.begin(); syn_it != {{synapse_name}}_syns_.end(); syn_it++) - syn_it->set_buffer_ptr( syn_buffers ); - {% endfor %} - }; - - std::map< Name, double* > - get_recordables( const long compartment_idx ) - { - std::map< Name, double* > recordables; - - // append ion channel state variables to recordables - {%- for ion_channel_name, channel_info in chan_info.items() %} - {{ion_channel_name}}{{channel_suffix}}.append_recordables( &recordables, compartment_idx ); - {% endfor %} - - // append concentration state variables to recordables - {%- for concentration_name, concentration_info in conc_info.items() %} - {{concentration_name}}{{concentration_suffix}}.append_recordables( &recordables, compartment_idx ); - {% endfor %} - - // append synapse state variables to recordables - {%- for synapse_name, synapse_info in syns_info.items() %} - for( auto syn_it = {{synapse_name}}_syns_.begin(); syn_it != {{synapse_name}}_syns_.end(); syn_it++) - syn_it->append_recordables( &recordables ); - {% endfor %} - - return recordables; - }; - - std::pair< double, double > - f_numstep( const double v_comp, const long lag ) - { - std::pair< double, double > gi(0., 0.); - double g_val = 0.; - double i_val = 0.; -{%- for synapse_name, synapse_info in syns_info.items() %} - double {{synapse_name}}{{channel_suffix}}current_sum = 0; - for( auto syn_it = {{synapse_name}}_syns_.begin(); - syn_it != {{synapse_name}}_syns_.end(); - ++syn_it ) - { - {{synapse_name}}{{channel_suffix}}current_sum += syn_it->get_current_{{synapse_name}}(); - } -{% endfor %} - - {%- for concentration_name, concentration_info in conc_info.items() %} - // computation of {{ concentration_name }} concentration - {{ concentration_name }}{{concentration_suffix}}.f_numstep( v_comp{% for ode in concentration_info["Dependencies"]["concentrations"] %}, {{ode.lhs.name}}{{concentration_suffix}}.get_concentration_{{ode.lhs.name}}(){% endfor %}{% if concentration_info["Dependencies"]["receptors"]|length %} - {% endif %}{% for inline in concentration_info["Dependencies"]["receptors"] %}, {{inline.variable_name}}{{channel_suffix}}_current_sum{% endfor %}{% if concentration_info["Dependencies"]["channels"]|length %} - {% endif %}{% for inline in concentration_info["Dependencies"]["channels"] %}, {{inline.variable_name}}{{channel_suffix}}.get_current_{{inline.variable_name}}(){% endfor %}); - - {% endfor %} - - {%- for ion_channel_name, channel_info in chan_info.items() %} - // contribution of {{ion_channel_name}} channel - gi = {{ion_channel_name}}{{channel_suffix}}.f_numstep( v_comp{% for ode in channel_info["Dependencies"]["concentrations"] %}, {{ode.lhs.name}}{{concentration_suffix}}.get_concentration_{{ode.lhs.name}}(){% endfor %}{% if channel_info["Dependencies"]["receptors"]|length %} - {% endif %}{% for inline in channel_info["Dependencies"]["receptors"] %}, {{inline.variable_name}}{{channel_suffix}}_current_sum{% endfor %}{% if channel_info["Dependencies"]["channels"]|length %} - {% endif %}{% for inline in channel_info["Dependencies"]["channels"] %}, {{inline.variable_name}}{{channel_suffix}}.get_current_{{inline.variable_name}}(){% endfor %}); - - g_val += gi.first; - i_val += gi.second; - - {% endfor %} - - {%- for synapse_name, synapse_info in syns_info.items() %} - // contribution of {{synapse_name}} synapses - for( auto syn_it = {{synapse_name}}_syns_.begin(); - syn_it != {{synapse_name}}_syns_.end(); - ++syn_it ) - { - gi = syn_it->f_numstep( v_comp, lag {% for ode in synapse_info["Dependencies"]["concentrations"] %}, {{ode.lhs.name}}{{concentration_suffix}}.get_concentration_{{ode.lhs.name}}(){% endfor %}{% if synapse_info["Dependencies"]["receptors"]|length %} - {% endif %}{% for inline in synapse_info["Dependencies"]["receptors"] %}, {{inline.variable_name}}{{channel_suffix}}_current_sum{% endfor %}{% if synapse_info["Dependencies"]["channels"]|length %} - {% endif %}{% for inline in synapse_info["Dependencies"]["channels"] %}, {{inline.variable_name}}{{channel_suffix}}.get_current{{inline.variable_name}}(){% endfor %}); - - g_val += gi.first; - i_val += gi.second; - } - {% endfor %} - - return std::make_pair(g_val, i_val); - }; -}; - -} // namespace - -#endif /* #ifndef SYNAPSES_NEAT_H_{{cm_unique_suffix | upper }} */ diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Block.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Block.jinja2 new file mode 100644 index 000000000..c1740f094 --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Block.jinja2 @@ -0,0 +1,13 @@ +{# + Handles a complex block statement + @grammar: Block = ( Stmt | NEWLINE )*; + @param ast ASTBlock +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +{%- for statement in ast.get_stmts() %} +{%- with stmt = statement %} +{%- filter indent(2) %} +{%- include "cm_directives_cpp/Statement.jinja2" %} +{%- endfilter %} +{%- endwith %} +{%- endfor %} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/CompoundStatement.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/CompoundStatement.jinja2 new file mode 100644 index 000000000..8439f340c --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/CompoundStatement.jinja2 @@ -0,0 +1,18 @@ +{# + Handles the compound statement. + @grammar: Compound_Stmt = IF_Stmt | FOR_Stmt | WHILE_Stmt; +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +{%- if stmt.is_if_stmt() %} +{%- with ast = stmt.get_if_stmt() %} +{%- include "cm_directives_cpp/IfStatement.jinja2" %} +{%- endwith %} +{%- elif stmt.is_for_stmt() %} +{%- with ast = stmt.get_for_stmt() %} +{%- include "cm_directives_cpp/ForStatement.jinja2" %} +{%- endwith %} +{%- elif stmt.is_while_stmt() %} +{%- with ast = stmt.get_while_stmt() %} +{%- include "cm_directives_cpp/WhileStatement.jinja2" %} +{%- endwith %} +{%- endif %} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Declaration.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Declaration.jinja2 new file mode 100644 index 000000000..b72323c90 --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Declaration.jinja2 @@ -0,0 +1,21 @@ +{# + Generates C++ declaration + @param ast ASTDeclaration +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +{%- for variable in declarations.get_variables(ast) %} +{%- if ast.has_size_parameter() %} +{{declarations.print_variable_type(variable)}} {{variable.get_symbol_name()}}(P_.{{declarations.print_size_parameter(ast)}}); +{%- if ast.has_expression() %} +for (long i=0; i < get_{{declarations.print_size_parameter(ast)}}(); i++) { + {{variable.get_symbol_name()}}[i] = {{printer.print(ast.getExpr())}}; +} +{%- endif %} +{%- else %} +{%- if ast.has_expression() %} +{{variable.get_symbol_name()}}[{{ printer.std_vector_parameter }}] = {{printer.print(ast.get_expression())}}; +{%- else %} +{{variable.get_symbol_name()}}[{{ printer.std_vector_parameter }}] = 0; +{%- endif %} +{%- endif %} +{%- endfor -%} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/FunctionCall.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/FunctionCall.jinja2 new file mode 100644 index 000000000..fb2d13f0c --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/FunctionCall.jinja2 @@ -0,0 +1,11 @@ +{# + Generates C++ function call + @param ast ASTFunctionCall +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +{%- if ast.get_name() == PredefinedFunctions.EMIT_SPIKE %} +{%- include "cm_directives_cpp/PredefinedFunction_emit_spike.jinja2" %} +{%- else %} +{# call to a non-predefined function #} +{{ printer.print(ast) }}; +{%- endif %} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/IfStatement.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/IfStatement.jinja2 new file mode 100644 index 000000000..72e72c13a --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/IfStatement.jinja2 @@ -0,0 +1,33 @@ +{# + Generates C++ if..then..else statement + @param ast ASTIfStmt +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +if ({{ printer.print(ast.get_if_clause().get_condition()) }}) +{ +{%- filter indent(2, True) %} +{%- with ast = ast.get_if_clause().get_block() %} +{%- include "cm_directives_cpp/Block.jinja2" %} +{%- endwith %} +{%- endfilter %} +{%- for elif in ast.get_elif_clauses() %} +} +else if ({{ printer.print(elif.get_condition()) }}) +{ +{%- filter indent(2, True) %} +{%- with ast = elif.get_block() %} +{%- include "cm_directives_cpp/Block.jinja2" %} +{%- endwith %} +{%- endfilter %} +{%- endfor %} +{%- if ast.has_else_clause() %} +} +else +{ +{%- filter indent(2, True) %} +{%- with ast = ast.get_else_clause().get_block() %} +{%- include "cm_directives_cpp/Block.jinja2" %} +{%- endwith %} +{%- endfilter %} +{%- endif %} +} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/PredefinedFunction_emit_spike.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/PredefinedFunction_emit_spike.jinja2 new file mode 100644 index 000000000..005c24ae1 --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/PredefinedFunction_emit_spike.jinja2 @@ -0,0 +1,31 @@ +{# + Generates code for emit_spike() function call + @param ast ASTFunctionCall +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} + +/** + * generated code for emit_spike() function +**/ +{% if ast.get_args() | length == 0 %} +{#- no parameters -- emit_spike() called from within neuron #} +set_spiketime(nest::Time::step(origin.get_steps() + lag + 1)); +nest::SpikeEvent se; +nest::kernel().event_delivery_manager.send(*this, se, lag); +{%- else %} +{#- weight and delay parameters given -- emit_spike() called from within synapse #} +{#- +set_delay( {{ printer.print(ast.get_args()[1]) }} ); +const long __delay_steps = nest::Time::delay_ms_to_steps( get_delay() ); +set_delay_steps(__delay_steps); +e.set_receiver( *__target ); +e.set_weight( {{ printer.print(ast.get_args()[0]) }} ); +// use accessor functions (inherited from Connection< >) to obtain delay in steps and rport +e.set_delay_steps( get_delay_steps() ); +e.set_rport( get_rport() ); +e(); +#} +s_val[{{ printer.std_vector_parameter }}] = {{ printer.print(ast.get_args()[0]) }}; +{%- endif %} + + diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/SmallStatement.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/SmallStatement.jinja2 new file mode 100644 index 000000000..cd9631768 --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/SmallStatement.jinja2 @@ -0,0 +1,22 @@ +{# + Generates a single small statement into equivalent C++ syntax. + @param stmt ASTSmallStmt +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +{%- if stmt.is_assignment() %} +{%- with ast = stmt.get_assignment() %} +{%- include "directives_cpp/Assignment.jinja2" %} +{%- endwith %} +{%- elif stmt.is_function_call() %} +{%- with ast = stmt.get_function_call() %} +{%- include "cm_directives_cpp/FunctionCall.jinja2" %} +{%- endwith %} +{%- elif stmt.is_declaration() %} +{%- with ast = stmt.get_declaration() %} +{%- include "cm_directives_cpp/Declaration.jinja2" %} +{%- endwith %} +{%- elif stmt.is_return_stmt() %} +{%- with ast = stmt.get_return_stmt() %} +{%- include "directives_cpp/ReturnStatement.jinja2" %} +{%- endwith %} +{%- endif %} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Statement.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Statement.jinja2 new file mode 100644 index 000000000..b0cade8e0 --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/Statement.jinja2 @@ -0,0 +1,16 @@ +{# + Generates a single statement, either a simple or compound, to equivalent C++ syntax. + @param ast ASTSmallStmt or ASTCompoundStmt +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} +{%- if stmt.has_comment() %} +{{stmt.print_comment('//')}}{%- endif %} +{%- if stmt.is_small_stmt() %} +{%- with stmt = stmt.small_stmt %} +{%- include "cm_directives_cpp/SmallStatement.jinja2" %} +{%- endwith %} +{%- elif stmt.is_compound_stmt() %} +{%- with stmt = stmt.compound_stmt %} +{%- include "cm_directives_cpp/CompoundStatement.jinja2" %} +{%- endwith %} +{%- endif %} diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_global_dynamics.cpp.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_global_dynamics.cpp.jinja2 new file mode 100644 index 000000000..e60ee75c9 --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_global_dynamics.cpp.jinja2 @@ -0,0 +1,177 @@ +{%- with %} +// Global channel ////////////////////////////////////////////////////////////////// + +void nest::Global{{cm_unique_suffix}}::new_compartment() +{ + {%- for pure_variable_name, variable_info in global_info["States"].items() %} + // state variable {{pure_variable_name}} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_compartment_count") -}}); + {%- endfor %} + + {% for variable_type, variable_info in global_info["Parameters"].items() %} + // channel parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_compartment_count") -}}); + {%- endfor %} + + {% for variable_type, variable_info in global_info["Internals"].items() %} + // channel parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_compartment_count") -}}); + {%- endfor %} + + {%- for in_function_declaration in global_info["InFunctionDeclarationsVars"] %} + {%- for variable in declarations.get_variables(in_function_declaration) %} + {{variable.name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + neuron_compartment_count++; +} + +void nest::Global{{cm_unique_suffix}}::new_compartment(const DictionaryDatum& global_params) +// update {{ion_channel_name}} channel parameters +{ + neuron_compartment_count++; + + {%- for pure_variable_name, variable_info in global_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_compartment_count-1") -}}); + {%- endfor %} + + {%- with %} + {%- for variable_type, variable_info in global_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + // {{ion_channel_name}} channel parameter {{dynamic_variable }} + if( global_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_compartment_count-1] = getValue< double >( global_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + {%- with %} + {%- for variable_type, variable_info in global_info["ODEs"].items() %} + {%- set variable_name = variable_type %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + // {{ion_channel_name}} channel ODE state {{dynamic_variable }} + if( global_params->known( "{{variable_name}}" ) ) + {{variable_name}}[neuron_compartment_count-1] = getValue< double >( global_params, "{{variable_name}}" ); + {%- endfor %} + {% endwith %} + + {% for variable_type, variable_info in global_info["Parameters"].items() %} + // channel parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_compartment_count-1") -}}); + {%- endfor %} + + {%- with %} + {%- for variable_type, variable_info in global_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + // {{ion_channel_name}} channel parameter {{dynamic_variable }} + if( global_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_compartment_count-1] = getValue< double >( global_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + {%- for pure_variable_name, variable_info in global_info["Internals"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_compartment_count-1") -}}); + {%- endfor %} + + {%- for in_function_declaration in global_info["InFunctionDeclarationsVars"] %} + {%- for variable in declarations.get_variables(in_function_declaration) %} + {{variable.name}}.push_back(0); + {%- endfor %} + {%- endfor %} +} + +void +nest::Global{{cm_unique_suffix}}::append_recordables(std::map< Name, double* >* recordables, + const long compartment_idx) +{ + // add state variables to recordables map + bool found_rec = false; + {%- with %} + {%- for pure_variable_name, variable_info in global_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + found_rec = false; + ( *recordables )[ Name( std::string("{{variable.name}}") + std::to_string(compartment_idx) )] = &{{variable.name}}[compartment_idx]; + {%- endfor %} + {% endwith %} +} + +void nest::Global{{cm_unique_suffix}}::f_numstep(std::vector< double > v_comp) +{ + {% if global_info["ODEs"].items()|length %} + std::vector< double > {{ printer_no_origin.print(global_info["time_resolution_var"]) }}(neuron_compartment_count, Time::get_resolution().get_ms()); + {% endif %} + + {%- for ode_variable, ode_info in global_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + std::vector< double > {{ propagator }}(neuron_compartment_count, 0); + {%- endfor %} + {%- endfor %} + + for(std::size_t i = 0; i < neuron_compartment_count; i++){ + //update ODE state variable + {%- for ode_variable, ode_info in global_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + {{ propagator }}[i] = {{ vector_printer.print(propagator_info["init_expression"], "i") }}; + {%- endfor %} + {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} + {{state}}[i] = {{ vector_printer.print(state_solution_info["update_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + double __resolution = Time::get_resolution().get_ms(); + {%- if global_info["UpdateBlock"] %} + {%- set function = global_info["UpdateBlock"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_block() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/Block.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + } + self_spikes = false; +} + +void nest::Global{{cm_unique_suffix}}::f_self_spike() +{ + self_spikes = true; + for(std::size_t i = 0; i < neuron_compartment_count; i++){ + double __resolution = Time::get_resolution().get_ms(); + {%- if global_info["SelfSpikesFunction"] %} + {%- set function = global_info["SelfSpikesFunction"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_block() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/Block.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + } +} + +{%- for function in global_info["Functions"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_block() %} + {%- set printer = printer_no_origin %} + {%- include "cm_directives_cpp/Block.jinja2" %} + {%- endwith %} + {%- endfilter %} +{%- endfor %} + +// Global end /////////////////////////////////////////////////////////// +{% endwith %} \ No newline at end of file diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_global_dynamics.h.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_global_dynamics.h.jinja2 new file mode 100644 index 000000000..ab9ff662f --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_global_dynamics.h.jinja2 @@ -0,0 +1,87 @@ +///////////////////////////////////// global + +#include + +{%- with %} + +class Global{{cm_unique_suffix}}{ +private: + // states + {%- for pure_variable_name, variable_info in global_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector<{{ render_variable_type(variable) }}> {{ variable.name }}; + {%- endfor %} + + // parameters + {%- for pure_variable_name, variable_info in global_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector<{{ render_variable_type(variable) }}> {{ variable.name }}; + {%- endfor %} + + // internals + {%- for pure_variable_name, variable_info in global_info["Internals"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector<{{ render_variable_type(variable) }}> {{ variable.name }}; + {%- endfor %} + + {%- with %} + {%- for in_function_declaration in global_info["InFunctionDeclarationsVars"] %} + {%- for variable in declarations.get_variables(in_function_declaration) %} + std::vector<{{declarations.print_variable_type(variable)}}> {{variable.get_symbol_name()}}; + {%- endfor %} + {%- endfor %} + {%- endwith %} + + bool self_spikes = false; + +public: + // constructor, destructor + Global{{cm_unique_suffix}}(){}; + ~Global{{cm_unique_suffix}}(){}; + + // initialization global +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} + void calibrate() { +{%- else %} + void pre_run_hook() { +{%- endif %} + }; + + void new_compartment(); + void new_compartment(const DictionaryDatum& channel_params); + + //number of channels + std::size_t neuron_compartment_count = 0; + + void append_recordables(std::map< Name, double* >* recordables, const long compartment_idx); + + // numerical integration step + void f_numstep(std::vector< double > v_comp); + + void f_self_spike(); + + // function declarations + +{%- for function in global_info["Functions"] %} + #pragma omp declare simd + __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function) }}; +{%- endfor %} + + // states getters + {%- for pure_variable_name, variable_info in global_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector<{{ render_variable_type(variable) }}> get_{{ variable.name }}(){ + return {{ variable.name }}; + }; + {%- endfor %} + + bool get_self_spikes(){ + return self_spikes; + }; + +}; +{% endwith -%} \ No newline at end of file diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.cpp.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.cpp.jinja2 new file mode 100644 index 000000000..ec0c5f4ed --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.cpp.jinja2 @@ -0,0 +1,1521 @@ +{#- +cm_neuroncurrents_@NEURON_NAME@.cpp.jinja2 + +This file is part of NEST. + +Copyright (C) 2004 The NEST Initiative + +NEST is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 2 of the License, or +(at your option) any later version. + +NEST is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with NEST. If not, see . +#} +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif -%} +{%- import 'directives_cpp/FunctionDeclaration.jinja2' as function_declaration with context %} +#include "{{neuronSpecificFileNamesCmSyns["neuroncurrents"]}}.h" + +#define NEAR_ZERO 1e-9 + +{%- set current_conductance_name_prefix = "g" %} +{%- set current_equilibrium_name_prefix = "e" %} +{% macro render_dynamic_channel_variable_name(variable_type, ion_channel_name) -%} + {%- if variable_type == "gbar" -%} + {{ current_conductance_name_prefix~"_"~ion_channel_name }} + {%- elif variable_type == "e" -%} + {{ current_equilibrium_name_prefix~"_"~ion_channel_name }} + {%- endif -%} +{%- endmacro -%} + +{%- macro render_state_variable_name(pure_variable_name, ion_channel_name) -%} + {{ pure_variable_name~"_"~ion_channel_name }} +{%- endmacro -%} + +{% macro render_time_resolution_variable(receptor_info) -%} +{# we assume here that there is only one such variable ! #} +{%- with %} +{%- for analytic_helper_name, analytic_helper_info in receptor_info["analytic_helpers"].items() -%} +{%- if analytic_helper_info["is_time_resolution"] -%} + {{ analytic_helper_name }} +{%- endif -%} +{%- endfor -%} +{% endwith %} +{%- endmacro %} + +{% macro render_function_return_type(function) -%} +{%- with -%} + {%- set symbol = function.get_scope().resolve_to_symbol(function.get_name(), SymbolKind.FUNCTION) -%} + {{ types_printer.print(symbol.get_return_type()) }} +{%- endwith -%} +{%- endmacro -%} + +{% macro render_inline_expression_type(inline_expression) -%} +{%- with -%} + {%- set symbol = inline_expression.get_scope().resolve_to_symbol(inline_expression.variable_name, SymbolKind.VARIABLE) -%} + {{ types_printer.print(symbol.get_type_symbol()) }} +{%- endwith -%} +{%- endmacro -%} + +{% macro render_static_channel_variable_name(variable_type, ion_channel_name) -%} + +{%- with %} +{%- for ion_channel_nm, channel_info in chan_info.items() -%} + {%- if ion_channel_nm == ion_channel_name -%} + {%- for variable_tp, variable_info in channel_info["channel_parameters"].items() -%} + {%- if variable_tp == variable_type -%} + {%- set variable = variable_info["parameter_block_variable"] -%} + {{ variable.name }} + {%- endif -%} + {%- endfor -%} + {%- endif -%} +{%- endfor -%} +{% endwith %} + +{%- endmacro %} + +{% macro render_channel_function(function, ion_channel_name) -%} +{%- with %} + {%- set printer = printer_no_origin %} +inline {{ function_declaration.FunctionDeclaration(function, "nest::"~ion_channel_name~cm_unique_suffix~"::", true) }} +{ +{%- filter indent(2,True) %} +{%- with ast = function.get_block() %} +{%- include "directives_cpp/Block.jinja2" %} +{%- endwith %} +{%- endfilter %} +} +{% endwith %} +{%- endmacro %} + +{% macro render_vectorized_channel_function(function, ion_channel_name) -%} +{%- with %} +{{ vectorized_function_declaration.FunctionDeclaration(function, "nest::"~ion_channel_name~cm_unique_suffix~"::", true) }} +{ +{%- filter indent(2,True) %} +{%- with ast = function.get_block() %} +{%- include "directives_cpp/VectorizedBlock.jinja2" %} +{%- endwith %} +{%- endfilter %} +} +{% endwith %} +{%- endmacro %} + +{%- macro vectorized_function_call(ast_function, ion_channel_name) -%} +{%- with function_symbol = ast_function.get_scope().resolve_to_symbol(ast_function.get_name(), SymbolKind.FUNCTION) -%} +{%- if function_symbol is none -%} +{{ raise('Cannot resolve the method ' + ast_function.get_name()) }} +{%- endif %} +{{ "std::vector< " + type_symbol_printer.print(function_symbol.get_return_type()) + " >" | replace('.', '::') }} {{ ast_function.get_name() }}_v(neuron_{{ ion_channel_name }}_channel_count); +{{ ast_function.get_name() }}( +{%- for param in ast_function.get_parameters() %} +{%- with typeSym = param.get_data_type().get_type_symbol() -%} +{%- filter indent(1, True) -%} +{{ param.get_name() }} +{%- if not loop.last -%} +, +{%- endif -%} +{%- endfilter -%} +{%- endwith -%} +{%- endfor -%} +, {{ ast_function.get_name() }}_v ); +{%- endwith -%} +{%- endmacro -%} + +{% macro render_variable_type(variable) -%} +{%- with -%} + {%- set symbol = variable.get_scope().resolve_to_symbol(variable.name, SymbolKind.VARIABLE) -%} + {{ types_printer.print(symbol.type_symbol) }} +{%- endwith -%} +{%- endmacro %} + + +{%- with %} +{%- for ion_channel_name, channel_info in chan_info.items() %} + +// {{ion_channel_name}} channel ////////////////////////////////////////////////////////////////// +void nest::{{ion_channel_name}}{{cm_unique_suffix}}::new_channel(std::size_t comp_ass) +{ + //Check whether the channel will contribute at all based on initial key-parameters. If not then don't add the channel. + bool channel_contributing = true; + {%- for key_zero_param in channel_info["RootInlineKeyZeros"] %} + {% for variable_type, variable_info in channel_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {% if key_zero_param == variable.name %} + if(std::abs({{ printer_no_origin.print(rhs_expression) }}) <= NEAR_ZERO){ + channel_contributing = false; + } + {% endif %} + {%- endfor %} + {%- endfor %} + + if(channel_contributing){ + neuron_{{ ion_channel_name }}_channel_count++; + i_tot_{{ion_channel_name}}.push_back(0); + compartment_association.push_back(comp_ass); + + {%- for pure_variable_name, variable_info in channel_info["States"].items() %} + // state variable {{pure_variable_name}} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+ion_channel_name+"_channel_count") -}}); + {%- endfor %} + + {% for variable_type, variable_info in channel_info["Parameters"].items() %} + // channel parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+ion_channel_name+"_channel_count") -}}); + {%- endfor %} + + {% for variable_type, variable_info in channel_info["Internals"].items() %} + // channel parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+ion_channel_name+"_channel_count") -}}); + {%- endfor %} + + {% for state in channel_info["Dependencies"]["global"] %} + {{ printer_no_origin.print(state) }}.push_back(0); + {% endfor %} + } +} + +void nest::{{ion_channel_name}}{{cm_unique_suffix}}::new_channel(std::size_t comp_ass, const DictionaryDatum& channel_params) +/* update {{ion_channel_name}} channel parameters and states */ +{ + //Check whether the channel will contribute at all based on initial key-parameters. If not then don't add the channel. + bool channel_contributing = true; + {%- for key_zero_param in channel_info["RootInlineKeyZeros"] %} + if( channel_params->known( "{{key_zero_param}}" ) ){ + if(std::abs(getValue< double >( channel_params, "{{key_zero_param}}" )) <= NEAR_ZERO){ + channel_contributing = false; + } + }else{ + {% for variable_type, variable_info in channel_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {% if key_zero_param == variable.name %} + if(std::abs({{ printer_no_origin.print(rhs_expression) }}) <= NEAR_ZERO){ + channel_contributing = false; + } + {% endif %} + {%- endfor %} + } + {%- endfor %} + + if(channel_contributing){ + neuron_{{ ion_channel_name }}_channel_count++; + compartment_association.push_back(comp_ass); + i_tot_{{ion_channel_name}}.push_back(0); + + {%- for pure_variable_name, variable_info in channel_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+ion_channel_name+"_channel_count") -}}); + {%- endfor %} + + {%- with %} + {%- for variable_type, variable_info in channel_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + // {{ion_channel_name}} channel parameter {{dynamic_variable }} + if( channel_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ ion_channel_name }}_channel_count-1] = getValue< double >( channel_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + {%- with %} + {%- for variable_type, variable_info in channel_info["ODEs"].items() %} + {%- set variable_name = variable_type %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + // {{ion_channel_name}} channel ODE state {{dynamic_variable }} + if( channel_params->known( "{{variable_name}}" ) ) + {{variable_name}}[neuron_{{ ion_channel_name }}_channel_count-1] = getValue< double >( channel_params, "{{variable_name}}" ); + {%- endfor %} + {% endwith %} + + {% for variable_type, variable_info in channel_info["Parameters"].items() %} + // channel parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+ion_channel_name+"_channel_count") -}}); + {%- endfor %} + + {%- with %} + {%- for variable_type, variable_info in channel_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + // {{ion_channel_name}} channel parameter {{dynamic_variable }} + if( channel_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ ion_channel_name }}_channel_count-1] = getValue< double >( channel_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + {%- for pure_variable_name, variable_info in channel_info["Internals"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+ion_channel_name+"_channel_count") -}}); + {%- endfor %} + + {% for state in channel_info["Dependencies"]["global"] %} + {{ printer_no_origin.print(state) }}.push_back(0); + {% endfor %} + } +} + +void +nest::{{ion_channel_name}}{{cm_unique_suffix}}::append_recordables(std::map< Name, double* >* recordables, + const long compartment_idx) +{ + // add state variables to recordables map + bool found_rec = false; + {%- with %} + {%- for pure_variable_name, variable_info in channel_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + found_rec = false; + for(size_t chan_id = 0; chan_id < neuron_{{ ion_channel_name }}_channel_count; chan_id++){ + if(compartment_association[chan_id] == compartment_idx){ + ( *recordables )[ Name( std::string("{{variable.name}}") + std::to_string(compartment_idx))] = &{{variable.name}}[chan_id]; + found_rec = true; + } + } + if(!found_rec) ( *recordables )[ Name( std::string("{{variable.name}}") + std::to_string(compartment_idx))] = &zero_recordable; + {%- endfor %} + {% endwith %} + found_rec = false; + for(size_t chan_id = 0; chan_id < neuron_{{ ion_channel_name }}_channel_count; chan_id++){ + if(compartment_association[chan_id] == compartment_idx){ + ( *recordables )[ Name( std::string("i_tot_{{ion_channel_name}}") + std::to_string(compartment_idx))] = &i_tot_{{ion_channel_name}}[chan_id]; + found_rec = true; + } + } + if(!found_rec) ( *recordables )[ Name( std::string("i_tot_{{ion_channel_name}}") + std::to_string(compartment_idx))] = &zero_recordable; +} + +std::pair< std::vector< double >, std::vector< double > > nest::{{ion_channel_name}}{{cm_unique_suffix}}::f_numstep(bool point_self_spikes, std::vector< double > v_comp{% for ode in channel_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if channel_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in channel_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if channel_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in channel_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if channel_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in channel_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if channel_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in channel_info["Dependencies"]["global"] %}, std::vector<{{ render_variable_type(state) }}> {{ printer_no_origin.print(state) }}{% endfor %}) +{ + std::vector< bool > self_spikes(neuron_{{ ion_channel_name }}_channel_count, point_self_spikes); + std::vector< double > g_val(neuron_{{ ion_channel_name }}_channel_count, 0.); + std::vector< double > i_val(neuron_{{ ion_channel_name }}_channel_count, 0.); + + std::vector< double > d_i_tot_dv(neuron_{{ ion_channel_name }}_channel_count, 0.); + + {% if channel_info["ODEs"].items()|length %} std::vector< double > {{ printer_no_origin.print(channel_info["time_resolution_var"]) }}(neuron_{{ ion_channel_name }}_channel_count, Time::get_resolution().get_ms()); {% endif %} + + {%- for ode_variable, ode_info in channel_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + std::vector< double > {{ propagator }}(neuron_{{ ion_channel_name }}_channel_count, 0); + {%- endfor %} + {%- endfor %} + #pragma omp simd + for(std::size_t i = 0; i < neuron_{{ ion_channel_name }}_channel_count; i++){ + {%- for ode_variable, ode_info in channel_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + {{ propagator }}[i] = {{ vector_printer.print(propagator_info["init_expression"], "i") }}; + {%- endfor %} + {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} + {{state}}[i] = {{ vector_printer.print(state_solution_info["update_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + {%- set inline_expression = channel_info["root_expression"] %} + {%- set inline_expression_d = channel_info["inline_derivative"] %} + + // compute the conductance of the {{ion_channel_name}} channel + this->i_tot_{{ion_channel_name}}[i] = {{ vector_printer.print(inline_expression.get_expression(), "i") }}; + + // derivative + d_i_tot_dv[i] = {{ vector_printer.print(inline_expression_d, "i") }}; + g_val[i] = - d_i_tot_dv[i]; + i_val[i] = this->i_tot_{{ion_channel_name}}[i] - d_i_tot_dv[i] * v_comp[i]; + } + return std::make_pair(g_val, i_val); + +} + +{%- for function in channel_info["Functions"] %} +{{render_channel_function(function, ion_channel_name)}} +{%- endfor %} +void nest::{{ion_channel_name}}{{cm_unique_suffix}}::get_currents_per_compartment(std::vector< double >& compartment_to_current){ + for(std::size_t comp_id = 0; comp_id < compartment_to_current.size(); comp_id++){ + compartment_to_current[comp_id] = 0; + } + for(std::size_t chan_id = 0; chan_id < neuron_{{ ion_channel_name }}_channel_count; chan_id++){ + compartment_to_current[this->compartment_association[chan_id]] += this->i_tot_{{ion_channel_name}}[chan_id]; + } +} + +std::vector< double > nest::{{ion_channel_name}}{{cm_unique_suffix}}::distribute_shared_vector(std::vector< double > shared_vector){ + std::vector< double > distributed_vector(this->neuron_{{ ion_channel_name }}_channel_count, 0.0); + for(std::size_t chan_id = 0; chan_id < this->neuron_{{ ion_channel_name }}_channel_count; chan_id++){ + distributed_vector[chan_id] = shared_vector[compartment_association[chan_id]]; + } + return distributed_vector; +} + +// {{ion_channel_name}} channel end /////////////////////////////////////////////////////////// +{% endfor -%} +{% endwith %} +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////// concentrations +{%- with %} +{%- for concentration_name, concentration_info in conc_info.items() %} + +// {{ concentration_name }} concentration ///////////////////////////////////////////////////// + +void nest::{{concentration_name}}{{cm_unique_suffix}}::new_concentration(std::size_t comp_ass) +{ + //Check whether the concentration will contribute at all based on initial key-parameters. If not then don't add the concentration. + bool concentration_contributing = true; + {%- for key_zero_param in concentration_info["RootInlineKeyZeros"] %} + {% for variable_type, variable_info in concentration_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {% if key_zero_param == variable.name %} + if(std::abs({{ printer_no_origin.print(rhs_expression) }}) <= NEAR_ZERO){ + concentration_contributing = false; + } + {% endif %} + {%- endfor %} + {%- endfor %} + + if(concentration_contributing){ + neuron_{{ concentration_name }}_concentration_count++; + {{concentration_name}}.push_back(0); + compartment_association.push_back(comp_ass); + + {%- for pure_variable_name, variable_info in concentration_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count") -}}); + {%- endfor %} + + {% for variable_type, variable_info in concentration_info["Parameters"].items() %} + // channel parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count") -}}); + {%- endfor %} + + {% for variable_type, variable_info in concentration_info["Internals"].items() %} + // channel parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count") -}}); + {%- endfor %} + } +} + +void nest::{{concentration_name}}{{cm_unique_suffix}}::new_concentration(std::size_t comp_ass, const DictionaryDatum& concentration_params) +{ + //Check whether the concentration will contribute at all based on initial key-parameters. If not then don't add the concentration. + bool concentration_contributing = true; + {%- for key_zero_param in concentration_info["RootInlineKeyZeros"] %} + if( concentration_params->known( "{{key_zero_param}}" ) ){ + if(std::abs(getValue< double >( concentration_params, "{{key_zero_param}}" )) <= NEAR_ZERO){ + concentration_contributing = false; + } + }else{ + {% for variable_type, variable_info in concentration_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {% if key_zero_param == variable.name %} + if(std::abs({{ printer_no_origin.print(rhs_expression) }}) <= NEAR_ZERO){ + concentration_contributing = false; + } + {% endif %} + {%- endfor %} + } + {%- endfor %} + + if(concentration_contributing){ + neuron_{{ concentration_name }}_concentration_count++; + {{concentration_name}}.push_back(0); + compartment_association.push_back(comp_ass); + + {%- for pure_variable_name, variable_info in concentration_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count") -}}); + {%- endfor %} + + {%- with %} + {%- for variable_type, variable_info in concentration_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + // {{ion_channel_name}} channel parameter {{dynamic_variable }} + if( concentration_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ concentration_name }}_concentration_count-1] = getValue< double >( concentration_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + {%- with %} + {%- for variable_type, variable_info in concentration_info["ODEs"].items() %} + {%- set variable_name = variable_type %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + // {{concentration_name}} concentration ODE state {{dynamic_variable }} + if( concentration_params->known( "{{variable_name}}" ) ) + {{variable_name}}[neuron_{{ concentration_name }}_concentration_count-1] = getValue< double >( concentration_params, "{{variable_name}}" ); + {%- endfor %} + {% endwith %} + + {% for variable_type, variable_info in concentration_info["Parameters"].items() %} + // channel parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count") -}}); + {%- endfor %} + + {%- with %} + {%- for variable_type, variable_info in concentration_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, concentration_name) %} + // {{ concentration_name }} concentration parameter {{dynamic_variable }} + if( concentration_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ concentration_name }}_concentration_count-1] = getValue< double >( concentration_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + {%- for pure_variable_name, variable_info in concentration_info["Internals"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+concentration_name+"_concentration_count") -}}); + {%- endfor %} + } +} + +void +nest::{{ concentration_name }}{{cm_unique_suffix}}::append_recordables(std::map< Name, double* >* recordables, + const long compartment_idx) +{ + // add state variables to recordables map + bool found_rec = false; + {%- with %} + {%- for pure_variable_name, variable_info in concentration_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + found_rec = false; + for(size_t conc_id = 0; conc_id < neuron_{{ concentration_name }}_concentration_count; conc_id++){ + if(compartment_association[conc_id] == compartment_idx){ + ( *recordables )[ Name( std::string("{{variable.name}}") + std::to_string(compartment_idx))] = &{{variable.name}}[conc_id]; + found_rec = true; + } + } + if(!found_rec) ( *recordables )[ Name( std::string("{{variable.name}}") + std::to_string(compartment_idx))] = &zero_recordable; + {%- endfor %} + {% endwith %} + found_rec = false; + for(size_t conc_id = 0; conc_id < neuron_{{ concentration_name }}_concentration_count; conc_id++){ + if(compartment_association[conc_id] == compartment_idx){ + ( *recordables )[ Name( std::string("{{concentration_name}}") + std::to_string(compartment_idx))] = &{{concentration_name}}[conc_id]; + found_rec = true; + } + } + if(!found_rec) ( *recordables )[ Name( std::string("{{concentration_name}}") + std::to_string(compartment_idx))] = &zero_recordable; +} + +void nest::{{ concentration_name }}{{cm_unique_suffix}}::f_numstep(bool point_self_spikes, std::vector< double > v_comp{% for ode in concentration_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if concentration_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in concentration_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if concentration_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in concentration_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if concentration_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in concentration_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if concentration_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in concentration_info["Dependencies"]["global"] %}, std::vector<{{ render_variable_type(state) }}> {{ printer_no_origin.print(state) }}{% endfor %}) +{ + std::vector< bool > self_spikes(neuron_{{ concentration_name }}_concentration_count, point_self_spikes); + std::vector< double > {{ printer_no_origin.print(concentration_info["time_resolution_var"]) }}(neuron_{{ concentration_name }}_concentration_count, Time::get_resolution().get_ms()); + + {%- for ode_variable, ode_info in concentration_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + std::vector< double > {{ propagator }}(neuron_{{ concentration_name }}_concentration_count, 0); + {%- endfor %} + {%- endfor %} + + #pragma omp simd + for(std::size_t i = 0; i < neuron_{{ concentration_name }}_concentration_count; i++){ + {%- for ode_variable, ode_info in concentration_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + {{ propagator }}[i] = {{ vector_printer.print(propagator_info["init_expression"], "i") }}; + {%- endfor %} + {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} + {{state}}[i] = {{ vector_printer.print(state_solution_info["update_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + } +} + +{%- for function in concentration_info["Functions"] %} +{{render_channel_function(function, concentration_name)}} +{%- endfor %} + +void nest::{{concentration_name}}{{cm_unique_suffix}}::get_concentrations_per_compartment(std::vector< double >& compartment_to_concentration){ + for(std::size_t comp_id = 0; comp_id < compartment_to_concentration.size(); comp_id++){ + compartment_to_concentration[comp_id] = 0; + } + for(std::size_t conc_id = 0; conc_id < neuron_{{ concentration_name }}_concentration_count; conc_id++){ + compartment_to_concentration[this->compartment_association[conc_id]] += this->{{concentration_name}}[conc_id]; + } +} + +std::vector< double > nest::{{concentration_name}}{{cm_unique_suffix}}::distribute_shared_vector(std::vector< double > shared_vector){ + std::vector< double > distributed_vector(this->neuron_{{ concentration_name }}_concentration_count, 0.0); + for(std::size_t conc_id = 0; conc_id < this->neuron_{{ concentration_name }}_concentration_count; conc_id++){ + distributed_vector[conc_id] = shared_vector[compartment_association[conc_id]]; + } + return distributed_vector; +} + +// {{concentration_name}} concentration end /////////////////////////////////////////////////////////// +{% endfor -%} +{% endwith %} + + +////////////////////////////////////// receptors + +{%- for receptor_name, receptor_info in recs_info.items() %} +// {{receptor_name}} receptor //////////////////////////////////////////////////////////////// + +void nest::{{receptor_name}}{{cm_unique_suffix}}::new_receptor(std::size_t comp_ass, const long rec_index) +{ + neuron_{{ receptor_name }}_receptor_count++; + i_tot_{{receptor_name}}.push_back(0); + compartment_association.push_back(comp_ass); + rec_idx.push_back(rec_index); + + {%- for pure_variable_name, variable_info in receptor_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count") -}}); + {%- endfor %} + + {% for variable_type, variable_info in receptor_info["Parameters"].items() %} + // receptor parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count") -}}); + {%- endfor %} + + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} + {{internal_name}}.push_back(0); + {%- endfor %} +} + +void nest::{{receptor_name}}{{cm_unique_suffix}}::new_receptor(std::size_t comp_ass, const long rec_index, const DictionaryDatum& receptor_params) +/* update {{receptor}} receptor parameters and states */ +{ + neuron_{{ receptor_name }}_receptor_count++; + compartment_association.push_back(comp_ass); + i_tot_{{receptor_name}}.push_back(0); + rec_idx.push_back(rec_index); + + {%- for pure_variable_name, variable_info in receptor_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count") -}}); + {%- endfor %} + {%- with %} + {%- for variable_type, variable_info in receptor_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( receptor_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + {%- with %} + {%- for variable_type, variable_info in receptor_info["ODEs"].items() %} + {%- set variable_name = variable_type %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + // {{concentration_name}} concentration ODE state {{dynamic_variable }} + if( receptor_params->known( "{{variable_name}}" ) ) + {{variable_name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable_name}}" ); + {%- endfor %} + {% endwith %} + + {% for variable_type, variable_info in receptor_info["Parameters"].items() %} + // receptor parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count") -}}); + {%- endfor %} + + {%- with %} + {%- for variable_type, variable_info in receptor_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( receptor_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} + {{internal_name}}.push_back(0); + {%- endfor %} +} + +void +nest::{{receptor_name}}{{cm_unique_suffix}}::append_recordables(std::map< Name, double* >* recordables, const long compartment_idx) +{ + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + for(size_t recs_id = 0; recs_id < neuron_{{ receptor_name }}_receptor_count; recs_id++){ + if(compartment_association[recs_id] == compartment_idx){ + ( *recordables )[ Name( "{{convolution_info["kernel"]["name"]}}" + std::to_string(recs_id) )] = &{{convolution}}[recs_id]; + } + } + {%- endfor %} + for(size_t recs_id = 0; recs_id < neuron_{{ receptor_name }}_receptor_count; recs_id++){ + if(compartment_association[recs_id] == compartment_idx){ + ( *recordables )[ Name( "i_tot_{{receptor_name}}" + std::to_string(recs_id) )] = &i_tot_{{receptor_name}}[recs_id]; + } + } +} + +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} +void nest::{{receptor_name}}{{cm_unique_suffix}}::calibrate() +{%- else %} +void nest::{{receptor_name}}{{cm_unique_suffix}}::pre_run_hook() +{%- endif %} +{ + + std::vector< double > {{ printer_no_origin.print(receptor_info["analytic_helpers"]["__h"]["ASTVariable"]) }}(neuron_{{ receptor_name }}_receptor_count, Time::get_resolution().get_ms()); + + {%- for state_name, state_declaration in receptor_info["States"].items() %} + std::vector< double > {{state_name}} = (neuron_{{ receptor_name }}_receptor_count, {{ printer_no_origin.print(state_declaration["rhs_expression"])}}); + {%- endfor %} + + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}[i] = 0; + {%- endfor %} + {%- endfor %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} + {{internal_name}}[i] = {{ vector_printer.print(internal_declaration.get_expression(), "i") }}; + {%- endfor %} + + {{receptor_info["buffer_name"]}}_[i]->clear(); + } +} + +std::pair< std::vector< double >, std::vector< double > > nest::{{receptor_name}}{{cm_unique_suffix}}::f_numstep( bool point_self_spikes, std::vector< double > v_comp, const long lag {% for ode in receptor_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if receptor_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if receptor_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if receptor_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if receptor_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in receptor_info["Dependencies"]["global"] %}, std::vector<{{ render_variable_type(state) }}> {{ printer_no_origin.print(state) }}{% endfor %}) +{ + std::vector< bool > self_spikes(neuron_{{ receptor_name }}_receptor_count, point_self_spikes); + std::vector< double > g_val(neuron_{{ receptor_name }}_receptor_count, 0.); + std::vector< double > i_val(neuron_{{ receptor_name }}_receptor_count, 0.); + std::vector< double > d_i_tot_dv(neuron_{{ receptor_name }}_receptor_count, 0.); + + {%- for ode_variable, ode_info in receptor_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + std::vector< double > {{ propagator }}(neuron_{{ receptor_name }}_receptor_count, 0); + {%- endfor %} + {%- endfor %} + + {% if receptor_info["ODEs"].items()|length %} std::vector< double > {{ printer_no_origin.print(receptor_info["analytic_helpers"]["__h"]["ASTVariable"]) }}(neuron_{{ receptor_name }}_receptor_count, Time::get_resolution().get_ms()); {% endif %} + + std::vector < double > s_val(neuron_{{ receptor_name }}_receptor_count, 0); + + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + // get spikes + s_val[i] = {{receptor_info["buffer_name"]}}_[i]->get_value( lag ); // * g_norm_; + } + + //update ODE state variable + #pragma omp simd + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + {%- for ode_variable, ode_info in receptor_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + {{ propagator }}[i] = {{ vector_printer.print(propagator_info["init_expression"], "i") }}; + {%- endfor %} + {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} + {{state}}[i] = {{ vector_printer.print(state_solution_info["update_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + // update kernel state variable / compute recaptic conductance + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items() %} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["update_expression"], "i") }}; + {{state_variable_name}}[i] += s_val[i] * {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + // total current + // this expression should be the transformed inline expression + + this->i_tot_{{receptor_name}}[i] = {{ vector_printer.print(receptor_info["root_expression"].get_expression(), "i") }}; + + // derivative of that expression + // voltage derivative of total current + // compute derivative with respect to current with sympy + d_i_tot_dv[i] = {{ vector_printer.print(receptor_info["inline_expression_d"], "i") }}; + + // for numerical integration + g_val[i] = - d_i_tot_dv[i]; + i_val[i] = this->i_tot_{{receptor_name}}[i] - d_i_tot_dv[i] * v_comp[i]; + } + + return std::make_pair(g_val, i_val); + +} + +{%- for function in receptor_info["functions_used"] %} +inline {{ function_declaration.FunctionDeclaration(function, "nest::"~receptor_name~cm_unique_suffix~"::", true) }} +{ +{%- filter indent(2,True) %} +{%- with ast = function.get_block() %} +{%- include "directives/Block.jinja2" %} +{%- endwith %} +{%- endfilter %} +} +{%- endfor %} + +void nest::{{receptor_name}}{{cm_unique_suffix}}::get_currents_per_compartment(std::vector< double >& compartment_to_current){ + for(std::size_t comp_id = 0; comp_id < compartment_to_current.size(); comp_id++){ + compartment_to_current[comp_id] = 0; + } + for(std::size_t rec_id = 0; rec_id < neuron_{{ receptor_name }}_receptor_count; rec_id++){ + compartment_to_current[this->compartment_association[rec_id]] += this->i_tot_{{receptor_name}}[rec_id]; + } +} + +std::vector< double > nest::{{receptor_name}}{{cm_unique_suffix}}::distribute_shared_vector(std::vector< double > shared_vector){ + std::vector< double > distributed_vector(this->neuron_{{ receptor_name }}_receptor_count, 0.0); + for(std::size_t rec_id = 0; rec_id < this->neuron_{{ receptor_name }}_receptor_count; rec_id++){ + distributed_vector[rec_id] = shared_vector[compartment_association[rec_id]]; + } + return distributed_vector; +} + +// {{receptor_name}} receptor end /////////////////////////////////////////////////////////// +{%- endfor %} + +////////////////////////////////////// receptors with synapses attached +{%- for synapse_name, synapse_info in syns_info.items() %} +{%- for receptor_name, receptor_info in recs_info.items() %} +// {{receptor_name}} receptor //////////////////////////////////////////////////////////////// + +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::new_receptor(std::size_t comp_ass, const long syn_index) +{ + neuron_{{ receptor_name }}_receptor_count++; + i_tot_{{receptor_name}}.push_back(0); + compartment_association.push_back(comp_ass); + syn_idx.push_back(syn_index); + + {%- for pure_variable_name, variable_info in receptor_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count") -}}); + {%- endfor %} + + {% for variable_type, variable_info in receptor_info["Parameters"].items() %} + // receptor parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count") -}}); + {%- endfor %} + + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} + {{internal_name}}.push_back(0); + {%- endfor %} + + + //synapse components: + {%- for pure_variable_name, variable_info in synapse_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count") -}}); + {%- endfor %} + + {% for variable_type, variable_info in synapse_info["Parameters"].items() %} + // receptor parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count") -}}); + {%- endfor %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in synapse_info["Internals"] %} + {{internal_name}}.push_back(0); + {%- endfor %} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for inline_name, inline in synapse_info["Inlines"].items() %} + {{inline_name}}.push_back(0); + {%- endfor %} + + {%- with %} + {%- for in_function_declaration in synapse_info["InFunctionDeclarationsVars"] %} + {%- for variable in declarations.get_variables(in_function_declaration) %} + {{variable.get_symbol_name()}}.push_back(0); + {%- endfor %} + {%- endfor %} + {%- endwith %} +} + +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::new_receptor(std::size_t comp_ass, const long syn_index, const DictionaryDatum& receptor_params) +// update {{receptor}} receptor parameters +{ + neuron_{{ receptor_name }}_receptor_count++; + compartment_association.push_back(comp_ass); + i_tot_{{receptor_name}}.push_back(0); + syn_idx.push_back(syn_index); + + {%- for pure_variable_name, variable_info in receptor_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count") -}}); + {%- endfor %} + {%- with %} + {%- for variable_type, variable_info in receptor_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( receptor_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + {%- with %} + {%- for variable_type, variable_info in receptor_info["ODEs"].items() %} + {%- set variable_name = variable_type %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + // {{concentration_name}} concentration ODE state {{dynamic_variable }} + if( receptor_params->known( "{{variable_name}}" ) ) + {{variable_name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable_name}}" ); + {%- endfor %} + {% endwith %} + + {% for variable_type, variable_info in receptor_info["Parameters"].items() %} + // receptor parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count") -}}); + {%- endfor %} + + {%- with %} + {%- for variable_type, variable_info in receptor_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( receptor_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} + {{internal_name}}.push_back(0); + {%- endfor %} + + + //synapse components: + {%- for pure_variable_name, variable_info in synapse_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count") -}}); + {%- endfor %} + {%- for variable_type, variable_info in synapse_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( receptor_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable.name}}" ); + {%- endfor %} + + {% for variable_type, variable_info in synapse_info["Parameters"].items() %} + // receptor parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+receptor_name+"_receptor_count") -}}); + {%- endfor %} + {%- for variable_type, variable_info in synapse_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( receptor_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ receptor_name }}_receptor_count-1] = getValue< double >( receptor_params, "{{variable.name}}" ); + {%- endfor %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in synapse_info["Internals"] %} + {{internal_name}}.push_back(0); + {%- endfor %} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}.push_back(0); + {%- endfor %} + {%- endfor %} + + {%- for inline_name, inline in synapse_info["Inlines"].items() %} + {{inline_name}}.push_back(0); + {%- endfor %} + + {%- for in_function_declaration in synapse_info["InFunctionDeclarationsVars"] %} + {%- for variable in declarations.get_variables(in_function_declaration) %} + {{variable.get_symbol_name()}}.push_back(0); + {%- endfor %} + {%- endfor %} +} + +void +nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::append_recordables(std::map< Name, double* >* recordables, const long compartment_idx) +{ + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + for(size_t syns_id = 0; syns_id < neuron_{{ receptor_name }}_receptor_count; syns_id++){ + if(compartment_association[syns_id] == compartment_idx){ + ( *recordables )[ Name( "{{convolution_info["kernel"]["name"]}}" + std::to_string(syns_id) )] = &{{convolution}}[syns_id]; + } + } + {%- endfor %} + for(size_t syns_id = 0; syns_id < neuron_{{ receptor_name }}_receptor_count; syns_id++){ + if(compartment_association[syns_id] == compartment_idx){ + ( *recordables )[ Name( "i_tot_{{receptor_name}}_{{synapse_name}}" + std::to_string(syns_id) )] = &i_tot_{{receptor_name}}[syns_id]; + } + } + + //synapse states + {%- for pure_variable_name, variable_info in synapse_info["States"].items() %} + for(size_t syns_id = 0; syns_id < neuron_{{ receptor_name }}_receptor_count; syns_id++){ + if(compartment_association[syns_id] == compartment_idx){ + ( *recordables )[ Name( "{{pure_variable_name}}" + std::to_string(syns_id) )] = &{{pure_variable_name}}[syns_id]; + } + } + {%- endfor %} + + {%- for inline_name, inline in synapse_info["Inlines"].items() %} + for(size_t syns_id = 0; syns_id < neuron_{{ receptor_name }}_receptor_count; syns_id++){ + if(compartment_association[syns_id] == compartment_idx){ + ( *recordables )[ Name( "{{inline_name}}" + std::to_string(syns_id) )] = &{{inline_name}}[syns_id]; + } + } + {%- endfor %} +} + +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::calibrate() +{%- else %} +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::pre_run_hook() +{%- endif %} +{ + + std::vector< double > {{ printer_no_origin.print(receptor_info["analytic_helpers"]["__h"]["ASTVariable"]) }}(neuron_{{ receptor_name }}_receptor_count, Time::get_resolution().get_ms()); + + {%- for state_name, state_declaration in receptor_info["States"].items() %} + std::vector< double > {{state_name}} = (neuron_{{ receptor_name }}_receptor_count, {{ printer_no_origin.print(state_declaration["rhs_expression"])}}); + {%- endfor %} + + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + // set propagators to ode toolbox returned value + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + // initial values for kernel state variables, set to zero + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + {{state_variable_name}}[i] = 0; + {%- endfor %} + {%- endfor %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} + {{internal_name}}[i] = {{ vector_printer.print(internal_declaration.get_expression(), "i") }}; + {%- endfor %} + + {{receptor_info["buffer_name"]}}_[i]->clear(); + } +} + +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::postsynaptic_synaptic_processing(){ + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++) { + {%- set function = synapse_info["PostSpikeFunction"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_block() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/Block.jinja2" %} + {%- endwith %} + {%- endfilter %} + } +} + {%- with %} + {%- set conc_dep = set(receptor_info["Dependencies"]["concentrations"]).union(synapse_info["Dependencies"]["concentrations"])%} + {%- set rec_dep = set(receptor_info["Dependencies"]["receptors"]).union(synapse_info["Dependencies"]["receptors"])%} + {%- set chan_dep = set(receptor_info["Dependencies"]["channels"]).union(synapse_info["Dependencies"]["channels"])%} + {%- set con_in_dep = set(receptor_info["Dependencies"]["continuous"]).union(synapse_info["Dependencies"]["continuous"])%} +std::pair< std::vector< double >, std::vector< double > > nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::f_numstep( bool point_self_spikes, std::vector< double > v_comp, const long lag {% for ode in conc_dep %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if rec_dep|length %} + {% endif %}{% for inline in rec_dep %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if chan_dep|length %} + {% endif %}{% for inline in chan_dep %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if con_in_dep|length %} + {% endif %}{% for inline in con_in_dep %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if receptor_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in receptor_info["Dependencies"]["global"] %}, std::vector<{{ render_variable_type(state) }}> {{ printer_no_origin.print(state) }}{% endfor %}) +{ + {%- endwith %} + std::vector< bool > self_spikes(neuron_{{ receptor_name }}_receptor_count, point_self_spikes); + std::vector< double > g_val(neuron_{{ receptor_name }}_receptor_count, 0.); + std::vector< double > i_val(neuron_{{ receptor_name }}_receptor_count, 0.); + std::vector< double > d_i_tot_dv(neuron_{{ receptor_name }}_receptor_count, 0.); + + {%- for ode_variable, ode_info in receptor_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + std::vector< double > {{ propagator }}(neuron_{{ receptor_name }}_receptor_count, 0); + {%- endfor %} + {%- endfor %} + + {% if receptor_info["ODEs"].items()|length %} std::vector< double > {{ printer_no_origin.print(receptor_info["analytic_helpers"]["__h"]["ASTVariable"]) }}(neuron_{{ receptor_name }}_receptor_count, Time::get_resolution().get_ms()); {% endif %} + + std::vector < double > s_val(neuron_{{ receptor_name }}_receptor_count, 0); + + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + // get spikes + s_val[i] = {{receptor_info["buffer_name"]}}_[i]->get_value( lag ); // * g_norm_; + } + + + //synaptic processing: + //continuous synaptic processing + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + //inlines and convolutions + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items() %} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["update_expression"], "i") }}; + {{state_variable_name}}[i] += {%- if convolution_info["post_port"] %}point_self_spikes{%- else %}s_val[i]{%- endif %} * {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + {%- for inline, inline_info in synapse_info["Inlines"].items() %} + {{ inline }}[i] = {{ vector_printer.print(inline_info["inline_expression"].get_expression(), "i") }}; + {%- endfor %} + //update block + {%- if synapse_info["UpdateBlock"] %} + {%- set function = synapse_info["UpdateBlock"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_block() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/Block.jinja2" %} + {%- endwith %} + {%- endfilter %} + {%- endif %} + } + + {% if synapse_info["ODEs"].items()|length %} + std::vector< double > {{ printer_no_origin.print(synapse_info["time_resolution_var"]) }}(neuron_{{ receptor_name }}_receptor_count, Time::get_resolution().get_ms()); + {% endif %} + {%- for ode_variable, ode_info in synapse_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + std::vector< double > {{ propagator }}(neuron_{{ receptor_name }}_receptor_count, 0); + {%- endfor %} + {%- endfor %} + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + {%- for ode_variable, ode_info in synapse_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + {{ propagator }}[i] = {{ vector_printer.print(propagator_info["init_expression"], "i") }}; + {%- endfor %} + {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} + {{state}}[i] = {{ vector_printer.print(state_solution_info["update_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + } + + //presynaptic spike processing + #pragma omp simd + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + if(s_val[i]!=0) { + {%- set function = synapse_info["PreSpikeFunction"] %} + {%- filter indent(2,True) %} + {%- with ast = function.get_block() %} + {%- set printer = vector_printer %} + {%- include "cm_directives_cpp/Block.jinja2" %} + {%- endwith %} + {%- endfilter %} + } + } + //presynaptic spike processing end + + + //update ODE state variable + #pragma omp simd + for(std::size_t i = 0; i < neuron_{{ receptor_name }}_receptor_count; i++){ + {%- for ode_variable, ode_info in receptor_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + {{ propagator }}[i] = {{ vector_printer.print(propagator_info["init_expression"], "i") }}; + {%- endfor %} + {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} + {{state}}[i] = {{ vector_printer.print(state_solution_info["update_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + // update kernel state variable / compute synaptic conductance + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items() %} + {{state_variable_name}}[i] = {{ vector_printer.print(state_variable_info["update_expression"], "i") }}; + {{state_variable_name}}[i] += s_val[i] * {{ vector_printer.print(state_variable_info["init_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + // total current + // this expression should be the transformed inline expression + + this->i_tot_{{receptor_name}}[i] = {{ vector_printer.print(receptor_info["root_expression"].get_expression(), "i") }}; + + // derivative of that expression + // voltage derivative of total current + // compute derivative with respect to current with sympy + d_i_tot_dv[i] = {{ vector_printer.print(receptor_info["inline_expression_d"], "i") }}; + + // for numerical integration + g_val[i] = - d_i_tot_dv[i]; + i_val[i] = this->i_tot_{{receptor_name}}[i] - d_i_tot_dv[i] * v_comp[i]; + } + + return std::make_pair(g_val, i_val); + +} + +{%- for function in receptor_info["functions_used"] %} +inline {{ function_declaration.FunctionDeclaration(function, "nest::"~receptor_name~cm_unique_suffix~"::") }} +{ +{%- filter indent(2,True) %} +{%- with ast = function.get_block() %} + {%- set printer = vector_printer %} +{%- include "directives/Block.jinja2" %} +{%- endwith %} +{%- endfilter %} +} +{%- endfor %} + +void nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::get_currents_per_compartment(std::vector< double >& compartment_to_current){ + for(std::size_t comp_id = 0; comp_id < compartment_to_current.size(); comp_id++){ + compartment_to_current[comp_id] = 0; + } + for(std::size_t syn_id = 0; syn_id < neuron_{{ receptor_name }}_receptor_count; syn_id++){ + compartment_to_current[this->compartment_association[syn_id]] += this->i_tot_{{receptor_name}}[syn_id]; + } +} + +std::vector< double > nest::{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}::distribute_shared_vector(std::vector< double > shared_vector){ + std::vector< double > distributed_vector(this->neuron_{{ receptor_name }}_receptor_count, 0.0); + for(std::size_t syn_id = 0; syn_id < this->neuron_{{ receptor_name }}_receptor_count; syn_id++){ + distributed_vector[syn_id] = shared_vector[compartment_association[syn_id]]; + } + return distributed_vector; +} + +// {{receptor_name}}_{{synapse_name}} receptor end /////////////////////////////////////////////////////////// +{%- endfor %} +{%- endfor %} + + +////////////////////////////////////// continuous inputs + +{%- for continuous_name, continuous_info in con_in_info.items() %} +// {{continuous_name}} continuous input /////////////////////////////////////////////////// + +void nest::{{continuous_name}}{{cm_unique_suffix}}::new_continuous_input(std::size_t comp_ass, const long con_in_index) +{ + neuron_{{ continuous_name }}_continuous_input_count++; + i_tot_{{continuous_name}}.push_back(0); + compartment_association.push_back(comp_ass); + continuous_idx.push_back(con_in_index); + + {%- for pure_variable_name, variable_info in continuous_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+continuous_name+"_continuous_input_count") -}}); + {%- endfor %} + + {% for variable_type, variable_info in continuous_info["Parameters"].items() %} + // parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+continuous_name+"_continuous_input_count") -}}); + {%- endfor %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in continuous_info["internals_used_declared"] %} + {{internal_name}}.push_back(0); + {%- endfor %} +} + +void nest::{{continuous_name}}{{cm_unique_suffix}}::new_continuous_input(std::size_t comp_ass, const long con_in_index, const DictionaryDatum& con_in_params) +/* update {{continuous_name}} continuous input parameters and states */ +{ + neuron_{{ continuous_name }}_continuous_input_count++; + compartment_association.push_back(comp_ass); + i_tot_{{continuous_name}}.push_back(0); + continuous_idx.push_back(con_in_index); + + {%- for pure_variable_name, variable_info in continuous_info["States"].items() %} + // state variable {{pure_variable_name }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+continuous_name+"_continuous_input_count") -}}); + {%- endfor %} + {%- for variable_type, variable_info in continuous_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( con_in_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ continuous_name }}_continuous_input_count-1] = getValue< double >( con_in_params, "{{variable.name}}" ); + {%- endfor %} + + {%- with %} + {%- for variable_type, variable_info in continuous_info["ODEs"].items() %} + {%- set variable_name = variable_type %} + {%- set dynamic_variable = render_dynamic_channel_variable_name(variable_type, ion_channel_name) %} + // {{continuous_name}} concentration ODE state {{dynamic_variable }} + if( con_in_params->known( "{{variable_name}}" ) ) + {{variable_name}}[neuron_{{ continuous_name }}_continuous_input_count-1] = getValue< double >( con_in_params, "{{variable_name}}" ); + {%- endfor %} + {% endwith %} + + {% for variable_type, variable_info in continuous_info["Parameters"].items() %} + // continuous parameter {{variable_type }} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name}}.push_back({{ vector_printer.print(rhs_expression, "neuron_"+continuous_name+"_continuous_input_count") -}}); + {%- endfor %} + + {%- with %} + {%- for variable_type, variable_info in continuous_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( con_in_params->known( "{{variable.name}}" ) ) + {{variable.name}}[neuron_{{ continuous_name }}_continuous_input_count-1] = getValue< double >( con_in_params, "{{variable.name}}" ); + {%- endfor %} + {% endwith %} + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in continuous_info["internals_used_declared"] %} + {{internal_name}}.push_back(0); + {%- endfor %} +} + +void +nest::{{continuous_name}}{{cm_unique_suffix}}::append_recordables(std::map< Name, double* >* recordables, const long compartment_idx) +{ + for(size_t con_in_id = 0; con_in_id < neuron_{{ continuous_name }}_continuous_input_count; con_in_id++){ + if(compartment_association[con_in_id] == compartment_idx){ + ( *recordables )[ Name( "i_tot_{{continuous_name}}" + std::to_string(con_in_id) )] = &i_tot_{{continuous_name}}[con_in_id]; + } + } +} + +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} +void nest::{{continuous_name}}{{cm_unique_suffix}}::calibrate() +{%- else %} +void nest::{{continuous_name}}{{cm_unique_suffix}}::pre_run_hook() +{%- endif %} +{ + + // user declared internals in order they were declared + {%- for internal_name, internal_declaration in continuous_info["Internals"] %} + {{internal_name}}[i] = {{ vector_printer.print(internal_declaration.get_expression(), "i") }}; + {%- endfor %} + + for(std::size_t i = 0; i < neuron_{{ continuous_name }}_continuous_input_count; i++){ + {% for port_name, port_info in continuous_info["Continuous"].items() %} + {{port_name}}_[i]->clear(); + {% endfor %} + } +} + +std::pair< std::vector< double >, std::vector< double > > nest::{{continuous_name}}{{cm_unique_suffix}}::f_numstep( bool point_self_spikes, std::vector< double > v_comp, const long lag {% for ode in continuous_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if continuous_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in continuous_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if continuous_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in continuous_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if continuous_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in continuous_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if continuous_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in continuous_info["Dependencies"]["global"] %}, std::vector<{{ render_variable_type(state) }}> {{ printer_no_origin.print(state) }}{% endfor %}) +{ + std::vector< bool > self_spikes(neuron_{{ continuous_name }}_continuous_input_count, point_self_spikes); + std::vector< double > g_val(neuron_{{ continuous_name }}_continuous_input_count, 0.); + std::vector< double > i_val(neuron_{{ continuous_name }}_continuous_input_count, 0.); + std::vector< double > d_i_tot_dv(neuron_{{ continuous_name }}_continuous_input_count, 0.); + + {% if continuous_info["ODEs"].items()|length %} + std::vector< double > {{ printer_no_origin.print(continuous_info["time_resolution_var"]) }}(neuron_{{ continuous_name }}_continuous_input_count, Time::get_resolution().get_ms()); + {% endif %} + + {%- for ode_variable, ode_info in continuous_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + std::vector< double > {{ propagator }}(neuron_{{ continuous_name }}_continuous_input_count, 0); + {%- endfor %} + {%- endfor %} + + {% for port_name, port_info in continuous_info["Continuous"].items() %} + std::vector< double > {{ port_name }}(neuron_{{ continuous_name }}_continuous_input_count, 0.); + {% endfor %} + #pragma omp simd + for(std::size_t i = 0; i < neuron_{{ continuous_name }}_continuous_input_count; i++){ + {% for port_name, port_info in continuous_info["Continuous"].items() %} + {{ port_name }}[i] = {{ port_name }}_[i]->get_value( lag ); + {% endfor %} + } + + #pragma omp simd + for(std::size_t i = 0; i < neuron_{{ continuous_name }}_continuous_input_count; i++){ + //update ODE state variable + {%- for ode_variable, ode_info in continuous_info["ODEs"].items() %} + {%- for propagator, propagator_info in ode_info["transformed_solutions"][0]["propagators"].items() %} + {{ propagator }}[i] = {{ vector_printer.print(propagator_info["init_expression"], "i") }}; + {%- endfor %} + {%- for state, state_solution_info in ode_info["transformed_solutions"][0]["states"].items() %} + {{state}}[i] = {{ vector_printer.print(state_solution_info["update_expression"], "i") }}; + {%- endfor %} + {%- endfor %} + + // total current + // this expression should be the transformed inline expression + this->i_tot_{{continuous_name}}[i] = {{ vector_printer.print(continuous_info["root_expression"].get_expression(), "i") }}; + + // derivative of that expression + // voltage derivative of total current + // compute derivative with respect to current with sympy + d_i_tot_dv[i] = {{ vector_printer.print(continuous_info["inline_derivative"], "i") }}; + + // for numerical integration + g_val[i] = - d_i_tot_dv[i]; + i_val[i] = this->i_tot_{{continuous_name}}[i] - d_i_tot_dv[i] * v_comp[i]; + } + + return std::make_pair(g_val, i_val); + +} + +{%- for function in continuous_info["Functions"] %} +inline {{ function_declaration.FunctionDeclaration(function, "nest::"~continuous_name~cm_unique_suffix~"::") }} +{ +{%- filter indent(2,True) %} +{%- with ast = function.get_block() %} +{%- include "directives/Block.jinja2" %} +{%- endwith %} +{%- endfilter %} +} +{%- endfor %} + +void nest::{{continuous_name}}{{cm_unique_suffix}}::get_currents_per_compartment(std::vector< double >& compartment_to_current){ + for(std::size_t comp_id = 0; comp_id < compartment_to_current.size(); comp_id++){ + compartment_to_current[comp_id] = 0; + } + for(std::size_t con_in_id = 0; con_in_id < neuron_{{ continuous_name }}_continuous_input_count; con_in_id++){ + compartment_to_current[this->compartment_association[con_in_id]] += this->i_tot_{{continuous_name}}[con_in_id]; + } +} + +std::vector< double > nest::{{continuous_name}}{{cm_unique_suffix}}::distribute_shared_vector(std::vector< double > shared_vector){ + std::vector< double > distributed_vector(this->neuron_{{ continuous_name }}_continuous_input_count, 0.0); + for(std::size_t con_in_id = 0; con_in_id < this->neuron_{{ continuous_name }}_continuous_input_count; con_in_id++){ + distributed_vector[con_in_id] = shared_vector[compartment_association[con_in_id]]; + } + return distributed_vector; +} + +// {{continuous_name}} continuous input end /////////////////////////////////////////////////////////// +{%- endfor %} + + +{%- include "cm_global_dynamics.cpp.jinja2" %} \ No newline at end of file diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.h.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.h.jinja2 new file mode 100644 index 000000000..bb82283ed --- /dev/null +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_neuroncurrents_@NEURON_NAME@.h.jinja2 @@ -0,0 +1,1294 @@ +{#- +cm_compartmentcurrents_@NEURON_NAME@.h.jinja2 + +This file is part of NEST. + +Copyright (C) 2004 The NEST Initiative + +NEST is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 2 of the License, or +(at your option) any later version. + +NEST is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with NEST. If not, see . +#} + +{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif -%} +{%- import 'directives_cpp/FunctionDeclaration.jinja2' as function_declaration with context %} +#ifndef RECEPTORS_NEAT_H_{{cm_unique_suffix | upper }} +#define RECEPTORS_NEAT_H_{{cm_unique_suffix | upper }} + +#include +#include +#include + +#include "ring_buffer.h" + +{% macro render_variable_type(variable) -%} +{%- with -%} + {%- set symbol = variable.get_scope().resolve_to_symbol(variable.name, SymbolKind.VARIABLE) -%} + {{ types_printer.print(symbol.type_symbol) }} +{%- endwith -%} +{%- endmacro %} + +//elementwise vector operations: +#include +#include +#include + +namespace nest +{ + +// entry in the spiking history +//class histentry +//{ +//public: +// histentry( double t, +//double post_trace, +//size_t access_counter ) +// : t_( t ) +// , post_trace_( post_trace ) +// , access_counter_( access_counter ) +// { +// } +// +// double t_; //!< point in time when spike occurred (in ms) +// double post_trace_; +// size_t access_counter_; //!< access counter to enable removal of the entry, once all neurons read it +//}; + +///////////////////////////////////// channels + +{%- with %} +{%- for ion_channel_name, channel_info in chan_info.items() %} + +class {{ion_channel_name}}{{cm_unique_suffix}}{ +private: + // states + {%- for pure_variable_name, variable_info in channel_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector< {{ render_variable_type(variable) }} > {{ variable.name }} = {}; + {%- endfor %} + + // parameters + {%- for pure_variable_name, variable_info in channel_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector< {{ render_variable_type(variable) }} > {{ variable.name }} = {}; + {%- endfor %} + + // internals + {%- for pure_variable_name, variable_info in channel_info["Internals"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector< {{ render_variable_type(variable) }} > {{ variable.name }} = {}; + {%- endfor %} + + // ion-channel root-inline value + std::vector< double > i_tot_{{ion_channel_name}} = {}; + + //zero recordable variable in case of zero contribution channel + double zero_recordable = 0; + + //Global dependencies + {% for state in channel_info["Dependencies"]["global"] %} + std::vector<{{ render_variable_type(state) }}> {{ printer_no_origin.print(state) }} = {}; + {% endfor %} + +public: + // constructor, destructor + {{ion_channel_name}}{{cm_unique_suffix}}(){}; + ~{{ion_channel_name}}{{cm_unique_suffix}}(){}; + + void new_channel(std::size_t comp_ass); + void new_channel(std::size_t comp_ass, const DictionaryDatum& channel_params); + + //number of channels + std::size_t neuron_{{ ion_channel_name }}_channel_count = 0; + + std::vector< size_t > compartment_association = {}; + + // initialization channel +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} + void calibrate() { +{%- else %} + void pre_run_hook() { +{%- endif %} + }; + + void append_recordables(std::map< Name, double* >* recordables, const long compartment_idx); + + // numerical integration step + std::pair< std::vector< double >, std::vector< double > > f_numstep( bool point_self_spikes, std::vector< double > v_comp{% for ode in channel_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if channel_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in channel_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if channel_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in channel_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if channel_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in channel_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if channel_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in channel_info["Dependencies"]["global"] %}, std::vector<{{ render_variable_type(state) }}> {{ printer_no_origin.print(state) }}{% endfor %}); + + // function declarations + +{%- for function in channel_info["Functions"] %} + #pragma omp declare simd + __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function, pass_by_reference = true) }}; +{%- endfor %} + + // root_inline getter + void get_currents_per_compartment(std::vector< double >& compartment_to_current); + + std::vector< double > distribute_shared_vector(std::vector< double > shared_vector); + +}; +{% endfor -%} +{% endwith -%} + + +///////////////////////////////////////////// concentrations + +{%- with %} +{%- for concentration_name, concentration_info in conc_info.items() %} + +class {{ concentration_name }}{{cm_unique_suffix}}{ +private: + // parameters + {%- for pure_variable_name, variable_info in concentration_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector< {{ render_variable_type(variable) }} > {{ variable.name }} = {}; + {%- endfor %} + + // states + {%- for pure_variable_name, variable_info in concentration_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector< {{ render_variable_type(variable) }} > {{ variable.name }} = {}; + {%- endfor %} + + // internals + {%- for pure_variable_name, variable_info in concentration_info["Internals"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector< {{ render_variable_type(variable) }} > {{ variable.name }} = {}; + {%- endfor %} + + // concentration value (root-ode state) + std::vector< double > {{concentration_name}} = {}; + + //zero recordable variable in case of zero contribution concentration + double zero_recordable = 0; + +public: + // constructor, destructor + {{ concentration_name }}{{cm_unique_suffix}}(){}; + ~{{ concentration_name }}{{cm_unique_suffix}}(){}; + + void new_concentration(std::size_t comp_ass); + void new_concentration(std::size_t comp_ass, const DictionaryDatum& concentration_params); + + //number of channels + std::size_t neuron_{{ concentration_name }}_concentration_count = 0; + + std::vector< size_t > compartment_association = {}; + + // initialization concentration +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} + void calibrate() { +{%- else %} + void pre_run_hook() { +{%- endif %} + for(std::size_t concentration_id = 0; concentration_id < neuron_{{ concentration_name }}_concentration_count; concentration_id++){ + // states + {%- for pure_variable_name, variable_info in concentration_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + {{ variable.name }}[concentration_id] = {{ vector_printer.print(rhs_expression, "concentration_id") }}; + {%- endfor %} + } + }; + void append_recordables(std::map< Name, double* >* recordables, const long compartment_idx); + + // numerical integration step + void f_numstep( bool point_self_spikes, std::vector< double > v_comp{% for ode in concentration_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if concentration_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in concentration_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if concentration_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in concentration_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if concentration_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in concentration_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if concentration_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in concentration_info["Dependencies"]["global"] %}, std::vector<{{ render_variable_type(state) }}> {{ printer_no_origin.print(state) }}{% endfor %}); + + // function declarations +{%- for function in concentration_info["Functions"] %} + #pragma omp declare simd + __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function, pass_by_reference = true) }}; +{%- endfor %} + + // root_ode getter + void get_concentrations_per_compartment(std::vector< double >& compartment_to_concentration); + + std::vector< double > distribute_shared_vector(std::vector< double > shared_vector); + +}; +{% endfor -%} +{% endwith -%} + + +////////////////////////////////////////////////// receptors + +{% macro render_time_resolution_variable(receptor_info) -%} +{# we assume here that there is only one such variable ! #} +{%- with %} +{%- for analytic_helper_name, analytic_helper_info in receptor_info["analytic_helpers"].items() -%} +{%- if analytic_helper_info["is_time_resolution"] -%} + {{ analytic_helper_name }} +{%- endif -%} +{%- endfor -%} +{% endwith %} +{%- endmacro %} + +{%- with %} +{%- for receptor_name, receptor_info in recs_info.items() %} + +class {{receptor_name}}{{cm_unique_suffix}}{ +private: + // global receptor index + std::vector< long > rec_idx = {}; + + // propagators, initialized via pre_run_hook() or calibrate() + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + + // kernel state variables, initialized via pre_run_hook() or calibrate() + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + + // user defined parameters, initialized via pre_run_hook() or calibrate() + {%- for param_name, param_declaration in receptor_info["Parameters"].items() %} + std::vector< double > {{param_name}}; + {%- endfor %} + + // states + {%- for pure_variable_name, variable_info in receptor_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector<{{ render_variable_type(variable) }}> {{ variable.name }} = {} + }; + {%- endfor %} + + std::vector< double > i_tot_{{receptor_name}} = {}; + + // user declared internals in order they were declared, initialized via pre_run_hook() or calibrate() + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} + std::vector< double > {{internal_name}}; + {%- endfor %} + + // spike buffer + std::vector< RingBuffer* > {{receptor_info["buffer_name"]}}_; + +public: + // constructor, destructor + {{receptor_name}}{{cm_unique_suffix}}(){}; + ~{{receptor_name}}{{cm_unique_suffix}}(){}; + + void new_receptor(std::size_t comp_ass, const long rec_index); + void new_receptor(std::size_t comp_ass, const long rec_index, const DictionaryDatum& receptor_params); + + //number of receptors + std::size_t neuron_{{ receptor_name }}_receptor_count = 0; + + std::vector< size_t > compartment_association = {}; + + // numerical integration step + std::pair< std::vector< double >, std::vector< double > > f_numstep( bool point_self_spikes, std::vector< double > v_comp, const long lag {% for ode in receptor_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if receptor_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if receptor_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if receptor_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if receptor_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in receptor_info["Dependencies"]["global"] %}, std::vector<{{ render_variable_type(state) }}> {{ printer_no_origin.print(state) }}{% endfor %}); + + // calibration +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} + void calibrate(); +{%- else %} + void pre_run_hook(); +{%- endif %} + void append_recordables(std::map< Name, double* >* recordables, const long compartment_idx); + void set_buffer_ptr( std::vector< RingBuffer >& rec_buffers ) + { + for(std::size_t i = 0; i < rec_idx.size(); i++){ + {{receptor_info["buffer_name"]}}_.push_back(&(rec_buffers[rec_idx[i]])); + } + }; + + // function declarations + {%- for function in receptor_info["Functions"] %} + #pragma omp declare simd + __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function, pass_by_reference = true) -}}; + + {% endfor %} + + // root_inline getter + void get_currents_per_compartment(std::vector< double >& compartment_to_current); + + std::vector< double > distribute_shared_vector(std::vector< double > shared_vector); + +}; + +{% endfor -%} +{% endwith -%} + +////////////////////////////////////////////////// receptors with synapses attached + + +{%- with %} +{%- for synapse_name, synapse_info in syns_info.items() %} +{%- for receptor_name, receptor_info in recs_info.items() %} + +class {{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}{ +private: + // global receptor index + std::vector< long > syn_idx = {}; + + // propagators, initialized via pre_run_hook() or calibrate() + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + + // kernel state variables, initialized via pre_run_hook() or calibrate() + {%- for convolution, convolution_info in receptor_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + + // user defined parameters, initialized via pre_run_hook() or calibrate() + {%- for param_name, param_declaration in receptor_info["Parameters"].items() %} + std::vector< double > {{param_name}}; + {%- endfor %} + + // states + {%- for pure_variable_name, variable_info in receptor_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector<{{ render_variable_type(variable) }}> {{ variable.name }} = {} + }; + {%- endfor %} + + std::vector< double > i_tot_{{receptor_name}} = {}; + + // user declared internals in order they were declared, initialized via pre_run_hook() or calibrate() + {%- for internal_name, internal_declaration in receptor_info["internals_used_declared"] %} + std::vector< double > {{internal_name}}; + {%- endfor %} + + + + // spike buffer + std::vector< RingBuffer* > {{receptor_info["buffer_name"]}}_; + + //synapse related variables: + // user defined parameters, initialized via pre_run_hook() or calibrate() + {%- for param_name, param_declaration in synapse_info["Parameters"].items() %} + std::vector< double > {{param_name}}; + {%- endfor %} + + // states + {%- for pure_variable_name, variable_info in synapse_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector<{{ render_variable_type(variable) }}> {{ variable.name }} = {}; + {%- endfor %} + + // user declared internals in order they were declared, initialized via pre_run_hook() or calibrate() + {%- for internal_name, internal_declaration in synapse_info["Internals"] %} + std::vector< double > {{internal_name}}; + {%- endfor %} + + {%- with %} + {%- for in_function_declaration in synapse_info["InFunctionDeclarationsVars"] %} + {%- for variable in declarations.get_variables(in_function_declaration) %} + std::vector<{{declarations.print_variable_type(variable)}}> {{variable.get_symbol_name()}} = {}; + {%- endfor %} + {%- endfor %} + {%- endwith %} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["propagators"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + + {%- for convolution, convolution_info in synapse_info["convolutions"].items() %} + {%- for state_variable_name, state_variable_info in convolution_info["analytic_solution"]["kernel_states"].items()%} + std::vector< double > {{state_variable_name}}; + {%- endfor %} + {%- endfor %} + + {%- for inline_name, inline in synapse_info["Inlines"].items() %} + std::vector< double > {{inline_name}}; + {%- endfor %} + + +public: + // constructor, destructor + {{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}(){}; + ~{{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}}(){}; + + void new_receptor(std::size_t comp_ass, const long syn_index); + void new_receptor(std::size_t comp_ass, const long syn_index, const DictionaryDatum& receptor_params); + + //number of receptors + std::size_t neuron_{{ receptor_name }}_receptor_count = 0; + + std::vector< size_t > compartment_association = {}; + + // numerical integration step + {%- with %} + {%- set conc_dep = set(receptor_info["Dependencies"]["concentrations"]).union(synapse_info["Dependencies"]["concentrations"])%} + {%- set rec_dep = set(receptor_info["Dependencies"]["receptors"]).union(synapse_info["Dependencies"]["receptors"])%} + {%- set chan_dep = set(receptor_info["Dependencies"]["channels"]).union(synapse_info["Dependencies"]["channels"])%} + {%- set con_in_dep = set(receptor_info["Dependencies"]["continuous"]).union(synapse_info["Dependencies"]["continuous"])%} + std::pair< std::vector< double >, std::vector< double > > f_numstep( bool point_self_spikes, std::vector< double > v_comp, const long lag {% for ode in conc_dep %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if rec_dep|length %} + {% endif %}{% for inline in rec_dep %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if chan_dep|length %} + {% endif %}{% for inline in chan_dep %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if con_in_dep|length %} + {% endif %}{% for inline in con_in_dep %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if receptor_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in receptor_info["Dependencies"]["global"] %}, std::vector<{{ render_variable_type(state) }}> {{ printer_no_origin.print(state) }}{% endfor %}); + {%- endwith %} + void postsynaptic_synaptic_processing(); + + // calibration +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} + void calibrate(); +{%- else %} + void pre_run_hook(); +{%- endif %} + void append_recordables(std::map< Name, double* >* recordables, const long compartment_idx); + void set_buffer_ptr( std::vector< RingBuffer >& syn_buffers ) + { + for(std::size_t i = 0; i < syn_idx.size(); i++){ + {{receptor_info["buffer_name"]}}_.push_back(&(syn_buffers[syn_idx[i]])); + } + }; + + // function declarations + {%- for function in receptor_info["Functions"] %} + #pragma omp declare simd + __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function, pass_by_reference = true) -}}; + + {% endfor %} + + // root_inline getter + void get_currents_per_compartment(std::vector< double >& compartment_to_current); + + std::vector< double > distribute_shared_vector(std::vector< double > shared_vector); + + void get_history__( double t1, + double t2, + std::deque< histentry >::iterator* start, + std::deque< histentry >::iterator* finish ); + + +}; + +{% endfor -%} +{% endfor %} +{% endwith -%} + + +////////////////////////////////////////////////// continuous inputs + +{%- with %} +{%- for continuous_name, continuous_info in con_in_info.items() %} + +class {{continuous_name}}{{cm_unique_suffix}}{ +private: + // global continuous input index + std::vector< long > continuous_idx = {}; + + // user defined parameters, initialized via pre_run_hook() or calibrate() + {%- for param_name, param_declaration in continuous_info["Parameters"].items() %} + std::vector< double > {{param_name}}; + {%- endfor %} + + // states + {%- for pure_variable_name, variable_info in continuous_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + {%- set rhs_expression = variable_info["rhs_expression"] %} + std::vector<{{ render_variable_type(variable) }}> {{ variable.name }} = {} + }; + {%- endfor %} + + std::vector< double > i_tot_{{continuous_name}} = {}; + + // user declared internals in order they were declared, initialized via pre_run_hook() or calibrate() + {%- for internal_name, internal_declaration in continuous_info["Internals"] %} + std::vector< double > {{internal_name}}; + {%- endfor %} + + + + // continuous buffer + {% for port_name, port_info in continuous_info["Continuous"].items() %} + std::vector< RingBuffer* > {{ port_name }}_; + {% endfor %} + +public: + // constructor, destructor + {{continuous_name}}{{cm_unique_suffix}}(){}; + ~{{continuous_name}}{{cm_unique_suffix}}(){}; + + void new_continuous_input(std::size_t comp_ass, const long con_in_index); + void new_continuous_input(std::size_t comp_ass, const long con_in_index, const DictionaryDatum& con_in_params); + + //number of continuous inputs + std::size_t neuron_{{ continuous_name }}_continuous_input_count = 0; + + std::vector< size_t > compartment_association = {}; + + // numerical integration step + std::pair< std::vector< double >, std::vector< double > > f_numstep( bool point_self_spikes, std::vector< double > v_comp, const long lag {% for ode in continuous_info["Dependencies"]["concentrations"] %}, std::vector< double > {{ode.lhs.name}}{% endfor %}{% if continuous_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in continuous_info["Dependencies"]["receptors"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if continuous_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in continuous_info["Dependencies"]["channels"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if continuous_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in continuous_info["Dependencies"]["continuous"] %}, std::vector< double > {{inline.variable_name}}{% endfor %}{% if continuous_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in continuous_info["Dependencies"]["global"] %}, std::vector<{{ render_variable_type(state) }}> {{ printer_no_origin.print(state) }}{% endfor %}); + + // calibration +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} + void calibrate(); +{%- else %} + void pre_run_hook(); +{%- endif %} + void append_recordables(std::map< Name, double* >* recordables, const long compartment_idx); + void set_buffer_ptr( std::vector< RingBuffer >& continuous_buffers ) + { + for(std::size_t i = 0; i < continuous_idx.size(); i++){ + {% for port_name, port_info in continuous_info["Continuous"].items() %} + {{port_name}}_.push_back(&(continuous_buffers[continuous_idx[i]])); + {% endfor %} + } + }; + + // function declarations + {%- for function in continuous_info["Functions"] %} + #pragma omp declare simd + __attribute__((always_inline)) inline {{ function_declaration.FunctionDeclaration(function, pass_by_reference = true) -}}; + + {% endfor %} + + // root_inline getter + void get_currents_per_compartment(std::vector< double >& compartment_to_current); + + std::vector< double > distribute_shared_vector(std::vector< double > shared_vector); + +}; + +{% endfor -%} +{% endwith -%} + + +{%- include "cm_global_dynamics.h.jinja2" %} + + +///////////////////////////////////////////// currents + +{%- set channel_suffix = "_chan_" %} +{%- set concentration_suffix = "_conc_" %} +{%- set receptor_suffix = "_syn_" %} +{%- set continuous_suffix = "_con_in_" %} + +class NeuronCurrents{{cm_unique_suffix}} { +private: + //mechanisms + // ion channels +{% with %} + {%- for ion_channel_name, channel_info in chan_info.items() %} + {{ion_channel_name}}{{cm_unique_suffix}} {{ion_channel_name}}{{channel_suffix}}; + {% endfor -%} +{% endwith %} + // concentrations +{% with %} + {%- for concentration_name, concentration_info in conc_info.items() %} + {{concentration_name}}{{cm_unique_suffix}} {{concentration_name}}{{concentration_suffix}}; + {% endfor -%} +{% endwith %} + // receptors +{% with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + {{receptor_name}}{{cm_unique_suffix}} {{receptor_name}}{{receptor_suffix}}; + {% endfor -%} +{% endwith %} + // receptors with synapses +{%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + {{receptor_name}}{{cm_unique_suffix}}_con_{{synapse_name}} {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}; + {% endfor -%} +{% endfor -%} + // continuous inputs +{% with %} + {%- for continuous_name, continuous_info in con_in_info.items() %} + {{continuous_name}}{{cm_unique_suffix}} {{continuous_name}}{{continuous_suffix}}; + {% endfor -%} +{% endwith %} + + //number of compartments + std::size_t compartment_number = 0; + + //interdependency shared reference vectors and consecutive area vectors + // ion channels +{% with %} + {%- for ion_channel_name, channel_info in chan_info.items() %} + std::vector < double > {{ion_channel_name}}{{channel_suffix}}_shared_current; + std::vector < std::pair< std::size_t, int > > {{ion_channel_name}}{{channel_suffix}}_con_area; + {% endfor -%} +{% endwith %} + // concentrations +{% with %} + {%- for concentration_name, concentration_info in conc_info.items() %} + std::vector < double > {{concentration_name}}{{concentration_suffix}}_shared_concentration; + std::vector < std::pair< std::size_t, int > > {{concentration_name}}{{concentration_suffix}}_con_area; + {% endfor -%} +{% endwith %} + // receptors +{% with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + std::vector < double > {{receptor_name}}{{receptor_suffix}}_shared_current; + std::vector < std::pair< std::size_t, int > > {{receptor_name}}{{receptor_suffix}}_con_area; + {% endfor -%} +{% endwith %} + // receptors with synapses +{% with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + std::vector < double > {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_shared_current; + std::vector < std::pair< std::size_t, int > > {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area; + {% endfor -%} + {% endfor -%} +{% endwith %} +// continuous inputs +{% with %} + {%- for continuous_name, continuous_info in con_in_info.items() %} + std::vector < double > {{continuous_name}}{{continuous_suffix}}_shared_current; + std::vector < std::pair< std::size_t, int > > {{continuous_name}}{{continuous_suffix}}_con_area; + {% endfor -%} +{% endwith %} + + //compartment gi states + std::vector < std::pair < double, double > > comps_gi; + + // global dynamics + Global{{cm_unique_suffix}} global_dynamics; + +public: + NeuronCurrents{{cm_unique_suffix}}(){}; + ~NeuronCurrents{{cm_unique_suffix}}(){}; + +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} + void calibrate() { +{%- else %} + void pre_run_hook() { +{%- endif %} + // initialization of ion channels + {%- with %} +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} + {%- for ion_channel_name, channel_info in chan_info.items() %} + {{ion_channel_name}}{{channel_suffix}}.calibrate(); + {% endfor -%} + {%- for concentration_name, concentration_info in conc_info.items() %} + {{concentration_name}}{{concentration_suffix}}.calibrate(); + {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + {{receptor_name}}{{receptor_suffix}}.calibrate(); + {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.calibrate(); + {% endfor -%} + {% endfor -%} + {%- for continuous_name, continuous_info in con_in_info.items() %} + {{continuous_name}}{{continuous_suffix}}.calibrate(); + {% endfor -%} +{%- else %} + {%- for ion_channel_name, channel_info in chan_info.items() %} + {{ion_channel_name}}{{channel_suffix}}.pre_run_hook(); + {% endfor -%} + {%- for concentration_name, concentration_info in conc_info.items() %} + {{concentration_name}}{{concentration_suffix}}.pre_run_hook(); + {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + {{receptor_name}}{{receptor_suffix}}.pre_run_hook(); + {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.pre_run_hook(); + {% endfor -%} + {% endfor -%} + {%- for continuous_name, continuous_info in con_in_info.items() %} + {{continuous_name}}{{continuous_suffix}}.pre_run_hook(); + {% endfor -%} +{%- endif %} + int con_end_index; + {%- for ion_channel_name, channel_info in chan_info.items() %} + if({{ion_channel_name}}{{channel_suffix}}.neuron_{{ ion_channel_name }}_channel_count){ + con_end_index = int({{ion_channel_name}}{{channel_suffix}}.compartment_association[0]); + {{ion_channel_name}}{{channel_suffix}}_con_area.push_back(std::pair< std::size_t, int >(0, con_end_index)); + } + for(std::size_t chan_id = 1; chan_id < {{ion_channel_name}}{{channel_suffix}}.neuron_{{ ion_channel_name }}_channel_count; chan_id++){ + if(!({{ion_channel_name}}{{channel_suffix}}.compartment_association[chan_id] == size_t(int(chan_id) + con_end_index))){ + con_end_index = int({{ion_channel_name}}{{channel_suffix}}.compartment_association[chan_id]) - int(chan_id); + {{ion_channel_name}}{{channel_suffix}}_con_area.push_back(std::pair< std::size_t, int >(chan_id, con_end_index)); + } + } + {% endfor -%} + {%- for concentration_name, concentration_info in conc_info.items() %} + if({{concentration_name}}{{concentration_suffix}}.neuron_{{ concentration_name }}_concentration_count){ + con_end_index = int({{concentration_name}}{{concentration_suffix}}.compartment_association[0]); + {{concentration_name}}{{concentration_suffix}}_con_area.push_back(std::pair< std::size_t, int >(0, con_end_index)); + } + for(std::size_t conc_id = 0; conc_id < {{concentration_name}}{{concentration_suffix}}.neuron_{{ concentration_name }}_concentration_count; conc_id++){ + if(!({{concentration_name}}{{concentration_suffix}}.compartment_association[conc_id] == size_t(int(conc_id) + con_end_index))){ + con_end_index = int({{concentration_name}}{{concentration_suffix}}.compartment_association[conc_id]) - int(conc_id); + {{concentration_name}}{{concentration_suffix}}_con_area.push_back(std::pair< std::size_t, int >(conc_id, con_end_index)); + } + } + {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + if({{receptor_name}}{{receptor_suffix}}.neuron_{{ receptor_name }}_receptor_count){ + con_end_index = int({{receptor_name}}{{receptor_suffix}}.compartment_association[0]); + {{receptor_name}}{{receptor_suffix}}_con_area.push_back(std::pair< std::size_t, int >(0, con_end_index)); + } + for(std::size_t syn_id = 0; syn_id < {{receptor_name}}{{receptor_suffix}}.neuron_{{ receptor_name }}_receptor_count; syn_id++){ + if(!({{receptor_name}}{{receptor_suffix}}.compartment_association[syn_id] == size_t(int(syn_id) + con_end_index))){ + con_end_index = int({{receptor_name}}{{receptor_suffix}}.compartment_association[syn_id]) - int(syn_id); + {{receptor_name}}{{receptor_suffix}}_con_area.push_back(std::pair< std::size_t, int >(syn_id, con_end_index)); + } + } + {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + if({{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.neuron_{{ receptor_name }}_receptor_count){ + con_end_index = int({{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.compartment_association[0]); + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area.push_back(std::pair< std::size_t, int >(0, con_end_index)); + } + for(std::size_t syn_id = 0; syn_id < {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.neuron_{{ receptor_name }}_receptor_count; syn_id++){ + if(!({{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.compartment_association[syn_id] == size_t(int(syn_id) + con_end_index))){ + con_end_index = int({{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.compartment_association[syn_id]) - int(syn_id); + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area.push_back(std::pair< std::size_t, int >(syn_id, con_end_index)); + } + } + {% endfor -%} + {% endfor -%} + {%- for continuous_name, continuous_info in con_in_info.items() %} + if({{continuous_name}}{{continuous_suffix}}.neuron_{{ continuous_name }}_continuous_input_count){ + con_end_index = int({{continuous_name}}{{continuous_suffix}}.compartment_association[0]); + {{continuous_name}}{{continuous_suffix}}_con_area.push_back(std::pair< std::size_t, int >(0, con_end_index)); + } + for(std::size_t cont_id = 0; cont_id < {{continuous_name}}{{continuous_suffix}}.neuron_{{ continuous_name }}_continuous_input_count; cont_id++){ + if(!({{continuous_name}}{{continuous_suffix}}.compartment_association[cont_id] == size_t(int(cont_id) + con_end_index))){ + con_end_index = int({{continuous_name}}{{continuous_suffix}}.compartment_association[cont_id]) - (cont_id); + {{continuous_name}}{{continuous_suffix}}_con_area.push_back(std::pair< std::size_t, int >(cont_id, con_end_index)); + } + } + {% endfor -%} + {% endwith -%} + }; + + void add_mechanism( const std::string& type, const std::size_t compartment_id, const long multi_mech_index = 0) + { + {%- with %} + bool mech_found = false; + {%- for ion_channel_name, channel_info in chan_info.items() %} + if ( type == "{{ion_channel_name}}" ) + { + {{ion_channel_name}}{{channel_suffix}}.new_channel(compartment_id); + mech_found = true; + } + {% endfor -%} + + {%- for concentration_name, concentration_info in conc_info.items() %} + if ( type == "{{concentration_name}}" ) + { + {{concentration_name}}{{concentration_suffix}}.new_concentration(compartment_id); + mech_found = true; + } + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} + if ( type == "{{receptor_name}}" ) + { + {{receptor_name}}{{receptor_suffix}}.new_receptor(compartment_id, multi_mech_index); + mech_found = true; + } + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + if ( type == "{{receptor_name}}_{{synapse_name}}" ) + { + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.new_receptor(compartment_id, multi_mech_index); + mech_found = true; + } + {% endfor -%} + {% endfor -%} + + {%- for continuous_name, continuous_info in con_in_info.items() %} + if ( type == "{{continuous_name}}" ) + { + {{continuous_name}}{{continuous_suffix}}.new_continuous_input(compartment_id, multi_mech_index); + mech_found = true; + } + {% endfor -%} + + {% endwith -%} + if(!mech_found) + { + assert( false ); + } + }; + + void add_mechanism( const std::string& type, const std::size_t compartment_id, const DictionaryDatum& mechanism_params, const long multi_mech_index = 0) + { + {%- with %} + bool mech_found = false; + {%- for ion_channel_name, channel_info in chan_info.items() %} + if ( type == "{{ion_channel_name}}" ) + { + {{ion_channel_name}}{{channel_suffix}}.new_channel(compartment_id, mechanism_params); + mech_found = true; + } + {% endfor -%} + + {%- for concentration_name, concentration_info in conc_info.items() %} + if ( type == "{{concentration_name}}" ) + { + {{concentration_name}}{{concentration_suffix}}.new_concentration(compartment_id, mechanism_params); + mech_found = true; + } + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} + if ( type == "{{receptor_name}}" ) + { + {{receptor_name}}{{receptor_suffix}}.new_receptor(compartment_id, multi_mech_index, mechanism_params); + mech_found = true; + } + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + if ( type == "{{receptor_name}}_{{synapse_name}}" ) + { + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.new_receptor(compartment_id, multi_mech_index, mechanism_params); + mech_found = true; + } + {% endfor -%} + {% endfor -%} + + {%- for continuous_name, continuous_info in con_in_info.items() %} + if ( type == "{{continuous_name}}" ) + { + {{continuous_name}}{{continuous_suffix}}.new_continuous_input(compartment_id, multi_mech_index, mechanism_params); + mech_found = true; + } + {% endfor -%} + {% endwith -%} + if(!mech_found) + { + assert( false ); + } + }; + + void add_compartment(){ + global_dynamics.new_compartment(); + + {%- for ion_channel_name, channel_info in chan_info.items() %} + this->add_mechanism("{{ ion_channel_name }}", compartment_number); + {% endfor -%} + + {%- for concentration_name, concentration_info in conc_info.items() %} + this->add_mechanism("{{ concentration_name }}", compartment_number); + {% endfor -%} + + compartment_number++; + + {%- for ion_channel_name, channel_info in chan_info.items() %} + this->{{ion_channel_name}}{{channel_suffix}}_shared_current.push_back(0.0); + {% endfor -%} + + {%- for concentration_name, concentration_info in conc_info.items() %} + this->{{concentration_name}}{{concentration_suffix}}_shared_concentration.push_back(0.0); + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} + this->{{receptor_name}}{{receptor_suffix}}_shared_current.push_back(0.0); + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + this->{{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_shared_current.push_back(0.0); + {% endfor -%} + {% endfor -%} + + {%- for continuous_name, continuous_info in con_in_info.items() %} + this->{{continuous_name}}{{continuous_suffix}}_shared_current.push_back(0.0); + {% endfor -%} + + }; + + void add_compartment(const DictionaryDatum& compartment_params){ + global_dynamics.new_compartment(compartment_params); + + {%- for ion_channel_name, channel_info in chan_info.items() %} + this->add_mechanism("{{ ion_channel_name }}", compartment_number, compartment_params); + {% endfor -%} + + {%- for concentration_name, concentration_info in conc_info.items() %} + this->add_mechanism("{{ concentration_name }}", compartment_number, compartment_params); + {% endfor -%} + + compartment_number++; + + {%- for ion_channel_name, channel_info in chan_info.items() %} + this->{{ion_channel_name}}{{channel_suffix}}_shared_current.push_back(0.0); + {% endfor -%} + + {%- for concentration_name, concentration_info in conc_info.items() %} + this->{{concentration_name}}{{concentration_suffix}}_shared_concentration.push_back(0.0); + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} + this->{{receptor_name}}{{receptor_suffix}}_shared_current.push_back(0.0); + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + this->{{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_shared_current.push_back(0.0); + {% endfor -%} + {% endfor -%} + + {%- for continuous_name, continuous_info in con_in_info.items() %} + this->{{continuous_name}}{{continuous_suffix}}_shared_current.push_back(0.0); + {% endfor -%} + }; + + void add_receptor_info( ArrayDatum& ad, long compartment_index ) + { + {%- with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + for( std::size_t syn_it = 0; syn_it != {{receptor_name}}{{receptor_suffix}}.neuron_{{receptor_name}}_receptor_count; syn_it++) + { + DictionaryDatum dd = DictionaryDatum( new Dictionary ); + def< long >( dd, names::receptor_idx, syn_it ); + def< long >( dd, names::comp_idx, compartment_index ); + def< std::string >( dd, names::receptor_type, "{{receptor_name}}" ); + ad.push_back( dd ); + } + {% endfor -%} + + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + for( std::size_t syn_it = 0; syn_it != {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.neuron_{{receptor_name}}_receptor_count; syn_it++) + { + DictionaryDatum dd = DictionaryDatum( new Dictionary ); + def< long >( dd, names::receptor_idx, syn_it ); + def< long >( dd, names::comp_idx, compartment_index ); + def< std::string >( dd, names::receptor_type, "{{receptor_name}}_{{synapse_name}}" ); + ad.push_back( dd ); + } + {% endfor -%} + {% endfor -%} + + {%- for continuous_name, continuous_info in con_in_info.items() %} + for( std::size_t con_it = 0; con_it != {{continuous_name}}{{continuous_suffix}}.neuron_{{continuous_name}}_continuous_input_count; con_it++) + { + DictionaryDatum dd = DictionaryDatum( new Dictionary ); + def< long >( dd, names::receptor_idx, con_it ); + def< long >( dd, names::comp_idx, compartment_index ); + def< std::string >( dd, names::receptor_type, "{{continuous_name}}" ); + ad.push_back( dd ); + } + {% endfor -%} + + {% endwith -%} + }; + + void set_buffers( std::vector< RingBuffer >& buffers) + { + // spike and continuous buffers for receptors and continuous inputs + + {%- with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + {{receptor_name}}{{ receptor_suffix }}.set_buffer_ptr( buffers ); + {% endfor -%} + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + {{receptor_name}}{{ receptor_suffix }}_con_{{synapse_name}}.set_buffer_ptr( buffers ); + {% endfor -%} + {% endfor -%} + {%- for continuous_name, continuous_info in con_in_info.items() %} + {{continuous_name}}{{ continuous_suffix }}.set_buffer_ptr( buffers ); + {% endfor -%} + {% endwith %} + + }; + + std::map< Name, double* > get_recordables( const long compartment_idx ) + { + std::map< Name, double* > recordables; + + // append ion channel state variables to recordables + {%- with %} + {%- for ion_channel_name, channel_info in chan_info.items() %} + {{ion_channel_name}}{{channel_suffix}}.append_recordables( &recordables, compartment_idx ); + {% endfor %} + {% endwith %} + + // append concentration state variables to recordables + {%- with %} + {%- for concentration_name, concentration_info in conc_info.items() %} + {{concentration_name}}{{concentration_suffix}}.append_recordables( &recordables, compartment_idx ); + {% endfor %} + {% endwith %} + + // append receptor state variables to recordables + {%- with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + {{receptor_name}}{{receptor_suffix}}.append_recordables( &recordables, compartment_idx ); + {% endfor %} + {% endwith %} + + // append receptor with synapse state variables to recordables + {%- with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.append_recordables( &recordables, compartment_idx ); + {% endfor %} + {% endfor %} + {% endwith %} + + // append continuous input state variables to recordables + {%- with %} + {%- for continuous_name, continuous_info in con_in_info.items() %} + {{continuous_name}}{{continuous_suffix}}.append_recordables( &recordables, compartment_idx ); + {% endfor %} + {% endwith %} + + global_dynamics.append_recordables( &recordables, compartment_idx ); + + return recordables; + }; + + std::vector< std::pair< double, double > > f_numstep( std::vector< double > v_comp_vec, const long lag ) + { + std::vector< std::pair< double, double > > comp_to_gi(compartment_number, std::make_pair(0., 0.)); +{%- for receptor_name, receptor_info in recs_info.items() %} + {{receptor_name}}{{receptor_suffix}}.get_currents_per_compartment({{receptor_name}}{{receptor_suffix}}_shared_current); +{% endfor %} +{%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.get_currents_per_compartment({{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_shared_current); + for(size_t i = 0; i < {{receptor_name}}{{receptor_suffix}}_shared_current.size(); i++){ + {{receptor_name}}{{receptor_suffix}}_shared_current[i] += {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_shared_current[i]; + } + {% endfor %} +{% endfor %} +{%- for continuous_name, continuous_info in con_in_info.items() %} + {{continuous_name}}{{continuous_suffix}}.get_currents_per_compartment({{continuous_name}}{{continuous_suffix}}_shared_current); +{% endfor %} +{%- for concentration_name, concentration_info in conc_info.items() %} + {{ concentration_name }}{{concentration_suffix}}.get_concentrations_per_compartment({{concentration_name}}{{concentration_suffix}}_shared_concentration); +{% endfor -%} +{%- for ion_channel_name, channel_info in chan_info.items() %} + {{ion_channel_name}}{{channel_suffix}}.get_currents_per_compartment({{ion_channel_name}}{{channel_suffix}}_shared_current); +{% endfor -%} + + + {%- with %} + {%- for concentration_name, concentration_info in conc_info.items() %} + // computation of {{ concentration_name }} concentration + {{ concentration_name }}{{concentration_suffix}}.f_numstep( global_dynamics.get_self_spikes(), {{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector(v_comp_vec){% for ode in concentration_info["Dependencies"]["concentrations"] %}, {{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if concentration_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in concentration_info["Dependencies"]["receptors"] %}, {{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector({{inline.variable_name}}{{receptor_suffix}}_shared_current){% endfor %}{% if concentration_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in concentration_info["Dependencies"]["channels"] %}, {{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector({{inline.variable_name}}{{channel_suffix}}_shared_current){% endfor %}{% if concentration_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in concentration_info["Dependencies"]["continuous"] %}, {{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector({{inline.variable_name}}{{continuous_suffix}}_shared_current){% endfor %}{% if concentration_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in concentration_info["Dependencies"]["global"] %}, {{ concentration_name }}{{concentration_suffix}}.distribute_shared_vector(global_dynamics.get_{{ printer_no_origin.print(state) }}()){% endfor %}); + + {% endfor -%} + {% endwith -%} + + std::pair< std::vector< double >, std::vector< double > > gi_mech; + std::size_t con_area_count; + + {%- with %} + {%- for ion_channel_name, channel_info in chan_info.items() %} + // contribution of {{ion_channel_name}} channel + gi_mech = {{ion_channel_name}}{{channel_suffix}}.f_numstep( global_dynamics.get_self_spikes(), {{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector(v_comp_vec){% for ode in channel_info["Dependencies"]["concentrations"] %}, {{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if channel_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in channel_info["Dependencies"]["receptors"] %}, {{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector({{inline.variable_name}}{{receptor_suffix}}_shared_current){% endfor %}{% if channel_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in channel_info["Dependencies"]["channels"] %}, {{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector({{inline.variable_name}}{{channel_suffix}}_shared_current){% endfor %}{% if channel_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in channel_info["Dependencies"]["continuous"] %}, {{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector({{inline.variable_name}}{{continuous_suffix}}_shared_current){% endfor %}{% if channel_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in channel_info["Dependencies"]["global"] %}, {{ion_channel_name}}{{channel_suffix}}.distribute_shared_vector(global_dynamics.get_{{ printer_no_origin.print(state) }}()){% endfor %}); + + con_area_count = {{ion_channel_name}}{{channel_suffix}}_con_area.size(); + if(con_area_count > 0){ + for(std::size_t con_area_index = 0; con_area_index < con_area_count-1; con_area_index++){ + std::size_t con_area = {{ion_channel_name}}{{channel_suffix}}_con_area[con_area_index].first; + std::size_t next_con_area = {{ion_channel_name}}{{channel_suffix}}_con_area[con_area_index+1].first; + int offset = {{ion_channel_name}}{{channel_suffix}}_con_area[con_area_index].second; + + #pragma omp simd + for(std::size_t chan_id = con_area; chan_id < next_con_area; chan_id++){ + comp_to_gi[chan_id+offset].first += gi_mech.first[chan_id]; + comp_to_gi[chan_id+offset].second += gi_mech.second[chan_id]; + } + } + + std::size_t con_area = {{ion_channel_name}}{{channel_suffix}}_con_area[con_area_count-1].first; + int offset = {{ion_channel_name}}{{channel_suffix}}_con_area[con_area_count-1].second; + + #pragma omp simd + for(std::size_t chan_id = con_area; chan_id < {{ion_channel_name}}{{channel_suffix}}.neuron_{{ ion_channel_name }}_channel_count; chan_id++){ + comp_to_gi[chan_id+offset].first += gi_mech.first[chan_id]; + comp_to_gi[chan_id+offset].second += gi_mech.second[chan_id]; + } + } + {% endfor -%} + {% endwith -%} + + {%- with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + // contribution of {{receptor_name}} receptors + gi_mech = {{receptor_name}}{{receptor_suffix}}.f_numstep( global_dynamics.get_self_spikes(), {{receptor_name}}{{receptor_suffix}}.distribute_shared_vector(v_comp_vec), lag {% for ode in receptor_info["Dependencies"]["concentrations"] %}, {{receptor_name}}{{receptor_suffix}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if receptor_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["receptors"] %}, {{receptor_name}}{{receptor_suffix}}.distribute_shared_vector({{inline.variable_name}}{{receptor_suffix}}_shared_current){% endfor %}{% if receptor_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["channels"] %}, {{receptor_name}}{{receptor_suffix}}.distribute_shared_vector({{inline.variable_name}}{{channel_suffix}}_shared_current){% endfor %}{% if receptor_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in receptor_info["Dependencies"]["continuous"] %}, {{receptor_name}}{{receptor_suffix}}.distribute_shared_vector({{inline.variable_name}}{{continuous_suffix}}_shared_current){% endfor %}{% if receptor_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in receptor_info["Dependencies"]["global"] %}, {{receptor_name}}{{receptor_suffix}}.distribute_shared_vector(global_dynamics.get_{{ printer_no_origin.print(state) }}()){% endfor %}); + + con_area_count = {{receptor_name}}{{receptor_suffix}}_con_area.size(); + if(con_area_count > 0){ + for(std::size_t con_area_index = 0; con_area_index < con_area_count-1; con_area_index++){ + std::size_t con_area = {{receptor_name}}{{receptor_suffix}}_con_area[con_area_index].first; + std::size_t next_con_area = {{receptor_name}}{{receptor_suffix}}_con_area[con_area_index+1].first; + int offset = {{receptor_name}}{{receptor_suffix}}_con_area[con_area_index].second; + + #pragma omp simd + for(std::size_t syn_id = con_area; syn_id < next_con_area; syn_id++){ + comp_to_gi[syn_id+offset].first += gi_mech.first[syn_id]; + comp_to_gi[syn_id+offset].second += gi_mech.second[syn_id]; + } + } + + std::size_t con_area = {{receptor_name}}{{receptor_suffix}}_con_area[con_area_count-1].first; + int offset = {{receptor_name}}{{receptor_suffix}}_con_area[con_area_count-1].second; + + #pragma omp simd + for(std::size_t syn_id = con_area; syn_id < {{receptor_name}}{{receptor_suffix}}.neuron_{{ receptor_name }}_receptor_count; syn_id++){ + comp_to_gi[syn_id+offset].first += gi_mech.first[syn_id]; + comp_to_gi[syn_id+offset].second += gi_mech.second[syn_id]; + } + } + {% endfor -%} + {% endwith -%} + +{%- with %} + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + // contribution of {{receptor_name}}_{{synapse_name}} receptors + {%- with %} + {%- set conc_dep = set(receptor_info["Dependencies"]["concentrations"]).union(synapse_info["Dependencies"]["concentrations"])%} + {%- set rec_dep = set(receptor_info["Dependencies"]["receptors"]).union(synapse_info["Dependencies"]["receptors"])%} + {%- set chan_dep = set(receptor_info["Dependencies"]["channels"]).union(synapse_info["Dependencies"]["channels"])%} + {%- set con_in_dep = set(receptor_info["Dependencies"]["continuous"]).union(synapse_info["Dependencies"]["continuous"])%} + gi_mech = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.f_numstep( global_dynamics.get_self_spikes(), {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.distribute_shared_vector(v_comp_vec), lag {% for ode in conc_dep %}, {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if rec_dep|length %} + {% endif %}{% for inline in rec_dep %}, {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.distribute_shared_vector({{inline.variable_name}}{{receptor_suffix}}_shared_current){% endfor %}{% if chan_dep|length %} + {% endif %}{% for inline in chan_dep %}, {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.distribute_shared_vector({{inline.variable_name}}{{channel_suffix}}_shared_current){% endfor %}{% if con_in_dep|length %} + {% endif %}{% for inline in con_in_dep %}, {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.distribute_shared_vector({{inline.variable_name}}{{continuous_suffix}}_shared_current){% endfor %}{% if receptor_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in receptor_info["Dependencies"]["global"] %}, {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.distribute_shared_vector(global_dynamics.get_{{ printer_no_origin.print(state) }}()){% endfor %}); + {%- endwith %} + con_area_count = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area.size(); + if(con_area_count > 0){ + for(std::size_t con_area_index = 0; con_area_index < con_area_count-1; con_area_index++){ + std::size_t con_area = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area[con_area_index].first; + std::size_t next_con_area = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area[con_area_index+1].first; + int offset = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area[con_area_index].second; + + #pragma omp simd + for(std::size_t syn_id = con_area; syn_id < next_con_area; syn_id++){ + comp_to_gi[syn_id+offset].first += gi_mech.first[syn_id]; + comp_to_gi[syn_id+offset].second += gi_mech.second[syn_id]; + } + } + + std::size_t con_area = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area[con_area_count-1].first; + int offset = {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}_con_area[con_area_count-1].second; + + #pragma omp simd + for(std::size_t syn_id = con_area; syn_id < {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.neuron_{{ receptor_name }}_receptor_count; syn_id++){ + comp_to_gi[syn_id+offset].first += gi_mech.first[syn_id]; + comp_to_gi[syn_id+offset].second += gi_mech.second[syn_id]; + } + } + {% endfor -%} + {% endfor -%} + {% endwith -%} + + {%- with %} + {%- for continuous_name, continuous_info in con_in_info.items() %} + // contribution of {{continuous_name}} continuous inputs + gi_mech = {{continuous_name}}{{continuous_suffix}}.f_numstep( global_dynamics.get_self_spikes(), {{continuous_name}}{{continuous_suffix}}.distribute_shared_vector(v_comp_vec), lag {% for ode in continuous_info["Dependencies"]["concentrations"] %}, {{continuous_name}}{{continuous_suffix}}.distribute_shared_vector({{ode.lhs.name}}{{concentration_suffix}}_shared_concentration){% endfor %}{% if continuous_info["Dependencies"]["receptors"]|length %} + {% endif %}{% for inline in continuous_info["Dependencies"]["receptors"] %}, {{continuous_name}}{{continuous_suffix}}.distribute_shared_vector({{inline.variable_name}}{{receptor_suffix}}_shared_current){% endfor %}{% if continuous_info["Dependencies"]["channels"]|length %} + {% endif %}{% for inline in continuous_info["Dependencies"]["channels"] %}, {{continuous_name}}{{continuous_suffix}}.distribute_shared_vector({{inline.variable_name}}{{channel_suffix}}_shared_current){% endfor %}{% if continuous_info["Dependencies"]["continuous"]|length %} + {% endif %}{% for inline in continuous_info["Dependencies"]["continuous"] %}, {{continuous_name}}{{continuous_suffix}}.distribute_shared_vector({{inline.variable_name}}{{continuous_suffix}}_shared_current){% endfor %}{% if continuous_info["Dependencies"]["global"]|length %} + {% endif %}{% for state in continuous_info["Dependencies"]["global"] %}, {{continuous_name}}{{continuous_suffix}}.distribute_shared_vector(global_dynamics.get_{{ printer_no_origin.print(state) }}()){% endfor %}); + + con_area_count = {{continuous_name}}{{continuous_suffix}}_con_area.size(); + if(con_area_count > 0){ + for(std::size_t con_area_index = 0; con_area_index < con_area_count-1; con_area_index++){ + std::size_t con_area = {{continuous_name}}{{continuous_suffix}}_con_area[con_area_index].first; + std::size_t next_con_area = {{continuous_name}}{{continuous_suffix}}_con_area[con_area_index+1].first; + int offset = {{continuous_name}}{{continuous_suffix}}_con_area[con_area_index].second; + + #pragma omp simd + for(std::size_t cont_id = con_area; cont_id < next_con_area; cont_id++){ + comp_to_gi[cont_id+offset].first += gi_mech.first[cont_id]; + comp_to_gi[cont_id+offset].second += gi_mech.second[cont_id]; + } + } + + std::size_t con_area = {{continuous_name}}{{continuous_suffix}}_con_area[con_area_count-1].first; + int offset = {{continuous_name}}{{continuous_suffix}}_con_area[con_area_count-1].second; + + #pragma omp simd + for(std::size_t cont_id = con_area; cont_id < {{continuous_name}}{{continuous_suffix}}.neuron_{{ continuous_name }}_continuous_input_count; cont_id++){ + comp_to_gi[cont_id+offset].first += gi_mech.first[cont_id]; + comp_to_gi[cont_id+offset].second += gi_mech.second[cont_id]; + } + } + {% endfor -%} + {% endwith -%} + + global_dynamics.f_numstep(v_comp_vec); + + return comp_to_gi; + }; + + void postsynaptic_synaptic_processing(){ + {%- for receptor_name, receptor_info in recs_info.items() %} + {%- for synapse_name, synapse_info in syns_info.items() %} + {{receptor_name}}{{receptor_suffix}}_con_{{synapse_name}}.postsynaptic_synaptic_processing(); + {% endfor -%} + {% endfor -%} + global_dynamics.f_self_spike(); + }; +}; + +} // namespace + +#endif /* #ifndef receptorS_NEAT_H_{{cm_unique_suffix | upper }} */ diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_tree_@NEURON_NAME@.cpp.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_tree_@NEURON_NAME@.cpp.jinja2 index 38bf6d446..19d97ebc8 100644 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_tree_@NEURON_NAME@.cpp.jinja2 +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_tree_@NEURON_NAME@.cpp.jinja2 @@ -21,52 +21,69 @@ */ #include "{{neuronSpecificFileNamesCmSyns["tree"]}}.h" - -nest::Compartment{{cm_unique_suffix}}::Compartment{{cm_unique_suffix}}( const long compartment_index, const long parent_index ) +nest::Compartment{{cm_unique_suffix}}::Compartment{{cm_unique_suffix}}( const long compartment_index, const long parent_index) : xx_( 0.0 ) , yy_( 0.0 ) , comp_index( compartment_index ) , p_index( parent_index ) , parent( nullptr ) - , v_comp( 0.0 ) + , v_comp( new double(0.0) ) , ca( 1.0 ) , gc( 0.01 ) , gl( 0.1 ) , el( -70. ) - , gg0( 0.0 ) - , ca__div__dt( 0.0 ) - , gl__div__2( 0.0 ) - , gc__div__2( 0.0 ) - , gl__times__el( 0.0 ) - , ff( 0.0 ) - , gg( 0.0 ) + , gg0( nullptr ) + , ca__div__dt( nullptr ) + , gl__times__el( nullptr ) + , ff( nullptr ) + , gg( nullptr ) , hh( 0.0 ) , n_passed( 0 ) { - v_comp = el; + *v_comp = el; +} - compartment_currents = CompartmentCurrents{{cm_unique_suffix}}(); +nest::Compartment{{cm_unique_suffix}}::Compartment{{cm_unique_suffix}}( const long compartment_index, const long parent_index, + double* v_comp_ref, double* ca__div__dt_ref, double* gl__times__el_ref, double* gg0_ref, double* gg_ref, double* ff_ref ) + : xx_( 0.0 ) + , yy_( 0.0 ) + , comp_index( compartment_index ) + , p_index( parent_index ) + , parent( nullptr ) + , v_comp( v_comp_ref ) + , ca( 1.0 ) + , gc( 0.01 ) + , gl( 0.1 ) + , el( -70. ) + , gg0( gg0_ref ) + , ca__div__dt( ca__div__dt_ref ) + , gl__times__el( gl__times__el_ref ) + , ff( ff_ref ) + , gg( gg_ref ) + , hh( 0.0 ) + , n_passed( 0 ) +{ + *v_comp = el; } nest::Compartment{{cm_unique_suffix}}::Compartment{{cm_unique_suffix}}( const long compartment_index, const long parent_index, - const DictionaryDatum& compartment_params ) + const DictionaryDatum& compartment_params, + double* v_comp_ref, double* ca__div__dt_ref, double* gl__times__el_ref, double* gg0_ref, double* gg_ref, double* ff_ref) : xx_( 0.0 ) , yy_( 0.0 ) , comp_index( compartment_index ) , p_index( parent_index ) , parent( nullptr ) - , v_comp( 0.0 ) + , v_comp( v_comp_ref ) , ca( 1.0 ) , gc( 0.01 ) , gl( 0.1 ) , el( -70. ) - , gg0( 0.0 ) - , ca__div__dt( 0.0 ) - , gl__div__2( 0.0 ) - , gc__div__2( 0.0 ) - , gl__times__el( 0.0 ) - , ff( 0.0 ) - , gg( 0.0 ) + , gg0( gg0_ref ) + , ca__div__dt( ca__div__dt_ref ) + , gl__times__el( gl__times__el_ref ) + , ff( ff_ref ) + , gg( gg_ref ) , hh( 0.0 ) , n_passed( 0 ) { @@ -75,10 +92,10 @@ nest::Compartment{{cm_unique_suffix}}::Compartment{{cm_unique_suffix}}( const lo updateValue< double >( compartment_params, names::g_C, gc ); updateValue< double >( compartment_params, names::g_L, gl ); updateValue< double >( compartment_params, names::e_L, el ); + double v_comp_update = el; + if( compartment_params->known( "v_comp" ) ) updateValue< double >( compartment_params, "v_comp", v_comp_update); - v_comp = el; - - compartment_currents = CompartmentCurrents{{cm_unique_suffix}}( compartment_params ); + *v_comp = v_comp_update; } void @@ -88,73 +105,39 @@ nest::Compartment{{cm_unique_suffix}}::calibrate() nest::Compartment{{cm_unique_suffix}}::pre_run_hook() {%- endif %} { -{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} - compartment_currents.calibrate(); -{%- else %} - compartment_currents.pre_run_hook(); -{%- endif %} const double dt = Time::get_resolution().get_ms(); - ca__div__dt = ca / dt; - gl__div__2 = gl / 2.; - gg0 = ca__div__dt + gl__div__2; - gc__div__2 = gc / 2.; - gl__times__el = gl * el; - - // initialize the buffer - currents.clear(); + // used in vectorized passive loop + *ca__div__dt = ca / dt; + *gl__times__el = gl * el; + *gg0 = *ca__div__dt + gl; } std::map< Name, double* > nest::Compartment{{cm_unique_suffix}}::get_recordables() { - std::map< Name, double* > recordables = compartment_currents.get_recordables( comp_index ); + std::map< Name, double* > recordables; - recordables.insert( recordables.begin(), recordables.end() ); - recordables[ Name( "v_comp" + std::to_string( comp_index ) ) ] = &v_comp; + recordables[ Name( "v_comp" + std::to_string( comp_index ) ) ] = v_comp; return recordables; } // for matrix construction void -nest::Compartment{{cm_unique_suffix}}::construct_matrix_element( const long lag ) +nest::Compartment{{cm_unique_suffix}}::construct_matrix_coupling_elements() { - // matrix diagonal element - gg = gg0; - if ( parent != nullptr ) { - gg += gc__div__2; + *gg += gc; // matrix off diagonal element - hh = -gc__div__2; - } - - for ( auto child_it = children.begin(); child_it != children.end(); ++child_it ) - { - gg += ( *child_it ).gc__div__2; - } - - // right hand side - ff = ( ca__div__dt - gl__div__2 ) * v_comp + gl__times__el; - - if ( parent != nullptr ) - { - ff -= gc__div__2 * ( v_comp - parent->v_comp ); + hh = -gc; } for ( auto child_it = children.begin(); child_it != children.end(); ++child_it ) { - ff -= ( *child_it ).gc__div__2 * ( v_comp - ( *child_it ).v_comp ); + *gg += ( *child_it ).gc; } - - // add all currents to compartment - std::pair< double, double > gi = compartment_currents.f_numstep( v_comp, lag ); - gg += gi.first; - ff += gi.second; - - // add input current - ff += currents.get_value( lag ); } @@ -174,14 +157,128 @@ nest::CompTree{{cm_unique_suffix}}::CompTree{{cm_unique_suffix}}() void nest::CompTree{{cm_unique_suffix}}::add_compartment( const long parent_index ) { - Compartment{{cm_unique_suffix}}* compartment = new Compartment{{cm_unique_suffix}}( size_, parent_index ); + v_comp_vec.push_back(0.0); + ca__div__dt_vec.push_back(0.0); + gl__times__el_vec.push_back(0.0); + gg0_vec.push_back(0.0); + gg_vec.push_back(0.0); + ff_vec.push_back(0.0); + size_t comp_index = v_comp_vec.size()-1; + + Compartment{{cm_unique_suffix}}* compartment = new Compartment{{cm_unique_suffix}}( + size_, parent_index, &(v_comp_vec[comp_index]), &(ca__div__dt_vec[comp_index]), + &(gl__times__el_vec[comp_index]), &(gg0_vec[comp_index]), &(gg_vec[comp_index]), &(ff_vec[comp_index]) + ); + + neuron_currents.add_compartment(); add_compartment( compartment, parent_index ); } void nest::CompTree{{cm_unique_suffix}}::add_compartment( const long parent_index, const DictionaryDatum& compartment_params ) { - Compartment{{cm_unique_suffix}}* compartment = new Compartment{{cm_unique_suffix}}( size_, parent_index, compartment_params ); + //Check whether all passed parameters exist within the used neuron model: + Dictionary* comp_param_copy = new Dictionary(*compartment_params); + + comp_param_copy->remove(names::C_m); + comp_param_copy->remove(names::g_C); + comp_param_copy->remove(names::g_L); + comp_param_copy->remove(names::e_L); + comp_param_copy->remove("v_comp"); +{%- for ion_channel_name, channel_info in chan_info.items() %} + {%- for variable_type, variable_info in channel_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); + {%- endfor %} + {%- for variable_type, variable_info in channel_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); + {%- endfor %} + {%- for variable_type, variable_info in channel_info["ODEs"].items() %} + if( comp_param_copy->known( "{{variable_type}}" ) ) comp_param_copy->remove("{{variable_type}}"); + {%- endfor %} +{%- endfor %} +{%- for concentration_name, concentration_info in conc_info.items() %} + {%- for variable_type, variable_info in concentration_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); + {%- endfor %} +{%- endfor %} +{%- for concentration_name, concentration_info in conc_info.items() %} + {%- for variable_type, variable_info in concentration_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); + {%- endfor %} + {%- for variable_type, variable_info in concentration_info["ODEs"].items() %} + if( comp_param_copy->known( "{{variable_type}}" ) ) comp_param_copy->remove("{{variable_type}}"); + {%- endfor %} +{%- endfor %} +{%- for receptor_name, receptor_info in recs_info.items() %} + {%- for variable_type, variable_info in receptor_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); + {%- endfor %} +{%- endfor %} +{%- for receptor_name, receptor_info in recs_info.items() %} + {%- for variable_type, variable_info in receptor_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); + {%- endfor %} + {%- for variable_type, variable_info in receptor_info["ODEs"].items() %} + if( comp_param_copy->known( "{{variable_type}}" ) ) comp_param_copy->remove("{{variable_type}}"); + {%- endfor %} +{%- endfor %} +{%- for continuous_name, continuous_info in con_in_info.items() %} + {%- for variable_type, variable_info in continuous_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); + {%- endfor %} + {%- for variable_type, variable_info in continuous_info["ODEs"].items() %} + if( comp_param_copy->known( "{{variable_type}}" ) ) comp_param_copy->remove("{{variable_type}}"); + {%- endfor %} +{%- endfor %} +{%- for continuous_name, continuous_info in con_in_info.items() %} + {%- for variable_type, variable_info in continuous_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); + {%- endfor %} +{%- endfor %} +//global vars + {%- for variable_type, variable_info in global_info["Parameters"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); + {%- endfor %} + {%- for variable_type, variable_info in global_info["ODEs"].items() %} + if( comp_param_copy->known( "{{variable_type}}" ) ) comp_param_copy->remove("{{variable_type}}"); + {%- endfor %} + {%- for variable_type, variable_info in global_info["States"].items() %} + {%- set variable = variable_info["ASTVariable"] %} + if( comp_param_copy->known( "{{variable.name}}" ) ) comp_param_copy->remove("{{variable.name}}"); + {%- endfor %} + + if(!comp_param_copy->empty()){ + std::string msg = "Following parameters are invalid: "; + for(auto& param : *comp_param_copy){ + msg += param.first.toString(); + msg += "\n"; + } + throw BadParameter(msg); + } + + v_comp_vec.push_back(0.0); + ca__div__dt_vec.push_back(0.0); + gl__times__el_vec.push_back(0.0); + gg0_vec.push_back(0.0); + gg_vec.push_back(0.0); + ff_vec.push_back(0.0); + size_t comp_index = v_comp_vec.size()-1; + + Compartment{{cm_unique_suffix}}* compartment = new Compartment{{cm_unique_suffix}}( + size_, parent_index, compartment_params, &(v_comp_vec[comp_index]), &(ca__div__dt_vec[comp_index]), + &(gl__times__el_vec[comp_index]), &(gg0_vec[comp_index]), &(gg_vec[comp_index]), &(ff_vec[comp_index]) + ); + + neuron_currents.add_compartment(compartment_params); add_compartment( compartment, parent_index ); } @@ -285,6 +382,7 @@ nest::CompTree{{cm_unique_suffix}}::init_pointers() { set_parents(); set_compartments(); + set_compartment_variables(); set_leafs(); } @@ -320,6 +418,23 @@ nest::CompTree{{cm_unique_suffix}}::set_compartments() } } +/** + * Set pointer variables within a compartment + */ +void +nest::CompTree{{cm_unique_suffix}}::set_compartment_variables() +{ + //reset compartment pointers due to unsafe pointers in vectors when resizing during compartment creation + for( size_t i = 0; i < v_comp_vec.size(); i++){ + compartments_[i]->v_comp = &(v_comp_vec[i]); + compartments_[i]->ca__div__dt = &(ca__div__dt_vec[i]); + compartments_[i]->gl__times__el = &(gl__times__el_vec[i]); + compartments_[i]->gg0 = &(gg0_vec[i]); + compartments_[i]->gg = &(gg_vec[i]); + compartments_[i]->ff = &(ff_vec[i]); + } +} + /** * Creates a vector of compartment pointers of compartments that are also leafs of the tree. */ @@ -337,15 +452,12 @@ nest::CompTree{{cm_unique_suffix}}::set_leafs() } /** - * Initializes pointers for the spike buffers for all synapse receptors + * Initializes pointers for the spike buffers for all receptor receptors */ void nest::CompTree{{cm_unique_suffix}}::set_syn_buffers( std::vector< RingBuffer >& syn_buffers ) { - for ( auto compartment_it = compartments_.begin(); compartment_it != compartments_.end(); ++compartment_it ) - { - ( *compartment_it )->compartment_currents.set_syn_buffers( syn_buffers ); - } + neuron_currents.set_buffers( syn_buffers ); } /** @@ -362,7 +474,11 @@ nest::CompTree{{cm_unique_suffix}}::get_recordables() */ for ( auto compartment_it = compartments_.begin(); compartment_it != compartments_.end(); ++compartment_it ) { - std::map< Name, double* > recordables_comp = ( *compartment_it )->get_recordables(); + long comp_index = (*compartment_it)->comp_index; + std::map< Name, double* > recordables_comp = neuron_currents.get_recordables( comp_index ); + recordables.insert( recordables_comp.begin(), recordables_comp.end() ); + + recordables_comp = ( *compartment_it )->get_recordables(); recordables.insert( recordables_comp.begin(), recordables_comp.end() ); } return recordables; @@ -393,6 +509,11 @@ nest::CompTree{{cm_unique_suffix}}::pre_run_hook() ( *compartment_it )->pre_run_hook(); {%- endif %} } +{%- if nest_version.startswith("v2") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") %} + neuron_currents.calibrate(); +{%- else %} + neuron_currents.pre_run_hook(); +{%- endif %} } /** @@ -401,12 +522,7 @@ nest::CompTree{{cm_unique_suffix}}::pre_run_hook() std::vector< double > nest::CompTree{{cm_unique_suffix}}::get_voltage() const { - std::vector< double > v_comps; - for ( auto compartment_it = compartments_.cbegin(); compartment_it != compartments_.cend(); ++compartment_it ) - { - v_comps.push_back( ( *compartment_it )->v_comp ); - } - return v_comps; + return v_comp_vec; } /** @@ -415,7 +531,7 @@ nest::CompTree{{cm_unique_suffix}}::get_voltage() const double nest::CompTree{{cm_unique_suffix}}::get_compartment_voltage( const long compartment_index ) { - return compartments_[ compartment_index ]->v_comp; + return *(compartments_[ compartment_index ]->v_comp); } /** @@ -424,9 +540,23 @@ nest::CompTree{{cm_unique_suffix}}::get_compartment_voltage( const long compartm void nest::CompTree{{cm_unique_suffix}}::construct_matrix( const long lag ) { + // compute all channel currents, receptor currents, and input currents + std::vector< std::pair< double, double > > comps_gi = neuron_currents.f_numstep( v_comp_vec, lag ); + + #pragma omp simd + for( size_t i = 0; i < v_comp_vec.size(); i++ ){ + // passive current left hand side + gg_vec[i] = gg0_vec[i]; + // passive currents right hand side + ff_vec[i] = ca__div__dt_vec[i] * v_comp_vec[i] + gl__times__el_vec[i]; + + // add all currents to compartment + gg_vec[i] += comps_gi[i].first; + ff_vec[i] += comps_gi[i].second; + } for ( auto compartment_it = compartments_.begin(); compartment_it != compartments_.end(); ++compartment_it ) { - ( *compartment_it )->construct_matrix_element( lag ); + ( *compartment_it )->construct_matrix_coupling_elements(); } } diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_tree_@NEURON_NAME@.h.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_tree_@NEURON_NAME@.h.jinja2 index fe1942c03..294970755 100644 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_tree_@NEURON_NAME@.h.jinja2 +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/cm_tree_@NEURON_NAME@.h.jinja2 @@ -56,7 +56,7 @@ along with NEST. If not, see . #include "ring_buffer.h" // compartmental model -#include "{{neuronSpecificFileNamesCmSyns["compartmentcurrents"]}}.h" +#include "{{neuronSpecificFileNamesCmSyns["neuroncurrents"]}}.h" // Includes from libnestutil: #include "dict_util.h" @@ -71,6 +71,8 @@ along with NEST. If not, see . #include "dict.h" #include "dictutils.h" +#include + namespace nest { @@ -90,34 +92,33 @@ public: // tree structure indices Compartment{{cm_unique_suffix}}* parent; std::vector< Compartment{{cm_unique_suffix}} > children; - // vector for synapses - CompartmentCurrents{{cm_unique_suffix}} compartment_currents; // buffer for currents RingBuffer currents; // voltage variable - double v_comp; + double* v_comp; // electrical parameters double ca; // compartment capacitance [uF] double gc; // coupling conductance with parent (meaningless if root) [uS] double gl; // leak conductance of compartment [uS] double el; // leak current reversal potential [mV] // auxiliary variables for efficienchy - double gg0; - double ca__div__dt; - double gl__div__2; - double gc__div__2; - double gl__times__el; + double* gg0; + double* ca__div__dt; + double* gl__times__el; // for numerical integration - double ff; - double gg; + double* ff; + double* gg; double hh; // passage counter for recursion int n_passed; // constructor, destructor - Compartment{{cm_unique_suffix}}( const long compartment_index, const long parent_index ); - Compartment{{cm_unique_suffix}}( const long compartment_index, const long parent_index, const DictionaryDatum& compartment_params ); + Compartment{{cm_unique_suffix}}( const long compartment_index, const long parent_index); + Compartment{{cm_unique_suffix}}( const long compartment_index, const long parent_index, + double* v_comp_ref, double* ca__div__dt_ref, double* gl__times__el_ref, double* gg0_ref, double* gg_ref, double* ff_ref); + Compartment{{cm_unique_suffix}}( const long compartment_index, const long parent_index, const DictionaryDatum& compartment_params, + double* v_comp_ref, double* ca__div__dt_ref, double* gl__times__el_ref, double* gg0_ref, double* gg_ref, double* ff_ref); ~Compartment{{cm_unique_suffix}}(){}; // initialization @@ -129,7 +130,7 @@ public: std::map< Name, double* > get_recordables(); // matrix construction - void construct_matrix_element( const long lag ); + void construct_matrix_coupling_elements(); // maxtrix inversion inline void gather_input( const std::pair< double, double >& in ); @@ -137,7 +138,6 @@ public: inline double calc_v( const double v_in ); }; // Compartment - /* Short helper functions for solving the matrix equation. Can hopefully be inlined */ @@ -151,12 +151,12 @@ inline std::pair< double, double > nest::Compartment{{cm_unique_suffix}}::io() { // include inputs from child compartments - gg -= xx_; - ff -= yy_; + *gg -= xx_; + *ff -= yy_; // output values - double g_val( hh * hh / gg ); - double f_val( ff * hh / gg ); + double g_val( hh * hh / *gg ); + double f_val( *ff * hh / *gg ); return std::make_pair( g_val, f_val ); } @@ -168,9 +168,9 @@ nest::Compartment{{cm_unique_suffix}}::calc_v( const double v_in ) yy_ = 0.0; // compute voltage - v_comp = ( ff - v_in * hh ) / gg; + *v_comp = ( *ff - v_in * hh ) / *gg; - return v_comp; + return *v_comp; } @@ -194,13 +194,24 @@ private: // functions for pointer initialization void set_parents(); void set_compartments(); + void set_compartment_variables(); void set_leafs(); + std::vector< double > v_comp_vec; + std::vector< double > ca__div__dt_vec; + std::vector< double > gl__div__2_vec; + std::vector< double > gl__times__el_vec; + std::vector< double > gg0_vec; + std::vector< double > gg_vec; + std::vector< double > ff_vec; + public: // constructor, destructor CompTree{{cm_unique_suffix}}(); ~CompTree{{cm_unique_suffix}}(){}; + NeuronCurrents{{cm_unique_suffix}} neuron_currents; + // initialization functions for tree structure void add_compartment( const long parent_index ); void add_compartment( const long parent_index, const DictionaryDatum& compartment_params ); diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/setup/CMakeLists.txt.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/setup/CMakeLists.txt.jinja2 index dc0c1f506..40529e635 100644 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/setup/CMakeLists.txt.jinja2 +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/setup/CMakeLists.txt.jinja2 @@ -59,7 +59,7 @@ set( MODULE_NAME ${SHORT_NAME} ) set( MODULE_SOURCES {{moduleName}}.h {{moduleName}}.cpp {%- for neuron in neurons %} - {{perNeuronFileNamesCm[neuron.get_name()]["compartmentcurrents"]}}.cpp {{perNeuronFileNamesCm[neuron.get_name()]["compartmentcurrents"]}}.h + {{perNeuronFileNamesCm[neuron.get_name()]["neuroncurrents"]}}.cpp {{perNeuronFileNamesCm[neuron.get_name()]["neuroncurrents"]}}.h {{perNeuronFileNamesCm[neuron.get_name()]["main"]}}.cpp {{perNeuronFileNamesCm[neuron.get_name()]["main"]}}.h {{perNeuronFileNamesCm[neuron.get_name()]["tree"]}}.cpp {{perNeuronFileNamesCm[neuron.get_name()]["tree"]}}.h {% endfor -%} @@ -255,7 +255,7 @@ if ( BUILD_SHARED_LIBS ) add_library( ${MODULE_NAME}_module MODULE ${MODULE_SOURCES} ) set_target_properties( ${MODULE_NAME}_module PROPERTIES - COMPILE_FLAGS "${NEST_CXXFLAGS} -DLTX_MODULE" + COMPILE_FLAGS "${NEST_CXXFLAGS} -DLTX_MODULE -march=native -ffast-math" LINK_FLAGS "${NEST_LIBS}" PREFIX "" OUTPUT_NAME ${MODULE_NAME} ) diff --git a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/setup/common/ModuleClassMaster.jinja2 b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/setup/common/ModuleClassMaster.jinja2 index a5053d159..3d3758e48 100644 --- a/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/setup/common/ModuleClassMaster.jinja2 +++ b/pynestml/codegeneration/resources_nest_compartmental/cm_neuron/setup/common/ModuleClassMaster.jinja2 @@ -51,6 +51,9 @@ {% for neuron in neurons %} #include "{{perNeuronFileNamesCm[neuron.get_name()]["main"]}}.h" {% endfor %} +{% for synapse in synapses %} +#include "{{perSynapseFileNamesCm[synapse.get_name()]["main"]}}.h" +{% endfor %} class {{moduleName}} : public nest::NESTExtensionInterface { @@ -68,4 +71,7 @@ void {{moduleName}}::initialize() {%- for neuron in neurons %} nest::register_{{perNeuronFileNamesCm[neuron.get_name()]["main"]}}("{{perNeuronFileNamesCm[neuron.get_name()]["main"]}}"); {%- endfor %} +{%- for synapse in synapses %} + nest::register_{{perSynapseFileNamesCm[synapse.get_name()]["main"]}}("{{perSynapseFileNamesCm[synapse.get_name()]["main"]}}"); +{%- endfor %} } \ No newline at end of file diff --git a/pynestml/frontend/pynestml_frontend.py b/pynestml/frontend/pynestml_frontend.py index 11e7c2f00..3b684a76a 100644 --- a/pynestml/frontend/pynestml_frontend.py +++ b/pynestml/frontend/pynestml_frontend.py @@ -76,7 +76,7 @@ def transformers_from_target_name(target_name: str, options: Optional[Mapping[st options = synapse_post_neuron_co_generation.set_options(options) transformers.append(synapse_post_neuron_co_generation) - if target_name.upper() == "NEST": + if target_name.upper() in ["NEST"]: from pynestml.transformers.synapse_post_neuron_transformer import SynapsePostNeuronTransformer # co-generate neuron and synapse @@ -363,6 +363,10 @@ def generate_nest_compartmental_target(input_path: Union[str, Sequence[str]], ta codegen_opts : Optional[Mapping[str, Any]] A dictionary containing additional options for the target code generator. """ + if codegen_opts == None: + codegen_opts: Mapping[str, Any] = {"fastexp": True} + if not "fastexp" in codegen_opts: + codegen_opts["fastexp"] = True generate_target(input_path, target_platform="NEST_compartmental", target_path=target_path, logging_level=logging_level, module_name=module_name, store_log=store_log, suffix=suffix, install_path=install_path, dev=dev, codegen_opts=codegen_opts) diff --git a/pynestml/meta_model/ast_expression.py b/pynestml/meta_model/ast_expression.py index c476bb58f..ed4354647 100644 --- a/pynestml/meta_model/ast_expression.py +++ b/pynestml/meta_model/ast_expression.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- + # -*- coding: utf-8 -*- # # ast_expression.py # diff --git a/pynestml/symbols/predefined_functions.py b/pynestml/symbols/predefined_functions.py index ebd6b2290..13c74c239 100644 --- a/pynestml/symbols/predefined_functions.py +++ b/pynestml/symbols/predefined_functions.py @@ -29,6 +29,7 @@ class PredefinedFunctions: This class is used to represent all predefined functions of NESTML. """ + HEAVISIDE = "Heaviside" TIME_RESOLUTION = "resolution" TIME_STEPS = "steps" EMIT_SPIKE = "emit_spike" diff --git a/pynestml/transformers/synapse_post_neuron_transformer.py b/pynestml/transformers/synapse_post_neuron_transformer.py index b58f526c7..d4ad6497f 100644 --- a/pynestml/transformers/synapse_post_neuron_transformer.py +++ b/pynestml/transformers/synapse_post_neuron_transformer.py @@ -365,6 +365,17 @@ def transform_neuron_synapse_pair_(self, neuron, synapse): # XXX: TODO + # + # move state variable declarations from synapse to neuron + # + for state_var in syn_to_neuron_state_vars: + decls = ASTUtils.move_decls(state_var, + neuron.get_state_blocks()[0], + synapse.get_state_blocks()[0], + var_name_suffix, + block_type=BlockType.STATE) + ASTUtils.add_suffix_to_variable_names(decls, var_name_suffix) + # # move defining equations for variables from synapse to neuron # @@ -586,6 +597,7 @@ def mark_post_port(_expr=None): return new_neuron, new_synapse def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]: + new_models = [] for neuron_synapse_pair in self.get_option("neuron_synapse_pairs"): neuron_name = neuron_synapse_pair["neuron"] synapse_name = neuron_synapse_pair["synapse"] @@ -598,8 +610,10 @@ def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, raise Exception("Synapse used in pair (\"" + synapse_name + "\") not found") # XXX: log error new_neuron, new_synapse = self.transform_neuron_synapse_pair_(neuron, synapse) - models.append(new_neuron) - models.append(new_synapse) + new_models.append(new_neuron) + new_models.append(new_synapse) + for model in new_models: + models.append(model) # remove the synapses used in neuron-synapse pairs, as they can potentially not be generated independently of a neuron and would otherwise result in an error for neuron_synapse_pair in self.get_option("neuron_synapse_pairs"): diff --git a/pynestml/utils/ast_global_information_collector.py b/pynestml/utils/ast_global_information_collector.py new file mode 100644 index 000000000..132fe13ec --- /dev/null +++ b/pynestml/utils/ast_global_information_collector.py @@ -0,0 +1,560 @@ +# -*- coding: utf-8 -*- +# +# ast_global_information_collector.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from collections import defaultdict + +from pynestml.meta_model.ast_node import ASTNode +from pynestml.frontend.frontend_configuration import FrontendConfiguration +from pynestml.meta_model.ast_on_receive_block import ASTOnReceiveBlock +from pynestml.symbols.predefined_units import PredefinedUnits +#from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor +from pynestml.visitors.ast_visitor import ASTVisitor +from pynestml.utils.port_signal_type import PortSignalType + + +class ASTGlobalInformationCollector(object): + """This class contains all basic mechanism information collection. Further collectors may be implemented to collect + further information for specific mechanism types (example: ASTReceptorInformationCollector)""" + collector_visitor = None + synapse = None + + @classmethod + def __init__(cls, neuron): + cls.neuron = neuron + cls.collector_visitor = ASTMechanismInformationCollectorVisitor() + neuron.accept(cls.collector_visitor) + + @classmethod + def collect_update_block(cls, synapse, global_info): + update_block_collector_visitor = ASTUpdateBlockVisitor() + synapse.accept(update_block_collector_visitor) + global_info["UpdateBlock"] = update_block_collector_visitor.update_block + return global_info + + @classmethod + def collect_self_spike_function(cls, neuron, global_info): + on_receive_collector_visitor = ASTOnReceiveBlockCollectorVisitor() + neuron.accept(on_receive_collector_visitor) + + for function in on_receive_collector_visitor.all_on_receive_blocks: + if function.get_port_name() == "self_spikes": + global_info["SelfSpikesFunction"] = function + + return global_info + + @classmethod + def extend_variables_with_initialisations(cls, neuron, global_info): + """collects initialization expressions for all variables and parameters contained in global_info""" + var_init_visitor = VariableInitializationVisitor(global_info) + neuron.accept(var_init_visitor) + global_info["States"] = var_init_visitor.states + global_info["Parameters"] = var_init_visitor.parameters + global_info["Internals"] = var_init_visitor.internals + + return global_info + + @classmethod + def extend_variable_list_name_based_restricted(cls, extended_list, appending_list, restrictor_list): + """go through appending_list and append every variable that is not in restrictor_list to extended_list for the + purpose of not re-searching the same variable""" + for app_item in appending_list: + appendable = True + for rest_item in restrictor_list: + if rest_item.name == app_item.name: + appendable = False + break + if appendable: + extended_list.append(app_item) + + return extended_list + + @classmethod + def extend_function_call_list_name_based_restricted(cls, extended_list, appending_list, restrictor_list): + """go through appending_list and append every variable that is not in restrictor_list to extended_list for the + purpose of not re-searching the same function""" + for app_item in appending_list: + appendable = True + for rest_item in restrictor_list: + if rest_item.callee_name == app_item.callee_name: + appendable = False + break + if appendable: + extended_list.append(app_item) + + return extended_list + + @classmethod + def collect_related_definitions(cls, neuron, global_info): + """Collects all parts of the nestml code the root expressions previously collected depend on. search + is cut at other mechanisms root expressions""" + from pynestml.meta_model.ast_inline_expression import ASTInlineExpression + from pynestml.meta_model.ast_ode_equation import ASTOdeEquation + + variable_collector = ASTVariableCollectorVisitor() + neuron.accept(variable_collector) + global_states = variable_collector.all_states + global_parameters = variable_collector.all_parameters + global_internals = variable_collector.all_internals + + function_collector = ASTFunctionCollectorVisitor() + neuron.accept(function_collector) + global_functions = function_collector.all_functions + + inline_collector = ASTInlineEquationCollectorVisitor() + neuron.accept(inline_collector) + global_inlines = inline_collector.all_inlines + + ode_collector = ASTODEEquationCollectorVisitor() + neuron.accept(ode_collector) + global_odes = ode_collector.all_ode_equations + + kernel_collector = ASTKernelCollectorVisitor() + neuron.accept(kernel_collector) + global_kernels = kernel_collector.all_kernels + + continuous_input_collector = ASTContinuousInputDeclarationVisitor() + neuron.accept(continuous_input_collector) + global_continuous_inputs = continuous_input_collector.ports + + mechanism_states = list() + mechanism_parameters = list() + mechanism_internals = list() + mechanism_functions = list() + mechanism_inlines = list() + mechanism_odes = list() + synapse_kernels = list() + mechanism_continuous_inputs = list() + mechanism_dependencies = defaultdict() + mechanism_dependencies["concentrations"] = list() + mechanism_dependencies["channels"] = list() + mechanism_dependencies["receptors"] = list() + mechanism_dependencies["continuous"] = list() + + search_variables = list() + search_functions = list() + + found_variables = list() + found_functions = list() + + if "SelfSpikesFunction" in global_info and global_info["SelfSpikesFunction"] is not None: + local_variable_collector = ASTVariableCollectorVisitor() + global_info["SelfSpikesFunction"].accept(local_variable_collector) + search_variables_self_spike = local_variable_collector.all_variables + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + search_variables_self_spike, + search_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + global_info["SelfSpikesFunction"].accept(local_function_call_collector) + search_functions_self_spike = local_function_call_collector.all_function_calls + search_functions = cls.extend_function_call_list_name_based_restricted(search_functions, + search_functions_self_spike, + search_functions) + + if "UpdateBlock" in global_info and global_info["UpdateBlock"] is not None: + local_variable_collector = ASTVariableCollectorVisitor() + global_info["UpdateBlock"].accept(local_variable_collector) + search_variables_update = local_variable_collector.all_variables + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + search_variables_update, + search_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + global_info["UpdateBlock"].accept(local_function_call_collector) + search_functions_update = local_function_call_collector.all_function_calls + search_functions = cls.extend_function_call_list_name_based_restricted(search_functions, + search_functions_update, + search_functions) + + while len(search_functions) > 0 or len(search_variables) > 0: + if len(search_functions) > 0: + function_call = search_functions[0] + for function in global_functions: + if function.name == function_call.callee_name: + mechanism_functions.append(function) + found_functions.append(function_call) + + local_variable_collector = ASTVariableCollectorVisitor() + function.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + function.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted(search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + # IMPLEMENT CATCH NONDEFINED!!! + search_functions.remove(function_call) + + elif len(search_variables) > 0: + variable = search_variables[0] + if not (variable.name == "v_comp" or variable.name in PredefinedUnits.get_units()): + is_dependency = False + for inline in global_inlines: + if variable.name == inline.variable_name: + if isinstance(inline.get_decorators(), list): + if "mechanism" in [e.namespace for e in inline.get_decorators()]: + is_dependency = True + if not (isinstance(global_info["root_expression"], + ASTInlineExpression) and inline.variable_name == + global_info["root_expression"].variable_name): + if "channel" in [e.name for e in inline.get_decorators()]: + if not inline.variable_name in [i.variable_name for i in + mechanism_dependencies["channels"]]: + mechanism_dependencies["channels"].append(inline) + if "receptor" in [e.name for e in inline.get_decorators()]: + if not inline.variable_name in [i.variable_name for i in + mechanism_dependencies["receptors"]]: + mechanism_dependencies["receptors"].append(inline) + if "continuous" in [e.name for e in inline.get_decorators()]: + if not inline.variable_name in [i.variable_name for i in + mechanism_dependencies["continuous"]]: + mechanism_dependencies["continuous"].append(inline) + + if not is_dependency: + mechanism_inlines.append(inline) + + local_variable_collector = ASTVariableCollectorVisitor() + inline.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + inline.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted( + search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + + for ode in global_odes: + if variable.name == ode.lhs.name: + if isinstance(ode.get_decorators(), list): + if "mechanism" in [e.namespace for e in ode.get_decorators()]: + is_dependency = True + if not (isinstance(global_info["root_expression"], + ASTOdeEquation) and ode.lhs.name == global_info[ + "root_expression"].lhs.name): + if "concentration" in [e.name for e in ode.get_decorators()]: + if not ode.lhs.name in [o.lhs.name for o in + mechanism_dependencies["concentrations"]]: + mechanism_dependencies["concentrations"].append(ode) + + if not is_dependency: + mechanism_odes.append(ode) + + local_variable_collector = ASTVariableCollectorVisitor() + ode.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + ode.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted( + search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + + for state in global_states: + if variable.name == state.name and not is_dependency: + mechanism_states.append(state) + + for parameter in global_parameters: + if variable.name == parameter.name: + mechanism_parameters.append(parameter) + + for internal in global_internals: + if variable.name == internal.name: + mechanism_internals.append(internal) + + for kernel in global_kernels: + if variable.name == kernel.get_variables()[0].name: + synapse_kernels.append(kernel) + + local_variable_collector = ASTVariableCollectorVisitor() + kernel.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + kernel.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted(search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + + for input in global_continuous_inputs: + if variable.name == input.name: + mechanism_continuous_inputs.append(input) + search_variables.remove(variable) + found_variables.append(variable) + # IMPLEMENT CATCH NONDEFINED!!! + + global_info["States"] = mechanism_states + global_info["Parameters"] = mechanism_parameters + global_info["Internals"] = mechanism_internals + global_info["Functions"] = mechanism_functions + global_info["SecondaryInlineExpressions"] = mechanism_inlines + global_info["ODEs"] = mechanism_odes + global_info["Continuous"] = mechanism_continuous_inputs + global_info["Dependencies"] = mechanism_dependencies + + return global_info + + +class ASTMechanismInformationCollectorVisitor(ASTVisitor): + + def __init__(self): + super(ASTMechanismInformationCollectorVisitor, self).__init__() + self.inEquationsBlock = False + self.inlinesInEquationsBlock = list() + self.odes = list() + + def visit_equations_block(self, node): + self.inEquationsBlock = True + + def endvisit_equations_block(self, node): + self.inEquationsBlock = False + + def visit_inline_expression(self, node): + if self.inEquationsBlock: + self.inlinesInEquationsBlock.append(node) + + def visit_ode_equation(self, node): + self.odes.append(node) + + +class ASTUpdateBlockVisitor(ASTVisitor): + def __init__(self): + super(ASTUpdateBlockVisitor, self).__init__() + self.inside_update_block = False + self.update_block = None + + def visit_update_block(self, node): + self.inside_update_block = True + self.update_block = node.clone() + + def endvisit_update_block(self, node): + self.inside_update_block = False + + +class VariableInitializationVisitor(ASTVisitor): + def __init__(self, channel_info): + super(VariableInitializationVisitor, self).__init__() + self.inside_variable = False + self.inside_declaration = False + self.inside_parameter_block = False + self.inside_state_block = False + self.inside_internal_block = False + self.current_declaration = None + self.states = defaultdict() + self.parameters = defaultdict() + self.internals = defaultdict() + self.channel_info = channel_info + + def visit_declaration(self, node): + self.inside_declaration = True + self.current_declaration = node + + def endvisit_declaration(self, node): + self.inside_declaration = False + self.current_declaration = None + + def visit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = True + if node.is_parameters: + self.inside_parameter_block = True + if node.is_internals: + self.inside_internal_block = True + + def endvisit_block_with_variables(self, node): + self.inside_state_block = False + self.inside_parameter_block = False + self.inside_internal_block = False + + def visit_variable(self, node): + self.inside_variable = True + if self.inside_state_block and self.inside_declaration: + if any(node.name == variable.name for variable in self.channel_info["States"]): + self.states[node.name] = defaultdict() + self.states[node.name]["ASTVariable"] = node.clone() + self.states[node.name]["rhs_expression"] = self.current_declaration.get_expression() + + if self.inside_parameter_block and self.inside_declaration: + if any(node.name == variable.name for variable in self.channel_info["Parameters"]): + self.parameters[node.name] = defaultdict() + self.parameters[node.name]["ASTVariable"] = node.clone() + self.parameters[node.name]["rhs_expression"] = self.current_declaration.get_expression() + + if self.inside_internal_block and self.inside_declaration: + if any(node.name == variable.name for variable in self.channel_info["Internals"]): + self.internals[node.name] = defaultdict() + self.internals[node.name]["ASTVariable"] = node.clone() + self.internals[node.name]["rhs_expression"] = self.current_declaration.get_expression() + + def endvisit_variable(self, node): + self.inside_variable = False + + +class ASTODEEquationCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTODEEquationCollectorVisitor, self).__init__() + self.inside_ode_expression = False + self.all_ode_equations = list() + + def visit_ode_equation(self, node): + self.inside_ode_expression = True + self.all_ode_equations.append(node.clone()) + + def endvisit_ode_equation(self, node): + self.inside_ode_expression = False + + +class ASTVariableCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTVariableCollectorVisitor, self).__init__() + self.inside_variable = False + self.inside_block_with_variables = False + self.all_states = list() + self.all_parameters = list() + self.all_internals = list() + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_internals_block = False + self.all_variables = list() + + def visit_block_with_variables(self, node): + self.inside_block_with_variables = True + if node.is_state: + self.inside_states_block = True + if node.is_parameters: + self.inside_parameters_block = True + if node.is_internals: + self.inside_internals_block = True + + def endvisit_block_with_variables(self, node): + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_internals_block = False + self.inside_block_with_variables = False + + def visit_variable(self, node): + self.inside_variable = True + self.all_variables.append(node.clone()) + if self.inside_states_block: + self.all_states.append(node.clone()) + if self.inside_parameters_block: + self.all_parameters.append(node.clone()) + if self.inside_internals_block: + self.all_internals.append(node.clone()) + + def endvisit_variable(self, node): + self.inside_variable = False + + +class ASTFunctionCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTFunctionCollectorVisitor, self).__init__() + self.inside_function = False + self.all_functions = list() + + def visit_function(self, node): + self.inside_function = True + self.all_functions.append(node.clone()) + + def endvisit_function(self, node): + self.inside_function = False + + +class ASTInlineEquationCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTInlineEquationCollectorVisitor, self).__init__() + self.inside_inline_expression = False + self.all_inlines = list() + + def visit_inline_expression(self, node): + self.inside_inline_expression = True + self.all_inlines.append(node.clone()) + + def endvisit_inline_expression(self, node): + self.inside_inline_expression = False + + +class ASTFunctionCallCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTFunctionCallCollectorVisitor, self).__init__() + self.inside_function_call = False + self.all_function_calls = list() + + def visit_function_call(self, node): + self.inside_function_call = True + self.all_function_calls.append(node.clone()) + + def endvisit_function_call(self, node): + self.inside_function_call = False + + +class ASTKernelCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTKernelCollectorVisitor, self).__init__() + self.inside_kernel = False + self.all_kernels = list() + + def visit_kernel(self, node): + self.inside_kernel = True + self.all_kernels.append(node.clone()) + + def endvisit_kernel(self, node): + self.inside_kernel = False + + +class ASTContinuousInputDeclarationVisitor(ASTVisitor): + def __init__(self): + super(ASTContinuousInputDeclarationVisitor, self).__init__() + self.inside_port = False + self.current_port = None + self.ports = list() + + def visit_input_port(self, node): + self.inside_port = True + self.current_port = node + if self.current_port.is_continuous(): + self.ports.append(node.clone()) + + def endvisit_input_port(self, node): + self.inside_port = False + +class ASTOnReceiveBlockCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTOnReceiveBlockCollectorVisitor, self).__init__() + self.inside_on_receive_block = False + self.all_on_receive_blocks = list() + + def visit_on_receive_block(self, node): + self.inside_on_receive_block = True + self.all_on_receive_blocks.append(node.clone()) + + def endvisit_on_receive_block(self, node): + self.inside_on_receive_block = False diff --git a/pynestml/utils/ast_mechanism_information_collector.py b/pynestml/utils/ast_mechanism_information_collector.py index eed37601d..de978d4fa 100644 --- a/pynestml/utils/ast_mechanism_information_collector.py +++ b/pynestml/utils/ast_mechanism_information_collector.py @@ -22,12 +22,13 @@ from collections import defaultdict from pynestml.frontend.frontend_configuration import FrontendConfiguration +from pynestml.symbols.predefined_units import PredefinedUnits from pynestml.visitors.ast_visitor import ASTVisitor class ASTMechanismInformationCollector(object): """This class contains all basic mechanism information collection. Further collectors may be implemented to collect - further information for specific mechanism types (example: ASTSynapseInformationCollector)""" + further information for specific mechanism types (example: ASTReceptorInformationCollector)""" collector_visitor = None neuron = None @@ -101,11 +102,12 @@ def extend_variables_with_initialisations(cls, neuron, mechs_info): neuron.accept(var_init_visitor) mechs_info[mechanism_name]["States"] = var_init_visitor.states mechs_info[mechanism_name]["Parameters"] = var_init_visitor.parameters + mechs_info[mechanism_name]["Internals"] = var_init_visitor.internals return mechs_info @classmethod - def collect_mechanism_related_definitions(cls, neuron, mechs_info): + def collect_mechanism_related_definitions(cls, neuron, mechs_info, global_info): """Collects all parts of the nestml code the root expressions previously collected depend on. search is cut at other mechanisms root expressions""" from pynestml.meta_model.ast_inline_expression import ASTInlineExpression @@ -116,6 +118,7 @@ def collect_mechanism_related_definitions(cls, neuron, mechs_info): neuron.accept(variable_collector) global_states = variable_collector.all_states global_parameters = variable_collector.all_parameters + global_internals = variable_collector.all_internals function_collector = ASTFunctionCollectorVisitor() neuron.accept(function_collector) @@ -133,22 +136,27 @@ def collect_mechanism_related_definitions(cls, neuron, mechs_info): neuron.accept(kernel_collector) global_kernels = kernel_collector.all_kernels + continuous_input_collector = ASTContinuousInputDeclarationVisitor() + neuron.accept(continuous_input_collector) + global_continuous_inputs = continuous_input_collector.ports + mechanism_states = list() mechanism_parameters = list() + mechanism_internals = list() mechanism_functions = list() mechanism_inlines = list() mechanism_odes = list() synapse_kernels = list() + mechanism_continuous_inputs = list() mechanism_dependencies = defaultdict() mechanism_dependencies["concentrations"] = list() mechanism_dependencies["channels"] = list() mechanism_dependencies["receptors"] = list() + mechanism_dependencies["continuous"] = list() + mechanism_dependencies["global"] = list() mechanism_inlines.append(mechs_info[mechanism_name]["root_expression"]) - search_variables = list() - search_functions = list() - found_variables = list() found_functions = list() @@ -184,7 +192,7 @@ def collect_mechanism_related_definitions(cls, neuron, mechs_info): elif len(search_variables) > 0: variable = search_variables[0] - if not variable.name == "v_comp": + if not (variable.name == "v_comp" or variable.name in PredefinedUnits.get_units()): is_dependency = False for inline in global_inlines: if variable.name == inline.variable_name: @@ -200,6 +208,10 @@ def collect_mechanism_related_definitions(cls, neuron, mechs_info): if not inline.variable_name in [i.variable_name for i in mechanism_dependencies["receptors"]]: mechanism_dependencies["receptors"].append(inline) + if "continuous" in [e.name for e in inline.get_decorators()]: + if not inline.variable_name in [i.variable_name for i in + mechanism_dependencies["continuous"]]: + mechanism_dependencies["continuous"].append(inline) if not is_dependency: mechanism_inlines.append(inline) @@ -219,6 +231,9 @@ def collect_mechanism_related_definitions(cls, neuron, mechs_info): for ode in global_odes: if variable.name == ode.lhs.name: + if ode.lhs.name in global_info["States"]: + is_dependency = True + #mechanism_dependencies["global"].append(ode) if isinstance(ode.get_decorators(), list): if "mechanism" in [e.namespace for e in ode.get_decorators()]: is_dependency = True @@ -245,13 +260,21 @@ def collect_mechanism_related_definitions(cls, neuron, mechs_info): search_functions + found_functions) for state in global_states: - if variable.name == state.name and not is_dependency: - mechanism_states.append(state) + if variable.name == state.name: + if state.name in global_info["States"]: + is_dependency = True + mechanism_dependencies["global"].append(state) + if not is_dependency: + mechanism_states.append(state) for parameter in global_parameters: if variable.name == parameter.name: mechanism_parameters.append(parameter) + for internal in global_internals: + if variable.name == internal.name: + mechanism_internals.append(internal) + for kernel in global_kernels: if variable.name == kernel.get_variables()[0].name: synapse_kernels.append(kernel) @@ -268,15 +291,21 @@ def collect_mechanism_related_definitions(cls, neuron, mechs_info): local_function_call_collector.all_function_calls, search_functions + found_functions) + for input in global_continuous_inputs: + if variable.name == input.name: + mechanism_continuous_inputs.append(input) + search_variables.remove(variable) found_variables.append(variable) # IMPLEMENT CATCH NONDEFINED!!! mechs_info[mechanism_name]["States"] = mechanism_states mechs_info[mechanism_name]["Parameters"] = mechanism_parameters + mechs_info[mechanism_name]["Internals"] = mechanism_internals mechs_info[mechanism_name]["Functions"] = mechanism_functions mechs_info[mechanism_name]["SecondaryInlineExpressions"] = mechanism_inlines mechs_info[mechanism_name]["ODEs"] = mechanism_odes + mechs_info[mechanism_name]["Continuous"] = mechanism_continuous_inputs mechs_info[mechanism_name]["Dependencies"] = mechanism_dependencies return mechs_info @@ -312,9 +341,11 @@ def __init__(self, channel_info): self.inside_declaration = False self.inside_parameter_block = False self.inside_state_block = False + self.inside_internal_block = False self.current_declaration = None self.states = defaultdict() self.parameters = defaultdict() + self.internals = defaultdict() self.channel_info = channel_info def visit_declaration(self, node): @@ -330,10 +361,13 @@ def visit_block_with_variables(self, node): self.inside_state_block = True if node.is_parameters: self.inside_parameter_block = True + if node.is_internals: + self.inside_internal_block = True def endvisit_block_with_variables(self, node): self.inside_state_block = False self.inside_parameter_block = False + self.inside_internal_block = False def visit_variable(self, node): self.inside_variable = True @@ -349,6 +383,12 @@ def visit_variable(self, node): self.parameters[node.name]["ASTVariable"] = node.clone() self.parameters[node.name]["rhs_expression"] = self.current_declaration.get_expression() + if self.inside_internal_block and self.inside_declaration: + if any(node.name == variable.name for variable in self.channel_info["Internals"]): + self.internals[node.name] = defaultdict() + self.internals[node.name]["ASTVariable"] = node.clone() + self.internals[node.name]["rhs_expression"] = self.current_declaration.get_expression() + def endvisit_variable(self, node): self.inside_variable = False @@ -374,8 +414,10 @@ def __init__(self): self.inside_block_with_variables = False self.all_states = list() self.all_parameters = list() + self.all_internals = list() self.inside_states_block = False self.inside_parameters_block = False + self.inside_internals_block = False self.all_variables = list() def visit_block_with_variables(self, node): @@ -384,10 +426,13 @@ def visit_block_with_variables(self, node): self.inside_states_block = True if node.is_parameters: self.inside_parameters_block = True + if node.is_internals: + self.inside_internals_block = True def endvisit_block_with_variables(self, node): self.inside_states_block = False self.inside_parameters_block = False + self.inside_internals_block = False self.inside_block_with_variables = False def visit_variable(self, node): @@ -397,6 +442,8 @@ def visit_variable(self, node): self.all_states.append(node.clone()) if self.inside_parameters_block: self.all_parameters.append(node.clone()) + if self.inside_internals_block: + self.all_internals.append(node.clone()) def endvisit_variable(self, node): self.inside_variable = False @@ -456,3 +503,20 @@ def visit_kernel(self, node): def endvisit_kernel(self, node): self.inside_kernel = False + + +class ASTContinuousInputDeclarationVisitor(ASTVisitor): + def __init__(self): + super(ASTContinuousInputDeclarationVisitor, self).__init__() + self.inside_port = False + self.current_port = None + self.ports = list() + + def visit_input_port(self, node): + self.inside_port = True + self.current_port = node + if self.current_port.is_continuous(): + self.ports.append(node.clone()) + + def endvisit_input_port(self, node): + self.inside_port = False diff --git a/pynestml/utils/ast_receptor_information_collector.py b/pynestml/utils/ast_receptor_information_collector.py new file mode 100644 index 000000000..0e64abe95 --- /dev/null +++ b/pynestml/utils/ast_receptor_information_collector.py @@ -0,0 +1,349 @@ +# -*- coding: utf-8 -*- +# +# ast_receptor_information_collector.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from _collections import defaultdict +import copy + +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression +from pynestml.meta_model.ast_kernel import ASTKernel +from pynestml.symbols.predefined_variables import PredefinedVariables +from pynestml.visitors.ast_visitor import ASTVisitor + + +class ASTReceptorInformationCollector(ASTVisitor): + """ + for each inline expression inside the equations block, + collect all synapse relevant information + + """ + + def __init__(self): + super(ASTReceptorInformationCollector, self).__init__() + + # various dicts to store collected information + self.kernel_name_to_kernel = defaultdict() + self.inline_expression_to_kernel_args = defaultdict(lambda: set()) + self.inline_expression_to_function_calls = defaultdict(lambda: set()) + self.kernel_to_function_calls = defaultdict(lambda: set()) + self.parameter_name_to_declaration = defaultdict(lambda: None) + self.state_name_to_declaration = defaultdict(lambda: None) + self.variable_name_to_declaration = defaultdict(lambda: None) + self.internal_var_name_to_declaration = defaultdict(lambda: None) + self.inline_expression_to_variables = defaultdict(lambda: set()) + self.kernel_to_rhs_variables = defaultdict(lambda: set()) + self.declaration_to_rhs_variables = defaultdict(lambda: set()) + self.input_port_name_to_input_port = defaultdict() + + # traversal states and nodes + self.inside_parameter_block = False + self.inside_state_block = False + self.inside_internals_block = False + self.inside_equations_block = False + self.inside_input_block = False + self.inside_inline_expression = False + self.inside_kernel = False + self.inside_kernel_call = False + self.inside_declaration = False + # self.inside_variable = False + self.inside_simple_expression = False + self.inside_expression = False + # self.inside_function_call = False + + self.current_inline_expression = None + self.current_kernel = None + self.current_expression = None + self.current_simple_expression = None + self.current_declaration = None + # self.current_variable = None + + self.current_synapse_name = None + + def get_state_declaration(self, variable_name): + return self.state_name_to_declaration[variable_name] + + def get_variable_declaration(self, variable_name): + return self.variable_name_to_declaration[variable_name] + + def get_kernel_by_name(self, name: str): + return self.kernel_name_to_kernel[name] + + def get_inline_expressions_with_kernels(self): + return self.inline_expression_to_kernel_args.keys() + + def get_kernel_function_calls(self, kernel: ASTKernel): + return self.kernel_to_function_calls[kernel] + + def get_inline_function_calls(self, inline: ASTInlineExpression): + return self.inline_expression_to_function_calls[inline] + + def get_variable_names_of_synapse(self, synapse_inline: ASTInlineExpression, exclude_names: set = set(), exclude_ignorable=True) -> set: + """extracts all variables specific to a single synapse + (which is defined by the inline expression containing kernels) + independently of what block they are declared in + it also cascades over all right hand side variables until all + variables are included""" + if exclude_ignorable: + exclude_names.update(self.get_variable_names_to_ignore()) + + # find all variables used in the inline + potential_variables = self.inline_expression_to_variables[synapse_inline] + + # find all kernels referenced by the inline + # and collect variables used by those kernels + kernel_arg_pairs = self.get_extracted_kernel_args(synapse_inline) + for kernel_var, spikes_var in kernel_arg_pairs: + kernel = self.get_kernel_by_name(kernel_var.get_name()) + potential_variables.update(self.kernel_to_rhs_variables[kernel]) + + # find declarations for all variables and check + # what variables their rhs expressions use + # for example if we have + # a = b * c + # then check if b and c are already in potential_variables + # if not, add those as well + potential_variables_copy = copy.copy(potential_variables) + + potential_variables_prev_count = len(potential_variables) + while True: + for potential_variable in potential_variables_copy: + var_name = potential_variable.get_name() + if var_name in exclude_names: + continue + declaration = self.get_variable_declaration(var_name) + if declaration is None: + continue + variables_referenced = self.declaration_to_rhs_variables[var_name] + potential_variables.update(variables_referenced) + if potential_variables_prev_count == len(potential_variables): + break + potential_variables_prev_count = len(potential_variables) + + # transform variables into their names and filter + # out anything form exclude_names + result = set() + for potential_variable in potential_variables: + var_name = potential_variable.get_name() + if var_name not in exclude_names: + result.add(var_name) + + return result + + @classmethod + def get_variable_names_to_ignore(cls): + return set(PredefinedVariables.get_variables().keys()).union({"v_comp"}) + + def get_synapse_specific_internal_declarations(self, synapse_inline: ASTInlineExpression) -> defaultdict: + synapse_variable_names = self.get_variable_names_of_synapse( + synapse_inline) + + # now match those variable names with + # variable declarations from the internals block + dereferenced = defaultdict() + for potential_internals_name in synapse_variable_names: + if potential_internals_name in self.internal_var_name_to_declaration: + dereferenced[potential_internals_name] = self.internal_var_name_to_declaration[potential_internals_name] + return dereferenced + + def get_synapse_specific_state_declarations(self, synapse_inline: ASTInlineExpression) -> defaultdict: + synapse_variable_names = self.get_variable_names_of_synapse( + synapse_inline) + + # now match those variable names with + # variable declarations from the state block + dereferenced = defaultdict() + for potential_state_name in synapse_variable_names: + if potential_state_name in self.state_name_to_declaration: + dereferenced[potential_state_name] = self.state_name_to_declaration[potential_state_name] + return dereferenced + + def get_synapse_specific_parameter_declarations(self, synapse_inline: ASTInlineExpression) -> defaultdict: + synapse_variable_names = self.get_variable_names_of_synapse( + synapse_inline) + + # now match those variable names with + # variable declarations from the parameter block + dereferenced = defaultdict() + for potential_param_name in synapse_variable_names: + if potential_param_name in self.parameter_name_to_declaration: + dereferenced[potential_param_name] = self.parameter_name_to_declaration[potential_param_name] + return dereferenced + + def get_extracted_kernel_args(self, inline_expression: ASTInlineExpression) -> set: + return self.inline_expression_to_kernel_args[inline_expression] + + def get_basic_kernel_variable_names(self, synapse_inline): + """ + for every occurence of convolve(port, spikes) generate "port__X__spikes" variable + gather those variables for this synapse inline and return their list + + note that those variables will occur as substring in other kernel variables i.e "port__X__spikes__d" or "__P__port__X__spikes__port__X__spikes" + + so we can use the result to identify all the other kernel variables related to the + specific synapse inline declaration + """ + order = 0 + results = [] + for syn_inline, args in self.inline_expression_to_kernel_args.items(): + if synapse_inline.variable_name == syn_inline.variable_name: + for kernel_var, spike_var in args: + kernel_name = kernel_var.get_name() + spike_input_port = self.input_port_name_to_input_port[spike_var.get_name( + )] + kernel_variable_name = self.construct_kernel_X_spike_buf_name( + kernel_name, spike_input_port, order) + results.append(kernel_variable_name) + + return results + + def get_used_kernel_names(self, inline_expression: ASTInlineExpression): + return [kernel_var.get_name() for kernel_var, _ in self.get_extracted_kernel_args(inline_expression)] + + def get_input_port_by_name(self, name): + return self.input_port_name_to_input_port[name] + + def get_used_spike_names(self, inline_expression: ASTInlineExpression): + return [spikes_var.get_name() for _, spikes_var in self.get_extracted_kernel_args(inline_expression)] + + def visit_kernel(self, node): + self.current_kernel = node + self.inside_kernel = True + if self.inside_equations_block: + kernel_name = node.get_variables()[0].get_name_of_lhs() + self.kernel_name_to_kernel[kernel_name] = node + + def visit_function_call(self, node): + if self.inside_equations_block: + if self.inside_inline_expression and self.inside_simple_expression: + if node.get_name() == "convolve": + self.inside_kernel_call = True + kernel, spikes = node.get_args() + kernel_var = kernel.get_variables()[0] + spikes_var = spikes.get_variables()[0] + if "mechanism::receptor" in [(e.namespace + "::" + e.name) for e in self.current_inline_expression.get_decorators()]: + self.inline_expression_to_kernel_args[self.current_inline_expression].add( + (kernel_var, spikes_var)) + else: + self.inline_expression_to_function_calls[self.current_inline_expression].add( + node) + if self.inside_kernel and self.inside_simple_expression: + self.kernel_to_function_calls[self.current_kernel].add(node) + + def endvisit_function_call(self, node): + self.inside_kernel_call = False + + def endvisit_kernel(self, node): + self.current_kernel = None + self.inside_kernel = False + + def visit_variable(self, node): + if self.inside_inline_expression and not self.inside_kernel_call: + self.inline_expression_to_variables[self.current_inline_expression].add( + node) + elif self.inside_kernel and (self.inside_expression or self.inside_simple_expression): + self.kernel_to_rhs_variables[self.current_kernel].add(node) + elif self.inside_declaration and self.inside_expression: + declared_variable = self.current_declaration.get_variables()[ + 0].get_name() + self.declaration_to_rhs_variables[declared_variable].add(node) + + def visit_inline_expression(self, node): + self.inside_inline_expression = True + self.current_inline_expression = node + + def endvisit_inline_expression(self, node): + self.inside_inline_expression = False + self.current_inline_expression = None + + def visit_equations_block(self, node): + self.inside_equations_block = True + + def endvisit_equations_block(self, node): + self.inside_equations_block = False + + def visit_input_block(self, node): + self.inside_input_block = True + + def visit_input_port(self, node): + self.input_port_name_to_input_port[node.get_name()] = node + + def endvisit_input_block(self, node): + self.inside_input_block = False + + def visit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = True + if node.is_parameters: + self.inside_parameter_block = True + if node.is_internals: + self.inside_internals_block = True + + def endvisit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = False + if node.is_parameters: + self.inside_parameter_block = False + if node.is_internals: + self.inside_internals_block = False + + def visit_simple_expression(self, node): + self.inside_simple_expression = True + self.current_simple_expression = node + + def endvisit_simple_expression(self, node): + self.inside_simple_expression = False + self.current_simple_expression = None + + def visit_declaration(self, node): + self.inside_declaration = True + self.current_declaration = node + + # collect decalarations generally + variable_name = node.get_variables()[0].get_name() + self.variable_name_to_declaration[variable_name] = node + + # collect declarations per block + if self.inside_parameter_block: + self.parameter_name_to_declaration[variable_name] = node + elif self.inside_state_block: + self.state_name_to_declaration[variable_name] = node + elif self.inside_internals_block: + self.internal_var_name_to_declaration[variable_name] = node + + def endvisit_declaration(self, node): + self.inside_declaration = False + self.current_declaration = None + + def visit_expression(self, node): + self.inside_expression = True + self.current_expression = node + + def endvisit_expression(self, node): + self.inside_expression = False + self.current_expression = None + + # this method was copied over from ast_transformer + # in order to avoid a circular dependency + @staticmethod + def construct_kernel_X_spike_buf_name(kernel_var_name: str, spike_input_port, order: int, diff_order_symbol="__d"): + assert type(kernel_var_name) is str + assert type(order) is int + assert type(diff_order_symbol) is str + return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + str(spike_input_port) + diff_order_symbol * order diff --git a/pynestml/utils/ast_synapse_information_collector.py b/pynestml/utils/ast_synapse_information_collector.py index f5a6763bc..f6f345c33 100644 --- a/pynestml/utils/ast_synapse_information_collector.py +++ b/pynestml/utils/ast_synapse_information_collector.py @@ -18,25 +18,345 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - -from _collections import defaultdict import copy +from collections import defaultdict +from pynestml.meta_model.ast_node import ASTNode +from pynestml.frontend.frontend_configuration import FrontendConfiguration from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.meta_model.ast_kernel import ASTKernel +from pynestml.meta_model.ast_on_receive_block import ASTOnReceiveBlock +from pynestml.symbols.predefined_units import PredefinedUnits from pynestml.symbols.predefined_variables import PredefinedVariables +#from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor from pynestml.visitors.ast_visitor import ASTVisitor +from pynestml.utils.port_signal_type import PortSignalType + + +class ASTSynapseInformationCollector(object): + """This class contains all basic mechanism information collection. Further collectors may be implemented to collect + further information for specific mechanism types (example: ASTReceptorInformationCollector)""" + collector_visitor = None + synapse = None + + @classmethod + def __init__(cls, synapse): + cls.synapse = synapse + cls.collector_visitor = ASTMechanismInformationCollectorVisitor() + synapse.accept(cls.collector_visitor) + + @classmethod + def collect_definitions(cls, synapse, syn_info): + # variables + var_collector_visitor = ASTVariableCollectorVisitor() + synapse.accept(var_collector_visitor) + syn_info["States"] = var_collector_visitor.all_states + syn_info["Parameters"] = var_collector_visitor.all_parameters + syn_info["Internals"] = var_collector_visitor.all_internals + + # ODEs + ode_collector_visitor = ASTODEEquationCollectorVisitor() + synapse.accept(ode_collector_visitor) + syn_info["ODEs"] = ode_collector_visitor.all_ode_equations + + # inlines + inline_collector_visitor = ASTInlineEquationCollectorVisitor() + synapse.accept(inline_collector_visitor) + syn_info["Inlines"] = inline_collector_visitor.all_inlines + + # functions + function_collector_visitor = ASTFunctionCollectorVisitor() + synapse.accept(function_collector_visitor) + syn_info["Functions"] = function_collector_visitor.all_functions + + return syn_info + + @classmethod + def collect_on_receive_blocks(cls, synapse, syn_info, pre_port, post_port): + pre_spike_collector_visitor = ASTOnReceiveBlockVisitor(pre_port) + synapse.accept(pre_spike_collector_visitor) + syn_info["PreSpikeFunction"] = pre_spike_collector_visitor.on_receive_block + + post_spike_collector_visitor = ASTOnReceiveBlockVisitor(post_port) + synapse.accept(post_spike_collector_visitor) + syn_info["PostSpikeFunction"] = post_spike_collector_visitor.on_receive_block + + return syn_info + + @classmethod + def collect_update_block(cls, synapse, syn_info): + update_block_collector_visitor = ASTUpdateBlockVisitor() + synapse.accept(update_block_collector_visitor) + syn_info["UpdateBlock"] = update_block_collector_visitor.update_block + return syn_info + + @classmethod + def collect_ports(cls, synapse, syn_info): + port_collector_visitor = ASTPortVisitor() + synapse.accept(port_collector_visitor) + syn_info["SpikingPorts"] = port_collector_visitor.spiking_ports + syn_info["ContinuousPorts"] = port_collector_visitor.continuous_ports + return syn_info + + @classmethod + def collect_potential_dependencies(cls, synapse, syn_info): + non_dec_asmt_visitor = ASTNonDeclaringAssignmentVisitor() + synapse.accept(non_dec_asmt_visitor) + + potential_dependencies = copy.deepcopy(syn_info["States"]) + for state in syn_info["States"]: + for assignment in non_dec_asmt_visitor.non_declaring_assignments: + if state == assignment.get_variable().get_name(): + if state in potential_dependencies: + del potential_dependencies[state] + + syn_info["PotentialDependencies"] = potential_dependencies + return syn_info -class ASTSynapseInformationCollector(ASTVisitor): - """ - for each inline expression inside the equations block, - collect all synapse relevant information + @classmethod + def extend_variables_with_initialisations(cls, synapse, syn_info): + """collects initialization expressions for all variables and parameters contained in syn_info""" + var_init_visitor = VariableInitializationVisitor(syn_info) + synapse.accept(var_init_visitor) + syn_info["States"] = var_init_visitor.states + syn_info["Parameters"] = var_init_visitor.parameters + syn_info["Internals"] = var_init_visitor.internals + + return syn_info - """ + @classmethod + def extend_variable_list_name_based_restricted(cls, extended_list, appending_list, restrictor_list): + """go through appending_list and append every variable that is not in restrictor_list to extended_list for the + purpose of not re-searching the same variable""" + for app_item in appending_list: + appendable = True + for rest_item in restrictor_list: + if rest_item.name == app_item.name: + appendable = False + break + if appendable: + extended_list.append(app_item) + + return extended_list + + @classmethod + def extend_function_call_list_name_based_restricted(cls, extended_list, appending_list, restrictor_list): + """go through appending_list and append every variable that is not in restrictor_list to extended_list for the + purpose of not re-searching the same function""" + for app_item in appending_list: + appendable = True + for rest_item in restrictor_list: + if rest_item.callee_name == app_item.callee_name: + appendable = False + break + if appendable: + extended_list.append(app_item) + + return extended_list + @classmethod + def collect_mechanism_related_definitions(cls, neuron, syn_info): + """Collects all parts of the nestml code the root expressions previously collected depend on. search + is cut at other mechanisms root expressions""" + from pynestml.meta_model.ast_inline_expression import ASTInlineExpression + from pynestml.meta_model.ast_ode_equation import ASTOdeEquation + + for mechanism_name, mechanism_info in syn_info.items(): + variable_collector = ASTVariableCollectorVisitor() + neuron.accept(variable_collector) + global_states = variable_collector.all_states + global_parameters = variable_collector.all_parameters + global_internals = variable_collector.all_internals + + function_collector = ASTFunctionCollectorVisitor() + neuron.accept(function_collector) + global_functions = function_collector.all_functions + + inline_collector = ASTInlineEquationCollectorVisitor() + neuron.accept(inline_collector) + global_inlines = inline_collector.all_inlines + + ode_collector = ASTODEEquationCollectorVisitor() + neuron.accept(ode_collector) + global_odes = ode_collector.all_ode_equations + + kernel_collector = ASTKernelCollectorVisitor() + neuron.accept(kernel_collector) + global_kernels = kernel_collector.all_kernels + + continuous_input_collector = ASTContinuousInputDeclarationVisitor() + neuron.accept(continuous_input_collector) + global_continuous_inputs = continuous_input_collector.ports + + mechanism_states = list() + mechanism_parameters = list() + mechanism_internals = list() + mechanism_functions = list() + mechanism_inlines = list() + mechanism_odes = list() + synapse_kernels = list() + mechanism_continuous_inputs = list() + mechanism_dependencies = defaultdict() + mechanism_dependencies["concentrations"] = list() + mechanism_dependencies["channels"] = list() + mechanism_dependencies["receptors"] = list() + mechanism_dependencies["continuous"] = list() + + mechanism_inlines.append(syn_info[mechanism_name]["root_expression"]) + + search_variables = list() + search_functions = list() + + found_variables = list() + found_functions = list() + + local_variable_collector = ASTVariableCollectorVisitor() + mechanism_inlines[0].accept(local_variable_collector) + search_variables = local_variable_collector.all_variables + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + mechanism_inlines[0].accept(local_function_call_collector) + search_functions = local_function_call_collector.all_function_calls + + while len(search_functions) > 0 or len(search_variables) > 0: + if len(search_functions) > 0: + function_call = search_functions[0] + for function in global_functions: + if function.name == function_call.callee_name: + mechanism_functions.append(function) + found_functions.append(function_call) + + local_variable_collector = ASTVariableCollectorVisitor() + function.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + function.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted(search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + # IMPLEMENT CATCH NONDEFINED!!! + search_functions.remove(function_call) + + elif len(search_variables) > 0: + variable = search_variables[0] + if not variable.name == "v_comp": + is_dependency = False + for inline in global_inlines: + if variable.name == inline.variable_name: + if isinstance(inline.get_decorators(), list): + if "mechanism" in [e.namespace for e in inline.get_decorators()]: + is_dependency = True + if not (isinstance(mechanism_info["root_expression"], ASTInlineExpression) and inline.variable_name == mechanism_info["root_expression"].variable_name): + if "channel" in [e.name for e in inline.get_decorators()]: + if not inline.variable_name in [i.variable_name for i in + mechanism_dependencies["channels"]]: + mechanism_dependencies["channels"].append(inline) + if "receptor" in [e.name for e in inline.get_decorators()]: + if not inline.variable_name in [i.variable_name for i in + mechanism_dependencies["receptors"]]: + mechanism_dependencies["receptors"].append(inline) + if "continuous" in [e.name for e in inline.get_decorators()]: + if not inline.variable_name in [i.variable_name for i in + mechanism_dependencies["continuous"]]: + mechanism_dependencies["continuous"].append(inline) + + if not is_dependency: + mechanism_inlines.append(inline) + + local_variable_collector = ASTVariableCollectorVisitor() + inline.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + inline.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted( + search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + + for ode in global_odes: + if variable.name == ode.lhs.name: + if isinstance(ode.get_decorators(), list): + if "mechanism" in [e.namespace for e in ode.get_decorators()]: + is_dependency = True + if not (isinstance(mechanism_info["root_expression"], ASTOdeEquation) and ode.lhs.name == mechanism_info["root_expression"].lhs.name): + if "concentration" in [e.name for e in ode.get_decorators()]: + if not ode.lhs.name in [o.lhs.name for o in + mechanism_dependencies["concentrations"]]: + mechanism_dependencies["concentrations"].append(ode) + + if not is_dependency: + mechanism_odes.append(ode) + + local_variable_collector = ASTVariableCollectorVisitor() + ode.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + ode.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted( + search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + + for state in global_states: + if variable.name == state.name and not is_dependency: + mechanism_states.append(state) + + for parameter in global_parameters: + if variable.name == parameter.name: + mechanism_parameters.append(parameter) + + for internal in global_internals: + if variable.name == internal.name: + mechanism_internals.append(internal) + + for kernel in global_kernels: + if variable.name == kernel.get_variables()[0].name: + synapse_kernels.append(kernel) + + local_variable_collector = ASTVariableCollectorVisitor() + kernel.accept(local_variable_collector) + search_variables = cls.extend_variable_list_name_based_restricted(search_variables, + local_variable_collector.all_variables, + search_variables + found_variables) + + local_function_call_collector = ASTFunctionCallCollectorVisitor() + kernel.accept(local_function_call_collector) + search_functions = cls.extend_function_call_list_name_based_restricted(search_functions, + local_function_call_collector.all_function_calls, + search_functions + found_functions) + + for input in global_continuous_inputs: + if variable.name == input.name: + mechanism_continuous_inputs.append(input) + + search_variables.remove(variable) + found_variables.append(variable) + # IMPLEMENT CATCH NONDEFINED!!! + + syn_info[mechanism_name]["States"] = mechanism_states + syn_info[mechanism_name]["Parameters"] = mechanism_parameters + syn_info[mechanism_name]["Internals"] = mechanism_internals + syn_info[mechanism_name]["Functions"] = mechanism_functions + syn_info[mechanism_name]["SecondaryInlineExpressions"] = mechanism_inlines + syn_info[mechanism_name]["ODEs"] = mechanism_odes + syn_info[mechanism_name]["Continuous"] = mechanism_continuous_inputs + syn_info[mechanism_name]["Dependencies"] = mechanism_dependencies + + return syn_info + +class ASTKernelInformationCollectorVisitor(ASTVisitor): def __init__(self): - super(ASTSynapseInformationCollector, self).__init__() + super(ASTKernelInformationCollectorVisitor, self).__init__() # various dicts to store collected information self.kernel_name_to_kernel = defaultdict() @@ -94,7 +414,8 @@ def get_kernel_function_calls(self, kernel: ASTKernel): def get_inline_function_calls(self, inline: ASTInlineExpression): return self.inline_expression_to_function_calls[inline] - def get_variable_names_of_synapse(self, synapse_inline: ASTInlineExpression, exclude_names: set = set(), exclude_ignorable=True) -> set: + def get_variable_names_of_synapse(self, synapse_inline: ASTInlineExpression, exclude_names: set = set(), + exclude_ignorable=True) -> set: """extracts all variables specific to a single synapse (which is defined by the inline expression containing kernels) independently of what block they are declared in @@ -159,7 +480,8 @@ def get_synapse_specific_internal_declarations(self, synapse_inline: ASTInlineEx dereferenced = defaultdict() for potential_internals_name in synapse_variable_names: if potential_internals_name in self.internal_var_name_to_declaration: - dereferenced[potential_internals_name] = self.internal_var_name_to_declaration[potential_internals_name] + dereferenced[potential_internals_name] = self.internal_var_name_to_declaration[ + potential_internals_name] return dereferenced def get_synapse_specific_state_declarations(self, synapse_inline: ASTInlineExpression) -> defaultdict: @@ -189,6 +511,11 @@ def get_synapse_specific_parameter_declarations(self, synapse_inline: ASTInlineE def get_extracted_kernel_args(self, inline_expression: ASTInlineExpression) -> set: return self.inline_expression_to_kernel_args[inline_expression] + def get_extracted_kernel_args_by_name(self, inline_name: str) -> set: + inline_expression = [inline for inline in self.inline_expression_to_kernel_args.keys() if inline.get_variable_name() == inline_name] + + return self.inline_expression_to_kernel_args[inline_expression[0]] + def get_basic_kernel_variable_names(self, synapse_inline): """ for every occurence of convolve(port, spikes) generate "port__X__spikes" variable @@ -237,9 +564,8 @@ def visit_function_call(self, node): kernel, spikes = node.get_args() kernel_var = kernel.get_variables()[0] spikes_var = spikes.get_variables()[0] - if "mechanism::receptor" in [(e.namespace + "::" + e.name) for e in self.current_inline_expression.get_decorators()]: - self.inline_expression_to_kernel_args[self.current_inline_expression].add( - (kernel_var, spikes_var)) + self.inline_expression_to_kernel_args[self.current_inline_expression].add( + (kernel_var, spikes_var)) else: self.inline_expression_to_function_calls[self.current_inline_expression].add( node) @@ -342,8 +668,305 @@ def endvisit_expression(self, node): # this method was copied over from ast_transformer # in order to avoid a circular dependency @staticmethod - def construct_kernel_X_spike_buf_name(kernel_var_name: str, spike_input_port, order: int, diff_order_symbol="__d"): + def construct_kernel_X_spike_buf_name(kernel_var_name: str, spike_input_port, order: int, + diff_order_symbol="__d"): assert type(kernel_var_name) is str assert type(order) is int assert type(diff_order_symbol) is str - return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + str(spike_input_port) + diff_order_symbol * order + return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + str( + spike_input_port) + diff_order_symbol * order + + +class ASTMechanismInformationCollectorVisitor(ASTVisitor): + + def __init__(self): + super(ASTMechanismInformationCollectorVisitor, self).__init__() + self.inEquationsBlock = False + self.inlinesInEquationsBlock = list() + self.odes = list() + + def visit_equations_block(self, node): + self.inEquationsBlock = True + + def endvisit_equations_block(self, node): + self.inEquationsBlock = False + + def visit_inline_expression(self, node): + if self.inEquationsBlock: + self.inlinesInEquationsBlock.append(node) + + def visit_ode_equation(self, node): + self.odes.append(node) + + +# Helper collectors: +class VariableInitializationVisitor(ASTVisitor): + def __init__(self, channel_info): + super(VariableInitializationVisitor, self).__init__() + self.inside_variable = False + self.inside_declaration = False + self.inside_parameter_block = False + self.inside_state_block = False + self.inside_internal_block = False + self.current_declaration = None + self.states = defaultdict() + self.parameters = defaultdict() + self.internals = defaultdict() + self.channel_info = channel_info + + def visit_declaration(self, node): + self.inside_declaration = True + self.current_declaration = node + + def endvisit_declaration(self, node): + self.inside_declaration = False + self.current_declaration = None + + def visit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = True + if node.is_parameters: + self.inside_parameter_block = True + if node.is_internals: + self.inside_internal_block = True + + def endvisit_block_with_variables(self, node): + self.inside_state_block = False + self.inside_parameter_block = False + self.inside_internal_block = False + + def visit_variable(self, node): + self.inside_variable = True + if self.inside_state_block and self.inside_declaration: + if any(node.name == variable.name for variable in self.channel_info["States"]): + self.states[node.name] = defaultdict() + self.states[node.name]["ASTVariable"] = node.clone() + self.states[node.name]["rhs_expression"] = self.current_declaration.get_expression() + + if self.inside_parameter_block and self.inside_declaration: + if any(node.name == variable.name for variable in self.channel_info["Parameters"]): + self.parameters[node.name] = defaultdict() + self.parameters[node.name]["ASTVariable"] = node.clone() + self.parameters[node.name]["rhs_expression"] = self.current_declaration.get_expression() + + if self.inside_internal_block and self.inside_declaration: + if any(node.name == variable.name for variable in self.channel_info["Internals"]): + self.internals[node.name] = defaultdict() + self.internals[node.name]["ASTVariable"] = node.clone() + self.internals[node.name]["rhs_expression"] = self.current_declaration.get_expression() + + def endvisit_variable(self, node): + self.inside_variable = False + + +class ASTODEEquationCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTODEEquationCollectorVisitor, self).__init__() + self.inside_ode_expression = False + self.all_ode_equations = list() + + def visit_ode_equation(self, node): + self.inside_ode_expression = True + self.all_ode_equations.append(node.clone()) + + def endvisit_ode_equation(self, node): + self.inside_ode_expression = False + + +class ASTVariableCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTVariableCollectorVisitor, self).__init__() + self.inside_variable = False + self.inside_block_with_variables = False + self.all_states = list() + self.all_parameters = list() + self.all_internals = list() + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_internals_block = False + self.all_variables = list() + + def visit_block_with_variables(self, node): + self.inside_block_with_variables = True + if node.is_state: + self.inside_states_block = True + if node.is_parameters: + self.inside_parameters_block = True + if node.is_internals: + self.inside_internals_block = True + + def endvisit_block_with_variables(self, node): + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_internals_block = False + self.inside_block_with_variables = False + + def visit_variable(self, node): + self.inside_variable = True + if not (node.name == "v_comp" or node.name in PredefinedUnits.get_units()): + self.all_variables.append(node.clone()) + if self.inside_states_block: + self.all_states.append(node.clone()) + if self.inside_parameters_block: + self.all_parameters.append(node.clone()) + if self.inside_internals_block: + self.all_internals.append(node.clone()) + + def endvisit_variable(self, node): + self.inside_variable = False + + +class ASTFunctionCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTFunctionCollectorVisitor, self).__init__() + self.inside_function = False + self.all_functions = list() + + def visit_function(self, node): + self.inside_function = True + self.all_functions.append(node.clone()) + + def endvisit_function(self, node): + self.inside_function = False + + +class ASTInlineEquationCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTInlineEquationCollectorVisitor, self).__init__() + self.inside_inline_expression = False + self.all_inlines = list() + + def visit_inline_expression(self, node): + self.inside_inline_expression = True + self.all_inlines.append(node.clone()) + + def endvisit_inline_expression(self, node): + self.inside_inline_expression = False + + +class ASTFunctionCallCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTFunctionCallCollectorVisitor, self).__init__() + self.inside_function_call = False + self.all_function_calls = list() + + def visit_function_call(self, node): + self.inside_function_call = True + self.all_function_calls.append(node.clone()) + + def endvisit_function_call(self, node): + self.inside_function_call = False + + +class ASTKernelCollectorVisitor(ASTVisitor): + def __init__(self): + super(ASTKernelCollectorVisitor, self).__init__() + self.inside_kernel = False + self.all_kernels = list() + + def visit_kernel(self, node): + self.inside_kernel = True + self.all_kernels.append(node.clone()) + + def endvisit_kernel(self, node): + self.inside_kernel = False + + +class ASTContinuousInputDeclarationVisitor(ASTVisitor): + def __init__(self): + super(ASTContinuousInputDeclarationVisitor, self).__init__() + self.inside_port = False + self.current_port = None + self.ports = list() + + def visit_input_port(self, node): + self.inside_port = True + self.current_port = node + if self.current_port.is_continuous(): + self.ports.append(node.clone()) + + def endvisit_input_port(self, node): + self.inside_port = False + + +class ASTOnReceiveBlockVisitor(ASTVisitor): + def __init__(self, port_name): + super(ASTOnReceiveBlockVisitor, self).__init__() + self.inside_on_receive = False + self.port_name = port_name + self.on_receive_block = None + + def visit_on_receive_block(self, node): + self.inside_on_receive = True + if node.port_name in self.port_name: + self.on_receive_block = node.clone() + + def endvisit_on_receive_block(self, node): + self.inside_on_receive = False + + +class ASTUpdateBlockVisitor(ASTVisitor): + def __init__(self): + super(ASTUpdateBlockVisitor, self).__init__() + self.inside_update_block = False + self.update_block = None + + def visit_update_block(self, node): + self.inside_update_block = True + self.update_block = node.clone() + + def endvisit_update_block(self, node): + self.inside_update_block = False + +class ASTPortVisitor(ASTVisitor): + def __init__(self): + super(ASTPortVisitor, self).__init__() + self.inside_port = False + self.spiking_ports = list() + self.continuous_ports = list() + + def visit_input_port(self, node): + self.inside_port = True + if node.is_spike(): + self.spiking_ports.append(node.clone()) + if node.is_continuous(): + self.continuous_ports.append(node.clone()) + + def endvisit_input_port(self, node): + self.inside_port = False + +class ASTNonDeclaringAssignmentVisitor(ASTVisitor): + def __init__(self): + super(ASTNonDeclaringAssignmentVisitor, self).__init__() + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_internals_block = False + self.inside_assignment = False + self.non_declaring_assignments = list() + + def visit_states_block(self, node): + self.inside_states_block = True + + def endvisit_states_block(self, node): + self.inside_states_block = False + + def visit_parameters_block(self, node): + self.inside_parameters_block = True + + def endvisit_parameters_block(self, node): + self.inside_parameters_block = False + + def visit_internals_block(self, node): + self.inside_internals_block = True + + def endvisit_internals_block(self, node): + self.inside_internals_block = False + + def visit_assignment(self, node): + self.inside_assignment = True + if not self.inside_parameters_block or not self.inside_internals_block or self.inside_states_block: + self.non_declaring_assignments.append(node.clone()) + + def endvisit_assignment(self, node): + self.inside_assignment = False + diff --git a/pynestml/utils/ast_vector_parameter_setter_and_printer.py b/pynestml/utils/ast_vector_parameter_setter_and_printer.py new file mode 100644 index 000000000..2df7d817e --- /dev/null +++ b/pynestml/utils/ast_vector_parameter_setter_and_printer.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# +# ast_vector_parameter_setter_and_printer.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from pynestml.codegeneration.printers.ast_printer import ASTPrinter +from pynestml.codegeneration.printers.nest_variable_printer import NESTVariablePrinter + +from pynestml.utils.model_parser import ModelParser +from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor +from pynestml.symbol_table.scope import Scope, ScopeType, Symbol, SymbolKind +from pynestml.symbols.variable_symbol import VariableSymbol + +class ASTVectorParameterSetterAndPrinter(ASTPrinter): + def __init__(self): + super(ASTVectorParameterSetterAndPrinter, self).__init__() + self.inside_variable = False + self.vector_parameter = "" + self.printer = None + self.model = None + + def set_vector_parameter(self, node, vector_parameter=None): + self.vector_parameter = vector_parameter + node.accept(self) + + def print(self, node): + assert isinstance(self.printer._simple_expression_printer._variable_printer, NESTVariablePrinter) + + self.printer._simple_expression_printer._variable_printer.cpp_variable_suffix = "" + + if self.vector_parameter: + self.printer._simple_expression_printer._variable_printer.cpp_variable_suffix = "[" + self.vector_parameter + "]" + + text = self.printer.print(node) + + self.printer._simple_expression_printer._variable_printer.cpp_variable_suffix = "" + + return text diff --git a/pynestml/utils/ast_vector_parameter_setter_and_printer_factory.py b/pynestml/utils/ast_vector_parameter_setter_and_printer_factory.py new file mode 100644 index 000000000..13c3b08d5 --- /dev/null +++ b/pynestml/utils/ast_vector_parameter_setter_and_printer_factory.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# +# ast_vector_parameter_setter_and_printer_factory.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from pynestml.utils.ast_vector_parameter_setter_and_printer import ASTVectorParameterSetterAndPrinter +from pynestml.visitors.ast_visitor import ASTVisitor + +from pynestml.utils.model_parser import ModelParser +from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor +from pynestml.symbol_table.scope import Scope, ScopeType, Symbol, SymbolKind +from pynestml.symbols.variable_symbol import VariableSymbol + + +class ASTVectorParameterSetterAndPrinterFactory: + + def __init__(self, model, printer): + self.printer = printer + self.model = model + + def create_ast_vector_parameter_setter_and_printer(self, vector_parameter=None): + my_printer = ASTVectorParameterSetterAndPrinter() + my_printer.printer = self.printer + my_printer.model = self.model + my_printer.vector_parameter = vector_parameter + return my_printer diff --git a/pynestml/utils/con_in_info_enricher.py b/pynestml/utils/con_in_info_enricher.py new file mode 100644 index 000000000..fa8228ac8 --- /dev/null +++ b/pynestml/utils/con_in_info_enricher.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# +# con_in_info_enricher.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor +from pynestml.utils.mechs_info_enricher import MechsInfoEnricher +from pynestml.utils.model_parser import ModelParser + +import sympy + + +class ConInInfoEnricher(MechsInfoEnricher): + """Class extends MechsInfoEnricher by the computation of the inline derivative. This hasn't been done in the + channel processing because it would cause a circular dependency through the coco checks used by the ModelParser + which we need to use.""" + def __init__(self, params): + super(MechsInfoEnricher, self).__init__(params) + + @classmethod + def enrich_mechanism_specific(cls, neuron, mechs_info): + mechs_info = cls.compute_expression_derivative(mechs_info) + return mechs_info + + @classmethod + def compute_expression_derivative(cls, chan_info): + for ion_channel_name, ion_channel_info in chan_info.items(): + inline_expression = chan_info[ion_channel_name]["root_expression"] + expr_str = str(inline_expression.get_expression()) + sympy_expr = sympy.parsing.sympy_parser.parse_expr(expr_str) + sympy_expr = sympy.diff(sympy_expr, "v_comp") + + ast_expression_d = ModelParser.parse_expression(str(sympy_expr)) + # copy scope of the original inline_expression into the the derivative + ast_expression_d.update_scope(inline_expression.get_scope()) + ast_expression_d.accept(ASTSymbolTableVisitor()) + + chan_info[ion_channel_name]["inline_derivative"] = ast_expression_d + + return chan_info diff --git a/pynestml/utils/conc_info_enricher.py b/pynestml/utils/conc_info_enricher.py index e4ed0507d..f62d4989e 100644 --- a/pynestml/utils/conc_info_enricher.py +++ b/pynestml/utils/conc_info_enricher.py @@ -23,6 +23,7 @@ class ConcInfoEnricher(MechsInfoEnricher): - """Just created for consistency. No more than the base-class enriching needs to be done""" + """Just created for consistency with the rest of the mechanism generation process. No more than the base-class + enriching needs to be done""" def __init__(self, params): super(MechsInfoEnricher, self).__init__(params) diff --git a/pynestml/utils/continuous_input_processing.py b/pynestml/utils/continuous_input_processing.py new file mode 100644 index 000000000..980217fc1 --- /dev/null +++ b/pynestml/utils/continuous_input_processing.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# +# continuous_input_processing.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import copy + +from pynestml.utils.mechanism_processing import MechanismProcessing + +from collections import defaultdict + + +class ContinuousInputProcessing(MechanismProcessing): + mechType = "continuous_input" + + def __init__(self, params): + super(MechanismProcessing, self).__init__(params) + + @classmethod + def collect_information_for_specific_mech_types(cls, neuron, mechs_info): + for continuous_name, continuous_info in mechs_info.items(): + continuous = defaultdict() + for port in continuous_info["Continuous"]: + continuous[port.name] = copy.deepcopy(port) + mechs_info[continuous_name]["Continuous"] = continuous + + return mechs_info diff --git a/pynestml/utils/global_info_enricher.py b/pynestml/utils/global_info_enricher.py new file mode 100644 index 000000000..c06cef2f4 --- /dev/null +++ b/pynestml/utils/global_info_enricher.py @@ -0,0 +1,264 @@ +# -*- coding: utf-8 -*- +# +# global_info_enricher.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from collections import defaultdict + +from executing.executing import node_linenos + +from pynestml.meta_model.ast_model import ASTModel +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor +from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor +from pynestml.utils.ast_utils import ASTUtils +from pynestml.visitors.ast_visitor import ASTVisitor +from pynestml.utils.model_parser import ModelParser +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.symbol import SymbolKind + +from collections import defaultdict + + +class GlobalInfoEnricher: + """ + Adds information collection that can't be done in the processing class since that is used in the cocos. + Here we use the ModelParser which would lead to a cyclic dependency. + + Additionally, we require information about the paired neurons mechanism to confirm what dependencies are actually existent in the neuron. + """ + + def __init__(self): + pass + + @classmethod + def enrich_with_additional_info(cls, neuron: ASTModel, global_info: dict): + global_info = cls.transform_ode_solutions(neuron, global_info) + global_info = cls.extract_infunction_declarations(global_info) + #global_info = cls.substituteNoneWithEmptyBlocks(global_info) + + return global_info + + @classmethod + def transform_ode_solutions(cls, neuron, global_info): + for ode_var_name, ode_info in global_info["ODEs"].items(): + global_info["ODEs"][ode_var_name]["transformed_solutions"] = list() + + for ode_solution_index in range(len(ode_info["ode_toolbox_output"])): + solution_transformed = defaultdict() + solution_transformed["states"] = defaultdict() + solution_transformed["propagators"] = defaultdict() + + for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index][ + "initial_values"].items(): + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression(rhs_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as neurons have been + # defined to get here + expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + + update_expr_str = ode_info["ode_toolbox_output"][ode_solution_index]["update_expressions"][ + variable_name] + update_expr_ast = ModelParser.parse_expression( + update_expr_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as differential equations + # must have been defined to get here + update_expr_ast.update_scope( + neuron.get_equations_blocks()[0].get_scope()) + update_expr_ast.accept(ASTSymbolTableVisitor()) + + solution_transformed["states"][variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + "update_expression": update_expr_ast, + } + for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index][ + "propagators"].items(): + prop_variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + if prop_variable is None: + ASTUtils.add_declarations_to_internals( + neuron, ode_info["ode_toolbox_output"][ode_solution_index]["propagators"]) + prop_variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol( + variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression(rhs_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as neurons have been + # defined to get here + expression.update_scope( + neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + + solution_transformed["propagators"][variable_name] = { + "ASTVariable": prop_variable, "init_expression": expression, } + expression_variable_collector = ASTEnricherInfoCollectorVisitor() + expression.accept(expression_variable_collector) + + neuron_internal_declaration_collector = ASTEnricherInfoCollectorVisitor() + neuron.accept(neuron_internal_declaration_collector) + + for variable in expression_variable_collector.all_variables: + for internal_declaration in neuron_internal_declaration_collector.internal_declarations: + if variable.get_name() == internal_declaration.get_variables()[0].get_name() \ + and internal_declaration.get_expression().is_function_call() \ + and internal_declaration.get_expression().get_function_call().callee_name == \ + PredefinedFunctions.TIME_RESOLUTION: + global_info["time_resolution_var"] = variable + + global_info["ODEs"][ode_var_name]["transformed_solutions"].append(solution_transformed) + + neuron.accept(ASTParentVisitor()) + + return global_info + + @classmethod + def extract_infunction_declarations(cls, global_info): + declaration_visitor = ASTDeclarationCollectorAndUniqueRenamerVisitor() + if "SelfSpikesFunction" in global_info and global_info["SelfSpikesFunction"] is not None: + self_spike_function = global_info["SelfSpikesFunction"] + self_spike_function.accept(declaration_visitor) + if "UpdateBlock" in global_info and global_info["UpdateBlock"] is not None: + update_block = global_info["UpdateBlock"] + update_block.accept(declaration_visitor) + + declaration_vars = list() + for decl in declaration_visitor.declarations: + for var in decl.get_variables(): + declaration_vars.append(var.get_name()) + + global_info["InFunctionDeclarationsVars"] = declaration_visitor.declarations + return global_info + + @classmethod + def substituteNoneWithEmptyBlocks(cls, global_info): + if (not "UpdateBlock" in global_info) or (global_info["UpdateBlock"] is None): + empty = ModelParser.parse_block("") + global_info["UpdateBlock"] = empty.clone() + if (not "SelfSpikesFunction" in global_info) or (global_info["SelfSpikesFunction"] is None): + empty = ModelParser.parse_block("") + global_info["SelfSpikesFunction"] = empty.clone() + + return global_info + + + +class ASTEnricherInfoCollectorVisitor(ASTVisitor): + + def __init__(self): + super(ASTEnricherInfoCollectorVisitor, self).__init__() + self.inside_variable = False + self.inside_block_with_variables = False + self.all_states = list() + self.all_parameters = list() + self.inside_states_block = False + self.inside_parameters_block = False + self.all_variables = list() + self.inside_internals_block = False + self.inside_declaration = False + self.internal_declarations = list() + + def visit_block_with_variables(self, node): + self.inside_block_with_variables = True + if node.is_state: + self.inside_states_block = True + if node.is_parameters: + self.inside_parameters_block = True + if node.is_internals: + self.inside_internals_block = True + + def endvisit_block_with_variables(self, node): + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_block_with_variables = False + self.inside_internals_block = False + + def visit_variable(self, node): + self.inside_variable = True + self.all_variables.append(node.clone()) + if self.inside_states_block: + self.all_states.append(node.clone()) + if self.inside_parameters_block: + self.all_parameters.append(node.clone()) + + def endvisit_variable(self, node): + self.inside_variable = False + + def visit_declaration(self, node): + self.inside_declaration = True + if self.inside_internals_block: + self.internal_declarations.append(node) + + def endvisit_declaration(self, node): + self.inside_declaration = False + + +class ASTDeclarationCollectorAndUniqueRenamerVisitor(ASTVisitor): + def __init__(self): + super(ASTDeclarationCollectorAndUniqueRenamerVisitor, self).__init__() + self.declarations = list() + self.variable_names = dict() + self.inside_declaration = False + self.inside_block = False + self.current_block = None + + def visit_block(self, node): + self.inside_block = True + self.current_block = node + + def endvisit_block(self, node): + self.inside_block = False + self.current_block = None + + def visit_declaration(self, node): + self.inside_declaration = True + for variable in node.get_variables(): + if variable.get_name() in self.variable_names: + self.variable_names[variable.get_name()] += 1 + else: + self.variable_names[variable.get_name()] = 0 + new_name = variable.get_name() + '_' + str(self.variable_names[variable.get_name()]) + name_replacer = ASTVariableNameReplacerVisitor(variable.get_name(), new_name) + self.current_block.accept(name_replacer) + node.accept(ASTSymbolTableVisitor()) + self.declarations.append(node.clone()) + + def endvisit_declaration(self, node): + self.inside_declaration = False + + +class ASTVariableNameReplacerVisitor(ASTVisitor): + def __init__(self, old_name, new_name): + super(ASTVariableNameReplacerVisitor, self).__init__() + self.inside_variable = False + self.new_name = new_name + self.old_name = old_name + + def visit_variable(self, node): + self.inside_variable = True + if node.get_name() == self.old_name: + node.set_name(self.new_name) + + def endvisit_variable(self, node): + self.inside_variable = False \ No newline at end of file diff --git a/pynestml/utils/global_processing.py b/pynestml/utils/global_processing.py new file mode 100644 index 000000000..3f1055d05 --- /dev/null +++ b/pynestml/utils/global_processing.py @@ -0,0 +1,188 @@ +# -*- coding: utf-8 -*- +# +# global_processing.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from collections import defaultdict + +import copy + +from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter +from pynestml.codegeneration.printers.constant_printer import ConstantPrinter +from pynestml.codegeneration.printers.ode_toolbox_expression_printer import ODEToolboxExpressionPrinter +from pynestml.codegeneration.printers.ode_toolbox_function_call_printer import ODEToolboxFunctionCallPrinter +from pynestml.codegeneration.printers.ode_toolbox_variable_printer import ODEToolboxVariablePrinter +from pynestml.codegeneration.printers.unitless_cpp_simple_expression_printer import UnitlessCppSimpleExpressionPrinter +from pynestml.frontend.frontend_configuration import FrontendConfiguration +from pynestml.meta_model.ast_expression import ASTExpression +from pynestml.meta_model.ast_model import ASTModel +from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression +from pynestml.utils.ast_global_information_collector import ASTGlobalInformationCollector +from pynestml.utils.ast_utils import ASTUtils + +from odetoolbox import analysis + +class GlobalProcessing: + """Manages the collection of basic information necesary for all types of mechanisms and uses the + collect_information_for_specific_mech_types interface that needs to be implemented by the specific mechanism type + processing classes""" + + # used to keep track of whenever check_co_co was already called + # see inside check_co_co + first_time_run = defaultdict(lambda: True) + # stores neuron from the first call of check_co_co + global_info = defaultdict() + + # ODE-toolbox printers + _constant_printer = ConstantPrinter() + _ode_toolbox_variable_printer = ODEToolboxVariablePrinter(None) + _ode_toolbox_function_call_printer = ODEToolboxFunctionCallPrinter(None) + _ode_toolbox_printer = ODEToolboxExpressionPrinter( + simple_expression_printer=UnitlessCppSimpleExpressionPrinter( + variable_printer=_ode_toolbox_variable_printer, + constant_printer=_constant_printer, + function_call_printer=_ode_toolbox_function_call_printer)) + + _ode_toolbox_variable_printer._expression_printer = _ode_toolbox_printer + _ode_toolbox_function_call_printer._expression_printer = _ode_toolbox_printer + + @classmethod + def prepare_equations_for_ode_toolbox(cls, synapse, syn_info): + """Transforms the collected ode equations to the required input format of ode-toolbox and adds it to the + syn_info dictionary""" + + mechanism_odes = defaultdict() + for ode in syn_info["ODEs"]: + nestml_printer = NESTMLPrinter() + ode_nestml_expression = nestml_printer.print_ode_equation(ode) + mechanism_odes[ode.lhs.name] = defaultdict() + mechanism_odes[ode.lhs.name]["ASTOdeEquation"] = ode + mechanism_odes[ode.lhs.name]["ODENestmlExpression"] = ode_nestml_expression + syn_info["ODEs"] = mechanism_odes + + for ode_variable_name, ode_info in syn_info["ODEs"].items(): + # Expression: + odetoolbox_indict = {"dynamics": []} + lhs = ASTUtils.to_ode_toolbox_name(ode_info["ASTOdeEquation"].get_lhs().get_complete_name()) + rhs = cls._ode_toolbox_printer.print(ode_info["ASTOdeEquation"].get_rhs()) + entry = {"expression": lhs + " = " + rhs, "initial_values": {}} + + # Initial values: + symbol_order = ode_info["ASTOdeEquation"].get_lhs().get_differential_order() + for order in range(symbol_order): + iv_symbol_name = ode_info["ASTOdeEquation"].get_lhs().get_name() + "'" * order + initial_value_expr = synapse.get_initial_value(iv_symbol_name) + entry["initial_values"][ + ASTUtils.to_ode_toolbox_name(iv_symbol_name)] = cls._ode_toolbox_printer.print( + initial_value_expr) + + odetoolbox_indict["dynamics"].append(entry) + syn_info["ODEs"][ode_variable_name]["ode_toolbox_input"] = odetoolbox_indict + + return syn_info + + @classmethod + def collect_raw_odetoolbox_output(cls, syn_info): + """calls ode-toolbox for each ode individually and collects the raw output""" + for ode_variable_name, ode_info in syn_info["ODEs"].items(): + solver_result = analysis(ode_info["ode_toolbox_input"], disable_stiffness_check=True) + syn_info["ODEs"][ode_variable_name]["ode_toolbox_output"] = solver_result + + return syn_info + + @classmethod + def ode_toolbox_processing(cls, neuron, global_info): + global_info = cls.prepare_equations_for_ode_toolbox(neuron, global_info) + global_info = cls.collect_raw_odetoolbox_output(global_info) + return global_info + + @classmethod + def get_global_info(cls, neuron): + """ + returns previously generated global_info + as a deep copy so it can't be changed externally + via object references + :param neuron: a single neuron instance. + """ + return copy.deepcopy(cls.global_info[neuron.get_name()]) + + @classmethod + def check_co_co(cls, neuron: ASTModel): + """ + Checks if mechanism conditions apply for the handed over neuron. + :param neuron: a single neuron instance. + """ + + # make sure we only run this a single time + # subsequent calls will be after AST has been transformed + # and there would be no kernels or inlines anymore + if cls.first_time_run[neuron.get_name()]: + # collect root expressions and initialize collector + info_collector = ASTGlobalInformationCollector(neuron) + + # collect and process all basic mechanism information + global_info = defaultdict() + + global_info = info_collector.collect_update_block(neuron, global_info) + global_info = info_collector.collect_self_spike_function(neuron, global_info) + + global_info = info_collector.collect_related_definitions(neuron, global_info) + global_info = info_collector.extend_variables_with_initialisations(neuron, global_info) + global_info = cls.ode_toolbox_processing(neuron, global_info) + + cls.global_info[neuron.get_name()] = copy.deepcopy(global_info) + cls.first_time_run[neuron.get_name()] = False + + @classmethod + def print_element(cls, name, element, rec_step): + message = "" + for indent in range(rec_step): + message += "----" + message += name + ": " + if isinstance(element, defaultdict): + message += "\n" + message += cls.print_dictionary(element, rec_step + 1) + else: + if hasattr(element, 'name'): + message += element.name + elif isinstance(element, str): + message += element + elif isinstance(element, dict): + message += "\n" + message += cls.print_dictionary(element, rec_step + 1) + elif isinstance(element, list): + for index in range(len(element)): + message += "\n" + message += cls.print_element(str(index), element[index], rec_step + 1) + elif isinstance(element, ASTExpression) or isinstance(element, ASTSimpleExpression): + message += cls._ode_toolbox_printer.print(element) + + message += "(" + type(element).__name__ + ")" + return message + + @classmethod + def print_dictionary(cls, dictionary, rec_step): + """ + Print the mechanisms info dictionaries. + """ + message = "" + for name, element in dictionary.items(): + message += cls.print_element(name, element, rec_step) + message += "\n" + return message \ No newline at end of file diff --git a/pynestml/utils/mechanism_processing.py b/pynestml/utils/mechanism_processing.py index e53c2d05a..336a42412 100644 --- a/pynestml/utils/mechanism_processing.py +++ b/pynestml/utils/mechanism_processing.py @@ -23,6 +23,7 @@ import copy +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter from pynestml.codegeneration.printers.constant_printer import ConstantPrinter from pynestml.codegeneration.printers.ode_toolbox_expression_printer import ODEToolboxExpressionPrinter @@ -148,7 +149,7 @@ def get_mechs_info(cls, neuron: ASTModel): return copy.deepcopy(cls.mechs_info[neuron][cls.mechType]) @classmethod - def check_co_co(cls, neuron: ASTModel): + def check_co_co(cls, neuron: ASTModel, global_info): """ Checks if mechanism conditions apply for the handed over neuron. :param neuron: a single neuron instance. @@ -163,7 +164,7 @@ def check_co_co(cls, neuron: ASTModel): mechs_info = info_collector.detect_mechs(cls.mechType) # collect and process all basic mechanism information - mechs_info = info_collector.collect_mechanism_related_definitions(neuron, mechs_info) + mechs_info = info_collector.collect_mechanism_related_definitions(neuron, mechs_info, global_info) mechs_info = info_collector.extend_variables_with_initialisations(neuron, mechs_info) mechs_info = cls.ode_toolbox_processing(neuron, mechs_info) @@ -187,6 +188,8 @@ def print_element(cls, name, element, rec_step): message += element.name elif isinstance(element, str): message += element + elif isinstance(element, bool): + message += str(element) elif isinstance(element, dict): message += "\n" message += cls.print_dictionary(element, rec_step + 1) @@ -196,6 +199,8 @@ def print_element(cls, name, element, rec_step): message += cls.print_element(str(index), element[index], rec_step + 1) elif isinstance(element, ASTExpression) or isinstance(element, ASTSimpleExpression): message += cls._ode_toolbox_printer.print(element) + elif isinstance(element, ASTInlineExpression): + message += cls._ode_toolbox_printer.print(element.get_expression()) message += "(" + type(element).__name__ + ")" return message diff --git a/pynestml/utils/mechs_info_enricher.py b/pynestml/utils/mechs_info_enricher.py index c5514bff8..456ece178 100644 --- a/pynestml/utils/mechs_info_enricher.py +++ b/pynestml/utils/mechs_info_enricher.py @@ -22,13 +22,13 @@ from collections import defaultdict from pynestml.meta_model.ast_model import ASTModel -from pynestml.symbols.predefined_functions import PredefinedFunctions -from pynestml.symbols.symbol import SymbolKind -from pynestml.utils.ast_utils import ASTUtils -from pynestml.utils.model_parser import ModelParser from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor +from pynestml.utils.ast_utils import ASTUtils from pynestml.visitors.ast_visitor import ASTVisitor +from pynestml.utils.model_parser import ModelParser +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.symbol import SymbolKind class MechsInfoEnricher: diff --git a/pynestml/utils/messages.py b/pynestml/utils/messages.py index f39d9ecdb..892e68088 100644 --- a/pynestml/utils/messages.py +++ b/pynestml/utils/messages.py @@ -1304,10 +1304,13 @@ def get_integrate_odes_wrong_arg(cls, arg: str): return MessageCode.INTEGRATE_ODES_WRONG_ARG, message @classmethod - def get_mechs_dictionary_info(cls, chan_info, syns_info, conc_info): + def get_mechs_dictionary_info(cls, chan_info, recs_info, conc_info, con_in_info, syns_info, global_info): message = "" message += "chan_info:\n" + chan_info + "\n" - message += "syns_info:\n" + syns_info + "\n" + message += "recs_info:\n" + recs_info + "\n" message += "conc_info:\n" + conc_info + "\n" + message += "con_in_info:\n" + con_in_info + "\n" + message += "syns_info:\n" + syns_info + "\n" + message += "global_info:\n" + global_info + "\n" return MessageCode.MECHS_DICTIONARY_INFO, message diff --git a/pynestml/utils/receptor_processing.py b/pynestml/utils/receptor_processing.py new file mode 100644 index 000000000..673c0eece --- /dev/null +++ b/pynestml/utils/receptor_processing.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- +# +# receptor_processing.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import copy +from collections import defaultdict + +from pynestml.frontend.frontend_configuration import FrontendConfiguration +from pynestml.meta_model.ast_block_with_variables import ASTBlockWithVariables +from pynestml.meta_model.ast_model import ASTModel +from pynestml.symbols.symbol import SymbolKind +from pynestml.utils.ast_receptor_information_collector import ASTReceptorInformationCollector +from pynestml.utils.ast_utils import ASTUtils +from pynestml.utils.logger import Logger, LoggingLevel +from pynestml.utils.mechanism_processing import MechanismProcessing +from pynestml.utils.messages import Messages + +from odetoolbox import analysis + + +class ReceptorProcessing(MechanismProcessing): + mechType = "receptor" + + def __init__(self, params): + super(MechanismProcessing, self).__init__(params) + + @classmethod + def collect_information_for_specific_mech_types(cls, neuron, mechs_info): + mechs_info, add_info_collector = cls.collect_additional_base_infos(neuron, mechs_info) + if len(mechs_info) > 0: + # only do this if any synapses found + # otherwise tests may fail + mechs_info = cls.collect_and_check_inputs_per_synapse(mechs_info) + + mechs_info = cls.convolution_ode_toolbox_processing(neuron, mechs_info) + + return mechs_info + + @classmethod + def collect_additional_base_infos(cls, neuron, syns_info): + """ + Collect internals, kernels, inputs and convolutions associated with the synapse. + """ + info_collector = ASTReceptorInformationCollector() + neuron.accept(info_collector) + for synapse_name, synapse_info in syns_info.items(): + synapse_inline = syns_info[synapse_name]["root_expression"] + syns_info[synapse_name][ + "internals_used_declared"] = info_collector.get_synapse_specific_internal_declarations(synapse_inline) + syns_info[synapse_name]["total_used_declared"] = info_collector.get_variable_names_of_synapse( + synapse_inline) + syns_info[synapse_name]["convolutions"] = defaultdict() + + kernel_arg_pairs = info_collector.get_extracted_kernel_args( + synapse_inline) + for kernel_var, spikes_var in kernel_arg_pairs: + kernel_name = kernel_var.get_name() + spikes_name = spikes_var.get_name() + convolution_name = info_collector.construct_kernel_X_spike_buf_name( + kernel_name, spikes_name, 0) + syns_info[synapse_name]["convolutions"][convolution_name] = { + "kernel": { + "name": kernel_name, + "ASTKernel": info_collector.get_kernel_by_name(kernel_name), + }, + "spikes": { + "name": spikes_name, + "ASTInputPort": info_collector.get_input_port_by_name(spikes_name), + }, + } + return syns_info, info_collector + + @classmethod + def collect_and_check_inputs_per_synapse( + cls, + syns_info: dict): + new_syns_info = copy.copy(syns_info) + + # collect all buffers used + for synapse_name, synapse_info in syns_info.items(): + new_syns_info[synapse_name]["buffers_used"] = set() + for convolution_name, convolution_info in synapse_info["convolutions"].items( + ): + input_name = convolution_info["spikes"]["name"] + new_syns_info[synapse_name]["buffers_used"].add(input_name) + + # now make sure each synapse is using exactly one buffer + for synapse_name, synapse_info in syns_info.items(): + buffers = new_syns_info[synapse_name]["buffers_used"] + if len(buffers) != 1: + code, message = Messages.get_syns_bad_buffer_count( + buffers, synapse_name) + causing_object = synapse_info["inline_expression"] + Logger.log_message( + code=code, + message=message, + error_position=causing_object.get_source_position(), + log_level=LoggingLevel.ERROR, + node=causing_object) + + return new_syns_info + + @classmethod + def convolution_ode_toolbox_processing(cls, neuron, syns_info): + if not neuron.get_parameters_blocks(): + return syns_info + + parameters_block = neuron.get_parameters_blocks()[0] + + for synapse_name, synapse_info in syns_info.items(): + for convolution_name, convolution_info in synapse_info["convolutions"].items(): + kernel_buffer = (convolution_info["kernel"]["ASTKernel"], convolution_info["spikes"]["ASTInputPort"]) + convolution_solution = cls.ode_solve_convolution(neuron, parameters_block, kernel_buffer) + syns_info[synapse_name]["convolutions"][convolution_name]["analytic_solution"] = convolution_solution + return syns_info + + @classmethod + def ode_solve_convolution(cls, + neuron: ASTModel, + parameters_block: ASTBlockWithVariables, + kernel_buffer): + odetoolbox_indict = cls.create_ode_indict( + neuron, parameters_block, kernel_buffer) + full_solver_result = analysis( + odetoolbox_indict, + disable_stiffness_check=True, + log_level=FrontendConfiguration.logging_level) + analytic_solver = None + analytic_solvers = [ + x for x in full_solver_result if x["solver"] == "analytical"] + assert len( + analytic_solvers) <= 1, "More than one analytic solver not presently supported" + if len(analytic_solvers) > 0: + analytic_solver = analytic_solvers[0] + + return analytic_solver + + @classmethod + def create_ode_indict(cls, + neuron: ASTModel, + parameters_block: ASTBlockWithVariables, + kernel_buffer): + kernel_buffers = {tuple(kernel_buffer)} + odetoolbox_indict = cls.transform_ode_and_kernels_to_json( + neuron, parameters_block, kernel_buffers) + odetoolbox_indict["options"] = {} + odetoolbox_indict["options"]["output_timestep_symbol"] = "__h" + return odetoolbox_indict + + @classmethod + def transform_ode_and_kernels_to_json( + cls, + neuron: ASTModel, + parameters_block, + kernel_buffers): + """ + Converts AST node to a JSON representation suitable for passing to ode-toolbox. + + Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements + + convolve(G, ex_spikes) + convolve(G, in_spikes) + + then `kernel_buffers` will contain the pairs `(G, ex_spikes)` and `(G, in_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__ex_spikes` and `G__X__in_spikes`. + + :param parameters_block: ASTBlockWithVariables + :return: Dict + """ + odetoolbox_indict = {"dynamics": []} + + equations_block = neuron.get_equations_blocks()[0] + + for kernel, spike_input_port in kernel_buffers: + if ASTUtils.is_delta_kernel(kernel): + continue + # delta function -- skip passing this to ode-toolbox + + for kernel_var in kernel.get_variables(): + expr = ASTUtils.get_expr_from_kernel_var( + kernel, kernel_var.get_complete_name()) + kernel_order = kernel_var.get_differential_order() + kernel_X_spike_buf_name_ticks = ASTUtils.construct_kernel_X_spike_buf_name( + kernel_var.get_name(), spike_input_port.get_name(), kernel_order, diff_order_symbol="'") + + ASTUtils.replace_rhs_variables(expr, kernel_buffers) + + entry = {"expression": kernel_X_spike_buf_name_ticks + " = " + str(expr), "initial_values": {}} + + # initial values need to be declared for order 1 up to kernel + # order (e.g. none for kernel function f(t) = ...; 1 for kernel + # ODE f'(t) = ...; 2 for f''(t) = ... and so on) + for order in range(kernel_order): + iv_sym_name_ode_toolbox = ASTUtils.construct_kernel_X_spike_buf_name( + kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'") + symbol_name_ = kernel_var.get_name() + "'" * order + symbol = equations_block.get_scope().resolve_to_symbol( + symbol_name_, SymbolKind.VARIABLE) + assert symbol is not None, "Could not find initial value for variable " + symbol_name_ + initial_value_expr = symbol.get_declaring_expression() + assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_ + entry["initial_values"][iv_sym_name_ode_toolbox] = cls._ode_toolbox_printer.print( + initial_value_expr) + + odetoolbox_indict["dynamics"].append(entry) + + odetoolbox_indict["parameters"] = {} + if parameters_block is not None: + for decl in parameters_block.get_declarations(): + for var in decl.variables: + odetoolbox_indict["parameters"][var.get_complete_name( + )] = cls._ode_toolbox_printer.print(decl.get_expression()) + + return odetoolbox_indict diff --git a/pynestml/utils/recs_info_enricher.py b/pynestml/utils/recs_info_enricher.py new file mode 100644 index 000000000..e292db96b --- /dev/null +++ b/pynestml/utils/recs_info_enricher.py @@ -0,0 +1,345 @@ +# -*- coding: utf-8 -*- +# +# recs_info_enricher.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from _collections import defaultdict + +import copy +import sympy + +from pynestml.meta_model.ast_expression import ASTExpression +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression +from pynestml.meta_model.ast_model import ASTModel +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.symbol import SymbolKind +from pynestml.utils.mechs_info_enricher import MechsInfoEnricher +from pynestml.utils.model_parser import ModelParser +from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor +from pynestml.visitors.ast_visitor import ASTVisitor + + +class RecsInfoEnricher(MechsInfoEnricher): + """ + input: a neuron after ODE-toolbox transformations + + the kernel analysis solves all kernels at the same time + this splits the variables on per kernel basis + """ + + def __init__(self, params): + super(MechsInfoEnricher, self).__init__(params) + + @classmethod + def enrich_mechanism_specific(cls, neuron, mechs_info): + specific_enricher_visitor = RecsInfoEnricherVisitor() + neuron.accept(specific_enricher_visitor) + mechs_info = cls.transform_convolutions_analytic_solutions(neuron, mechs_info) + mechs_info = cls.restore_order_internals(neuron, mechs_info) + return mechs_info + + @classmethod + def transform_convolutions_analytic_solutions(cls, neuron: ASTModel, cm_syns_info: dict): + + enriched_syns_info = copy.copy(cm_syns_info) + for synapse_name, synapse_info in cm_syns_info.items(): + for convolution_name in synapse_info["convolutions"].keys(): + analytic_solution = enriched_syns_info[synapse_name][ + "convolutions"][convolution_name]["analytic_solution"] + analytic_solution_transformed = defaultdict( + lambda: defaultdict()) + + for variable_name, expression_str in analytic_solution["initial_values"].items(): + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression(expression_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as synapses have been + # defined to get here + expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + + update_expr_str = analytic_solution["update_expressions"][variable_name] + update_expr_ast = ModelParser.parse_expression( + update_expr_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as differential equations + # must have been defined to get here + update_expr_ast.update_scope( + neuron.get_equations_blocks()[0].get_scope()) + update_expr_ast.accept(ASTSymbolTableVisitor()) + + analytic_solution_transformed['kernel_states'][variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + "update_expression": update_expr_ast, + } + + for variable_name, expression_string in analytic_solution["propagators"].items( + ): + variable = RecsInfoEnricherVisitor.internal_variable_name_to_variable[variable_name] + expression = ModelParser.parse_expression( + expression_string) + # pretend that update expressions are in "equations" block, + # which should always be present, as synapses have been + # defined to get here + expression.update_scope( + neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + analytic_solution_transformed['propagators'][variable_name] = { + "ASTVariable": variable, "init_expression": expression, } + + enriched_syns_info[synapse_name]["convolutions"][convolution_name]["analytic_solution"] = \ + analytic_solution_transformed + + # only one buffer allowed, so allow direct access + # to it instead of a list + if "buffer_name" not in enriched_syns_info[synapse_name]: + buffers_used = list( + enriched_syns_info[synapse_name]["buffers_used"]) + del enriched_syns_info[synapse_name]["buffers_used"] + enriched_syns_info[synapse_name]["buffer_name"] = buffers_used[0] + + inline_expression_name = enriched_syns_info[synapse_name]["root_expression"].variable_name + enriched_syns_info[synapse_name]["root_expression"] = \ + RecsInfoEnricherVisitor.inline_name_to_transformed_inline[inline_expression_name] + enriched_syns_info[synapse_name]["inline_expression_d"] = \ + cls.compute_expression_derivative( + enriched_syns_info[synapse_name]["root_expression"]) + + # now also identify analytic helper variables such as __h + enriched_syns_info[synapse_name]["analytic_helpers"] = cls.get_analytic_helper_variable_declarations( + enriched_syns_info[synapse_name]) + + return enriched_syns_info + + @classmethod + def restore_order_internals(cls, neuron: ASTModel, cm_syns_info: dict): + """orders user defined internals + back to the order they were originally defined + this is important if one such variable uses another + user needs to have control over the order + assign each variable a rank + that corresponds to the order in + RecsInfoEnricher.declarations_ordered""" + variable_name_to_order = {} + for index, declaration in enumerate( + RecsInfoEnricherVisitor.declarations_ordered): + variable_name = declaration.get_variables()[0].get_name() + variable_name_to_order[variable_name] = index + + enriched_syns_info = copy.copy(cm_syns_info) + for synapse_name, synapse_info in cm_syns_info.items(): + user_internals = enriched_syns_info[synapse_name]["internals_used_declared"] + user_internals_sorted = sorted( + user_internals.items(), key=lambda x: variable_name_to_order[x[0]]) + enriched_syns_info[synapse_name]["internals_used_declared"] = user_internals_sorted + + return enriched_syns_info + + @classmethod + def compute_expression_derivative( + cls, inline_expression: ASTInlineExpression) -> ASTExpression: + expr_str = str(inline_expression.get_expression()) + sympy_expr = sympy.parsing.sympy_parser.parse_expr(expr_str) + sympy_expr = sympy.diff(sympy_expr, "v_comp") + + ast_expression_d = ModelParser.parse_expression(str(sympy_expr)) + # copy scope of the original inline_expression into the the derivative + ast_expression_d.update_scope(inline_expression.get_scope()) + ast_expression_d.accept(ASTSymbolTableVisitor()) + + return ast_expression_d + + @classmethod + def get_variable_names_used(cls, node) -> set: + variable_names_extractor = ASTUsedVariableNamesExtractor(node) + return variable_names_extractor.variable_names + + @classmethod + def get_all_synapse_variables(cls, single_synapse_info): + """returns all variable names referenced by the synapse inline + and by the analytical solution + assumes that the model has already been transformed""" + + # get all variables from transformed inline + inline_variables = cls.get_variable_names_used( + single_synapse_info["root_expression"]) + + analytic_solution_vars = set() + # get all variables from transformed analytic solution + for convolution_name, convolution_info in single_synapse_info["convolutions"].items( + ): + analytic_sol = convolution_info["analytic_solution"] + # get variables from init and update expressions + # for each kernel + for kernel_var_name, kernel_info in analytic_sol["kernel_states"].items( + ): + analytic_solution_vars.add(kernel_var_name) + + update_vars = cls.get_variable_names_used( + kernel_info["update_expression"]) + init_vars = cls.get_variable_names_used( + kernel_info["init_expression"]) + + analytic_solution_vars.update(update_vars) + analytic_solution_vars.update(init_vars) + + # get variables from init expressions + # for each propagator + # include propagator variable itself + for propagator_var_name, propagator_info in analytic_sol["propagators"].items( + ): + analytic_solution_vars.add(propagator_var_name) + + init_vars = cls.get_variable_names_used( + propagator_info["init_expression"]) + + analytic_solution_vars.update(init_vars) + + return analytic_solution_vars.union(inline_variables) + + @classmethod + def get_new_variables_after_transformation(cls, single_synapse_info): + return cls.get_all_synapse_variables(single_synapse_info).difference( + single_synapse_info["total_used_declared"]) + + @classmethod + def get_analytic_helper_variable_names(cls, single_synapse_info): + """get new variables that only occur on the right hand side of analytic solution Expressions + but for wich analytic solution does not offer any values + this can isolate out additional variables that suddenly appear such as __h + whose initial values are not inlcuded in the output of analytic solver""" + + analytic_lhs_vars = set() + + for convolution_name, convolution_info in single_synapse_info["convolutions"].items( + ): + analytic_sol = convolution_info["analytic_solution"] + + # get variables representing convolutions by kernel + for kernel_var_name, kernel_info in analytic_sol["kernel_states"].items( + ): + analytic_lhs_vars.add(kernel_var_name) + + # get propagator variable names + for propagator_var_name, propagator_info in analytic_sol["propagators"].items( + ): + analytic_lhs_vars.add(propagator_var_name) + + return cls.get_new_variables_after_transformation( + single_synapse_info).symmetric_difference(analytic_lhs_vars) + + @classmethod + def get_analytic_helper_variable_declarations(cls, single_synapse_info): + variable_names = cls.get_analytic_helper_variable_names( + single_synapse_info) + result = dict() + for variable_name in variable_names: + if variable_name not in RecsInfoEnricherVisitor.internal_variable_name_to_variable: + continue + variable = RecsInfoEnricherVisitor.internal_variable_name_to_variable[variable_name] + expression = RecsInfoEnricherVisitor.variables_to_internal_declarations[variable] + result[variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + } + if expression.is_function_call() and expression.get_function_call( + ).callee_name == PredefinedFunctions.TIME_RESOLUTION: + result[variable_name]["is_time_resolution"] = True + else: + result[variable_name]["is_time_resolution"] = False + + return result + + +class RecsInfoEnricherVisitor(ASTVisitor): + variables_to_internal_declarations = {} + internal_variable_name_to_variable = {} + inline_name_to_transformed_inline = {} + + # assuming depth first traversal + # collect declaratins in the order + # in which they were present in the neuron + declarations_ordered = [] + + def __init__(self): + super(RecsInfoEnricherVisitor, self).__init__() + + self.inside_parameter_block = False + self.inside_state_block = False + self.inside_internals_block = False + self.inside_inline_expression = False + self.inside_inline_expression = False + self.inside_declaration = False + self.inside_simple_expression = False + + def visit_inline_expression(self, node): + self.inside_inline_expression = True + inline_name = node.variable_name + RecsInfoEnricherVisitor.inline_name_to_transformed_inline[inline_name] = node + + def endvisit_inline_expression(self, node): + self.inside_inline_expression = False + + def visit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = True + if node.is_parameters: + self.inside_parameter_block = True + if node.is_internals: + self.inside_internals_block = True + + def endvisit_block_with_variables(self, node): + if node.is_state: + self.inside_state_block = False + if node.is_parameters: + self.inside_parameter_block = False + if node.is_internals: + self.inside_internals_block = False + + def visit_simple_expression(self, node): + self.inside_simple_expression = True + + def endvisit_simple_expression(self, node): + self.inside_simple_expression = False + + def visit_declaration(self, node): + self.declarations_ordered.append(node) + self.inside_declaration = True + if self.inside_internals_block: + variable = node.get_variables()[0] + expression = node.get_expression() + RecsInfoEnricherVisitor.variables_to_internal_declarations[variable] = expression + RecsInfoEnricherVisitor.internal_variable_name_to_variable[variable.get_name( + )] = variable + + def endvisit_declaration(self, node): + self.inside_declaration = False + + +class ASTUsedVariableNamesExtractor(ASTVisitor): + def __init__(self, node): + super(ASTUsedVariableNamesExtractor, self).__init__() + self.variable_names = set() + node.accept(self) + + def visit_variable(self, node): + self.variable_names.add(node.get_name()) diff --git a/pynestml/utils/synapse_processing.py b/pynestml/utils/synapse_processing.py index 464abd269..18a536020 100644 --- a/pynestml/utils/synapse_processing.py +++ b/pynestml/utils/synapse_processing.py @@ -19,63 +19,160 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -import copy from collections import defaultdict +import copy + +from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter +from pynestml.codegeneration.printers.constant_printer import ConstantPrinter +from pynestml.codegeneration.printers.ode_toolbox_expression_printer import ODEToolboxExpressionPrinter +from pynestml.codegeneration.printers.ode_toolbox_function_call_printer import ODEToolboxFunctionCallPrinter +from pynestml.codegeneration.printers.ode_toolbox_variable_printer import ODEToolboxVariablePrinter +from pynestml.codegeneration.printers.unitless_cpp_simple_expression_printer import UnitlessCppSimpleExpressionPrinter from pynestml.frontend.frontend_configuration import FrontendConfiguration from pynestml.meta_model.ast_block_with_variables import ASTBlockWithVariables +from pynestml.meta_model.ast_expression import ASTExpression from pynestml.meta_model.ast_model import ASTModel +from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression from pynestml.symbols.symbol import SymbolKind -from pynestml.utils.ast_synapse_information_collector import ASTSynapseInformationCollector +from pynestml.utils.ast_synapse_information_collector import ASTSynapseInformationCollector, \ + ASTKernelInformationCollectorVisitor from pynestml.utils.ast_utils import ASTUtils + +from odetoolbox import analysis + from pynestml.utils.logger import Logger, LoggingLevel -from pynestml.utils.mechanism_processing import MechanismProcessing from pynestml.utils.messages import Messages -from odetoolbox import analysis +class SynapseProcessing: + """Manages the collection of basic information necesary for all types of mechanisms and uses the + collect_information_for_specific_mech_types interface that needs to be implemented by the specific mechanism type + processing classes""" + + # used to keep track of whenever check_co_co was already called + # see inside check_co_co + first_time_run = defaultdict(lambda: True) + # stores synapse from the first call of check_co_co + syn_info = defaultdict() + + # ODE-toolbox printers + _constant_printer = ConstantPrinter() + _ode_toolbox_variable_printer = ODEToolboxVariablePrinter(None) + _ode_toolbox_function_call_printer = ODEToolboxFunctionCallPrinter(None) + _ode_toolbox_printer = ODEToolboxExpressionPrinter( + simple_expression_printer=UnitlessCppSimpleExpressionPrinter( + variable_printer=_ode_toolbox_variable_printer, + constant_printer=_constant_printer, + function_call_printer=_ode_toolbox_function_call_printer)) -class SynapseProcessing(MechanismProcessing): - mechType = "receptor" + _ode_toolbox_variable_printer._expression_printer = _ode_toolbox_printer + _ode_toolbox_function_call_printer._expression_printer = _ode_toolbox_printer - def __init__(self, params): - super(MechanismProcessing, self).__init__(params) + @classmethod + def prepare_equations_for_ode_toolbox(cls, synapse, syn_info): + """Transforms the collected ode equations to the required input format of ode-toolbox and adds it to the + syn_info dictionary""" + + mechanism_odes = defaultdict() + for ode in syn_info["ODEs"]: + nestml_printer = NESTMLPrinter() + ode_nestml_expression = nestml_printer.print_ode_equation(ode) + mechanism_odes[ode.lhs.name] = defaultdict() + mechanism_odes[ode.lhs.name]["ASTOdeEquation"] = ode + mechanism_odes[ode.lhs.name]["ODENestmlExpression"] = ode_nestml_expression + syn_info["ODEs"] = mechanism_odes + + for ode_variable_name, ode_info in syn_info["ODEs"].items(): + # Expression: + odetoolbox_indict = {"dynamics": []} + lhs = ASTUtils.to_ode_toolbox_name(ode_info["ASTOdeEquation"].get_lhs().get_complete_name()) + rhs = cls._ode_toolbox_printer.print(ode_info["ASTOdeEquation"].get_rhs()) + entry = {"expression": lhs + " = " + rhs, "initial_values": {}} + + # Initial values: + symbol_order = ode_info["ASTOdeEquation"].get_lhs().get_differential_order() + for order in range(symbol_order): + iv_symbol_name = ode_info["ASTOdeEquation"].get_lhs().get_name() + "'" * order + initial_value_expr = synapse.get_initial_value(iv_symbol_name) + entry["initial_values"][ + ASTUtils.to_ode_toolbox_name(iv_symbol_name)] = cls._ode_toolbox_printer.print( + initial_value_expr) + + odetoolbox_indict["dynamics"].append(entry) + syn_info["ODEs"][ode_variable_name]["ode_toolbox_input"] = odetoolbox_indict + + return syn_info @classmethod - def collect_information_for_specific_mech_types(cls, neuron, mechs_info): - mechs_info, add_info_collector = cls.collect_additional_base_infos(neuron, mechs_info) - if len(mechs_info) > 0: - # only do this if any synapses found - # otherwise tests may fail - mechs_info = cls.collect_and_check_inputs_per_synapse(mechs_info) + def collect_raw_odetoolbox_output(cls, syn_info): + """calls ode-toolbox for each ode individually and collects the raw output""" + for ode_variable_name, ode_info in syn_info["ODEs"].items(): + solver_result = analysis(ode_info["ode_toolbox_input"], disable_stiffness_check=True) + syn_info["ODEs"][ode_variable_name]["ode_toolbox_output"] = solver_result - mechs_info = cls.convolution_ode_toolbox_processing(neuron, mechs_info) + return syn_info - return mechs_info + @classmethod + def ode_toolbox_processing(cls, synapse, syn_info): + syn_info = cls.prepare_equations_for_ode_toolbox(synapse, syn_info) + syn_info = cls.collect_raw_odetoolbox_output(syn_info) + return syn_info @classmethod - def collect_additional_base_infos(cls, neuron, syns_info): + def collect_information_for_specific_mech_types(cls, synapse, syn_info): + # to be implemented for specific mechanisms by child class (concentration, synapse, channel) + pass + + @classmethod + def determine_dependencies(cls, syn_info): + for mechanism_name, mechanism_info in syn_info.items(): + dependencies = list() + for inline in mechanism_info["Inlines"]: + if isinstance(inline.get_decorators(), list): + if "mechanism" in [e.namespace for e in inline.get_decorators()]: + dependencies.append(inline) + for ode in mechanism_info["ODEs"]: + if isinstance(ode.get_decorators(), list): + if "mechanism" in [e.namespace for e in ode.get_decorators()]: + dependencies.append(ode) + syn_info[mechanism_name]["dependencies"] = dependencies + return syn_info + + @classmethod + def get_port_names(cls, syn_info): + spiking_port_names = list() + continuous_port_names = list() + for port in syn_info["SpikingPorts"]: + spiking_port_names.append(port.get_name()) + for port in syn_info["ContinuousPorts"]: + continuous_port_names.append(port.get_name()) + + return spiking_port_names, continuous_port_names + + @classmethod + def collect_kernels(cls, neuron, syn_info, neuron_synapse_pairs): """ Collect internals, kernels, inputs and convolutions associated with the synapse. """ - info_collector = ASTSynapseInformationCollector() + syn_info["convolutions"] = defaultdict() + info_collector = ASTKernelInformationCollectorVisitor() neuron.accept(info_collector) - for synapse_name, synapse_info in syns_info.items(): - synapse_inline = syns_info[synapse_name]["root_expression"] - syns_info[synapse_name][ + for inline in syn_info["Inlines"]: + synapse_inline = inline + syn_info[ "internals_used_declared"] = info_collector.get_synapse_specific_internal_declarations(synapse_inline) - syns_info[synapse_name]["total_used_declared"] = info_collector.get_variable_names_of_synapse( + syn_info["total_used_declared"] = info_collector.get_variable_names_of_synapse( synapse_inline) - syns_info[synapse_name]["convolutions"] = defaultdict() - kernel_arg_pairs = info_collector.get_extracted_kernel_args( - synapse_inline) + kernel_arg_pairs = info_collector.get_extracted_kernel_args_by_name( + inline.get_variable_name()) for kernel_var, spikes_var in kernel_arg_pairs: kernel_name = kernel_var.get_name() spikes_name = spikes_var.get_name() convolution_name = info_collector.construct_kernel_X_spike_buf_name( kernel_name, spikes_name, 0) - syns_info[synapse_name]["convolutions"][convolution_name] = { + syn_info["convolutions"][convolution_name] = { "kernel": { "name": kernel_name, "ASTKernel": info_collector.get_kernel_by_name(kernel_name), @@ -84,52 +181,37 @@ def collect_additional_base_infos(cls, neuron, syns_info): "name": spikes_name, "ASTInputPort": info_collector.get_input_port_by_name(spikes_name), }, + "post_port": (len([dict for dict in neuron_synapse_pairs if dict["synapse"]+"_nestml" == neuron.name and spikes_name in dict["post_ports"]]) > 0), } - return syns_info, info_collector + return syn_info @classmethod def collect_and_check_inputs_per_synapse( cls, - syns_info: dict): - new_syns_info = copy.copy(syns_info) + syn_info: dict): + new_syn_info = copy.copy(syn_info) # collect all buffers used - for synapse_name, synapse_info in syns_info.items(): - new_syns_info[synapse_name]["buffers_used"] = set() - for convolution_name, convolution_info in synapse_info["convolutions"].items( - ): - input_name = convolution_info["spikes"]["name"] - new_syns_info[synapse_name]["buffers_used"].add(input_name) - - # now make sure each synapse is using exactly one buffer - for synapse_name, synapse_info in syns_info.items(): - buffers = new_syns_info[synapse_name]["buffers_used"] - if len(buffers) != 1: - code, message = Messages.get_syns_bad_buffer_count( - buffers, synapse_name) - causing_object = synapse_info["inline_expression"] - Logger.log_message( - code=code, - message=message, - error_position=causing_object.get_source_position(), - log_level=LoggingLevel.ERROR, - node=causing_object) - - return new_syns_info - - @classmethod - def convolution_ode_toolbox_processing(cls, neuron, syns_info): + new_syn_info["buffers_used"] = set() + for convolution_name, convolution_info in syn_info["convolutions"].items( + ): + input_name = convolution_info["spikes"]["name"] + new_syn_info["buffers_used"].add(input_name) + + return new_syn_info + + @classmethod + def convolution_ode_toolbox_processing(cls, neuron, syn_info): if not neuron.get_parameters_blocks(): - return syns_info + return syn_info parameters_block = neuron.get_parameters_blocks()[0] - for synapse_name, synapse_info in syns_info.items(): - for convolution_name, convolution_info in synapse_info["convolutions"].items(): - kernel_buffer = (convolution_info["kernel"]["ASTKernel"], convolution_info["spikes"]["ASTInputPort"]) - convolution_solution = cls.ode_solve_convolution(neuron, parameters_block, kernel_buffer) - syns_info[synapse_name]["convolutions"][convolution_name]["analytic_solution"] = convolution_solution - return syns_info + for convolution_name, convolution_info in syn_info["convolutions"].items(): + kernel_buffer = (convolution_info["kernel"]["ASTKernel"], convolution_info["spikes"]["ASTInputPort"]) + convolution_solution = cls.ode_solve_convolution(neuron, parameters_block, kernel_buffer) + syn_info["convolutions"][convolution_name]["analytic_solution"] = convolution_solution + return syn_info @classmethod def ode_solve_convolution(cls, @@ -228,3 +310,93 @@ def transform_ode_and_kernels_to_json( )] = cls._ode_toolbox_printer.print(decl.get_expression()) return odetoolbox_indict + + @classmethod + def get_syn_info(cls, synapse: ASTModel): + """ + returns previously generated syn_info + as a deep copy so it can't be changed externally + via object references + :param synapse: a single synapse instance. + """ + return copy.deepcopy(cls.syn_info) + + @classmethod + def process(cls, synapse: ASTModel, neuron_synapse_pairs): + """ + Checks if mechanism conditions apply for the handed over synapse. + :param synapse: a single synapse instance. + """ + + # make sure we only run this a single time + # subsequent calls will be after AST has been transformed + # and there would be no kernels or inlines any more + if cls.first_time_run[synapse]: + # collect root expressions and initialize collector + info_collector = ASTSynapseInformationCollector(synapse) + + # collect and process all basic mechanism information + syn_info = defaultdict() + syn_info = info_collector.collect_definitions(synapse, syn_info) + syn_info = info_collector.extend_variables_with_initialisations(synapse, syn_info) + syn_info = cls.ode_toolbox_processing(synapse, syn_info) + + # collect all spiking ports + syn_info = info_collector.collect_ports(synapse, syn_info) + + # collect the onReceive function of pre- and post-spikes + spiking_port_names, continuous_port_names = cls.get_port_names(syn_info) + post_ports = FrontendConfiguration.get_codegen_opts()["neuron_synapse_pairs"][0]["post_ports"] + pre_ports = list(set(spiking_port_names) - set(post_ports)) + syn_info = info_collector.collect_on_receive_blocks(synapse, syn_info, pre_ports, post_ports) + + # collect the update block + syn_info = info_collector.collect_update_block(synapse, syn_info) + + # collect dependencies (defined mechanism in neuron and no LHS appearance in synapse) + syn_info = info_collector.collect_potential_dependencies(synapse, syn_info) + + syn_info = cls.collect_kernels(synapse, syn_info, neuron_synapse_pairs) + + syn_info = cls.convolution_ode_toolbox_processing(synapse, syn_info) + + cls.syn_info[synapse.get_name()] = syn_info + cls.first_time_run[synapse.get_name()] = False + + @classmethod + def print_element(cls, name, element, rec_step): + message = "" + for indent in range(rec_step): + message += "----" + message += name + ": " + if isinstance(element, defaultdict): + message += "\n" + message += cls.print_dictionary(element, rec_step + 1) + else: + if hasattr(element, 'name'): + message += element.name + elif isinstance(element, str): + message += element + elif isinstance(element, dict): + message += "\n" + message += cls.print_dictionary(element, rec_step + 1) + elif isinstance(element, list): + for index in range(len(element)): + message += "\n" + message += cls.print_element(str(index), element[index], rec_step + 1) + elif isinstance(element, ASTExpression) or isinstance(element, ASTSimpleExpression): + message += cls._ode_toolbox_printer.print(element) + + message += "(" + type(element).__name__ + ")" + return message + + @classmethod + def print_dictionary(cls, dictionary, rec_step): + """ + Print the mechanisms info dictionaries. + """ + message = "" + for name, element in dictionary.items(): + message += cls.print_element(name, element, rec_step) + message += "\n" + return message diff --git a/pynestml/utils/syns_info_enricher.py b/pynestml/utils/syns_info_enricher.py index 5c4b639ec..7224f3a26 100644 --- a/pynestml/utils/syns_info_enricher.py +++ b/pynestml/utils/syns_info_enricher.py @@ -18,139 +18,264 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . - -from _collections import defaultdict - import copy +from collections import defaultdict + import sympy +from executing.executing import node_linenos from pynestml.meta_model.ast_expression import ASTExpression from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.meta_model.ast_model import ASTModel -from pynestml.symbols.predefined_functions import PredefinedFunctions -from pynestml.symbols.symbol import SymbolKind -from pynestml.utils.mechs_info_enricher import MechsInfoEnricher -from pynestml.utils.model_parser import ModelParser +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor +from pynestml.utils.ast_utils import ASTUtils from pynestml.visitors.ast_visitor import ASTVisitor +from pynestml.utils.model_parser import ModelParser +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.symbol import SymbolKind +from collections import defaultdict -class SynsInfoEnricher(MechsInfoEnricher): + +class SynsInfoEnricher: """ - input: a neuron after ODE-toolbox transformations + Adds information collection that can't be done in the processing class since that is used in the cocos. + Here we use the ModelParser which would lead to a cyclic dependency. - the kernel analysis solves all kernels at the same time - this splits the variables on per kernel basis + Additionally we require information about the paired synapses mechanism to confirm what dependencies are actually existent in the synapse. """ - def __init__(self, params): - super(MechsInfoEnricher, self).__init__(params) + def __init__(self): + pass @classmethod - def enrich_mechanism_specific(cls, neuron, mechs_info): + def enrich_with_additional_info(cls, synapse: ASTModel, syns_info: dict, chan_info: dict, recs_info: dict, conc_info: dict, con_in_info: dict): specific_enricher_visitor = SynsInfoEnricherVisitor() - neuron.accept(specific_enricher_visitor) - mechs_info = cls.transform_convolutions_analytic_solutions(neuron, mechs_info) - mechs_info = cls.restore_order_internals(neuron, mechs_info) - return mechs_info + synapse.accept(specific_enricher_visitor) + synapse_info = syns_info[synapse.get_name()] + synapse_info = cls.transform_ode_solutions(synapse, synapse_info) + synapse_info = cls.confirm_dependencies(synapse_info, chan_info, recs_info, conc_info, con_in_info) + synapse_info = cls.extract_infunction_declarations(synapse_info) - @classmethod - def transform_convolutions_analytic_solutions(cls, neuron: ASTModel, cm_syns_info: dict): + synapse_info = cls.transform_convolutions_analytic_solutions(synapse, synapse_info) + syns_info[synapse.get_name()] = synapse_info - enriched_syns_info = copy.copy(cm_syns_info) - for synapse_name, synapse_info in cm_syns_info.items(): - for convolution_name in synapse_info["convolutions"].keys(): - analytic_solution = enriched_syns_info[synapse_name][ - "convolutions"][convolution_name]["analytic_solution"] - analytic_solution_transformed = defaultdict( - lambda: defaultdict()) - - for variable_name, expression_str in analytic_solution["initial_values"].items(): - variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + return syns_info + + + @classmethod + def transform_ode_solutions(cls, synapse, syns_info): + for ode_var_name, ode_info in syns_info["ODEs"].items(): + syns_info["ODEs"][ode_var_name]["transformed_solutions"] = list() + + for ode_solution_index in range(len(ode_info["ode_toolbox_output"])): + solution_transformed = defaultdict() + solution_transformed["states"] = defaultdict() + solution_transformed["propagators"] = defaultdict() + + for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index][ + "initial_values"].items(): + variable = synapse.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE) - expression = ModelParser.parse_expression(expression_str) + expression = ModelParser.parse_expression(rhs_str) # pretend that update expressions are in "equations" block, # which should always be present, as synapses have been # defined to get here - expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) + expression.update_scope(synapse.get_equations_blocks()[0].get_scope()) expression.accept(ASTSymbolTableVisitor()) - update_expr_str = analytic_solution["update_expressions"][variable_name] + update_expr_str = ode_info["ode_toolbox_output"][ode_solution_index]["update_expressions"][ + variable_name] update_expr_ast = ModelParser.parse_expression( update_expr_str) # pretend that update expressions are in "equations" block, # which should always be present, as differential equations # must have been defined to get here update_expr_ast.update_scope( - neuron.get_equations_blocks()[0].get_scope()) + synapse.get_equations_blocks()[0].get_scope()) update_expr_ast.accept(ASTSymbolTableVisitor()) - analytic_solution_transformed['kernel_states'][variable_name] = { + solution_transformed["states"][variable_name] = { "ASTVariable": variable, "init_expression": expression, "update_expression": update_expr_ast, } - - for variable_name, expression_string in analytic_solution["propagators"].items( - ): - variable = SynsInfoEnricherVisitor.internal_variable_name_to_variable[variable_name] - expression = ModelParser.parse_expression( - expression_string) + for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index][ + "propagators"].items(): + prop_variable = synapse.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + if prop_variable is None: + ASTUtils.add_declarations_to_internals( + synapse, ode_info["ode_toolbox_output"][ode_solution_index]["propagators"]) + prop_variable = synapse.get_equations_blocks()[0].get_scope().resolve_to_symbol( + variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression(rhs_str) # pretend that update expressions are in "equations" block, # which should always be present, as synapses have been # defined to get here expression.update_scope( - neuron.get_equations_blocks()[0].get_scope()) + synapse.get_equations_blocks()[0].get_scope()) expression.accept(ASTSymbolTableVisitor()) - analytic_solution_transformed['propagators'][variable_name] = { - "ASTVariable": variable, "init_expression": expression, } - - enriched_syns_info[synapse_name]["convolutions"][convolution_name]["analytic_solution"] = \ - analytic_solution_transformed - - # only one buffer allowed, so allow direct access - # to it instead of a list - if "buffer_name" not in enriched_syns_info[synapse_name]: - buffers_used = list( - enriched_syns_info[synapse_name]["buffers_used"]) - del enriched_syns_info[synapse_name]["buffers_used"] - enriched_syns_info[synapse_name]["buffer_name"] = buffers_used[0] - - inline_expression_name = enriched_syns_info[synapse_name]["root_expression"].variable_name - enriched_syns_info[synapse_name]["root_expression"] = \ - SynsInfoEnricherVisitor.inline_name_to_transformed_inline[inline_expression_name] - enriched_syns_info[synapse_name]["inline_expression_d"] = \ - cls.compute_expression_derivative( - enriched_syns_info[synapse_name]["root_expression"]) - # now also identify analytic helper variables such as __h - enriched_syns_info[synapse_name]["analytic_helpers"] = cls.get_analytic_helper_variable_declarations( - enriched_syns_info[synapse_name]) + solution_transformed["propagators"][variable_name] = { + "ASTVariable": prop_variable, "init_expression": expression, } + expression_variable_collector = ASTEnricherInfoCollectorVisitor() + expression.accept(expression_variable_collector) - return enriched_syns_info + synapse_internal_declaration_collector = ASTEnricherInfoCollectorVisitor() + synapse.accept(synapse_internal_declaration_collector) + + for variable in expression_variable_collector.all_variables: + for internal_declaration in synapse_internal_declaration_collector.internal_declarations: + if variable.get_name() == internal_declaration.get_variables()[0].get_name() \ + and internal_declaration.get_expression().is_function_call() \ + and internal_declaration.get_expression().get_function_call().callee_name == \ + PredefinedFunctions.TIME_RESOLUTION: + syns_info["time_resolution_var"] = variable + + syns_info["ODEs"][ode_var_name]["transformed_solutions"].append(solution_transformed) + + synapse.accept(ASTParentVisitor()) + + return syns_info + + @classmethod + def confirm_dependencies(cls, syns_info: dict, chan_info: dict, recs_info: dict, conc_info: dict, con_in_info: dict): + actual_dependencies = dict() + chan_deps = list() + rec_deps = list() + conc_deps = list() + con_in_deps = list() + for pot_dep, dep_info in syns_info["PotentialDependencies"].items(): + for channel_name, channel_info in chan_info.items(): + if pot_dep == channel_name: + chan_deps.append(channel_info["root_expression"]) + for receptor_name, receptor_info in recs_info.items(): + if pot_dep == receptor_name: + rec_deps.append(receptor_info["root_expression"]) + for concentration_name, concentration_info in conc_info.items(): + if pot_dep == concentration_name: + conc_deps.append(concentration_info["root_expression"]) + for continuous_name, continuous_info in con_in_info.items(): + if pot_dep == continuous_name: + con_in_deps.append(continuous_info["root_expression"]) + + actual_dependencies["channels"] = chan_deps + actual_dependencies["receptors"] = rec_deps + actual_dependencies["concentrations"] = conc_deps + actual_dependencies["continuous"] = con_in_deps + syns_info["Dependencies"] = actual_dependencies + return syns_info @classmethod - def restore_order_internals(cls, neuron: ASTModel, cm_syns_info: dict): - """orders user defined internals - back to the order they were originally defined - this is important if one such variable uses another - user needs to have control over the order - assign each variable a rank - that corresponds to the order in - SynsInfoEnricher.declarations_ordered""" - variable_name_to_order = {} - for index, declaration in enumerate( - SynsInfoEnricherVisitor.declarations_ordered): - variable_name = declaration.get_variables()[0].get_name() - variable_name_to_order[variable_name] = index + def extract_infunction_declarations(cls, syn_info): + pre_spike_function = syn_info["PreSpikeFunction"] + post_spike_function = syn_info["PostSpikeFunction"] + update_block = syn_info["UpdateBlock"] + #general_functions = syn_info["Functions"] + declaration_visitor = ASTDeclarationCollectorAndUniqueRenamerVisitor() + if pre_spike_function is not None: + pre_spike_function.accept(declaration_visitor) + if post_spike_function is not None: + post_spike_function.accept(declaration_visitor) + if update_block is not None: + update_block.accept(declaration_visitor) + + declaration_vars = list() + for decl in declaration_visitor.declarations: + for var in decl.get_variables(): + declaration_vars.append(var.get_name()) + + syn_info["InFunctionDeclarationsVars"] = declaration_visitor.declarations #list(declaration_vars) + return syn_info + + @classmethod + def transform_convolutions_analytic_solutions(cls, neuron: ASTModel, cm_syns_info: dict): enriched_syns_info = copy.copy(cm_syns_info) - for synapse_name, synapse_info in cm_syns_info.items(): - user_internals = enriched_syns_info[synapse_name]["internals_used_declared"] - user_internals_sorted = sorted( - user_internals.items(), key=lambda x: variable_name_to_order[x[0]]) - enriched_syns_info[synapse_name]["internals_used_declared"] = user_internals_sorted + for convolution_name in cm_syns_info["convolutions"].keys(): + analytic_solution = enriched_syns_info[ + "convolutions"][convolution_name]["analytic_solution"] + analytic_solution_transformed = defaultdict( + lambda: defaultdict()) + + for variable_name, expression_str in analytic_solution["initial_values"].items(): + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + if variable is None: + ASTUtils.add_declarations_to_internals( + neuron, analytic_solution["initial_values"]) + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol( + variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression(expression_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as synapses have been + # defined to get here + expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + + update_expr_str = analytic_solution["update_expressions"][variable_name] + update_expr_ast = ModelParser.parse_expression( + update_expr_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as differential equations + # must have been defined to get here + update_expr_ast.update_scope( + neuron.get_equations_blocks()[0].get_scope()) + update_expr_ast.accept(ASTSymbolTableVisitor()) + + analytic_solution_transformed['kernel_states'][variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + "update_expression": update_expr_ast, + } + + for variable_name, expression_string in analytic_solution["propagators"].items( + ): + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + if variable is None: + ASTUtils.add_declarations_to_internals( + neuron, analytic_solution["propagators"]) + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol( + variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression( + expression_string) + # pretend that update expressions are in "equations" block, + # which should always be present, as synapses have been + # defined to get here + expression.update_scope( + neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + analytic_solution_transformed['propagators'][variable_name] = { + "ASTVariable": variable, "init_expression": expression, } + + enriched_syns_info["convolutions"][convolution_name]["analytic_solution"] = \ + analytic_solution_transformed + + transformed_inlines = dict() + for inline in enriched_syns_info["Inlines"]: + transformed_inlines[inline.get_variable_name()] = dict() + transformed_inlines[inline.get_variable_name()]["inline_expression"] = \ + SynsInfoEnricherVisitor.inline_name_to_transformed_inline[inline.get_variable_name()] + transformed_inlines[inline.get_variable_name()]["inline_expression_d"] = \ + cls.compute_expression_derivative( + transformed_inlines[inline.get_variable_name()]["inline_expression"]) + enriched_syns_info["Inlines"] = transformed_inlines + + # now also identify analytic helper variables such as __h + enriched_syns_info["analytic_helpers"] = cls.get_analytic_helper_variable_declarations( + enriched_syns_info) + + neuron.accept(ASTParentVisitor()) return enriched_syns_info @@ -169,9 +294,57 @@ def compute_expression_derivative( return ast_expression_d @classmethod - def get_variable_names_used(cls, node) -> set: - variable_names_extractor = ASTUsedVariableNamesExtractor(node) - return variable_names_extractor.variable_names + def get_analytic_helper_variable_declarations(cls, single_synapse_info): + variable_names = cls.get_analytic_helper_variable_names( + single_synapse_info) + result = dict() + for variable_name in variable_names: + if variable_name not in SynsInfoEnricherVisitor.internal_variable_name_to_variable: + continue + variable = SynsInfoEnricherVisitor.internal_variable_name_to_variable[variable_name] + expression = SynsInfoEnricherVisitor.variables_to_internal_declarations[variable] + result[variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + } + if expression.is_function_call() and expression.get_function_call( + ).callee_name == PredefinedFunctions.TIME_RESOLUTION: + result[variable_name]["is_time_resolution"] = True + else: + result[variable_name]["is_time_resolution"] = False + + return result + + @classmethod + def get_analytic_helper_variable_names(cls, single_synapse_info): + """get new variables that only occur on the right hand side of analytic solution Expressions + but for wich analytic solution does not offer any values + this can isolate out additional variables that suddenly appear such as __h + whose initial values are not inlcuded in the output of analytic solver""" + + analytic_lhs_vars = set() + + for convolution_name, convolution_info in single_synapse_info["convolutions"].items( + ): + analytic_sol = convolution_info["analytic_solution"] + + # get variables representing convolutions by kernel + for kernel_var_name, kernel_info in analytic_sol["kernel_states"].items( + ): + analytic_lhs_vars.add(kernel_var_name) + + # get propagator variable names + for propagator_var_name, propagator_info in analytic_sol["propagators"].items( + ): + analytic_lhs_vars.add(propagator_var_name) + + return cls.get_new_variables_after_transformation( + single_synapse_info).symmetric_difference(analytic_lhs_vars) + + @classmethod + def get_new_variables_after_transformation(cls, single_synapse_info): + return cls.get_all_synapse_variables(single_synapse_info).difference( + single_synapse_info["total_used_declared"]) @classmethod def get_all_synapse_variables(cls, single_synapse_info): @@ -179,9 +352,9 @@ def get_all_synapse_variables(cls, single_synapse_info): and by the analytical solution assumes that the model has already been transformed""" - # get all variables from transformed inline - inline_variables = cls.get_variable_names_used( - single_synapse_info["root_expression"]) + inline_variables = set() + for inline_name, inline in single_synapse_info["Inlines"].items(): + inline_variables = cls.get_variable_names_used(inline["inline_expression"]) analytic_solution_vars = set() # get all variables from transformed analytic solution @@ -217,57 +390,111 @@ def get_all_synapse_variables(cls, single_synapse_info): return analytic_solution_vars.union(inline_variables) @classmethod - def get_new_variables_after_transformation(cls, single_synapse_info): - return cls.get_all_synapse_variables(single_synapse_info).difference( - single_synapse_info["total_used_declared"]) + def get_variable_names_used(cls, node) -> set: + variable_names_extractor = ASTUsedVariableNamesExtractor(node) + return variable_names_extractor.variable_names - @classmethod - def get_analytic_helper_variable_names(cls, single_synapse_info): - """get new variables that only occur on the right hand side of analytic solution Expressions - but for wich analytic solution does not offer any values - this can isolate out additional variables that suddenly appear such as __h - whose initial values are not inlcuded in the output of analytic solver""" - analytic_lhs_vars = set() - for convolution_name, convolution_info in single_synapse_info["convolutions"].items( - ): - analytic_sol = convolution_info["analytic_solution"] - # get variables representing convolutions by kernel - for kernel_var_name, kernel_info in analytic_sol["kernel_states"].items( - ): - analytic_lhs_vars.add(kernel_var_name) +class ASTEnricherInfoCollectorVisitor(ASTVisitor): - # get propagator variable names - for propagator_var_name, propagator_info in analytic_sol["propagators"].items( - ): - analytic_lhs_vars.add(propagator_var_name) + def __init__(self): + super(ASTEnricherInfoCollectorVisitor, self).__init__() + self.inside_variable = False + self.inside_block_with_variables = False + self.all_states = list() + self.all_parameters = list() + self.inside_states_block = False + self.inside_parameters_block = False + self.all_variables = list() + self.inside_internals_block = False + self.inside_declaration = False + self.internal_declarations = list() - return cls.get_new_variables_after_transformation( - single_synapse_info).symmetric_difference(analytic_lhs_vars) + def visit_block_with_variables(self, node): + self.inside_block_with_variables = True + if node.is_state: + self.inside_states_block = True + if node.is_parameters: + self.inside_parameters_block = True + if node.is_internals: + self.inside_internals_block = True - @classmethod - def get_analytic_helper_variable_declarations(cls, single_synapse_info): - variable_names = cls.get_analytic_helper_variable_names( - single_synapse_info) - result = dict() - for variable_name in variable_names: - if variable_name not in SynsInfoEnricherVisitor.internal_variable_name_to_variable: - continue - variable = SynsInfoEnricherVisitor.internal_variable_name_to_variable[variable_name] - expression = SynsInfoEnricherVisitor.variables_to_internal_declarations[variable] - result[variable_name] = { - "ASTVariable": variable, - "init_expression": expression, - } - if expression.is_function_call() and expression.get_function_call( - ).callee_name == PredefinedFunctions.TIME_RESOLUTION: - result[variable_name]["is_time_resolution"] = True + def endvisit_block_with_variables(self, node): + self.inside_states_block = False + self.inside_parameters_block = False + self.inside_block_with_variables = False + self.inside_internals_block = False + + def visit_variable(self, node): + self.inside_variable = True + self.all_variables.append(node.clone()) + if self.inside_states_block: + self.all_states.append(node.clone()) + if self.inside_parameters_block: + self.all_parameters.append(node.clone()) + + def endvisit_variable(self, node): + self.inside_variable = False + + def visit_declaration(self, node): + self.inside_declaration = True + if self.inside_internals_block: + self.internal_declarations.append(node) + + def endvisit_declaration(self, node): + self.inside_declaration = False + + +class ASTDeclarationCollectorAndUniqueRenamerVisitor(ASTVisitor): + def __init__(self): + super(ASTDeclarationCollectorAndUniqueRenamerVisitor, self).__init__() + self.declarations = list() + self.variable_names = dict() + self.inside_declaration = False + self.inside_block = False + self.current_block = None + + def visit_block(self, node): + self.inside_block = True + self.current_block = node + + def endvisit_block(self, node): + self.inside_block = False + self.current_block = None + + def visit_declaration(self, node): + self.inside_declaration = True + for variable in node.get_variables(): + if variable.get_name() in self.variable_names: + self.variable_names[variable.get_name()] += 1 else: - result[variable_name]["is_time_resolution"] = False + self.variable_names[variable.get_name()] = 0 + new_name = variable.get_name() + '_' + str(self.variable_names[variable.get_name()]) + name_replacer = ASTVariableNameReplacerVisitor(variable.get_name(), new_name) + self.current_block.accept(name_replacer) + node.accept(ASTSymbolTableVisitor()) + self.declarations.append(node.clone()) - return result + def endvisit_declaration(self, node): + self.inside_declaration = False + + +class ASTVariableNameReplacerVisitor(ASTVisitor): + def __init__(self, old_name, new_name): + super(ASTVariableNameReplacerVisitor, self).__init__() + self.inside_variable = False + self.new_name = new_name + self.old_name = old_name + + def visit_variable(self, node): + self.inside_variable = True + if node.get_name() == self.old_name: + node.set_name(self.new_name) + + def endvisit_variable(self, node): + self.inside_variable = False class SynsInfoEnricherVisitor(ASTVisitor): @@ -343,3 +570,5 @@ def __init__(self, node): def visit_variable(self, node): self.variable_names.add(node.get_name()) + + diff --git a/setup.py b/setup.py index 73bf415f7..1397eddbd 100755 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ "codegeneration/resources_nest/point_neuron/setup/common/*.jinja2", "codegeneration/resources_nest_compartmental/cm_neuron/*.jinja2", "codegeneration/resources_nest_compartmental/cm_neuron/directives_cpp/*.jinja2", + "codegeneration/resources_nest_compartmental/cm_neuron/cm_directives_cpp/*.jinja2", "codegeneration/resources_nest_compartmental/cm_neuron/setup/*.jinja2", "codegeneration/resources_nest_compartmental/cm_neuron/setup/common/*.jinja2", "codegeneration/resources_python_standalone/point_neuron/*.jinja2", diff --git a/tests/nest_compartmental_tests/FR_RELU_DC.py b/tests/nest_compartmental_tests/FR_RELU_DC.py new file mode 100644 index 000000000..4683bc025 --- /dev/null +++ b/tests/nest_compartmental_tests/FR_RELU_DC.py @@ -0,0 +1,212 @@ +# +# RELU code with DC stimulation +# +# First version: 25/01/2023 +# Author: Elena Pastorelli, INFN, Rome (IT) +# +# Description: code adapted from MC-Adex-BAC optimizee +# + + +from collections import namedtuple + +import numpy as np +import matplotlib.pyplot as plt +import random +import statistics as stat +import sys +import yaml +import time +import datetime + +from utility_RELU import stimula, stimulaNoisy, spikeInSteps, stimula_soma, FRInLastSilence +from plot_FR import plot_FR, plot_FR1, plot_FR_paper + +class MCADEXBACOptimizee(): + + def __init__(self): + + return + + + def simulate(self): + """ + :param ~l2l.utils.trajectory.Trajectory traj: Trajectory + :return: a single element :obj:`tuple` containing the fitness of the simulation + """ + global nest + import nest + from mpi4py import MPI + comm = MPI.COMM_WORLD + self.comm = comm + self.rank = self.comm.Get_rank() + assert (nest.Rank() == self.rank) + print(self.rank) + self.__init__() + individual = self.create_individual() + print(individual) + #self.id = traj.individual.ind_idx + + self.C_m_d = individual['C_m_d'] + self.C_m_s = individual['C_m_s'] + self.delta_T = individual['delta_T'] + self.e_K = individual['e_K'] + self.e_L_d = individual['e_L_d'] + self.e_L_s = individual['e_L_s'] + self.e_Na_Adex = individual['e_Na_Adex'] + self.g_C_d = individual['g_C_d'] + self.g_L_d = individual['g_L_d'] + self.g_L_s = individual['g_L_s'] + self.gbar_Ca = individual['gbar_Ca'] + self.gbar_K_Ca = individual['gbar_K_Ca'] + self.phi = individual['phi'] + self.tau_decay_Ca = individual['tau_decay_Ca'] + self.m_half = individual['m_half'] + self.h_half = individual['h_half'] + self.m_slope = individual['m_slope'] + self.h_slope = individual['h_slope'] + self.tau_m = individual['tau_m'] + self.tau_h = individual['tau_h'] + self.tau_m_K_Ca = individual['tau_m_K_Ca'] + self.Ca_th = individual['Ca_th'] + self.exp_K_Ca = individual['exp_K_Ca'] + self.t_ref = individual['t_ref'] + self.d_BAP = individual['d_BAP'] + self.w_BAP = individual['w_BAP'] + self.V_reset = individual['V_reset'] + self.g_w = individual['g_w'] + + # simulation step (ms). + self.dt = 0.1 + self.local_num_threads = 1 + + plotVars = 0 + silence = 3000 + stimulusStart = 0. + stimulusDuration = 2000. + Is_start = 0. + Is_end = 1000. + Id_start = 0. + Id_end = 1000. + delta_mu = 40 + step = int((Id_end-Id_start)//delta_mu) + sigma = 0 + + SimTime = step * (stimulusDuration + silence) + I = np.arange(Is_start,Is_end,delta_mu) + extent=[Is_start,Is_end,Id_end,Id_start] + + self.neuron_model = 'cm_default' + + self.soma_params = {'C_m': self.C_m_s, # [pF] Soma capacitance + 'g_L': self.g_L_s, # [nS] Soma leak conductance + 'e_L': self.e_L_s, # [mV] Soma reversal potential + 'gbar_Na_Adex': self.g_L_s, # [nS] Adex conductance + 'e_Na_Adex': self.e_Na_Adex, # [mV] Adex threshold + 'delta_T': self.delta_T # [mV] Adex slope factor + } + + self.distal_params = {'C_m': self.C_m_d, # [pF] Distal capacitance + 'g_L': self.g_L_d, # [nS] Distal leak conductance + 'g_C': self.g_C_d, # [nS] Soma-distal coupling conductance + 'e_L': self.e_L_d, # [mV] Distal reversal potential + 'gbar_Ca': self.gbar_Ca, # [nS] Ca maximal conductance + 'gbar_K_Ca': self.gbar_K_Ca, # [nS] K_Ca maximal conductance + 'e_K': self.e_K, # [mV] K reversal potential + 'tau_decay_Ca': self.tau_decay_Ca, # [ms] decay of Ca concentration + 'phi': self.phi, # [-] scale factor + 'm_half': self.m_half, # [mV] m half-value for Ca + 'h_half': self.h_half, # [mV] h half-value for Ca + 'm_slope': self.m_slope, # [-] m slope factor for Ca + 'h_slope': self.h_slope, # [-] h slope factor for Ca + 'tau_m': self.tau_m, # [ms] m tau decay for Ca + 'tau_h': self.tau_h, # [ms] h tau decay dor Ca + 'tau_m_K_Ca': self.tau_m_K_Ca, # [ms] m tau decay for K_Ca + 'Ca_0': self.default_param["Ca_0"],# [mM] Baseline intracellular Ca conc + 'Ca_th': self.Ca_th, # [mM] Threshold Ca conc for Ca channel opening + 'exp_K_Ca': self.exp_K_Ca # [-] Exponential factor in K_Ca current with Hay dyn + } + + + #============================================================================== + # Soma and dist + #============================================================================== + + spikesInStep_somadist = [] + events_somadist = [] + + print("Current Time: ", datetime.datetime.now()) + + n_neu = int((Id_end-Id_start)//delta_mu) ## same value as "step" + print('Total number of neurons: ', n_neu) + + Tstart = time.time() + + nest.ResetKernel() + nest.set_verbosity('M_ERROR') + nest.SetKernelStatus({'resolution': self.dt}) + nest.SetKernelStatus({'local_num_threads':self.local_num_threads}) + + #self.create_cm_neuron(n_neu) + self.cm = nest.Create(xxx) + sr = nest.Create('spike_recorder',n_neu) + nest.Connect(self.cm, sr, 'one_to_one') + + mu0 = Is_start + sigma0 = sigma + delta_mu0 = delta_mu + step0 = step + sigma1 = sigma + delta_mu1 = 0. + step1 = step + neuID = 0 + + for index in range(int(Id_start),int(Id_end),delta_mu): + + print("Iteration n. ", index//delta_mu) + + mu1 = index + + ############################################################################### + # create and connect current generators to compartments + stimula(self.cm[neuID],stimulusStart,stimulusDuration,mu0,sigma0,delta_mu0,step0,mu1,sigma1,delta_mu1,step1,silence) + neuID = neuID + 1 + + print("Simulation started") + nest.Simulate(SimTime) + + neuID = 0 + for index in range(int(Id_start),int(Id_end),delta_mu): + + print("Results from iteration n. ", index//delta_mu) + + events_somadist_values = nest.GetStatus(sr[neuID])[0]['events'] + events_somadist.append(events_somadist_values['times']) + spikesInStep_somadist.append(spikeInSteps(events_somadist_values,stimulusStart,stimulusDuration,step0,silence)) + neuID = neuID + 1 + + #============================================================================== + # End simulation + #============================================================================== + Tend = time.time() + print("Simulation completed") + print('Execution time: ', Tend-Tstart) + + #============================================================================== + # Plot + #============================================================================== + print("Elaborating plot...") + print("firing_grid size is ", len(spikesInStep_somadist)) + fr=spikesInStep_somadist + fig=plot_FR_paper(fr, extent, yInvert=1, maskWhite=1, contour=1, centralAxis=1, Dlevels=10) + + Id_current = np.arange(Id_start,Id_end,delta_mu) + spikesInStep_dist = [spikesInStep_somadist[i][np.argmax(Id_current==0)] for i in range (n_neu)] + label_somadist = 'soma + dist @' + str(Id_end) + + return(spikesInStep_somadist,events_somadist) + +myRun = MCADEXBACOptimizee() +firing_grid,events_grid = myRun.simulate() + +plt.show() diff --git a/tests/nest_compartmental_tests/MC_ISI.py b/tests/nest_compartmental_tests/MC_ISI.py new file mode 100644 index 000000000..31407342b --- /dev/null +++ b/tests/nest_compartmental_tests/MC_ISI.py @@ -0,0 +1,234 @@ +# +# MC_ISI +# +# First version: 25/09/2024 +# Author: Elena Pastorelli, INFN, Rome (IT) +# +# Description: Comparison between ISI of pure somatic DC input in Ca-AdEx vs AdEx +# The AdEx is built with a specific set of parameters agains which the multi-comp had been fitted +# + + + +import nest +import numpy as np +import matplotlib.pyplot as plt +import random +import statistics as stat +import sys +import yaml + + + +""" +0 - Poisson +1 - single exc spike on ALPHA +2 - single exc spikes of increasing weight on ALPHA +3 - single inh spike on ALPHA +4 - single inh spike on GABA +5 - single spikes of increasing weight on ALPHA +6 - single spikes of increasing weight on GABA +""" + +action = 1 + + +stimulusStart = 10000 +stimulusDuration = 2000 +stimulusStop = stimulusStart + stimulusDuration +SimTime = stimulusStop + 3000 +countWindow = stimulusDuration + + +I_s = 300 + +aeif_dict = { + "a": 0., + "b": 40., + "t_ref": 0., + "Delta_T": 2., + "C_m": 200., + "g_L": 10., + "E_L": -63., + "V_reset": -65., + "tau_w": 500., + "V_th": -50., + "V_peak": -40., + } + +cm_dict = { + "C_mD": 10.0, + "C_m": 362.5648533496359, + "Ca_0": 0.0001, + "Ca_th": 0.00043, + "V_reset": -62.12885359171539, + #"d_BAP": 2.4771369535227308, + "Delta_T": 2.0, + "E_K": -90.0, + "E_LD": -80.0, + "E_L": -58.656837907086036, + "V_th": -50.0, + "exp_K_Ca": 4.8, + "g_C": 17.55192973190035, + #"g_C": 0.0, + "g_LD": 2.5088334130360064, + "g_L": 6.666182946322264, + "g_Ca": 22.9883727668534, + "g_K": 18.361017565618574, + "h_half_Ca": -21.0, + "h_slope_Ca": -0.5, + "m_half_Ca": -9.0, + "m_slope_Ca": 0.5, + "phi_ca": 2.200252914099994e-08, + "refr_T": 0.0, + "tau_Ca": 129.45363748885939, + "tau_h_Ca": 80.0, + "tau_m_Ca": 15.0, + "tau_K": 1.0, + #"w_BAP": 32.39598141845997, + "tau_w": 500.0, + "a": 0, + "b": 40.0, + "V_peak": -40.0, + } + +w_BAP = 32.39598141845997 +d_BAP = 2.4771369535227308 + +nest.ResetKernel() + +nest.Install('nestmlmodule') + +aeif = nest.Create("aeif_cond_alpha", params=aeif_dict) +cm = nest.Create('aeif_cond_alpha_neuron', params=cm_dict) +#nest.Connect(cm, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': w_BAP, 'delay': d_BAP, 'receptor_type': 1}) + +############################# +# Test for Poisson stimulus # +############################# + +if action == 0: + + SimTime = 10 + stimulusStart = 0.0 + stimulusStop = SimTime + countWindow = stimulusStop-stimulusStart + + # Poisson parameters + spreading_factor = 4 + basic_rate = 600.0 + basic_weight = 0.6 + weight = basic_weight * spreading_factor + rate = basic_rate / spreading_factor + + cf = 1. + + # Create and connct Poisson generator + pg0 = nest.Create('poisson_generator', 20, params={'rate': rate, 'start': stimulusStart, 'stop': stimulusStop}) + nest.Connect(pg0, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': weight*cf, 'delay': 1., 'receptor_type': 0}) + nest.Connect(pg0, aeif, syn_spec={'synapse_model': 'static_synapse', 'weight': weight, 'delay': 1.}) + +############################# +# Test for spike generator # +############################# + +elif action == 1: + + SimTime = 100 + stimulusStart = 0.0 + stimulusStop = SimTime + countWindow = stimulusStop-stimulusStart + + weight = 300 + cf = 1. + + # Create and connct spike generator + sg0 = nest.Create('spike_generator', 1, {'spike_times': [50]}) + nest.Connect(sg0, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': weight, 'delay': 1., 'receptor_type': 0}) + nest.Connect(sg0, aeif, syn_spec={'synapse_model': 'static_synapse', 'weight': weight, 'delay': 1.}) + +############################# +# Test for DC input # +############################# + +elif action == 2: + + # Create and connect current generators (to soma in cm) + dcgs = nest.Create('dc_generator', {'start': stimulusStart, 'stop': stimulusStop, 'amplitude': I_s}) + nest.Connect(dcgs, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': 1., 'delay': .1, 'receptor_type': 0}) + nest.Connect(dcgs, aeif, syn_spec={'synapse_model': 'static_synapse', 'weight': 1., 'delay': .1}) + + +# create multimeters to record compartment voltages and various state variables +rec_list = ['v_comp0', 'v_comp1', 'w_5','m_Ca_1','h_Ca_1','i_AMPA_9'] +mm_cm = nest.Create('multimeter', 1, {'record_from': ['V_m','V_mD','w'], 'interval': .1}) +mm_aeif = nest.Create('multimeter', 1, {'record_from': ['V_m','w'], 'interval': .1}) +nest.Connect(mm_cm, cm) +nest.Connect(mm_aeif, aeif) + +# create and connect a spike recorder +sr_cm = nest.Create('spike_recorder') +sr_aeif = nest.Create('spike_recorder') +nest.Connect(cm, sr_cm) +nest.Connect(aeif, sr_aeif) + +nest.Simulate(SimTime) + +print('I_s current = ', I_s) + +res_cm = nest.GetStatus(mm_cm, 'events')[0] +events_cm = nest.GetStatus(sr_cm)[0]['events'] +res_aeif = nest.GetStatus(mm_aeif, 'events')[0] +events_aeif = nest.GetStatus(sr_aeif)[0]['events'] + +totalSpikes_cm = sum(map(lambda x: x>stimulusStart and xstimulusStart and x. - -import os -import pytest - -import nest - -from pynestml.codegeneration.nest_tools import NESTTools -from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target - -# set to `True` to plot simulation traces -TEST_PLOTS = True -try: - import matplotlib - import matplotlib.pyplot as plt -except BaseException as e: - # always set TEST_PLOTS to False if matplotlib can not be imported - TEST_PLOTS = False - - -class TestCompartmentalConcmech: - @pytest.fixture(scope="module", autouse=True) - def setup(self): - nest.ResetKernel() - nest.SetKernelStatus(dict(resolution=.1)) - - generate_nest_compartmental_target(input_path=os.path.join(os.path.realpath(os.path.dirname(__file__)), "resources", "concmech.nestml"), - suffix="_nestml", - logging_level="DEBUG", - module_name="concmech_mockup_module") - nest.Install("concmech_mockup_module") - - def test_concmech(self): - cm = nest.Create('multichannel_test_model_nestml') - - soma_params = {'C_m': 10.0, 'g_c': 0.0, 'g_L': 1.5, 'e_L': -70.0, 'gbar_Ca_HVA': 1.0, 'gbar_Ca_LVAst': 0.0} - dend_params = {'C_m': 0.1, 'g_c': 0.1, 'g_L': 0.1, 'e_L': -70.0} - - # nest.AddCompartment(cm, 0, -1, soma_params) - cm.compartments = [ - {"parent_idx": -1, "params": soma_params} - # {"parent_idx": 0, "params": dend_params}, - # {"parent_idx": 0, "params": dend_params} - ] - # nest.AddCompartment(cm, 1, 0, dend_params) - # nest.AddCompartment(cm, 2, 0, dend_params) - - # cm.V_th = -50. - - cm.receptors = [ - {"comp_idx": 0, "receptor_type": "AMPA"} - # {"comp_idx": 1, "receptor_type": "AMPA"}, - # {"comp_idx": 2, "receptor_type": "AMPA"} - ] - - # syn_idx_GABA = 0 - # syn_idx_AMPA = 1 - # syn_idx_NMDA = 2 - - # sg1 = nest.Create('spike_generator', 1, {'spike_times': [50., 100., 125., 137., 143., 146., 600.]}) - sg1 = nest.Create('spike_generator', 1, {'spike_times': [100., 1000., 1100., 1200., 1300., 1400., 1500., 1600., 1700., 1800., 1900., 2000., 5000.]}) - # sg1 = nest.Create('spike_generator', 1, {'spike_times': [(item*6000) for item in range(1, 20)]}) - # sg2 = nest.Create('spike_generator', 1, {'spike_times': [115., 155., 160., 162., 170., 254., 260., 272., 278.]}) - # sg3 = nest.Create('spike_generator', 1, {'spike_times': [250., 255., 260., 262., 270.]}) - - nest.Connect(sg1, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': 4.0, 'delay': 0.5, 'receptor_type': 0}) - # nest.Connect(sg2, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': .2, 'delay': 0.5, 'receptor_type': 1}) - # nest.Connect(sg3, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': .3, 'delay': 0.5, 'receptor_type': 2}) - - mm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'c_Ca0', 'i_tot_Ca_LVAst0', 'i_tot_Ca_HVA0'], 'interval': .1}) - - nest.Connect(mm, cm) - - nest.Simulate(6000.) - - res = nest.GetStatus(mm, 'events')[0] - - fig, axs = plt.subplots(5) - - axs[0].plot(res['times'], res['v_comp0'], c='b', label='V_m_0') - axs[1].plot(res['times'], res['i_tot_Ca_LVAst0'], c='r', label='i_Ca_LVAst_0') - axs[1].plot(res['times'], res['i_tot_Ca_HVA0'], c='g', label='i_Ca_HVA_0') - axs[2].plot(res['times'], res['c_Ca0'], c='r', label='c_Ca_0') - - axs[0].set_title('V_m_0') - axs[1].set_title('i_Ca_HVA/LVA_0') - axs[2].set_title('c_Ca_0') - # plt.plot(res['times'], res['v_comp2'], c='g', label='V_m_2') - - axs[0].legend() - axs[1].legend() - axs[2].legend() - - plt.savefig("concmech_test.png") diff --git a/tests/nest_compartmental_tests/resources/adex_test.nestml b/tests/nest_compartmental_tests/resources/adex_test.nestml new file mode 100644 index 000000000..8a631ee63 --- /dev/null +++ b/tests/nest_compartmental_tests/resources/adex_test.nestml @@ -0,0 +1,184 @@ +""" +aeif_cond_alpha - Conductance based exponential integrate-and-fire neuron model +############################################################################### + +Description ++++++++++++ + +aeif_cond_alpha is the adaptive exponential integrate and fire neuron according to Brette and Gerstner (2005), with post-synaptic conductances in the form of a bi-exponential ("alpha") function. + +The membrane potential is given by the following differential equation: + +.. math:: + + C_m \frac{dv_comp}{dt} = + -g_L(v_comp-E_L)+g_L\Delta_T\exp\left(\frac{v_comp-V_{th}}{\Delta_T}\right) - + g_e(t)(v_comp-E_e) \\ + -g_i(t)(v_comp-E_i)-w + I_e + +and + +.. math:: + + \tau_w \frac{dw}{dt} = a(v_comp-E_L) - w + +Note that the membrane potential can diverge to positive infinity due to the exponential term. To avoid numerical instabilities, instead of :math:`v_comp`, the value :math:`\min(v_comp,V_{peak})` is used in the dynamical equations. + + +References +++++++++++ + +.. [1] Brette R and Gerstner W (2005). Adaptive exponential + integrate-and-fire model as an effective description of neuronal + activity. Journal of Neurophysiology. 943637-3642 + DOI: https://doi.org/10.1152/jn.00686.2005 + + +See also +++++++++ + +iaf_cond_alpha, aeif_cond_exp +""" +model aeif_cond_alpha_neuron: + + state: + v_comp mV = 0 mV # Membrane potential + w pA = 0 pA # Spike-adaptation current + refr_t ms = 0 ms # Refractory period timer + c_Ca umol = 0.1 umol + m_Ca real = 1. + h_Ca real = 0.0001 + + m_K real = 0. + + I_bp real = 0 # back propagation + + is_refr real = 0. + + w_add real = 0. + + equations: + kernel g_inh = (e / tau_syn_inh) * t * exp(-t / tau_syn_inh) + kernel g_exc = (e / tau_syn_exc) * t * exp(-t / tau_syn_exc) + + # Add inlines to simplify the equation definition of v_comp + # Soma + + inline I_spike pA = g_L * Delta_T * exp((min(v_comp, V_peak) - v_comp) / Delta_T) @mechanism::channel + inline I_syn_exc pA = convolve(g_exc, exc_spikes) * nS * (min(v_comp, V_peak) - E_exc) @mechanism::receptor + inline I_syn_inh pA = convolve(g_inh, inh_spikes) * nS * (min(v_comp, V_peak) - E_inh) @mechanism::receptor + inline external_stim pA = I_stim @mechanism::continuous_input + inline refr real = G_refr * is_refr * (V_reset - v_comp) @mechanism::channel + inline adapt pA = G_adapt * w @mechanism::channel + + w' = (SthA * (min(v_comp, V_peak) - E_L) - w) / tau_w + + inline I_Ca pA = g_Ca * m_Ca * h_Ca * (E_Ca - v_comp) @mechanism::channel + inline I_K pA = g_K * m_K * (E_K - v_comp) @mechanism::channel + + m_Ca' = (m_inf_Ca(v_comp, m_slope_Ca, m_half_Ca) - m_Ca) / tau_m_Ca + h_Ca' = (h_inf_Ca(v_comp, h_slope_Ca, h_half_Ca) - h_Ca) / tau_h_Ca + c_Ca' = phi_ca * I_Ca + (c_Ca - Ca_0) / tau_Ca @mechanism::concentration + m_K' = (m_inf_K(c_Ca, Ca_th, exp_K_Ca) - m_K) / tau_K + + parameters: + # membrane parameters + + refr_T ms = 2 ms # Duration of refractory period + V_reset mV = -60.0 mV # Reset Potential + g_L nS = 0 nS # Leak Conductance + g_Ca nS = 0 nS + g_K nS = 0 nS + E_L mV = -70.6 mV # Leak reversal Potential (aka resting potential) + E_Ca mV = 50 mV + E_K mV = -90 mV + + G_adapt real = 0. + + # spike adaptation parameters + + SthA nS = 4 nS # Subthreshold adaptation + b pA = 80.5 pA # Spike-triggered adaptation + Delta_T mV = 2.0 mV # Slope factor + tau_w ms = 144.0 ms # Adaptation time constant + V_peak mV = 0 mV # Spike detection threshold + + # synaptic parameters + + E_exc mV = 0 mV # Excitatory reversal Potential + tau_syn_exc ms = 0.2 ms # Synaptic Time Constant Excitatory Synapse + E_inh mV = -85.0 mV # Inhibitory reversal Potential + tau_syn_inh ms = 2.0 ms # Synaptic Time Constant for Inhibitory Synapse + + # Distal + # synaptic parameters + + phi_ca pA**-1 = 2.2e-08 pA**-1 + Ca_th umol = 0.43 umol + Ca_0 real = 0.1 + tau_Ca ms = 129 ms + + m_slope_Ca real = 0.5 + m_half_Ca mV = -9 mV + tau_m_Ca ms = 15 ms + + h_slope_Ca real = -0.5 + h_half_Ca mV = -21 mV + tau_h_Ca ms = 80 ms + + exp_K_Ca real = 4.8 + tau_K ms = 1.0 ms + + # Constant external input current + I_e pA = 0 pA + + G_refr real = 0. + internals: + # Impulse to add to DG_EXC on spike arrival to evoke unit-amplitude conductance excursion + PSConInit_E nS/ms = nS * e / tau_syn_exc + + # Impulse to add to DG_INH on spike arrival to evoke unit-amplitude conductance excursion + PSConInit_I nS/ms = nS * e / tau_syn_inh + + input: + self_spikes <- spike + + exc_spikes <- excitatory spike + inh_spikes <- inhibitory spike + I_stim pA <- continuous + + output: + spike + + update: + if w_add > 0: + w += b + if refr_t > resolution() / 2: + refr_t -= resolution() + else: + refr_t = 0 ms + is_refr = 0 + + I_bp = 0 + w_add = 0 + + onReceive(self_spikes): + is_refr = 1 + refr_t = refr_T + w_add = 1 + I_bp = 10 + + function m_inf_Ca(v mV, m_slope_Ca_f real, m_half_Ca_f real) real: + m_inf real = 0 + m_inf = 1 / (1 + exp(m_slope_Ca_f * (v - m_half_Ca_f))) + return m_inf + + function h_inf_Ca(v mV, h_slope_Ca_f real, h_half_Ca_f real) real: + h_inf real = 0 + h_inf = 1 / (1 + exp(h_slope_Ca_f * (v - h_half_Ca_f))) + return h_inf + + function m_inf_K(Ca real, Ca_th_f real, exp_K_Ca_f real) real: + m_inf real = 0 + m_inf = 1 / (1 + pow((Ca_th_f / max(Ca, 0.1)), exp_K_Ca_f)) + return m_inf diff --git a/tests/nest_compartmental_tests/resources/cm_default.nestml b/tests/nest_compartmental_tests/resources/cm_default.nestml index 83e4b5e44..932e60ce3 100644 --- a/tests/nest_compartmental_tests/resources/cm_default.nestml +++ b/tests/nest_compartmental_tests/resources/cm_default.nestml @@ -4,15 +4,33 @@ Example compartmental model for NESTML Description +++++++++++ Corresponds to standard compartmental model implemented in NEST. + + +Copyright statement ++++++++++++++++++++ + +This file is part of NEST. + +Copyright (C) 2004 The NEST Initiative + +NEST is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 2 of the License, or +(at your option) any later version. + +NEST is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with NEST. If not, see . """ + model cm_default: state: - # compartmental voltage variable, - # rhs value is irrelevant but the state must exist so that the nestml parser doesn't complain - v_comp real = 0 - ### ion channels ### # initial values state variables sodium channel m_Na real = 0.01696863 @@ -21,6 +39,10 @@ model cm_default: # initial values state variables potassium channel n_K real = 0.00014943 + # compartmental voltage variable, + # rhs value is irrelevant but the state must exist so that the nestml parser doesn't complain + v_comp real = 0 + parameters: ### ion channels ### diff --git a/tests/nest_compartmental_tests/resources/cm_iaf_psc_exp_dend_neuron.nestml b/tests/nest_compartmental_tests/resources/cm_iaf_psc_exp_dend_neuron.nestml new file mode 100644 index 000000000..6514ad8e0 --- /dev/null +++ b/tests/nest_compartmental_tests/resources/cm_iaf_psc_exp_dend_neuron.nestml @@ -0,0 +1,89 @@ +""" +iaf_psc_exp_dend - Leaky integrate-and-fire neuron model with exponential PSCs +######################################################################### + +Description ++++++++++++ + +iaf_psc_exp is an implementation of a leaky integrate-and-fire model +with exponential-kernel postsynaptic currents (PSCs) according to [1]_. +Thus, postsynaptic currents have an infinitely short rise time. + +The threshold crossing is followed by an absolute refractory period (t_ref) +during which the membrane potential is clamped to the resting potential +and spiking is prohibited. + +.. note:: + If tau_m is very close to tau_syn_ex or tau_syn_in, numerical problems + may arise due to singularities in the propagator matrics. If this is + the case, replace equal-valued parameters by a single parameter. + + For details, please see ``IAF_neurons_singularity.ipynb`` in + the NEST source code (``docs/model_details``). + + +References +++++++++++ + +.. [1] Tsodyks M, Uziel A, Markram H (2000). Synchrony generation in recurrent + networks with frequency-dependent synapses. The Journal of Neuroscience, + 20,RC50:1-5. URL: https://infoscience.epfl.ch/record/183402 + + +See also +++++++++ + +iaf_cond_exp +""" +model iaf_psc_exp_cm_dend: + + state: + v_comp real = 0 # Membrane potential + refr_t ms = 0 ms # Refractory period timer + + is_refr real = 0.0 + + + equations: + kernel I_kernel_inh = exp(-t/tau_syn_inh) + kernel I_kernel_exc = exp(-t/tau_syn_exc) + + inline leak real = (E_l - v_comp) * C_m / tau_m @mechanism::channel + inline syn_exc real = convolve(I_kernel_exc, exc_spikes) @mechanism::receptor + inline syn_inh real = convolve(I_kernel_inh, inh_spikes) @mechanism::receptor + inline refr real = G_refr * is_refr * (V_reset - v_comp) @mechanism::channel + + parameters: + C_m pF = 250 pF # Capacity of the membrane + tau_m ms = 10 ms # Membrane time constant + tau_syn_inh ms = 2 ms # Time constant of inhibitory synaptic current + tau_syn_exc ms = 2 ms # Time constant of excitatory synaptic current + refr_T ms = 5 ms # Duration of refractory period + E_l mV = -70 mV # Resting potential + V_reset mV = -70 mV # Reset potential of the membrane + V_th mV = -55 mV # Spike threshold potential + + # constant external input current + I_e pA = 0 pA + + G_refr real = 0. + + input: + exc_spikes <- excitatory spike + inh_spikes <- inhibitory spike + I_stim pA <- continuous + + output: + spike + + update: + if refr_t > resolution() / 2: + # neuron is absolute refractory, do not evolve V_m + refr_t -= resolution() + else: + refr_t = 0 ms + is_refr = 0 + + onReceive(self_spikes): + is_refr = 1 + refr_t = refr_T \ No newline at end of file diff --git a/tests/nest_compartmental_tests/resources/concmech.nestml b/tests/nest_compartmental_tests/resources/concmech.nestml index 703c33fbf..486403661 100644 --- a/tests/nest_compartmental_tests/resources/concmech.nestml +++ b/tests/nest_compartmental_tests/resources/concmech.nestml @@ -1,3 +1,24 @@ +""" +Copyright statement ++++++++++++++++++++ + +This file is part of NEST. + +Copyright (C) 2004 The NEST Initiative + +NEST is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 2 of the License, or +(at your option) any later version. + +NEST is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with NEST. If not, see . +""" model multichannel_test_model: parameters: @@ -31,7 +52,7 @@ model multichannel_test_model: inf_Ca real = 0.0001 state: - v_comp real = -7500.00000000 + v_comp real = -75.00000000 # state variables Ca_HVA h_Ca_HVA real = 0.69823671 @@ -69,17 +90,17 @@ model multichannel_test_model: h_Ca_LVAst' = ( h_inf_Ca_LVAst(v_comp) - h_Ca_LVAst ) / (tau_h_Ca_LVAst(v_comp)*1s) # equations NaTa_t - #inline NaTa_t real = gbar_NaTa_t * (h_NaTa_t*m_NaTa_t**3) * (e_NaTa_t - v_comp) @mechanism::channel - #m_NaTa_t' = ( m_inf_NaTa_t(v_comp) - m_NaTa_t ) / (tau_m_NaTa_t(v_comp)*1s) - #h_NaTa_t' = ( h_inf_NaTa_t(v_comp) - h_NaTa_t ) / (tau_h_NaTa_t(v_comp)*1s) + inline NaTa_t real = gbar_NaTa_t * (h_NaTa_t*m_NaTa_t**3) * (e_NaTa_t - v_comp) @mechanism::channel + m_NaTa_t' = ( m_inf_NaTa_t(v_comp) - m_NaTa_t ) / (tau_m_NaTa_t(v_comp)*1s) + h_NaTa_t' = ( h_inf_NaTa_t(v_comp) - h_NaTa_t ) / (tau_h_NaTa_t(v_comp)*1s) # equations SKv3_1 #inline SKv3_1 real = gbar_SKv3_1 * (z_SKv3_1) * (e_SKv3_1 - v_comp) @mechanism::channel #z_SKv3_1' = ( z_inf_SKv3_1(v_comp) - z_SKv3_1 ) / (tau_z_SKv3_1(v_comp)*1s) # equations SK_E2 - #inline SK_E2 real = gbar_SK_E2 * (z_SK_E2) * (e_SK_E2 - v_comp) @mechanism::channel - #z_SK_E2' = ( z_inf_SK_E2(c_Ca) - z_SK_E2) / 1.0s + inline SK_E2 real = gbar_SK_E2 * (z_SK_E2) * (e_SK_E2 - v_comp) @mechanism::channel + z_SK_E2' = ( z_inf_SK_E2(c_Ca) - z_SK_E2) / 1.0s # equations Ca concentration mechanism c_Ca' = (inf_Ca - c_Ca) / (tau_Ca*1s) + (gamma_Ca * (Ca_HVA + Ca_LVAst)) / 1s @mechanism::concentration diff --git a/tests/nest_compartmental_tests/resources/continuous_test.nestml b/tests/nest_compartmental_tests/resources/continuous_test.nestml new file mode 100644 index 000000000..b3aa4d5ec --- /dev/null +++ b/tests/nest_compartmental_tests/resources/continuous_test.nestml @@ -0,0 +1,44 @@ +""" +Copyright statement ++++++++++++++++++++ + +This file is part of NEST. + +Copyright (C) 2004 The NEST Initiative + +NEST is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 2 of the License, or +(at your option) any later version. + +NEST is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with NEST. If not, see . +""" + +model continuous_test_model: + state: + v_comp real = 0 + + parameters: + e_AMPA real = 0 mV + tau_r_AMPA real = 0.2 ms + tau_d_AMPA real = 3.0 ms + + equations: + inline con_in real = (I_stim*10) @mechanism::continuous_input + + kernel g_AMPA = g_norm_AMPA * ( -exp(-t / tau_r_AMPA) + exp(-t / tau_d_AMPA) ) + inline AMPA real = convolve(g_AMPA, spikes_AMPA) * (e_AMPA - v_comp) @mechanism::receptor + + internals: + tp_AMPA real = (tau_r_AMPA * tau_d_AMPA) / (tau_d_AMPA - tau_r_AMPA) * ln( tau_d_AMPA / tau_r_AMPA ) + g_norm_AMPA real = 1. / ( -exp( -tp_AMPA / tau_r_AMPA ) + exp( -tp_AMPA / tau_d_AMPA ) ) + + input: + I_stim real <- continuous + spikes_AMPA <- spike diff --git a/tests/nest_compartmental_tests/resources/invalid/CoCoCmVcompExists.nestml b/tests/nest_compartmental_tests/resources/invalid/CoCoCmVcompExists.nestml index 518e6b4a2..8557c3649 100644 --- a/tests/nest_compartmental_tests/resources/invalid/CoCoCmVcompExists.nestml +++ b/tests/nest_compartmental_tests/resources/invalid/CoCoCmVcompExists.nestml @@ -38,7 +38,8 @@ model cm_model_eight_invalid: state: # compartmental voltage variable, # rhs value is irrelevant but the state must exist so that the nestml parser doesn't complain - m_Na real = 0.0 + m_Na real = 0.0 + h_Na real = 0.0 #sodium function m_inf_Na(v_comp real) real: @@ -54,7 +55,8 @@ model cm_model_eight_invalid: return 0.3115264797507788/((-0.0091000000000000004*v_comp - 0.68261830000000012)/(1.0 - 3277527.8765015295*exp(0.20000000000000001*v_comp)) + (0.024*v_comp + 1.200312)/(1.0 - 4.5282043263959816e-5*exp(-0.20000000000000001*v_comp))) equations: - inline Na real = m_Na**3 * h_Na**1 + inline Na real = gbar_Na * m_Na**3 * h_Na * (e_Na - v_comp) @mechanism::channel parameters: - foo real = 1. + gbar_Na real = 0. + e_Na real = 50. diff --git a/tests/nest_compartmental_tests/resources/stdp_synapse.nestml b/tests/nest_compartmental_tests/resources/stdp_synapse.nestml new file mode 100644 index 000000000..91aecb726 --- /dev/null +++ b/tests/nest_compartmental_tests/resources/stdp_synapse.nestml @@ -0,0 +1,79 @@ +""" +stdp - Synapse model for spike-timing dependent plasticity +######################################################### + +Description ++++++++++++ + +stdp_synapse is a synapse with spike-timing dependent plasticity (as defined in [1]_). Here the weight dependence exponent can be set separately for potentiation and depression. Examples: + +=================== ==== ============================= +Multiplicative STDP [2]_ mu_plus = mu_minus = 1 +Additive STDP [3]_ mu_plus = mu_minus = 0 +Guetig STDP [1]_ mu_plus, mu_minus in [0, 1] +Van Rossum STDP [4]_ mu_plus = 0 mu_minus = 1 +=================== ==== ============================= + + +References +++++++++++ + +.. [1] Guetig et al. (2003) Learning Input Correlations through Nonlinear + Temporally Asymmetric Hebbian Plasticity. Journal of Neuroscience + +.. [2] Rubin, J., Lee, D. and Sompolinsky, H. (2001). Equilibrium + properties of temporally asymmetric Hebbian plasticity, PRL + 86,364-367 + +.. [3] Song, S., Miller, K. D. and Abbott, L. F. (2000). Competitive + Hebbian learning through spike-timing-dependent synaptic + plasticity,Nature Neuroscience 3:9,919--926 + +.. [4] van Rossum, M. C. W., Bi, G-Q and Turrigiano, G. G. (2000). + Stable Hebbian learning from spike timing-dependent + plasticity, Journal of Neuroscience, 20:23,8812--8821 +""" +model stdp_synapse: + state: + w real = 1. # Synaptic weight + pre_trace real = 0. + post_trace real = 0. + + parameters: + d ms = 1 ms # Synaptic transmission delay + lambda real = .01 + tau_tr_pre ms = 20 ms + tau_tr_post ms = 20 ms + alpha real = 1 + mu_plus real = 1 + mu_minus real = 1 + Wmax real = 100. + Wmin real = 0. + + equations: + pre_trace' = -pre_trace / tau_tr_pre + post_trace' = -post_trace*pre_trace / tau_tr_post + + input: + pre_spikes <- spike + post_spikes <- spike + + output: + spike + + onReceive(post_spikes): + post_trace += 1 + + # potentiate synapse + w_ real = Wmax * ( w / Wmax + (lambda * ( 1. - ( w / Wmax ) )**mu_plus * pre_trace )) + w = min(Wmax, w_) + + onReceive(pre_spikes): + pre_trace += 1 + + # depress synapse + w_ real = Wmax * ( w / Wmax - ( alpha * lambda * ( w / Wmax )**mu_minus * post_trace )) + w = max(Wmin, w_) + + # deliver spike to postsynaptic partner + emit_spike(w, d) diff --git a/tests/nest_compartmental_tests/resources/third_factor_stdp_synapse.nestml b/tests/nest_compartmental_tests/resources/third_factor_stdp_synapse.nestml new file mode 100644 index 000000000..2aff80a38 --- /dev/null +++ b/tests/nest_compartmental_tests/resources/third_factor_stdp_synapse.nestml @@ -0,0 +1,93 @@ +""" +third_factor_stdp_synapse - Synapse model for spike-timing dependent plasticity with postsynaptic third-factor modulation +######################################################################################################################### + +Description ++++++++++++ + +third_factor_stdp_synapse is a synapse with spike time dependent plasticity (as defined in [1]). Here the weight dependence exponent can be set separately for potentiation and depression. Examples:: + +Multiplicative STDP [2] mu_plus = mu_minus = 1 +Additive STDP [3] mu_plus = mu_minus = 0 +Guetig STDP [1] mu_plus, mu_minus in [0, 1] +Van Rossum STDP [4] mu_plus = 0 mu_minus = 1 + +The weight changes are modulated by a "third factor", in this case the postsynaptic dendritic current ``I_post_dend``. + +``I_post_dend`` "gates" the weight update, so that if the current is 0, the weight is constant, whereas for a current of 1 pA, the weight change is maximal. + +Do not use values of ``I_post_dend`` larger than 1 pA! + +References +++++++++++ + +[1] Guetig et al. (2003) Learning Input Correlations through Nonlinear + Temporally Asymmetric Hebbian Plasticity. Journal of Neuroscience + +[2] Rubin, J., Lee, D. and Sompolinsky, H. (2001). Equilibrium + properties of temporally asymmetric Hebbian plasticity, PRL + 86,364-367 + +[3] Song, S., Miller, K. D. and Abbott, L. F. (2000). Competitive + Hebbian learning through spike-timing-dependent synaptic + plasticity,Nature Neuroscience 3:9,919--926 + +[4] van Rossum, M. C. W., Bi, G-Q and Turrigiano, G. G. (2000). + Stable Hebbian learning from spike timing-dependent + plasticity, Journal of Neuroscience, 20:23,8812--8821 +""" +model third_factor_stdp_synapse: + state: + w real = 1. # Synaptic weight + I_post_dend pA = 0 pA + AMPA pA = 0 pA + Ca_HVA pA = 0 pA + Ca_LVAst pA = 0 pA + NaTa_t pA = 0 pA + SK_E2 pA = 0 pA + + parameters: + d ms = 1 ms # Synaptic transmission delay + lambda real = .01 + tau_tr_pre ms = 20 ms + tau_tr_post ms = 20 ms + alpha real = 1. + mu_plus real = 1. + mu_minus real = 1. + Wmax real = 100. + Wmin real = 0. + + equations: + kernel pre_trace_kernel = exp(-t / tau_tr_pre) + inline pre_trace real = convolve(pre_trace_kernel, pre_spikes) + + # all-to-all trace of postsynaptic neuron + kernel post_trace_kernel = exp(-t / tau_tr_post) + inline post_trace real = convolve(post_trace_kernel, post_spikes) + + input: + pre_spikes <- spike + post_spikes <- spike + + output: + spike + + onReceive(post_spikes): + # potentiate synapse + w_ real = Wmax * ( w / Wmax + (lambda * ( 1. - ( w / Wmax ) )**mu_plus * pre_trace )) + if I_post_dend <= 1 pA: + w_ = (I_post_dend / pA) * w_ + (1 - I_post_dend / pA) * w # "gating" of the weight update + w = min(Wmax, w_) + + onReceive(pre_spikes): + # depress synapse + w_ real = Wmax * ( w / Wmax - ( alpha * lambda * ( w / Wmax )**mu_minus * post_trace )) + if I_post_dend <= 1 pA: + w_ = (I_post_dend / pA) * w_ + (1 - I_post_dend / pA) * w # "gating" of the weight update + w = max(Wmin, w_) + + # deliver spike to postsynaptic partner + emit_spike(w, d) + + update: + I_post_dend = AMPA + Ca_HVA + Ca_LVAst + NaTa_t + SK_E2 diff --git a/tests/nest_compartmental_tests/run_cm_tests.py b/tests/nest_compartmental_tests/run_cm_tests.py new file mode 100644 index 000000000..e947d3a05 --- /dev/null +++ b/tests/nest_compartmental_tests/run_cm_tests.py @@ -0,0 +1,40 @@ +import subprocess +import os +import matplotlib.pyplot as plt + + +def run_tests(): + # Enable interactive mode for matplotlib + plt.ion() + + # Specify the directory containing the tests + test_directory = "./" + + # Check if the directory exists + if not os.path.exists(test_directory): + print(f"Error: The directory '{test_directory}' does not exist.") + return + + # Run pytest in the specified directory + try: + # Run pytest in the specified directory with live output + process = subprocess.Popen( + ["pytest", test_directory], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True + ) + + # Print output line by line as it happens + for line in iter(process.stdout.readline, ""): + print(line, end="") # Output each line immediately + + process.stdout.close() # Close stdout once done + process.wait() # Wait for the process to finish + + except Exception as e: + print(f"An error occurred while running the tests: {e}") + + +if __name__ == "__main__": + run_tests() diff --git a/tests/nest_compartmental_tests/test__adex.py b/tests/nest_compartmental_tests/test__adex.py new file mode 100644 index 000000000..f8f27c950 --- /dev/null +++ b/tests/nest_compartmental_tests/test__adex.py @@ -0,0 +1,266 @@ +# -*- coding: utf-8 -*- +# +# test__continuous_input.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os + +import pytest + +import nest + +from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target + +# set to `True` to plot simulation traces +TEST_PLOTS = True +try: + import matplotlib + import matplotlib.pyplot as plt +except BaseException as e: + # always set TEST_PLOTS to False if matplotlib can not be imported + TEST_PLOTS = False + + +class TestContinuousInput: + @pytest.fixture(scope="module", autouse=True) + def setup(self): + tests_path = os.path.realpath(os.path.dirname(__file__)) + input_path = os.path.join( + tests_path, + "resources", + "adex_test.nestml" + ) + target_path = os.path.join( + tests_path, + "target/" + ) + + if not os.path.exists(target_path): + os.makedirs(target_path) + + print( + f"Compiled nestml model 'cm_main_cm_default_nestml' not found, installing in:" + f" {target_path}" + ) + + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=.1)) + + if True: + generate_nest_compartmental_target( + input_path=input_path, + target_path=target_path, + module_name="aeif_cond_alpha_neuron_module", + suffix="_nestml", + logging_level="DEBUG" + ) + + nest.Install("aeif_cond_alpha_neuron_module.so") + + def test_continuous_input(self): + """We test the continuous input mechanism by just comparing the input current at a certain critical point in + time to a previously achieved value at this point""" + I_s = 300 + + aeif_dict = { + "a": 0., + "b": 40., + "t_ref": 0., + "Delta_T": 2., + "C_m": 200., + "g_L": 10., + "E_L": -63., + "V_reset": -65., + "tau_w": 500., + "V_th": -50., + "V_peak": -40., + } + + aeif = nest.Create("aeif_cond_alpha", params=aeif_dict) + + cm = nest.Create('aeif_cond_alpha_neuron_nestml') + + soma_params = { + "C_m": 362.5648533496359, + "Ca_0": 0.0001, + "Ca_th": 0.00043, + "V_reset": -62.12885359171539, + "Delta_T": 2.0, + "E_K": -90.0, + "E_L": -58.656837907086036, + #"V_th": -50.0, + "exp_K_Ca": 4.8, + "g_C": 17.55192973190035, + "g_L": 6.666182946322264, + "g_Ca": 22.9883727668534, + "g_K": 18.361017565618574, + "h_half_Ca": -21.0, + "h_slope_Ca": -0.5, + "m_half_Ca": -9.0, + "m_slope_Ca": 0.5, + "phi_ca": 2.200252914099994e-08, + #"refr_T": 0.0, + "tau_Ca": 129.45363748885939, + "tau_h_Ca": 80.0, + "tau_m_Ca": 15.0, + "tau_K": 1.0, + "tau_w": 500.0, + "SthA": 0, + "b": 40.0, + "V_peak": -40.0, + "G_refr": 1000. + } + + dendritic_params ={ + "C_m": 10.0, + "E_L": -80.0, + "g_L": 2.5088334130360064, + "g_Ca": 22.9883727668534, + "g_K": 18.361017565618574, + } + + cm.compartments = [ + {"parent_idx": -1, "params": soma_params}, + {"parent_idx": 0, "params": dendritic_params} + ] + + cm.receptors = [ + {"comp_idx": 0, "receptor_type": "I_syn_exc"} + ] + + SimTime = 10 + stimulusStart = 0.0 + stimulusStop = SimTime + countWindow = stimulusStop - stimulusStart + + # Poisson parameters + spreading_factor = 4 + basic_rate = 600.0 + basic_weight = 0.6 + weight = basic_weight * spreading_factor + rate = basic_rate / spreading_factor + + cf = 1. + + # Create and connct Poisson generator + pg0 = nest.Create('poisson_generator', 20, params={'rate': rate, 'start': stimulusStart, 'stop': stimulusStop}) + nest.Connect(pg0, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': weight * cf, 'delay': 1., + 'receptor_type': 0}) + nest.Connect(pg0, aeif, syn_spec={'synapse_model': 'static_synapse', 'weight': weight, 'delay': 1.}) + + # create multimeters to record compartment voltages and various state variables + rec_list = [ + 'v_comp0', 'w0', 'i_tot_I_spike0', 'i_tot_I_syn_exc0', 'i_tot_refr0', 'i_tot_adapt0', 'i_tot_I_Ca0', 'i_tot_I_K0', 'c_Ca0', + ] + mm_cm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'v_comp1', 'w0', 'i_tot_I_spike0', 'i_tot_I_syn_exc0', 'i_tot_refr0', 'i_tot_adapt0', 'i_tot_I_Ca0', 'i_tot_I_K0', 'c_Ca0'], 'interval': .1}) + mm_aeif = nest.Create('multimeter', 1, {'record_from': ['V_m', 'w'], 'interval': .1}) + nest.Connect(mm_cm, cm) + nest.Connect(mm_aeif, aeif) + + # create and connect a spike recorder + sr_cm = nest.Create('spike_recorder') + sr_aeif = nest.Create('spike_recorder') + nest.Connect(cm, sr_cm) + nest.Connect(aeif, sr_aeif) + + nest.Simulate(SimTime) + + print('I_s current = ', I_s) + + res_cm = nest.GetStatus(mm_cm, 'events')[0] + events_cm = nest.GetStatus(sr_cm)[0]['events'] + res_aeif = nest.GetStatus(mm_aeif, 'events')[0] + events_aeif = nest.GetStatus(sr_aeif)[0]['events'] + + totalSpikes_cm = sum(map(lambda x: x > stimulusStart and x < stimulusStop, events_cm['times'])) + totalSpikes_aeif = sum(map(lambda x: x > stimulusStart and x < stimulusStop, events_aeif['times'])) + print("Total spikes multiComp = ", totalSpikes_cm) + print("Total spikes adex = ", totalSpikes_aeif) + print("FR multiComp = ", totalSpikes_cm * 1000 / countWindow) + print("FR adex = ", totalSpikes_aeif * 1000 / countWindow) + + print("Spike times multiComp:\n") + print(events_cm['times']) + print("Spike times adex:\n") + print(events_aeif['times']) + + stdtest = True + + if stdtest: + plt.figure('ISI @ Is = ' + str(I_s)) + ############################################################################### + plt.subplot(411) + plt.plot(res_aeif['times'], res_aeif['V_m'], c='r', label='v_m adex') + plt.plot(res_cm['times'], res_cm['v_comp0'], c='b', label='v_m soma cm') + plt.plot(res_cm['times'], res_cm['v_comp1'], c='g', label='v_m dist cm') + plt.legend() + plt.xlim(0, SimTime) + plt.ylabel('Vm [mV]') + plt.title('MultiComp (blue) and adex (red) voltage') + + plt.subplot(412) + # plt.plot(res_cm['times'], res_cm['m_Ca_1'], c='b', ls='--', lw=2., label='m') + # plt.plot(res_cm['times'], res_cm['h_Ca_1'], c='r', ls='--', lw=2., label='h') + # plt.plot(res_cm['times'], res_cm['m_Ca_1']*res_cm['h_Ca_1'], c='k', ls='--', lw=2., label='g') + plt.legend() + plt.xlim(0, SimTime) + plt.ylabel('Ca') + plt.title('Distal Ca activation') + + plt.subplot(413) + plt.plot(res_cm['times'], res_cm['w0'], c='b', ls='--', lw=2., label='W cm') + plt.plot(res_aeif['times'], res_aeif['w'], c='r', ls='--', lw=2., label='W adex') + plt.legend() + plt.xlim(0, SimTime) + plt.ylabel('W') + plt.title('Adaptation') + + plt.subplot(414) + events_cm = nest.GetStatus(sr_cm)[0]['events'] + plt.eventplot(events_cm['times'], linelengths=0.2, color='b') + events_aeif = nest.GetStatus(sr_aeif)[0]['events'] + plt.eventplot(events_aeif['times'], linelengths=0.2, color='r') + plt.xlim(0, SimTime) + plt.ylabel('Spikes') + plt.title('Raster - cm (blue) VS adex (red)') + plt.xlabel('Time [ms]') + + #plt.show() + #else: + fig, axs = plt.subplots(7) + + axs[0].plot(res_cm['times'], res_cm['i_tot_I_spike0'], c='b', label='I_spike0') + axs[1].plot(res_cm['times'], res_cm['i_tot_I_syn_exc0'], c='b', label='I_syn_exc0') + #plt.plot(res_cm['times'], res_cm['i_tot_I_syn_inh0'], c='b', label='3') + #plt.plot(res_cm['times'], res_cm['i_tot_external_stim0'], c='b', label='4') + axs[2].plot(res_cm['times'], res_cm['i_tot_refr0'], c='b', label='refr0') + axs[3].plot(res_cm['times'], res_cm['i_tot_adapt0'], c='b', label='adapt0') + axs[4].plot(res_cm['times'], res_cm['i_tot_I_Ca0'], c='b', label='I_Ca0') + axs[5].plot(res_cm['times'], res_cm['i_tot_I_K0'], c='b', label='I_K0') + axs[6].plot(res_cm['times'], res_cm['c_Ca0'], c='b', label='c_Ca0') + + axs[0].legend() + axs[1].legend() + axs[2].legend() + axs[3].legend() + axs[4].legend() + axs[5].legend() + axs[6].legend() + + plt.show() diff --git a/tests/nest_compartmental_tests/test__cm_iaf_psc_exp_dend_neuron.py b/tests/nest_compartmental_tests/test__cm_iaf_psc_exp_dend_neuron.py new file mode 100644 index 000000000..aa30378ed --- /dev/null +++ b/tests/nest_compartmental_tests/test__cm_iaf_psc_exp_dend_neuron.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# +# test__concmech_model.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os + +import pytest + +import nest + +from pynestml.codegeneration.nest_tools import NESTTools +from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target + +# set to `True` to plot simulation traces +TEST_PLOTS = True +try: + import matplotlib + import matplotlib.pyplot as plt +except BaseException as e: + # always set TEST_PLOTS to False if matplotlib can not be imported + TEST_PLOTS = False + + +class TestCompartmentalIAF: + @pytest.fixture(scope="module", autouse=True) + def setup(self): + tests_path = os.path.realpath(os.path.dirname(__file__)) + input_path = os.path.join( + tests_path, + "resources/cm_iaf_psc_exp_dend_neuron.nestml" + ) + target_path = os.path.join( + tests_path, + "target/" + ) + + if not os.path.exists(target_path): + os.makedirs(target_path) + + print( + f"Compiled nestml model 'cm_main_cm_default_nestml' not found, installing in:" + f" {target_path}" + ) + + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=.1)) + + if True: + generate_nest_compartmental_target( + input_path=input_path, + target_path=target_path, + module_name="iaf_psc_exp_dend_neuron_compartmental_module", + suffix="_nestml", + logging_level="DEBUG" + ) + + nest.Install("iaf_psc_exp_dend_neuron_compartmental_module.so") + + def test_iaf(self): + """We test the concentration mechanism by comparing the concentration value at a certain critical point in + time to a previously achieved value at this point""" + cm = nest.Create('iaf_psc_exp_cm_dend_nestml') + + params = {"G_refr": 1000.} + + cm.compartments = [ + {"parent_idx": -1, "params": params} + ] + + cm.receptors = [ + {"comp_idx": 0, "receptor_type": "syn_exc"} + ] + + sg1 = nest.Create('spike_generator', 1, {'spike_times': [1., 2., 3., 4.]}) + + nest.Connect(sg1, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': 1000.0, 'delay': 0.5, 'receptor_type': 0}) + + mm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'i_tot_leak0', 'i_tot_refr0'], 'interval': .1}) + + nest.Connect(mm, cm) + + nest.Simulate(10.) + + res = nest.GetStatus(mm, 'events')[0] + + step_time_delta = res['times'][1] - res['times'][0] + data_array_index = int(200 / step_time_delta) + + fig, axs = plt.subplots(3) + + axs[0].plot(res['times'], res['v_comp0'], c='r', label='V_m_0') + axs[1].plot(res['times'], res['i_tot_leak0'], c='y', label='leak0') + axs[2].plot(res['times'], res['i_tot_refr0'], c='b', label='refr0') + + axs[0].set_title('V_m_0') + axs[1].set_title('leak0') + axs[2].set_title('refr0') + + axs[0].legend() + axs[1].legend() + axs[2].legend() + + plt.show() + plt.savefig("cm_iaf_test.png") + + #assert res['c_Ca0'][data_array_index] == expected_conc, ("the concentration (left) is not as expected (right). (" + str(res['c_Ca0'][data_array_index]) + "!=" + str(expected_conc) + ")") diff --git a/tests/nest_compartmental_tests/cocos_test.py b/tests/nest_compartmental_tests/test__cocos.py similarity index 66% rename from tests/nest_compartmental_tests/cocos_test.py rename to tests/nest_compartmental_tests/test__cocos.py index b3355d8e3..dc4daa28c 100644 --- a/tests/nest_compartmental_tests/cocos_test.py +++ b/tests/nest_compartmental_tests/test__cocos.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# cocos_test.py +# test__cocos.py # # This file is part of NEST. # @@ -22,7 +22,8 @@ from __future__ import print_function import os -import unittest +import pytest + from pynestml.frontend.frontend_configuration import FrontendConfiguration from pynestml.utils.ast_source_location import ASTSourceLocation @@ -35,23 +36,25 @@ from pynestml.utils.model_parser import ModelParser -class CoCosTest(unittest.TestCase): +@pytest.fixture +def setUp(): + Logger.init_logger(LoggingLevel.INFO) + SymbolTable.initialize_symbol_table( + ASTSourceLocation( + start_line=0, + start_column=0, + end_line=0, + end_column=0)) + PredefinedUnits.register_units() + PredefinedTypes.register_types() + PredefinedVariables.register_variables() + PredefinedFunctions.register_functions() + FrontendConfiguration.target_platform = "NEST_COMPARTMENTAL" + - def setUp(self): - Logger.init_logger(LoggingLevel.INFO) - SymbolTable.initialize_symbol_table( - ASTSourceLocation( - start_line=0, - start_column=0, - end_line=0, - end_column=0)) - PredefinedUnits.register_units() - PredefinedTypes.register_types() - PredefinedVariables.register_variables() - PredefinedFunctions.register_functions() - FrontendConfiguration.target_platform = "NEST_COMPARTMENTAL" +class TestCoCos: - def test_invalid_cm_variables_declared(self): + def test_invalid_cm_variables_declared(self, setUp): model = ModelParser.parse_file( os.path.join( os.path.realpath( @@ -59,10 +62,10 @@ def test_invalid_cm_variables_declared(self): os.path.dirname(__file__), 'resources', 'invalid')), 'CoCoCmVariablesDeclared.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)), 5) + assert len(Logger.get_all_messages_of_level_and_or_node( + model.get_model_list()[0], LoggingLevel.ERROR)) == 5 - def test_valid_cm_variables_declared(self): + def test_valid_cm_variables_declared(self, setUp): Logger.set_logging_level(LoggingLevel.INFO) model = ModelParser.parse_file( os.path.join( @@ -71,12 +74,12 @@ def test_valid_cm_variables_declared(self): os.path.dirname(__file__), 'resources', 'valid')), 'CoCoCmVariablesDeclared.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)), 0) + assert len(Logger.get_all_messages_of_level_and_or_node( + model.get_model_list()[0], LoggingLevel.ERROR)) == 0 # it is currently not enforced for the non-cm parameter block, but cm # needs that - def test_invalid_cm_variable_has_rhs(self): + def test_invalid_cm_variable_has_rhs(self, setUp): model = ModelParser.parse_file( os.path.join( os.path.realpath( @@ -84,10 +87,10 @@ def test_invalid_cm_variable_has_rhs(self): os.path.dirname(__file__), 'resources', 'invalid')), 'CoCoCmVariableHasRhs.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)), 2) + assert len(Logger.get_all_messages_of_level_and_or_node( + model.get_model_list()[0], LoggingLevel.ERROR)) == 2 - def test_valid_cm_variable_has_rhs(self): + def test_valid_cm_variable_has_rhs(self, setUp): Logger.set_logging_level(LoggingLevel.INFO) model = ModelParser.parse_file( os.path.join( @@ -96,12 +99,12 @@ def test_valid_cm_variable_has_rhs(self): os.path.dirname(__file__), 'resources', 'valid')), 'CoCoCmVariableHasRhs.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)), 0) + assert len(Logger.get_all_messages_of_level_and_or_node( + model.get_model_list()[0], LoggingLevel.ERROR)) == 0 # it is currently not enforced for the non-cm parameter block, but cm # needs that - def test_invalid_cm_v_comp_exists(self): + def test_invalid_cm_v_comp_exists(self, setUp): model = ModelParser.parse_file( os.path.join( os.path.realpath( @@ -109,10 +112,10 @@ def test_invalid_cm_v_comp_exists(self): os.path.dirname(__file__), 'resources', 'invalid')), 'CoCoCmVcompExists.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)), 4) + assert len(Logger.get_all_messages_of_level_and_or_node( + model.get_model_list()[0], LoggingLevel.ERROR)) == 4 - def test_valid_cm_v_comp_exists(self): + def test_valid_cm_v_comp_exists(self, setUp): Logger.set_logging_level(LoggingLevel.INFO) model = ModelParser.parse_file( os.path.join( @@ -121,5 +124,5 @@ def test_valid_cm_v_comp_exists(self): os.path.dirname(__file__), 'resources', 'valid')), 'CoCoCmVcompExists.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)), 0) + assert len(Logger.get_all_messages_of_level_and_or_node( + model.get_model_list()[0], LoggingLevel.ERROR)) == 0 diff --git a/tests/nest_compartmental_tests/compartmental_model_test.py b/tests/nest_compartmental_tests/test__compartmental_model.py similarity index 96% rename from tests/nest_compartmental_tests/compartmental_model_test.py rename to tests/nest_compartmental_tests/test__compartmental_model.py index 9af75bb43..95e9ff9b7 100644 --- a/tests/nest_compartmental_tests/compartmental_model_test.py +++ b/tests/nest_compartmental_tests/test__compartmental_model.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# compartmental_model_test.py +# test__compartmental_model.py # # This file is part of NEST. # @@ -23,7 +23,6 @@ import os import copy import pytest -import unittest import nest @@ -76,7 +75,7 @@ } -class CMTest(unittest.TestCase): +class TestCM(): def reset_nest(self): nest.ResetKernel() @@ -103,10 +102,10 @@ def install_nestml_model(self): generate_nest_compartmental_target( input_path=input_path, - target_path="/tmp/nestml-component/", + target_path=target_path, module_name="cm_defaultmodule", suffix="_nestml", - logging_level="DEBUG" + logging_level="ERROR" ) def get_model(self, reinstall_flag=True): @@ -262,6 +261,8 @@ def run_model(self): @pytest.mark.skipif(NESTTools.detect_nest_version().startswith("v2"), reason="This test does not support NEST 2") def test_compartmental_model(self): + """We numerically compare the output of the standard nest compartmental model to the equivalent nestml + compartmental model""" self.nestml_flag = False recordables_nest = self.get_rec_list() res_act_nest, res_pas_nest = self.run_model() @@ -528,38 +529,35 @@ def test_compartmental_model(self): for var_nest, var_nestml in zip( recordables_nest[:8], recordables_nestml[:8]): if var_nest == "v_comp0": - atol = 0.51 + atol = 1.0 elif var_nest == "v_comp1": - atol = 0.15 + atol = 0.3 else: - atol = 0.01 - self.assertTrue(np.allclose( + atol = 0.02 + assert (np.allclose( res_act_nest[var_nest], res_act_nestml[var_nestml], atol=atol )) for var_nest, var_nestml in zip( recordables_nest[:8], recordables_nestml[:8]): - if var_nest == "v_comp0": - atol = 0.51 - elif var_nest == "v_comp1": - atol = 0.15 - else: - atol = 0.01 - self.assertTrue(np.allclose( - res_pas_nest[var_nest], res_pas_nestml[var_nestml], atol=atol - )) + if not var_nest in ["h_Na_1", "m_Na_1", "n_K_1"]: + if var_nest == "v_comp0": + atol = 1.0 + elif var_nest == "v_comp1": + atol = 0.3 + else: + atol = 0.02 + assert (np.allclose( + res_pas_nest[var_nest], res_pas_nestml[var_nestml], atol=atol + )) # check if synaptic conductances are equal - self.assertTrue( + assert ( np.allclose( res_act_nest['g_r_AN_AMPA_1'] + res_act_nest['g_d_AN_AMPA_1'], res_act_nestml['g_AN_AMPA1'], 5e-3)) - self.assertTrue( + assert ( np.allclose( res_act_nest['g_r_AN_NMDA_1'] + res_act_nest['g_d_AN_NMDA_1'], res_act_nestml['g_AN_NMDA1'], 5e-3)) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/nest_compartmental_tests/test__compartmental_stdp.py b/tests/nest_compartmental_tests/test__compartmental_stdp.py new file mode 100644 index 000000000..0941fb023 --- /dev/null +++ b/tests/nest_compartmental_tests/test__compartmental_stdp.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# +# compartmental_stdp_test.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os +import unittest + +import numpy as np +import pytest + +import nest + +from pynestml.codegeneration.nest_tools import NESTTools +from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target + +# set to `True` to plot simulation traces +TEST_PLOTS = True +try: + import matplotlib + import matplotlib.pyplot as plt +except BaseException as e: + # always set TEST_PLOTS to False if matplotlib can not be imported + TEST_PLOTS = False + +class TestCompartmentalConcmech(unittest.TestCase): + @pytest.fixture(scope="module", autouse=True) + def setup(self): + tests_path = os.path.realpath(os.path.dirname(__file__)) + neuron_input_path = os.path.join( + tests_path, + "resources", + "concmech.nestml" + ) + synapse_input_path = os.path.join( + tests_path, + "resources", + "third_factor_stdp_synapse.nestml" + ) + target_path = os.path.join( + tests_path, + "target/" + ) + + if not os.path.exists(target_path): + os.makedirs(target_path) + + print( + f"Compiled nestml model 'cm_main_cm_default_nestml' not found, installing in:" + f" {target_path}" + ) + + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=.1)) + if True: + generate_nest_compartmental_target( + input_path=[neuron_input_path, synapse_input_path], + target_path=target_path, + module_name="cm_stdp_module", + suffix="_nestml", + logging_level="DEBUG", + codegen_opts={"neuron_synapse_pairs": [{"neuron": "multichannel_test_model", + "synapse": "third_factor_stdp_synapse", + "post_ports": ["post_spikes"]}], + "delay_variable": {"stdp_synapse": "d"}, + "weight_variable": {"stdp_synapse": "w"} + } + ) + + nest.Install("cm_stdp_module.so") + + def test_cm_stdp(self): + """ + Test the interaction between the pre- and post-synaptic spikes using STDP (Spike-Timing-Dependent Plasticity). + + This function sets up a simulation environment using NEST Simulator to demonstrate synaptic dynamics with pre-defined spike times for pre- and post-synaptic neurons. The function creates neuron models, assigns parameters, sets up connections, and records data from the simulation. It then plots the results for voltage, synaptic weight, spike timing, and pre- and post-synaptic traces. + + Simulation Procedure: + 1. Define pre- and post-synaptic spike timings and calculate simulation duration. + 2. Set up neuron models: + a. `spike_generator` to provide external spike input. + b. `parrot_neuron` for relaying spikes. + c. Custom `multichannel_test_model_nestml` neuron for the postsynaptic side, with compartments and receptor configurations specified. + 3. Create recording devices: + a. `multimeter` to record voltage, synaptic weights, currents, and traces. + b. `spike_recorder` to record spikes from pre- and post-synaptic neurons. + 4. Establish connections: + a. Connect spike generators to pre and post-neurons with static synaptic configurations. + b. Connect pre-neuron to post-neuron using a configured STDP synapse. + c. Connect recording devices to the respective neurons. + 5. Simulate the network for the specified time duration. + 6. Retrieve data from the multimeter and spike recorders. + 7. Plot the recorded data: + a. Membrane voltage of the post-synaptic neuron. + b. Synaptic weight change. + c. Pre- and post-spike timings marked with vertical lines. + d. Pre- and post-synaptic traces. + + Results: + The plots generated illustrate the effects of spike timing on various properties of the post-synaptic neuron, highlighting STDP-driven synaptic weight changes and trace dynamics. + """ + pre_spike_times = [11, 50] + post_spike_times = [12, 45] + sim_time = max(np.amax(pre_spike_times), np.amax(post_spike_times)) + 20 + #wr = nest.Create("weight_recorder") + #nest.CopyModel("stdp_synapse_nestml__with_multichannel_test_model_nestml", "stdp_nestml_rec", + # {"weight_recorder": wr[0], "w": 1., "d": 1., "receptor_type": 0}) + external_input_pre = nest.Create("spike_generator", params={"spike_times": pre_spike_times}) + external_input_post = nest.Create("spike_generator", params={"spike_times": post_spike_times}) + pre_neuron = nest.Create("parrot_neuron") + post_neuron = nest.Create('multichannel_test_model_nestml') + print("created") + + params = {'C_m': 10.0, 'g_C': 0.0, 'g_L': 1.5, 'e_L': -70.0, 'gbar_Ca_HVA': 1.0, 'gbar_SK_E2': 1.0} + post_neuron.compartments = [ + {"parent_idx": -1, "params": params} + ] + print("comps") + post_neuron.receptors = [ + {"comp_idx": 0, "receptor_type": "AMPA"}, + {"comp_idx": 0, "receptor_type": "AMPA_third_factor_stdp_synapse_nestml", "params": {'w': 50.0}} + ] + print("syns") + mm = nest.Create('multimeter', 1, { + 'record_from': ['v_comp0', 'w0', 'i_tot_AMPA0', 'i_tot_AMPA_third_factor_stdp_synapse_nestml0', 'pre_trace0', 'post_trace0'], 'interval': .1}) + spikedet_pre = nest.Create("spike_recorder") + spikedet_post = nest.Create("spike_recorder") + + nest.Connect(external_input_pre, pre_neuron, "one_to_one", syn_spec={'synapse_model': 'static_synapse', 'weight': 2.0, 'delay': 0.1}) + nest.Connect(external_input_post, post_neuron, "one_to_one", syn_spec={'synapse_model': 'static_synapse', 'weight': 5.0, 'delay': 0.1, 'receptor_type': 0}) + nest.Connect(pre_neuron, post_neuron, "one_to_one", syn_spec={'synapse_model': 'static_synapse', 'weight': 1.0, 'delay': 0.1, 'receptor_type': 1}) + nest.Connect(mm, post_neuron) + nest.Connect(pre_neuron, spikedet_pre) + nest.Connect(post_neuron, spikedet_post) + print("pre sim") + nest.Simulate(sim_time) + res = nest.GetStatus(mm, 'events')[0] + pre_spikes_rec = nest.GetStatus(spikedet_pre, 'events')[0] + post_spikes_rec = nest.GetStatus(spikedet_post, 'events')[0] + + fig, axs = plt.subplots(4) + + axs[0].plot(res['times'], res['v_comp0'], c='r', label='V_m_0') + axs[1].plot(res['times'], res['w0'], c='r', label="weight") + #axs[1].plot(res['times'], res['pre_trace_AMPA0'], c='b', label="pre_trace") + #axs[1].plot(res['times'], res['post_trace_AMPA0'], c='g', label="post_trace") + axs[2].plot(res['times'], res['i_tot_AMPA0'], c='b', label="AMPA") + axs[2].plot(res['times'], res['i_tot_AMPA_third_factor_stdp_synapse_nestml0'], c='g', label="AMPA STDP") + label_set = False + for spike in pre_spikes_rec['times']: + if(label_set): + axs[2].axvline(x=spike, color='purple', linestyle='--', linewidth=1) + else: + axs[2].axvline(x=spike, color='purple', linestyle='--', linewidth=1, label="pre syn spikes") + label_set = True + + label_set = False + for spike in post_spikes_rec['times']: + if(label_set): + axs[2].axvline(x=spike, color='orange', linestyle='--', linewidth=1) + else: + axs[2].axvline(x=spike, color='orange', linestyle='--', linewidth=1, label="post syn spikes") + label_set = True + + axs[3].plot(res['times'], res['pre_trace0'], c='b', label="pre_trace") + axs[3].plot(res['times'], res['post_trace0'], c='g', label="post_trace") + + + axs[0].set_title('V_m_0') + axs[1].set_title('weight') + axs[2].set_title('spikes') + axs[3].set_title('traces') + + axs[0].legend() + axs[1].legend() + axs[2].legend() + axs[3].legend() + + plt.show() diff --git a/tests/nest_compartmental_tests/test__concmech_model.py b/tests/nest_compartmental_tests/test__concmech_model.py new file mode 100644 index 000000000..7c4add105 --- /dev/null +++ b/tests/nest_compartmental_tests/test__concmech_model.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# +# test__concmech_model.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os + +import pytest + +import nest + +from pynestml.codegeneration.nest_tools import NESTTools +from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target + +# set to `True` to plot simulation traces +TEST_PLOTS = True +try: + import matplotlib + import matplotlib.pyplot as plt +except BaseException as e: + # always set TEST_PLOTS to False if matplotlib can not be imported + TEST_PLOTS = False + + +class TestCompartmentalConcmech: + @pytest.fixture(scope="module", autouse=True) + def setup(self): + tests_path = os.path.realpath(os.path.dirname(__file__)) + input_path = os.path.join( + tests_path, + "resources", + "concmech.nestml" + ) + target_path = os.path.join( + tests_path, + "target/" + ) + + if not os.path.exists(target_path): + os.makedirs(target_path) + + print( + f"Compiled nestml model 'cm_main_cm_default_nestml' not found, installing in:" + f" {target_path}" + ) + + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=.1)) + + generate_nest_compartmental_target( + input_path=input_path, + target_path=target_path, + module_name="concmech_mockup_module", + suffix="_nestml", + logging_level="DEBUG" + ) + + nest.Install("concmech_mockup_module.so") + + def test_concmech(self): + """We test the concentration mechanism by comparing the concentration value at a certain critical point in + time to a previously achieved value at this point""" + cm = nest.Create('multichannel_test_model_nestml') + + params = {'C_m': 10.0, 'g_C': 0.0, 'g_L': 1.5, 'e_L': -70.0, 'gbar_Ca_HVA': 1.0, 'gbar_SK_E2': 1.0} + + cm.compartments = [ + {"parent_idx": -1, "params": params} + ] + + cm.receptors = [ + {"comp_idx": 0, "receptor_type": "AMPA"} + ] + + sg1 = nest.Create('spike_generator', 1, {'spike_times': [100.]}) + + nest.Connect(sg1, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': 4.0, 'delay': 0.5, 'receptor_type': 0}) + + mm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'c_Ca0', 'i_tot_Ca_LVAst0', 'i_tot_Ca_HVA0', 'i_tot_SK_E20', 'm_Ca_HVA0', 'h_Ca_HVA0'], 'interval': .1}) + + nest.Connect(mm, cm) + + nest.Simulate(1000.) + + res = nest.GetStatus(mm, 'events')[0] + + step_time_delta = res['times'][1] - res['times'][0] + data_array_index = int(200 / step_time_delta) + + expected_conc = 0.03559438228347359 + + fig, axs = plt.subplots(5) + + axs[0].plot(res['times'], res['v_comp0'], c='r', label='V_m_0') + axs[1].plot(res['times'], res['c_Ca0'], c='y', label='c_Ca_0') + axs[2].plot(res['times'], res['i_tot_Ca_HVA0'], c='b', label='i_tot_Ca_HVA0') + axs[3].plot(res['times'], res['i_tot_SK_E20'], c='b', label='i_tot_SK_E20') + axs[4].plot(res['times'], res['m_Ca_HVA0'], c='g', label='gating var m') + axs[4].plot(res['times'], res['h_Ca_HVA0'], c='r', label='gating var h') + + axs[0].set_title('V_m_0') + axs[1].set_title('c_Ca_0') + axs[2].set_title('i_Ca_HVA_0') + axs[3].set_title('i_tot_SK_E20') + axs[4].set_title('gating vars') + + axs[0].legend() + axs[1].legend() + axs[2].legend() + axs[3].legend() + axs[4].legend() + + plt.savefig("concmech test.png") + + assert res['c_Ca0'][data_array_index] == expected_conc, ("the concentration (left) is not as expected (right). (" + str(res['c_Ca0'][data_array_index]) + "!=" + str(expected_conc) + ")") diff --git a/tests/nest_compartmental_tests/test__continuous_input.py b/tests/nest_compartmental_tests/test__continuous_input.py new file mode 100644 index 000000000..6f8a60060 --- /dev/null +++ b/tests/nest_compartmental_tests/test__continuous_input.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +# +# test__continuous_input.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os + +import pytest + +import nest + +from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target + +# set to `True` to plot simulation traces +TEST_PLOTS = True +try: + import matplotlib + import matplotlib.pyplot as plt +except BaseException as e: + # always set TEST_PLOTS to False if matplotlib can not be imported + TEST_PLOTS = False + + +class TestContinuousInput: + @pytest.fixture(scope="module", autouse=True) + def setup(self): + tests_path = os.path.realpath(os.path.dirname(__file__)) + input_path = os.path.join( + tests_path, + "resources", + "continuous_test.nestml" + ) + target_path = os.path.join( + tests_path, + "target/" + ) + + if not os.path.exists(target_path): + os.makedirs(target_path) + + print( + f"Compiled nestml model 'cm_main_cm_default_nestml' not found, installing in:" + f" {target_path}" + ) + + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=.1)) + + generate_nest_compartmental_target( + input_path=input_path, + target_path=target_path, + module_name="continuous_test_module", + suffix="_nestml", + logging_level="DEBUG" + ) + + nest.Install("continuous_test_module.so") + + def test_continuous_input(self): + """We test the continuous input mechanism by just comparing the input current at a certain critical point in + time to a previously achieved value at this point""" + cm = nest.Create('continuous_test_model_nestml') + + soma_params = {'C_m': 10.0, 'g_C': 0.0, 'g_L': 1.5, 'e_L': -70.0} + + cm.compartments = [ + {"parent_idx": -1, "params": soma_params} + ] + + cm.receptors = [ + {"comp_idx": 0, "receptor_type": "con_in"}, + {"comp_idx": 0, "receptor_type": "AMPA"} + ] + + dcg = nest.Create("ac_generator", {"amplitude": 2.0, "start": 200, "stop": 800, "frequency": 20}) + + nest.Connect(dcg, cm, syn_spec={"synapse_model": "static_synapse", "weight": 1.0, "delay": 0.1, "receptor_type": 0}) + + sg1 = nest.Create('spike_generator', 1, {'spike_times': [205]}) + + nest.Connect(sg1, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': 3.0, 'delay': 0.5, 'receptor_type': 1}) + + mm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'i_tot_con_in0', 'i_tot_AMPA0'], 'interval': .1}) + + nest.Connect(mm, cm) + + nest.Simulate(1000.) + + res = nest.GetStatus(mm, 'events')[0] + + fig, axs = plt.subplots(2) + + axs[0].plot(res['times'], res['v_comp0'], c='b', label='V_m_0') + axs[1].plot(res['times'], res['i_tot_con_in0'], c='r', label='continuous') + axs[1].plot(res['times'], res['i_tot_AMPA0'], c='g', label='synapse') + + axs[0].set_title('V_m_0') + axs[1].set_title('inputs') + + axs[0].legend() + axs[1].legend() + + plt.savefig("continuous input test.png") + + step_time_delta = res['times'][1] - res['times'][0] + data_array_index = int(212 / step_time_delta) + + assert 19.9 < res['i_tot_con_in0'][data_array_index] < 20.1, ("the current (left) is not close enough to expected (right). (" + str(res['i_tot_con_in0'][data_array_index]) + " != " + "20.0 +- 0.1" + ")") diff --git a/tests/nest_compartmental_tests/test__interaction_with_disabled_mechanism.py b/tests/nest_compartmental_tests/test__interaction_with_disabled_mechanism.py new file mode 100644 index 000000000..8834c5c7f --- /dev/null +++ b/tests/nest_compartmental_tests/test__interaction_with_disabled_mechanism.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +# +# test__interaction_with_disabled_mechanism.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os + +import pytest + +import nest + +from pynestml.codegeneration.nest_tools import NESTTools +from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target + +# set to `True` to plot simulation traces +TEST_PLOTS = True +try: + import matplotlib + import matplotlib.pyplot as plt +except BaseException as e: + # always set TEST_PLOTS to False if matplotlib can not be imported + TEST_PLOTS = False + + +class TestCompartmentalMechDisabled(): + @pytest.fixture(scope="module", autouse=True) + def setup(self): + tests_path = os.path.realpath(os.path.dirname(__file__)) + input_path = os.path.join( + tests_path, + "resources", + "concmech.nestml" + ) + target_path = os.path.join( + tests_path, + "target/" + ) + + if not os.path.exists(target_path): + os.makedirs(target_path) + + print( + f"Compiled nestml model 'cm_main_cm_default_nestml' not found, installing in:" + f" {target_path}" + ) + + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=.1)) + + generate_nest_compartmental_target( + input_path=input_path, + target_path="/tmp/nestml-component/", + module_name="concmech_mockup_module", + suffix="_nestml", + logging_level="DEBUG" + ) + + nest.Install("concmech_mockup_module.so") + + def test_interaction_with_disabled(self): + """We test the interaction of active mechanisms (the concentration in this case) with disabled mechanisms + (zero key parameters) by just comparing the concentration value at a certain critical point in + time to a previously achieved value at this point""" + cm = nest.Create('multichannel_test_model_nestml') + + params = {'C_m': 10.0, 'g_C': 0.0, 'g_L': 1.5, 'e_L': -70.0, 'gbar_Ca_HVA': 0.0, 'gbar_SK_E2': 1.0} + + cm.compartments = [ + {"parent_idx": -1, "params": params} + ] + + cm.receptors = [ + {"comp_idx": 0, "receptor_type": "AMPA"} + ] + + sg1 = nest.Create('spike_generator', 1, {'spike_times': [100.]}) + + nest.Connect(sg1, cm, syn_spec={'synapse_model': 'static_synapse', 'weight': 4.0, 'delay': 0.5, 'receptor_type': 0}) + + mm = nest.Create('multimeter', 1, {'record_from': ['v_comp0', 'c_Ca0', 'i_tot_Ca_LVAst0', 'i_tot_Ca_HVA0', 'i_tot_SK_E20'], 'interval': .1}) + + nest.Connect(mm, cm) + + nest.Simulate(1000.) + + res = nest.GetStatus(mm, 'events')[0] + + step_time_delta = res['times'][1] - res['times'][0] + data_array_index = int(200 / step_time_delta) + + expected_conc = 2.8159902294145262e-05 + + fig, axs = plt.subplots(4) + + axs[0].plot(res['times'], res['v_comp0'], c='r', label='V_m_0') + axs[1].plot(res['times'], res['c_Ca0'], c='y', label='c_Ca_0') + axs[2].plot(res['times'], res['i_tot_Ca_HVA0'], c='b', label='i_tot_Ca_HVA0') + axs[3].plot(res['times'], res['i_tot_SK_E20'], c='b', label='i_tot_SK_E20') + + axs[0].set_title('V_m_0') + axs[1].set_title('c_Ca_0') + axs[2].set_title('i_Ca_HVA_0') + axs[3].set_title('i_tot_SK_E20') + + axs[0].legend() + axs[1].legend() + axs[2].legend() + axs[3].legend() + + plt.savefig("interaction with disabled mechanism test.png") + + assert res['c_Ca0'][data_array_index] == expected_conc, ("the concentration (left) is not as expected (right). (" + str(res['c_Ca0'][data_array_index]) + "!=" + str(expected_conc) + ")") diff --git a/tests/nest_compartmental_tests/test__model_variable_initialization.py b/tests/nest_compartmental_tests/test__model_variable_initialization.py new file mode 100644 index 000000000..3a8b313f7 --- /dev/null +++ b/tests/nest_compartmental_tests/test__model_variable_initialization.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# +# test__model_variable_initialization.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os + +import pytest + +import nest + +from pynestml.codegeneration.nest_tools import NESTTools +from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target + +# set to `True` to plot simulation traces +TEST_PLOTS = True +try: + import matplotlib + import matplotlib.pyplot as plt +except BaseException as e: + # always set TEST_PLOTS to False if matplotlib can not be imported + TEST_PLOTS = False + + +class TestInitialization(): + @pytest.fixture(scope="module", autouse=True) + def setup(self): + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=0.1)) + + tests_path = os.path.realpath(os.path.dirname(__file__)) + input_path = os.path.join( + tests_path, + "resources", + "concmech.nestml" + ) + target_path = os.path.join( + tests_path, + "target/" + ) + + if not os.path.exists(target_path): + os.makedirs(target_path) + + print( + f"Compiled nestml model 'cm_main_cm_default_nestml' not found, installing in:" + f" {target_path}" + ) + + generate_nest_compartmental_target( + input_path=input_path, + target_path="/tmp/nestml-component/", + module_name="concmech_module", + suffix="_nestml", + logging_level="DEBUG" + ) + + nest.Install("concmech_module.so") + + def test_non_existing_param(self): + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=0.1)) + nest.Install("concmech_module.so") + + params = {'C_m': 10.0, 'g_C': 0.0, 'g_L': 1., 'e_L': -70.0, 'non_existing': 1.0} + + with pytest.raises(nest.NESTErrors.BadParameter): + cm = nest.Create('multichannel_test_model_nestml') + cm.compartments = [{"parent_idx": -1, "params": params}] + + def test_existing_states(self): + """Testing whether the python initialization of variables works by looking up the variables at the very start of + the simulation. Since the values change dramatically in the very first step, before which we can not sample them + we test whether they are still large enough and not whether they are the same""" + nest.ResetKernel() + nest.SetKernelStatus(dict(resolution=0.1)) + nest.Install("concmech_module.so") + + params = {'C_m': 10.0, 'g_C': 0.0, 'g_L': 1., 'e_L': -70.0, 'gbar_NaTa_t': 1.0, 'h_NaTa_t': 1000.0, 'c_Ca': 1000.0, 'v_comp': 1000.0} + + cm = nest.Create('multichannel_test_model_nestml') + cm.compartments = [{"parent_idx": -1, "params": params}] + + mm = nest.Create('multimeter', 1, { + 'record_from': ['v_comp0', 'c_Ca0', 'h_NaTa_t0'], 'interval': .1}) + + nest.Connect(mm, cm) + + nest.Simulate(1000.) + + res = nest.GetStatus(mm, 'events')[0] + + data_array_index = 0 + + fig, axs = plt.subplots(3) + + axs[0].plot(res['times'], res['v_comp0'], c='r', label='v_comp0') + axs[1].plot(res['times'], res['c_Ca0'], c='y', label='c_Ca0') + axs[2].plot(res['times'], res['h_NaTa_t0'], c='b', label='h_NaTa_t0') + + axs[0].set_title('v_comp0') + axs[1].set_title('c_Ca0') + axs[2].set_title('h_NaTa_t') + + axs[0].legend() + axs[1].legend() + axs[2].legend() + + plt.savefig("initialization test.png") + + assert res['v_comp0'][data_array_index] > 50.0, ("the voltage (left) is not as expected (right). (" + str(res['v_comp0'][data_array_index]) + "<" + str(50.0) + ")") + + assert res['c_Ca0'][data_array_index] > 900.0, ("the concentration (left) is not as expected (right). (" + str(res['c_Ca0'][data_array_index]) + "<" + str(900.0) + ")") + + assert res['h_NaTa_t0'][data_array_index] > 5.0, ("the gating variable state (left) is not as expected (right). (" + str(res['h_NaTa_t0'][data_array_index]) + "<" + str(5.0) + ")") diff --git a/tests/nest_tests/test_gap_junction.py b/tests/nest_tests/test_gap_junction.py index e94a34de6..25ced348d 100644 --- a/tests/nest_tests/test_gap_junction.py +++ b/tests/nest_tests/test_gap_junction.py @@ -23,6 +23,7 @@ import os import pytest import scipy +import scipy.signal import nest @@ -58,7 +59,7 @@ def generate_code(self, neuron_model: str): files = [os.path.join("models", "neurons", neuron_model + ".nestml")] input_path = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join(os.pardir, os.pardir, s))) for s in files] generate_nest_target(input_path=input_path, - logging_level="DEBUG", + logging_level="WARNING", module_name="nestml_gap_" + neuron_model + "_module", suffix="_nestml", codegen_opts=codegen_opts)