diff --git a/pynestml/cocos/co_co_all_variables_defined.py b/pynestml/cocos/co_co_all_variables_defined.py
index e41b0727e..38cfa89ab 100644
--- a/pynestml/cocos/co_co_all_variables_defined.py
+++ b/pynestml/cocos/co_co_all_variables_defined.py
@@ -41,11 +41,10 @@ class CoCoAllVariablesDefined(CoCo):
"""
@classmethod
- def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False):
+ def check_co_co(cls, node: ASTModel):
"""
Checks if this coco applies for the handed over neuron. Models which contain undefined variables are not correct.
:param node: a single neuron instance.
- :param after_ast_rewrite: indicates whether this coco is checked after the code generator has done rewriting of the abstract syntax tree. If True, checks are not as rigorous. Use False where possible.
"""
# for each variable in all expressions, check if the variable has been defined previously
expression_collector_visitor = ASTExpressionCollectorVisitor()
@@ -62,32 +61,6 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False):
# test if the symbol has been defined at least
if symbol is None:
- if after_ast_rewrite: # after ODE-toolbox transformations, convolutions are replaced by state variables, so cannot perform this check properly
- symbol2 = node.get_scope().resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE)
- if symbol2 is not None:
- # an inline expression defining this variable name (ignoring differential order) exists
- if "__X__" in str(symbol2): # if this variable was the result of a convolution...
- continue
- else:
- # for kernels, also allow derivatives of that kernel to appear
-
- inline_expr_names = []
- inline_exprs = []
- for equations_block in node.get_equations_blocks():
- inline_expr_names.extend([inline_expr.variable_name for inline_expr in equations_block.get_inline_expressions()])
- inline_exprs.extend(equations_block.get_inline_expressions())
-
- if var.get_name() in inline_expr_names:
- inline_expr_idx = inline_expr_names.index(var.get_name())
- inline_expr = inline_exprs[inline_expr_idx]
- from pynestml.utils.ast_utils import ASTUtils
- if ASTUtils.inline_aliases_convolution(inline_expr):
- symbol2 = node.get_scope().resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE)
- if symbol2 is not None:
- # actually, no problem detected, skip error
- # XXX: TODO: check that differential order is less than or equal to that of the kernel
- continue
-
# check if this symbol is actually a type, e.g. "mV" in the expression "(1 + 2) * mV"
symbol2 = var.get_scope().resolve_to_symbol(var.get_complete_name(), SymbolKind.TYPE)
if symbol2 is not None:
@@ -106,9 +79,14 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False):
# in this case its ok if it is recursive or defined later on
continue
+ if symbol.is_predefined:
+ continue
+
+ if symbol.block_type == BlockType.LOCAL and symbol.get_referenced_object().get_source_position().before(var.get_source_position()):
+ continue
+
# check if it has been defined before usage, except for predefined symbols, input ports and variables added by the AST transformation functions
- if (not symbol.is_predefined) \
- and symbol.block_type != BlockType.INPUT \
+ if symbol.block_type != BlockType.INPUT \
and not symbol.get_referenced_object().get_source_position().is_added_source_position():
# except for parameters, those can be defined after
if ((not symbol.get_referenced_object().get_source_position().before(var.get_source_position()))
diff --git a/pynestml/cocos/co_co_function_unique.py b/pynestml/cocos/co_co_function_unique.py
index 15643c0ad..bf0f2be60 100644
--- a/pynestml/cocos/co_co_function_unique.py
+++ b/pynestml/cocos/co_co_function_unique.py
@@ -65,4 +65,5 @@ def check_co_co(cls, model: ASTModel):
log_level=LoggingLevel.ERROR,
message=message, code=code)
checked.append(funcA)
+
checked_funcs_names.append(func.get_name())
diff --git a/pynestml/cocos/co_co_illegal_expression.py b/pynestml/cocos/co_co_illegal_expression.py
index b78396e3b..c362d0dc5 100644
--- a/pynestml/cocos/co_co_illegal_expression.py
+++ b/pynestml/cocos/co_co_illegal_expression.py
@@ -18,13 +18,13 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
-from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
-from pynestml.utils.ast_source_location import ASTSourceLocation
-from pynestml.meta_model.ast_declaration import ASTDeclaration
from pynestml.cocos.co_co import CoCo
+from pynestml.meta_model.ast_declaration import ASTDeclaration
+from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
from pynestml.symbols.error_type_symbol import ErrorTypeSymbol
from pynestml.symbols.predefined_types import PredefinedTypes
+from pynestml.utils.ast_source_location import ASTSourceLocation
from pynestml.utils.logger import LoggingLevel, Logger
from pynestml.utils.logging_helper import LoggingHelper
from pynestml.utils.messages import Messages
diff --git a/pynestml/cocos/co_co_no_kernels_except_in_convolve.py b/pynestml/cocos/co_co_no_kernels_except_in_convolve.py
index 18b862292..e318ae566 100644
--- a/pynestml/cocos/co_co_no_kernels_except_in_convolve.py
+++ b/pynestml/cocos/co_co_no_kernels_except_in_convolve.py
@@ -22,11 +22,14 @@
from typing import List
from pynestml.cocos.co_co import CoCo
+from pynestml.meta_model.ast_declaration import ASTDeclaration
+from pynestml.meta_model.ast_external_variable import ASTExternalVariable
from pynestml.meta_model.ast_function_call import ASTFunctionCall
from pynestml.meta_model.ast_kernel import ASTKernel
from pynestml.meta_model.ast_model import ASTModel
from pynestml.meta_model.ast_node import ASTNode
from pynestml.meta_model.ast_variable import ASTVariable
+from pynestml.symbols.predefined_functions import PredefinedFunctions
from pynestml.symbols.symbol import SymbolKind
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.utils.messages import Messages
@@ -89,24 +92,44 @@ def visit_variable(self, node: ASTNode):
if not (isinstance(node, ASTExternalVariable) and node.get_alternate_name()):
code, message = Messages.get_no_variable_found(kernelName)
Logger.log_message(node=self.__neuron_node, code=code, message=message, log_level=LoggingLevel.ERROR)
+
continue
+
if not symbol.is_kernel():
continue
+
if node.get_complete_name() == kernelName:
- parent = node.get_parent()
- if parent is not None:
+ parent = node
+ correct = False
+ while parent is not None and not isinstance(parent, ASTModel):
+ parent = parent.get_parent()
+ assert parent is not None
+
+ if isinstance(parent, ASTDeclaration):
+ for lhs_var in parent.get_variables():
+ if kernelName == lhs_var.get_complete_name():
+ # kernel name appears on lhs of declaration, assume it is initial state
+ correct = True
+ parent = None # break out of outer loop
+ break
+
if isinstance(parent, ASTKernel):
- continue
- grandparent = parent.get_parent()
- if grandparent is not None and isinstance(grandparent, ASTFunctionCall):
- grandparent_func_name = grandparent.get_name()
- if grandparent_func_name == 'convolve':
- continue
- code, message = Messages.get_kernel_outside_convolve(kernelName)
- Logger.log_message(code=code,
- message=message,
- log_level=LoggingLevel.ERROR,
- error_position=node.get_source_position())
+ # kernel name is used inside kernel definition, e.g. for a node ``g``, it appears in ``kernel g'' = -1/tau**2 * g - 2/tau * g'``
+ correct = True
+ break
+
+ if isinstance(parent, ASTFunctionCall):
+ func_name = parent.get_name()
+ if func_name == PredefinedFunctions.CONVOLVE:
+ # kernel name is used inside convolve call
+ correct = True
+
+ if not correct:
+ code, message = Messages.get_kernel_outside_convolve(kernelName)
+ Logger.log_message(code=code,
+ message=message,
+ log_level=LoggingLevel.ERROR,
+ error_position=node.get_source_position())
class KernelCollectingVisitor(ASTVisitor):
diff --git a/pynestml/cocos/co_co_v_comp_exists.py b/pynestml/cocos/co_co_v_comp_exists.py
index 4ef08c0ec..51308f2cc 100644
--- a/pynestml/cocos/co_co_v_comp_exists.py
+++ b/pynestml/cocos/co_co_v_comp_exists.py
@@ -43,9 +43,6 @@ def check_co_co(cls, neuron: ASTModel):
Models which are supposed to be compartmental but do not contain
state variable called v_comp are not correct.
:param neuron: a single neuron instance.
- :param after_ast_rewrite: indicates whether this coco is checked
- after the code generator has done rewriting of the abstract syntax tree.
- If True, checks are not as rigorous. Use False where possible.
"""
from pynestml.codegeneration.nest_compartmental_code_generator import NESTCompartmentalCodeGenerator
diff --git a/pynestml/cocos/co_cos_manager.py b/pynestml/cocos/co_cos_manager.py
index 01d008890..9ad3b37bf 100644
--- a/pynestml/cocos/co_cos_manager.py
+++ b/pynestml/cocos/co_cos_manager.py
@@ -69,6 +69,7 @@
from pynestml.cocos.co_co_priorities_correctly_specified import CoCoPrioritiesCorrectlySpecified
from pynestml.meta_model.ast_model import ASTModel
from pynestml.frontend.frontend_configuration import FrontendConfiguration
+from pynestml.utils.logger import Logger
class CoCosManager:
@@ -123,12 +124,12 @@ def check_state_variables_initialized(cls, model: ASTModel):
CoCoStateVariablesInitialized.check_co_co(model)
@classmethod
- def check_variables_defined_before_usage(cls, model: ASTModel, after_ast_rewrite: bool) -> None:
+ def check_variables_defined_before_usage(cls, model: ASTModel) -> None:
"""
Checks that all variables are defined before being used.
:param model: a single model.
"""
- CoCoAllVariablesDefined.check_co_co(model, after_ast_rewrite)
+ CoCoAllVariablesDefined.check_co_co(model)
@classmethod
def check_v_comp_requirement(cls, neuron: ASTModel):
@@ -402,17 +403,19 @@ def check_input_port_size_type(cls, model: ASTModel):
CoCoVectorInputPortsCorrectSizeType.check_co_co(model)
@classmethod
- def post_symbol_table_builder_checks(cls, model: ASTModel, after_ast_rewrite: bool = False):
+ def check_cocos(cls, model: ASTModel, after_ast_rewrite: bool = False):
"""
Checks all context conditions.
:param model: a single model object.
"""
+ Logger.set_current_node(model)
+
cls.check_each_block_defined_at_most_once(model)
cls.check_function_defined(model)
cls.check_variables_unique_in_scope(model)
cls.check_inline_expression_not_assigned_to(model)
cls.check_state_variables_initialized(model)
- cls.check_variables_defined_before_usage(model, after_ast_rewrite)
+ cls.check_variables_defined_before_usage(model)
if FrontendConfiguration.get_target_platform().upper() == 'NEST_COMPARTMENTAL':
# XXX: TODO: refactor this out; define a ``cocos_from_target_name()`` in the frontend instead.
cls.check_v_comp_requirement(model)
@@ -452,3 +455,5 @@ def post_symbol_table_builder_checks(cls, model: ASTModel, after_ast_rewrite: bo
cls.check_co_co_priorities_correctly_specified(model)
cls.check_resolution_func_legally_used(model)
cls.check_input_port_size_type(model)
+
+ Logger.set_current_node(None)
diff --git a/pynestml/codegeneration/builder.py b/pynestml/codegeneration/builder.py
index 2e6757c1a..a9f98bf58 100644
--- a/pynestml/codegeneration/builder.py
+++ b/pynestml/codegeneration/builder.py
@@ -20,12 +20,12 @@
# along with NEST. If not, see .
from __future__ import annotations
-import subprocess
-import os
from typing import Any, Mapping, Optional
from abc import ABCMeta, abstractmethod
+import os
+import subprocess
from pynestml.exceptions.invalid_target_exception import InvalidTargetException
from pynestml.frontend.frontend_configuration import FrontendConfiguration
diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py
index 0551e9a6e..155bea95c 100644
--- a/pynestml/codegeneration/nest_code_generator.py
+++ b/pynestml/codegeneration/nest_code_generator.py
@@ -28,6 +28,7 @@
import pynestml
from pynestml.cocos.co_co_nest_synapse_delay_not_assigned_to import CoCoNESTSynapseDelayNotAssignedTo
+from pynestml.cocos.co_cos_manager import CoCosManager
from pynestml.codegeneration.code_generator import CodeGenerator
from pynestml.codegeneration.code_generator_utils import CodeGeneratorUtils
from pynestml.codegeneration.nest_assignments_helper import NestAssignmentsHelper
@@ -374,6 +375,9 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di
if not used_in_eq:
self.non_equations_state_variables[neuron.get_name()].append(var)
+ # cache state variables before symbol table update for the sake of delay variables
+ state_vars_before_update = neuron.get_state_symbols()
+
ASTUtils.remove_initial_values_for_kernels(neuron)
kernels = ASTUtils.remove_kernel_definitions_from_equations_block(neuron)
ASTUtils.update_initial_values_for_odes(neuron, [analytic_solver, numeric_solver])
@@ -388,7 +392,6 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di
neuron = ASTUtils.add_declarations_to_internals(
neuron, self.analytic_solver[neuron.get_name()]["propagators"])
- state_vars_before_update = neuron.get_state_symbols()
self.update_symbol_table(neuron)
# Update the delay parameter parameters after symbol table update
@@ -898,8 +901,8 @@ def update_symbol_table(self, neuron) -> None:
"""
SymbolTable.delete_model_scope(neuron.get_name())
symbol_table_visitor = ASTSymbolTableVisitor()
- symbol_table_visitor.after_ast_rewrite_ = True
neuron.accept(symbol_table_visitor)
+ CoCosManager.check_cocos(neuron, after_ast_rewrite=True)
SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope())
def get_spike_update_expressions(self, neuron: ASTModel, kernel_buffers, solver_dicts, delta_factors) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment]]:
diff --git a/pynestml/codegeneration/nest_compartmental_code_generator.py b/pynestml/codegeneration/nest_compartmental_code_generator.py
index 4711bc497..84199c07e 100644
--- a/pynestml/codegeneration/nest_compartmental_code_generator.py
+++ b/pynestml/codegeneration/nest_compartmental_code_generator.py
@@ -740,8 +740,8 @@ def update_symbol_table(self, neuron, kernel_buffers):
"""
SymbolTable.delete_model_scope(neuron.get_name())
symbol_table_visitor = ASTSymbolTableVisitor()
- symbol_table_visitor.after_ast_rewrite_ = True
neuron.accept(symbol_table_visitor)
+ CoCosManager.check_cocos(neuron, after_ast_rewrite=True)
SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope())
def _get_ast_variable(self, neuron, var_name) -> Optional[ASTVariable]:
diff --git a/pynestml/codegeneration/python_standalone_code_generator.py b/pynestml/codegeneration/python_standalone_code_generator.py
index f44123743..d6afaa095 100644
--- a/pynestml/codegeneration/python_standalone_code_generator.py
+++ b/pynestml/codegeneration/python_standalone_code_generator.py
@@ -111,7 +111,6 @@ def setup_printers(self):
# GSL printers
self._gsl_variable_printer = PythonSteppingFunctionVariablePrinter(None)
- print("In Python code generator: created self._gsl_variable_printer = " + str(self._gsl_variable_printer))
self._gsl_function_call_printer = PythonSteppingFunctionFunctionCallPrinter(None)
self._gsl_printer = PythonExpressionPrinter(simple_expression_printer=PythonSimpleExpressionPrinter(variable_printer=self._gsl_variable_printer,
constant_printer=self._constant_printer,
diff --git a/pynestml/codegeneration/spinnaker_code_generator.py b/pynestml/codegeneration/spinnaker_code_generator.py
index 2a8fed7de..dce247e9c 100644
--- a/pynestml/codegeneration/spinnaker_code_generator.py
+++ b/pynestml/codegeneration/spinnaker_code_generator.py
@@ -137,7 +137,6 @@ def setup_printers(self):
# GSL printers
self._gsl_variable_printer = PythonSteppingFunctionVariablePrinter(None)
- print("In Python code generator: created self._gsl_variable_printer = " + str(self._gsl_variable_printer))
self._gsl_function_call_printer = PythonSteppingFunctionFunctionCallPrinter(None)
self._gsl_printer = PythonExpressionPrinter(simple_expression_printer=SpinnakerPythonSimpleExpressionPrinter(
variable_printer=self._gsl_variable_printer,
@@ -216,6 +215,7 @@ def generate_code(self, models: Sequence[ASTModel]) -> None:
for model in models:
cloned_model = model.clone()
cloned_model.accept(ASTSymbolTableVisitor())
+ CoCosManager.check_cocos(cloned_model)
cloned_models.append(cloned_model)
self.codegen_cpp.generate_code(cloned_models)
@@ -224,6 +224,7 @@ def generate_code(self, models: Sequence[ASTModel]) -> None:
for model in models:
cloned_model = model.clone()
cloned_model.accept(ASTSymbolTableVisitor())
+ CoCosManager.check_cocos(cloned_model)
cloned_models.append(cloned_model)
self.codegen_py.generate_code(cloned_models)
diff --git a/pynestml/frontend/frontend_configuration.py b/pynestml/frontend/frontend_configuration.py
index 173534c95..aae1fc29a 100644
--- a/pynestml/frontend/frontend_configuration.py
+++ b/pynestml/frontend/frontend_configuration.py
@@ -244,8 +244,8 @@ def handle_module_name(cls, module_name):
@classmethod
def handle_target_platform(cls, target_platform: Optional[str]):
- if target_platform is None or target_platform.upper() == 'NONE':
- target_platform = '' # make sure `target_platform` is always a string
+ if target_platform is None:
+ target_platform = "NONE" # make sure `target_platform` is always a string
from pynestml.frontend.pynestml_frontend import get_known_targets
diff --git a/pynestml/frontend/pynestml_frontend.py b/pynestml/frontend/pynestml_frontend.py
index c257822de..c3dc2d2ae 100644
--- a/pynestml/frontend/pynestml_frontend.py
+++ b/pynestml/frontend/pynestml_frontend.py
@@ -41,6 +41,8 @@
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.utils.messages import Messages
from pynestml.utils.model_parser import ModelParser
+from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
+from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
def get_known_targets():
@@ -131,10 +133,10 @@ def code_generator_from_target_name(target_name: str, options: Optional[Mapping[
return SpiNNakerCodeGenerator(options)
if target_name.upper() == "NONE":
- # dummy/null target: user requested to not generate any code
+ # dummy/null target: user requested to not generate any code (for instance, when just doing validation of a model)
code, message = Messages.get_no_code_generated()
Logger.log_message(None, code, message, None, LoggingLevel.INFO)
- return CodeGenerator("", options)
+ return CodeGenerator(options)
# cannot reach here due to earlier assert -- silence static checker warnings
assert "Unknown code generator requested: " + target_name
@@ -193,12 +195,17 @@ def generate_target(input_path: Union[str, Sequence[str]], target_platform: str,
Enable development mode: code generation is attempted even for models that contain errors, and extra information is rendered in the generated code.
codegen_opts : Optional[Mapping[str, Any]]
A dictionary containing additional options for the target code generator.
+
+ Return
+ ------
+ errors_occurred
+ Flag indicating whether errors occurred during processing. False if processing was successful; True if errors occurred in any of the models.
"""
configure_front_end(input_path, target_platform, target_path, install_path, logging_level,
module_name, store_log, suffix, dev, codegen_opts)
- if not process() == 0:
- raise Exception("Error(s) occurred while processing the model")
+
+ return process()
def configure_front_end(input_path: Union[str, Sequence[str]], target_platform: str, target_path=None,
@@ -373,34 +380,36 @@ def generate_nest_compartmental_target(input_path: Union[str, Sequence[str]], ta
def main() -> int:
- """
+ r"""
Entry point for the command-line application.
Returns
-------
- The process exit code: 0 for success, > 0 for failure
+ exit_code
+ The process exit code: 0 for success, > 0 for failure
"""
try:
FrontendConfiguration.parse_config(sys.argv[1:])
except InvalidPathException as e:
print(e)
+
return 1
+
# the default Python recursion limit is 1000, which might not be enough in practice when running an AST visitor on a deep tree, e.g. containing an automatically generated expression
sys.setrecursionlimit(10000)
+
# after all argument have been collected, start the actual processing
return int(process())
-def get_parsed_models():
+def get_parsed_models() -> List[ASTModel]:
r"""
Handle the parsing and validation of the NESTML files
Returns
-------
- models: Sequence[ASTModel]
+ models
List of correctly parsed models
- errors_occurred : bool
- Flag indicating whether errors occurred during processing
"""
# init log dir
create_report_dir()
@@ -417,36 +426,29 @@ def get_parsed_models():
for nestml_file in nestml_files:
parsed_unit = ModelParser.parse_file(nestml_file)
- if parsed_unit is None:
- # Parsing error in the NESTML model, return True
- return [], True
-
- compilation_units.append(parsed_unit)
+ if parsed_unit:
+ compilation_units.append(parsed_unit)
- if len(compilation_units) > 0:
- # generate a list of all models
- models: Sequence[ASTModel] = []
- for compilationUnit in compilation_units:
- models.extend(compilationUnit.get_model_list())
+ # generate a list of all models
+ models: Sequence[ASTModel] = []
+ for compilation_unit in compilation_units:
+ CoCosManager.check_model_names_unique(compilation_unit)
+ models.extend(compilation_unit.get_model_list())
- # check that no models with duplicate names have been defined
- CoCosManager.check_no_duplicate_compilation_unit_names(models)
+ # check that no models with duplicate names have been defined
+ CoCosManager.check_no_duplicate_compilation_unit_names(models)
- # now exclude those which are broken, i.e. have errors.
- for model in models:
- if Logger.has_errors(model):
- code, message = Messages.get_model_contains_errors(model.get_name())
- Logger.log_message(node=model, code=code, message=message,
- error_position=model.get_source_position(),
- log_level=LoggingLevel.WARNING)
- return [model], True
+ for model in models:
+ model.accept(ASTParentVisitor())
+ model.accept(ASTSymbolTableVisitor())
- return models, False
+ return models
def transform_models(transformers, models):
for transformer in transformers:
models = transformer.transform(models)
+
return models
@@ -454,14 +456,14 @@ def generate_code(code_generators, models):
code_generators.generate_code(models)
-def process():
+def process() -> bool:
r"""
The main toolchain workflow entry point. For all models: parse, validate, transform, generate code and build.
- Returns
- -------
- errors_occurred : bool
- Flag indicating whether errors occurred during processing
+ Return
+ ------
+ errors_occurred
+ Flag indicating whether errors occurred during processing. False if processing was successful; True if errors occurred in any of the models.
"""
# initialize and set options for transformers, code generator and builder
@@ -478,20 +480,38 @@ def process():
if len(codegen_and_builder_opts) > 0:
raise CodeGeneratorOptionsException("The code generator option(s) \"" + ", ".join(codegen_and_builder_opts.keys()) + "\" do not exist.")
- models, errors_occurred = get_parsed_models()
+ models = get_parsed_models()
+
+ # validation -- check cocos for models that do not have errors already
+ excluded_models = []
+ for model in models:
+ if Logger.has_errors(model.name):
+ code, message = Messages.get_model_contains_errors(model.get_name())
+ Logger.log_message(node=model, code=code, message=message,
+ error_position=model.get_source_position(),
+ log_level=LoggingLevel.WARNING)
+ excluded_models.append(model)
+ else:
+ CoCosManager.check_cocos(model)
+
+ # exclude models that have errors
+ models = list(set(models) - set(excluded_models))
+
+ # transformation(s)
+ models = transform_models(transformers, models)
- if not errors_occurred:
- models = transform_models(transformers, models)
- generate_code(code_generator, models)
+ # generate code
+ generate_code(code_generator, models)
- # perform build
- if _builder is not None:
- _builder.build()
+ # perform build
+ if _builder is not None:
+ _builder.build()
if FrontendConfiguration.store_log:
store_log_to_file()
- return errors_occurred
+ # return a boolean indicating whether errors occurred
+ return len(Logger.get_all_messages_of_level(LoggingLevel.ERROR)) > 0
def init_predefined():
diff --git a/pynestml/meta_model/ast_model.py b/pynestml/meta_model/ast_model.py
index 834e56897..4b50d59b4 100644
--- a/pynestml/meta_model/ast_model.py
+++ b/pynestml/meta_model/ast_model.py
@@ -459,23 +459,27 @@ def add_to_internals_block(self, declaration: ASTDeclaration, index: int = -1) -
Adds the handed over declaration the internals block
:param declaration: a single declaration
"""
- assert len(self.get_internals_blocks()) <= 1, "Only one internals block supported for now"
from pynestml.utils.ast_utils import ASTUtils
+ from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
+ from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
+
+ assert len(self.get_internals_blocks()) <= 1, "Only one internals block supported for now"
+
if not self.get_internals_blocks():
ASTUtils.create_internal_block(self)
+
n_declarations = len(self.get_internals_blocks()[0].get_declarations())
if n_declarations == 0:
index = 0
else:
index = 1 + (index % len(self.get_internals_blocks()[0].get_declarations()))
+
self.get_internals_blocks()[0].get_declarations().insert(index, declaration)
declaration.update_scope(self.get_internals_blocks()[0].get_scope())
- from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
- from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
symtable_vistor = ASTSymbolTableVisitor()
symtable_vistor.block_type_stack.push(BlockType.INTERNALS)
- declaration.accept(symtable_vistor)
- self.get_internals_blocks()[0].accept(ASTParentVisitor())
+ self.accept(ASTParentVisitor())
+ self.accept(symtable_vistor)
symtable_vistor.block_type_stack.pop()
def add_to_state_block(self, declaration: ASTDeclaration) -> None:
@@ -483,24 +487,26 @@ def add_to_state_block(self, declaration: ASTDeclaration) -> None:
Adds the handed over declaration to an arbitrary state block. A state block will be created if none exists.
:param declaration: a single declaration.
"""
- assert len(self.get_state_blocks()) <= 1, "Only one internals block supported for now"
+ from pynestml.symbols.symbol import SymbolKind
from pynestml.utils.ast_utils import ASTUtils
+ from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
+ from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
+
+ assert len(self.get_state_blocks()) <= 1, "Only one internals block supported for now"
+
if not self.get_state_blocks():
ASTUtils.create_state_block(self)
+
self.get_state_blocks()[0].get_declarations().append(declaration)
declaration.update_scope(self.get_state_blocks()[0].get_scope())
- from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
- from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
symtable_vistor = ASTSymbolTableVisitor()
symtable_vistor.block_type_stack.push(BlockType.STATE)
- declaration.accept(symtable_vistor)
- self.get_state_blocks()[0].accept(ASTParentVisitor())
+ self.accept(ASTParentVisitor())
+ self.accept(symtable_vistor)
symtable_vistor.block_type_stack.pop()
- from pynestml.symbols.symbol import SymbolKind
- assert declaration.get_variables()[0].get_scope().resolve_to_symbol(
- declaration.get_variables()[0].get_name(), SymbolKind.VARIABLE) is not None
- assert declaration.get_scope().resolve_to_symbol(declaration.get_variables()[0].get_name(),
- SymbolKind.VARIABLE) is not None
+
+ assert declaration.get_variables()[0].get_scope().resolve_to_symbol(declaration.get_variables()[0].get_name(), SymbolKind.VARIABLE) is not None
+ assert declaration.get_scope().resolve_to_symbol(declaration.get_variables()[0].get_name(), SymbolKind.VARIABLE) is not None
def print_comment(self, prefix: str = "") -> str:
"""
diff --git a/pynestml/symbols/symbol.py b/pynestml/symbols/symbol.py
index 1e294566b..c73435c6d 100644
--- a/pynestml/symbols/symbol.py
+++ b/pynestml/symbols/symbol.py
@@ -18,8 +18,8 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
-from abc import ABCMeta, abstractmethod
+from abc import ABCMeta, abstractmethod
from enum import Enum
diff --git a/pynestml/symbols/type_symbol.py b/pynestml/symbols/type_symbol.py
index 7047cdbca..a3eb28a12 100644
--- a/pynestml/symbols/type_symbol.py
+++ b/pynestml/symbols/type_symbol.py
@@ -18,11 +18,11 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
+
from abc import ABCMeta, abstractmethod
from pynestml.symbols.symbol import Symbol
from pynestml.utils.logger import Logger, LoggingLevel
-from pynestml.utils.messages import Messages
class TypeSymbol(Symbol):
@@ -198,6 +198,7 @@ def is_castable_to(self, _other_type):
def binary_operation_not_defined_error(self, _operator, _other):
from pynestml.symbols.error_type_symbol import ErrorTypeSymbol
+ from pynestml.utils.messages import Messages
result = ErrorTypeSymbol()
code, message = Messages.get_binary_operation_not_defined(
lhs=self.print_nestml_type(), operator=_operator, rhs=_other.print_nestml_type())
@@ -208,6 +209,7 @@ def binary_operation_not_defined_error(self, _operator, _other):
def unary_operation_not_defined_error(self, _operator):
from pynestml.symbols.error_type_symbol import ErrorTypeSymbol
result = ErrorTypeSymbol()
+ from pynestml.utils.messages import Messages
code, message = Messages.get_unary_operation_not_defined(_operator,
self.print_symbol())
Logger.log_message(code=code, message=message, error_position=self.referenced_object.get_source_position(),
@@ -226,6 +228,7 @@ def inverse_of_unit(cls, other):
return result
def warn_implicit_cast_from_to(self, _from, _to):
+ from pynestml.utils.messages import Messages
code, message = Messages.get_implicit_cast_rhs_to_lhs(_to.print_symbol(), _from.print_symbol())
Logger.log_message(code=code, message=message,
error_position=self.get_referenced_object().get_source_position(),
diff --git a/pynestml/symbols/unit_type_symbol.py b/pynestml/symbols/unit_type_symbol.py
index 37c43b035..1f9977de0 100644
--- a/pynestml/symbols/unit_type_symbol.py
+++ b/pynestml/symbols/unit_type_symbol.py
@@ -19,6 +19,7 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
+from typing import Optional
from pynestml.symbols.type_symbol import TypeSymbol
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.utils.messages import Messages
@@ -131,12 +132,12 @@ def __sub__(self, other):
def add_or_sub_another_unit(self, other):
if self.equals(other):
return other
- else:
- return self.attempt_magnitude_cast(other)
+
+ return self.attempt_magnitude_cast(other)
def attempt_magnitude_cast(self, other):
if self.differs_only_in_magnitude(other):
- factor = UnitTypeSymbol.get_conversion_factor(self.astropy_unit, other.astropy_unit)
+ factor = UnitTypeSymbol.get_conversion_factor(other.astropy_unit, self.astropy_unit)
other.referenced_object.set_implicit_conversion_factor(factor)
code, message = Messages.get_implicit_magnitude_conversion(self, other, factor)
Logger.log_message(code=code, message=message,
@@ -144,18 +145,20 @@ def attempt_magnitude_cast(self, other):
log_level=LoggingLevel.INFO)
return self
- else:
- return self.binary_operation_not_defined_error('+/-', other)
- # TODO: change order of parameters to conform with the from_to scheme.
- # TODO: Also rename to reflect that, i.e. get_conversion_factor_from_to
+ return self.binary_operation_not_defined_error('+/-', other)
+
@classmethod
- def get_conversion_factor(cls, to, _from):
+ def get_conversion_factor(cls, _from, to) -> Optional[float]:
"""
- Calculates the conversion factor from _convertee_unit to target_unit.
- Behaviour is only well-defined if both units have the same physical base type
+ Calculates the conversion factor from _convertee_unit to target_unit. Behaviour is only well-defined if both units have the same physical base type.
"""
- factor = (_from / to).si.scale
+ try:
+ factor = (_from / to).si.scale
+ except BaseException:
+ # this can fail in case of e.g. trying to convert from "1/s" to "2/s"
+ return None
+
return factor
def is_castable_to(self, _other_type):
diff --git a/pynestml/transformers/assign_implicit_conversion_factors_transformer.py b/pynestml/transformers/assign_implicit_conversion_factors_transformer.py
new file mode 100644
index 000000000..f44ee12d5
--- /dev/null
+++ b/pynestml/transformers/assign_implicit_conversion_factors_transformer.py
@@ -0,0 +1,335 @@
+# -*- coding: utf-8 -*-
+#
+# assign_implicit_conversion_factors_transformer.py
+#
+# This file is part of NEST.
+#
+# Copyright (C) 2004 The NEST Initiative
+#
+# NEST is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# NEST is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+
+from typing import Sequence, Union
+
+from pynestml.meta_model.ast_compound_stmt import ASTCompoundStmt
+from pynestml.meta_model.ast_declaration import ASTDeclaration
+from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
+from pynestml.meta_model.ast_node import ASTNode
+from pynestml.meta_model.ast_small_stmt import ASTSmallStmt
+from pynestml.meta_model.ast_stmt import ASTStmt
+from pynestml.symbols.error_type_symbol import ErrorTypeSymbol
+from pynestml.symbols.predefined_types import PredefinedTypes
+from pynestml.symbols.symbol import SymbolKind
+from pynestml.symbols.template_type_symbol import TemplateTypeSymbol
+from pynestml.symbols.variadic_type_symbol import VariadicTypeSymbol
+from pynestml.transformers.transformer import Transformer
+from pynestml.utils.ast_source_location import ASTSourceLocation
+from pynestml.utils.ast_utils import ASTUtils
+from pynestml.utils.logger import LoggingLevel, Logger
+from pynestml.utils.logging_helper import LoggingHelper
+from pynestml.utils.messages import Messages
+from pynestml.utils.type_caster import TypeCaster
+from pynestml.visitors.ast_visitor import ASTVisitor
+
+
+class AssignImplicitConversionFactorsTransformer(Transformer):
+ r"""
+ Assign implicit conversion factors in expressions.
+ """
+
+ def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]:
+ single = False
+ if isinstance(models, ASTNode):
+ single = True
+ models = [models]
+
+ for model in models:
+ model.accept(AssignImplicitConversionFactorVisitor())
+ self.__assign_return_types(model)
+
+ if single:
+ return models[0]
+ return models
+
+ def __assign_return_types(self, _node):
+ for userDefinedFunction in _node.get_functions():
+ symbol = userDefinedFunction.get_scope().resolve_to_symbol(userDefinedFunction.get_name(),
+ SymbolKind.FUNCTION)
+ # first ensure that the block contains at least one statement
+ if symbol is not None and len(userDefinedFunction.get_block().get_stmts()) > 0:
+ # now check that the last statement is a return
+ self.__check_return_recursively(userDefinedFunction,
+ symbol.get_return_type(),
+ userDefinedFunction.get_block().get_stmts(),
+ False)
+ # now if it does not have a statement, but uses a return type, it is an error
+ elif symbol is not None and userDefinedFunction.has_return_type() and \
+ not symbol.get_return_type().equals(PredefinedTypes.get_void_type()):
+ code, message = Messages.get_no_return()
+ Logger.log_message(node=_node, code=code, message=message,
+ error_position=userDefinedFunction.get_source_position(),
+ log_level=LoggingLevel.ERROR)
+
+ def __check_return_recursively(self, processed_function, type_symbol=None, stmts=None, ret_defined: bool = False) -> None:
+ """
+ For a handed over statement, it checks if the statement is a return statement and if it is typed according to the handed over type symbol.
+ :param type_symbol: a single type symbol
+ :type type_symbol: type_symbol
+ :param stmts: a list of statements, either simple or compound
+ :type stmts: list(ASTSmallStmt,ASTCompoundStmt)
+ :param ret_defined: indicates whether a ret has already been defined after this block of stmt, thus is not
+ necessary. Implies that the return has been defined in the higher level block
+ """
+ # in order to ensure that in the sub-blocks, a return is not necessary, we check if the last one in this
+ # block is a return statement, thus it is not required to have a return in the sub-blocks, but optional
+ last_statement = stmts[len(stmts) - 1]
+ ret_defined = False or ret_defined
+ if (len(stmts) > 0 and isinstance(last_statement, ASTStmt)
+ and last_statement.is_small_stmt()
+ and last_statement.small_stmt.is_return_stmt()):
+ ret_defined = True
+
+ # now check that returns are there if necessary and correctly typed
+ for c_stmt in stmts:
+ if c_stmt.is_small_stmt():
+ stmt = c_stmt.small_stmt
+ else:
+ stmt = c_stmt.compound_stmt
+
+ # if it is a small statement, check if it is a return statement
+ if isinstance(stmt, ASTSmallStmt) and stmt.is_return_stmt():
+ # first check if the return is the last one in this block of statements
+ if stmts.index(c_stmt) != (len(stmts) - 1):
+ code, message = Messages.get_not_last_statement('Return')
+ Logger.log_message(error_position=stmt.get_source_position(),
+ code=code, message=message,
+ log_level=LoggingLevel.WARNING)
+
+ # now check that it corresponds to the declared type
+ if stmt.get_return_stmt().has_expression() and type_symbol is PredefinedTypes.get_void_type():
+ code, message = Messages.get_type_different_from_expected(PredefinedTypes.get_void_type(),
+ stmt.get_return_stmt().get_expression().type)
+ Logger.log_message(error_position=stmt.get_source_position(),
+ message=message, code=code, log_level=LoggingLevel.ERROR)
+
+ # if it is not void check if the type corresponds to the one stated
+ if not stmt.get_return_stmt().has_expression() and \
+ not type_symbol.equals(PredefinedTypes.get_void_type()):
+ code, message = Messages.get_type_different_from_expected(PredefinedTypes.get_void_type(),
+ type_symbol)
+ Logger.log_message(error_position=stmt.get_source_position(),
+ message=message, code=code, log_level=LoggingLevel.ERROR)
+
+ if stmt.get_return_stmt().has_expression():
+ type_of_return = stmt.get_return_stmt().get_expression().type
+ if isinstance(type_of_return, ErrorTypeSymbol):
+ code, message = Messages.get_type_could_not_be_derived(processed_function.get_name())
+ Logger.log_message(error_position=stmt.get_source_position(),
+ code=code, message=message, log_level=LoggingLevel.ERROR)
+ elif not type_of_return.equals(type_symbol):
+ TypeCaster.try_to_recover_or_error(type_symbol, type_of_return,
+ stmt.get_return_stmt().get_expression())
+ elif isinstance(stmt, ASTCompoundStmt):
+ # otherwise it is a compound stmt, thus check recursively
+ if stmt.is_if_stmt():
+ self.__check_return_recursively(processed_function,
+ type_symbol,
+ stmt.get_if_stmt().get_if_clause().get_block().get_stmts(),
+ ret_defined)
+ for else_ifs in stmt.get_if_stmt().get_elif_clauses():
+ self.__check_return_recursively(processed_function,
+ type_symbol, else_ifs.get_block().get_stmts(), ret_defined)
+ if stmt.get_if_stmt().has_else_clause():
+ self.__check_return_recursively(processed_function,
+ type_symbol,
+ stmt.get_if_stmt().get_else_clause().get_block().get_stmts(),
+ ret_defined)
+ elif stmt.is_while_stmt():
+ self.__check_return_recursively(processed_function,
+ type_symbol, stmt.get_while_stmt().get_block().get_stmts(),
+ ret_defined)
+ elif stmt.is_for_stmt():
+ self.__check_return_recursively(processed_function,
+ type_symbol, stmt.get_for_stmt().get_block().get_stmts(),
+ ret_defined)
+ # now, if a return statement has not been defined in the corresponding higher level block, we have to ensure that it is defined here
+ elif not ret_defined and stmts.index(c_stmt) == (len(stmts) - 1):
+ if not (isinstance(stmt, ASTSmallStmt) and stmt.is_return_stmt()):
+ code, message = Messages.get_no_return()
+ Logger.log_message(error_position=stmt.get_source_position(), log_level=LoggingLevel.ERROR,
+ code=code, message=message)
+
+
+class AssignImplicitConversionFactorVisitor(ASTVisitor):
+ """
+ This visitor checks that all expression correspond to the expected type.
+ """
+
+ def visit_declaration(self, node):
+ """
+ Visits a single declaration and asserts that type of lhs is equal to type of rhs.
+ :param node: a single declaration.
+ :type node: ASTDeclaration
+ """
+ assert isinstance(node, ASTDeclaration)
+ if node.has_expression():
+ if node.get_expression().get_source_position().equals(ASTSourceLocation.get_added_source_position()):
+ # no type checks are executed for added nodes, since we assume correctness
+ return
+ lhs_type = node.get_data_type().get_type_symbol()
+ rhs_type = node.get_expression().type
+ if isinstance(rhs_type, ErrorTypeSymbol):
+ LoggingHelper.drop_missing_type_error(node)
+ return
+ if self.__types_do_not_match(lhs_type, rhs_type):
+ TypeCaster.try_to_recover_or_error(lhs_type, rhs_type, node.get_expression())
+
+ def visit_inline_expression(self, node):
+ """
+ Visits a single inline expression and asserts that type of lhs is equal to type of rhs.
+ """
+ assert isinstance(node, ASTInlineExpression)
+ lhs_type = node.get_data_type().get_type_symbol()
+ rhs_type = node.get_expression().type
+ if isinstance(rhs_type, ErrorTypeSymbol):
+ LoggingHelper.drop_missing_type_error(node)
+ return
+
+ if self.__types_do_not_match(lhs_type, rhs_type):
+ TypeCaster.try_to_recover_or_error(lhs_type, rhs_type, node.get_expression())
+
+ def visit_assignment(self, node):
+ """
+ Visits a single expression and assures that type(lhs) == type(rhs).
+ :param node: a single assignment.
+ :type node: ASTAssignment
+ """
+ from pynestml.meta_model.ast_assignment import ASTAssignment
+ assert isinstance(node, ASTAssignment)
+
+ if node.get_source_position().equals(ASTSourceLocation.get_added_source_position()):
+ # no type checks are executed for added nodes, since we assume correctness
+ return
+ if node.is_direct_assignment: # case a = b is simple
+ self.handle_simple_assignment(node)
+ else:
+ self.handle_compound_assignment(node) # e.g. a *= b
+
+ def handle_compound_assignment(self, node):
+ rhs_expr = node.get_expression()
+ lhs_variable_symbol = node.get_variable().resolve_in_own_scope()
+ rhs_type_symbol = rhs_expr.type
+
+ if lhs_variable_symbol is None:
+ code, message = Messages.get_equation_var_not_in_state_block(node.get_variable().get_complete_name())
+ Logger.log_message(code=code, message=message, error_position=node.get_source_position(),
+ log_level=LoggingLevel.ERROR)
+ return
+
+ if isinstance(rhs_type_symbol, ErrorTypeSymbol):
+ LoggingHelper.drop_missing_type_error(node)
+ return
+
+ lhs_type_symbol = lhs_variable_symbol.get_type_symbol()
+
+ if node.is_compound_product:
+ if self.__types_do_not_match(lhs_type_symbol, lhs_type_symbol * rhs_type_symbol):
+ TypeCaster.try_to_recover_or_error(lhs_type_symbol, lhs_type_symbol * rhs_type_symbol,
+ node.get_expression())
+ return
+ return
+
+ if node.is_compound_quotient:
+ if self.__types_do_not_match(lhs_type_symbol, lhs_type_symbol / rhs_type_symbol):
+ TypeCaster.try_to_recover_or_error(lhs_type_symbol, lhs_type_symbol / rhs_type_symbol,
+ node.get_expression())
+ return
+ return
+
+ assert node.is_compound_sum or node.is_compound_minus
+ if self.__types_do_not_match(lhs_type_symbol, rhs_type_symbol):
+ TypeCaster.try_to_recover_or_error(lhs_type_symbol, rhs_type_symbol,
+ node.get_expression())
+
+ @staticmethod
+ def __types_do_not_match(lhs_type_symbol, rhs_type_symbol):
+ if lhs_type_symbol is None:
+ return True
+
+ return not lhs_type_symbol.equals(rhs_type_symbol)
+
+ def handle_simple_assignment(self, node):
+ from pynestml.symbols.symbol import SymbolKind
+ lhs_variable_symbol = node.get_scope().resolve_to_symbol(node.get_variable().get_complete_name(),
+ SymbolKind.VARIABLE)
+
+ rhs_type_symbol = node.get_expression().type
+ if isinstance(rhs_type_symbol, ErrorTypeSymbol):
+ LoggingHelper.drop_missing_type_error(node)
+ return
+
+ if lhs_variable_symbol is not None and self.__types_do_not_match(lhs_variable_symbol.get_type_symbol(),
+ rhs_type_symbol):
+ TypeCaster.try_to_recover_or_error(lhs_variable_symbol.get_type_symbol(), rhs_type_symbol,
+ node.get_expression())
+
+ def visit_function_call(self, node):
+ """
+ Check consistency for a single function call: check if the called function has been declared, whether the number and types of arguments correspond to the declaration, etc.
+
+ :param node: a single function call.
+ :type node: ASTFunctionCall
+ """
+ func_name = node.get_name()
+
+ if func_name == 'convolve':
+ return
+
+ symbol = node.get_scope().resolve_to_symbol(node.get_name(), SymbolKind.FUNCTION)
+
+ if symbol is None and ASTUtils.is_function_delay_variable(node):
+ return
+
+ # first check if the function has been declared
+ if symbol is None:
+ code, message = Messages.get_function_not_declared(node.get_name())
+ Logger.log_message(error_position=node.get_source_position(), log_level=LoggingLevel.ERROR,
+ code=code, message=message)
+ return
+
+ # check if the number of arguments is the same as in the symbol; accept anything for variadic types
+ is_variadic: bool = len(symbol.get_parameter_types()) == 1 and isinstance(symbol.get_parameter_types()[0], VariadicTypeSymbol)
+ if (not is_variadic) and len(node.get_args()) != len(symbol.get_parameter_types()):
+ code, message = Messages.get_wrong_number_of_args(str(node), len(symbol.get_parameter_types()),
+ len(node.get_args()))
+ Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR,
+ error_position=node.get_source_position())
+ return
+
+ # finally check if the call is correctly typed
+ expected_types = symbol.get_parameter_types()
+ actual_args = node.get_args()
+ actual_types = [arg.type for arg in actual_args]
+ for actual_arg, actual_type, expected_type in zip(actual_args, actual_types, expected_types):
+ if isinstance(actual_type, ErrorTypeSymbol):
+ code, message = Messages.get_type_could_not_be_derived(actual_arg)
+ Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR,
+ error_position=actual_arg.get_source_position())
+ return
+
+ if isinstance(expected_type, VariadicTypeSymbol):
+ # variadic type symbol accepts anything
+ return
+
+ if not actual_type.equals(expected_type) and not isinstance(expected_type, TemplateTypeSymbol):
+ TypeCaster.try_to_recover_or_error(expected_type, actual_type, actual_arg)
diff --git a/pynestml/transformers/synapse_post_neuron_transformer.py b/pynestml/transformers/synapse_post_neuron_transformer.py
index 68cc70a62..bf5b821dc 100644
--- a/pynestml/transformers/synapse_post_neuron_transformer.py
+++ b/pynestml/transformers/synapse_post_neuron_transformer.py
@@ -23,6 +23,7 @@
from typing import Any, Sequence, Mapping, Optional, Union
+from pynestml.cocos.co_cos_manager import CoCosManager
from pynestml.frontend.frontend_configuration import FrontendConfiguration
from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.meta_model.ast_equations_block import ASTEquationsBlock
@@ -565,8 +566,8 @@ def mark_post_port(_expr=None):
# make sure the moved symbols can be resolved in the scope of the neuron (that's where ``ASTExternalVariable._altscope`` will be pointing to)
ast_symbol_table_visitor = ASTSymbolTableVisitor()
- ast_symbol_table_visitor.after_ast_rewrite_ = True
new_neuron.accept(ast_symbol_table_visitor)
+ CoCosManager.check_cocos(new_neuron)
Logger.log_message(
None, -1, "In synapse: replacing variables with suffixed external variable references", None, LoggingLevel.INFO)
@@ -609,9 +610,10 @@ def mark_post_port(_expr=None):
new_neuron.accept(ASTParentVisitor())
new_synapse.accept(ASTParentVisitor())
ast_symbol_table_visitor = ASTSymbolTableVisitor()
- ast_symbol_table_visitor.after_ast_rewrite_ = True
new_neuron.accept(ast_symbol_table_visitor)
new_synapse.accept(ast_symbol_table_visitor)
+ CoCosManager.check_cocos(new_neuron)
+ CoCosManager.check_cocos(new_synapse)
ASTUtils.update_blocktype_for_common_parameters(new_synapse)
diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py
index d3d6f6ef5..a3983694d 100644
--- a/pynestml/utils/ast_utils.py
+++ b/pynestml/utils/ast_utils.py
@@ -28,7 +28,6 @@
from pynestml.codegeneration.printers.ast_printer import ASTPrinter
from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter
-from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter
from pynestml.frontend.frontend_configuration import FrontendConfiguration
from pynestml.generated.PyNestMLLexer import PyNestMLLexer
from pynestml.meta_model.ast_assignment import ASTAssignment
@@ -66,7 +65,6 @@
from pynestml.utils.messages import Messages
from pynestml.utils.string_utils import removesuffix
from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor
-from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
from pynestml.visitors.ast_visitor import ASTVisitor
@@ -1766,10 +1764,12 @@ def remove_initial_values_for_kernels(cls, model: ASTModel) -> None:
@classmethod
def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict]) -> None:
"""
- Update initial values for original ODE declarations (e.g. V_m', g_ahp'') that are present in the model
- before ODE-toolbox processing, with the formatted variable names and initial values returned by ODE-toolbox.
+ Update initial values for original ODE declarations (e.g. V_m', g_ahp'') that are present in the model before ODE-toolbox processing, with the formatted variable names and initial values returned by ODE-toolbox.
"""
from pynestml.utils.model_parser import ModelParser
+ from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
+ from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
+
assert len(model.get_equations_blocks()) == 1, "Only one equation block should be present"
if not model.get_state_blocks():
@@ -1782,10 +1782,6 @@ def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict
if cls.is_ode_variable(var.get_name(), model):
assert cls.variable_in_solver(cls.to_ode_toolbox_processed_name(var_name), solver_dicts)
- # replace the left-hand side variable name by the ode-toolbox format
- var.set_name(cls.to_ode_toolbox_processed_name(var.get_complete_name()))
- var.set_differential_order(0)
-
# replace the defining expression by the ode-toolbox result
iv_expr = cls.get_initial_value_from_ode_toolbox_result(
cls.to_ode_toolbox_processed_name(var_name), solver_dicts)
@@ -1794,6 +1790,9 @@ def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict
iv_expr.update_scope(state_block.get_scope())
iv_decl.set_expression(iv_expr)
+ model.accept(ASTParentVisitor())
+ model.accept(ASTSymbolTableVisitor())
+
@classmethod
def integrate_odes_args_strs_from_function_call(cls, function_call: ASTFunctionCall):
arg_names = []
@@ -2296,6 +2295,7 @@ def replace_convolve_calls_with_buffers_(cls, model: ASTModel, equations_block:
r"""
Replace all occurrences of `convolve(kernel[']^n, spike_input_port)` with the corresponding buffer variable, e.g. `g_E__X__spikes_exc[__d]^n` for a kernel named `g_E` and a spike input port named `spikes_exc`.
"""
+ from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
def replace_function_call_through_var(_expr=None):
if _expr.is_function_call() and _expr.get_function_call().get_name() == "convolve":
@@ -2326,6 +2326,7 @@ def func(x):
return replace_function_call_through_var(x) if isinstance(x, ASTSimpleExpression) else True
equations_block.accept(ASTHigherOrderVisitor(func))
+ equations_block.accept(ASTSymbolTableVisitor())
@classmethod
def update_blocktype_for_common_parameters(cls, node):
diff --git a/pynestml/utils/logger.py b/pynestml/utils/logger.py
index 06e95b804..8404f1245 100644
--- a/pynestml/utils/logger.py
+++ b/pynestml/utils/logger.py
@@ -19,7 +19,7 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
-from typing import List, Mapping, Optional, Tuple
+from typing import List, Mapping, Optional, Tuple, Union
from collections import OrderedDict
from enum import Enum
@@ -75,6 +75,7 @@ class Logger:
def init_logger(cls, logging_level: LoggingLevel):
"""
Initializes the logger.
+
:param logging_level: the logging level as required
:type logging_level: LoggingLevel
"""
@@ -82,7 +83,6 @@ def init_logger(cls, logging_level: LoggingLevel):
cls.curr_message = 0
cls.log = {}
cls.log_frozen = False
- return
@classmethod
def freeze_log(cls, do_freeze: bool = True):
@@ -95,6 +95,7 @@ def freeze_log(cls, do_freeze: bool = True):
def get_log(cls) -> Mapping[int, Tuple[ASTNode, LoggingLevel, str]]:
"""
Returns the overall log of messages. The structure of the log is: (NODE, LEVEL, MESSAGE)
+
:return: mapping from id to ASTNode, log level and message.
"""
return cls.log
@@ -103,6 +104,7 @@ def get_log(cls) -> Mapping[int, Tuple[ASTNode, LoggingLevel, str]]:
def set_log(cls, log, counter):
"""
Restores log from the 'log' variable
+
:param log: the log
:param counter: the counter
"""
@@ -113,20 +115,19 @@ def set_log(cls, log, counter):
def log_message(cls, node: ASTNode = None, code: MessageCode = None, message: str = None, error_position: ASTSourceLocation = None, log_level: LoggingLevel = None):
"""
Logs the handed over message on the handed over node. If the current logging is appropriate, the message is also printed.
+
:param node: the node in which the error occurred
:param code: a single error code
- :type code: ErrorCode
:param error_position: the position on which the error occurred.
- :type error_position: SourcePosition
:param message: a message.
- :type message: str
:param log_level: the corresponding log level.
- :type log_level: LoggingLevel
"""
if cls.log_frozen:
return
+
if cls.curr_message is None:
cls.init_logger(LoggingLevel.INFO)
+
from pynestml.meta_model.ast_node import ASTNode
from pynestml.utils.ast_source_location import ASTSourceLocation
assert (node is None or isinstance(node, ASTNode)), \
@@ -134,15 +135,23 @@ def log_message(cls, node: ASTNode = None, code: MessageCode = None, message: st
assert (error_position is None or isinstance(error_position, ASTSourceLocation)), \
'(PyNestML.Logger) Wrong type of error position provided (%s)!' % type(error_position)
from pynestml.meta_model.ast_model import ASTModel
+
if isinstance(node, ASTModel):
cls.log[cls.curr_message] = (
node.get_artifact_name(), node, log_level, code, error_position, message)
- elif cls.current_node is not None:
- cls.log[cls.curr_message] = (cls.current_node.get_artifact_name(), cls.current_node,
+ else:
+ if cls.current_node is not None:
+ artifact_name = cls.current_node.get_artifact_name()
+ else:
+ artifact_name = ""
+
+ cls.log[cls.curr_message] = (artifact_name, cls.current_node,
log_level, code, error_position, message)
+
cls.curr_message += 1
if cls.no_print:
return
+
if cls.logging_level.value <= log_level.value:
if isinstance(node, ASTInlineExpression):
node_name = node.variable_name
@@ -163,10 +172,9 @@ def log_message(cls, node: ASTNode = None, code: MessageCode = None, message: st
def string_to_level(cls, string: str) -> LoggingLevel:
"""
Returns the logging level corresponding to the handed over string. If no such exits, returns None.
+
:param string: a single string representing the level.
- :type string: str
:return: a single logging level.
- :rtype: LoggingLevel
"""
if string == 'DEBUG':
return LoggingLevel.DEBUG
@@ -183,7 +191,7 @@ def string_to_level(cls, string: str) -> LoggingLevel:
if string == 'NO' or string == 'NONE':
return LoggingLevel.NO
- raise Exception('Tried to convert unknown string \"' + string + '\" to logging level')
+ raise Exception("Tried to convert unknown string '" + string + "' to logging level")
@classmethod
def level_to_string(cls, level: LoggingLevel) -> str:
@@ -207,7 +215,7 @@ def level_to_string(cls, level: LoggingLevel) -> str:
if level == LoggingLevel.NO:
return 'NO'
- raise Exception('Tried to convert unknown logging level \"' + str(level) + '\" to string')
+ raise Exception("Tried to convert unknown logging level '" + str(level) + "' to string")
@classmethod
def set_logging_level(cls, level: LoggingLevel) -> None:
@@ -218,79 +226,89 @@ def set_logging_level(cls, level: LoggingLevel) -> None:
"""
if cls.log_frozen:
return
+
cls.logging_level = level
@classmethod
def set_current_node(cls, node: Optional[ASTNode]) -> None:
"""
- Sets the handed over node as the currently processed one. This enables a retrieval of messages for a
- specific node.
- :param node: a single node instance
+ Sets the handed over node as the currently processed one. This enables a retrieval of messages for a specific node.
+
+ :param node: a single node instance
"""
cls.current_node = node
@classmethod
- def get_all_messages_of_level_and_or_node(cls, node: ASTNode, level: LoggingLevel) -> List[Tuple[ASTNode, LoggingLevel, str]]:
+ def get_all_messages_of_level_and_or_node(cls, node: Union[ASTNode, str], level: LoggingLevel) -> List[Tuple[ASTNode, LoggingLevel, str]]:
"""
- Returns all messages which have a certain logging level, or have been reported for a certain node, or
- both.
+ Returns all messages which have a certain logging level, or have been reported for a certain node, or both.
+
:param node: a single node instance
:param level: a logging level
- :type level: LoggingLevel
:return: a list of messages with their levels.
- :rtype: list((str,Logging_Level)
"""
if level is None and node is None:
return cls.get_log()
+
+ if isinstance(node, str):
+ # search by artifact name
+ node_artifact_name = node
+ node = None
+ else:
+ # search by artifact class object
+ node_artifact_name = None
+
ret = list()
for (artifactName, node_i, logLevel, code, errorPosition, message) in cls.log.values():
- if (level == logLevel if level is not None else True) and (
- node if node is not None else True) and (
- node.get_artifact_name() == artifactName if node is not None else True):
+ if (level == logLevel if level is not None else True) and (node if node is not None else True) and (node_artifact_name == artifactName if node is not None else True):
ret.append((node, logLevel, message))
+
return ret
@classmethod
def get_all_messages_of_level(cls, level: LoggingLevel) -> List[Tuple[ASTNode, LoggingLevel, str]]:
"""
Returns all messages which have a certain logging level.
+
:param level: a logging level
- :type level: LoggingLevel
:return: a list of messages with their levels.
- :rtype: list((str,Logging_Level)
"""
if level is None:
return cls.get_log()
+
ret = list()
for (artifactName, node, logLevel, code, errorPosition, message) in cls.log.values():
if level == logLevel:
ret.append((node, logLevel, message))
+
return ret
@classmethod
def get_all_messages_of_node(cls, node: ASTNode) -> List[Tuple[ASTNode, LoggingLevel, str]]:
"""
Returns all messages which have been reported for a certain node.
+
:param node: a single node instance
:return: a list of messages with their levels.
- :rtype: list((str,Logging_Level)
"""
if node is None:
return cls.get_log()
+
ret = list()
for (artifactName, node_i, logLevel, code, errorPosition, message) in cls.log.values():
if (node_i == node if node is not None else True) and \
(node.get_artifact_name() == artifactName if node is not None else True):
ret.append((node, logLevel, message))
+
return ret
@classmethod
def has_errors(cls, node: ASTNode) -> bool:
"""
Indicates whether the handed over node, thus the corresponding model, has errors.
+
:param node: a single node instance.
:return: True if errors detected, otherwise False
- :rtype: bool
"""
return len(cls.get_all_messages_of_level_and_or_node(node, LoggingLevel.ERROR)) > 0
@@ -311,6 +329,7 @@ def get_json_format(cls) -> str:
(node.get_name() if node is not None else 'GLOBAL') + '", ' + \
'"severity":"' \
+ str(logLevel.name) + '", '
+
if code is not None:
ret += '"code":"' + \
code.name + \
@@ -323,10 +342,12 @@ def get_json_format(cls) -> str:
'", ' + \
'"message":"' + str(message).replace('"', "'") + '"}'
ret += ','
+
if len(cls.log.keys()) == 0:
parsed = json.loads('[]', object_pairs_hook=OrderedDict)
else:
ret = ret[:-1] # delete the last ","
ret += ']'
parsed = json.loads(ret, object_pairs_hook=OrderedDict)
+
return json.dumps(parsed, indent=2, sort_keys=False)
diff --git a/pynestml/utils/messages.py b/pynestml/utils/messages.py
index 69b32a8f4..27cd5cfcd 100644
--- a/pynestml/utils/messages.py
+++ b/pynestml/utils/messages.py
@@ -18,11 +18,15 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
-from enum import Enum
+
+from __future__ import annotations
+
from typing import Tuple
-from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
from collections.abc import Iterable
+from enum import Enum
+
+from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
from pynestml.meta_model.ast_function import ASTFunction
@@ -158,8 +162,8 @@ def get_input_path_not_found(cls, path):
return MessageCode.INPUT_PATH_NOT_FOUND, message
@classmethod
- def get_unknown_target(cls, target):
- message = 'Unknown target ("%s")' % (target)
+ def get_unknown_target_platform(cls, target: str):
+ message = "Unknown target: '" + target + "'"
return MessageCode.UNKNOWN_TARGET, message
@classmethod
@@ -313,22 +317,13 @@ def get_different_type_rhs_lhs(
return MessageCode.CAST_NOT_POSSIBLE, message
@classmethod
- def get_type_different_from_expected(cls, expected_type, got_type):
+ def get_type_different_from_expected(cls, expected_type, got_type) -> Tuple[MessageCode, str]:
"""
Returns a message indicating that the received type is different from the expected one.
:param expected_type: the expected type
- :type expected_type: TypeSymbol
:param got_type: the actual type
- :type got_type: type_symbol
:return: a message
- :rtype: (MessageCode,str)
"""
- from pynestml.symbols.type_symbol import TypeSymbol
- assert (expected_type is not None and isinstance(expected_type, TypeSymbol)), \
- '(PyNestML.Utils.Message) Not a type symbol provided (%s)!' % type(
- expected_type)
- assert (got_type is not None and isinstance(got_type, TypeSymbol)), \
- '(PyNestML.Utils.Message) Not a type symbol provided (%s)!' % type(got_type)
message = 'Actual type different from expected. Expected: \'%s\', got: \'%s\'!' % (
expected_type.print_symbol(), got_type.print_symbol())
return MessageCode.TYPE_DIFFERENT_FROM_EXPECTED, message
@@ -430,11 +425,10 @@ def get_module_generated(cls, path: str) -> Tuple[MessageCode, str]:
return MessageCode.MODULE_SUCCESSFULLY_GENERATED, message
@classmethod
- def get_variable_used_before_declaration(cls, variable_name):
+ def get_variable_used_before_declaration(cls, variable_name: str):
"""
Returns a message indicating that a variable is used before declaration.
:param variable_name: a variable name
- :type variable_name: str
:return: a message
:rtype: (MessageCode,str)
"""
@@ -701,7 +695,7 @@ def get_model_redeclared(cls, name: str) -> Tuple[MessageCode, str]:
'(PyNestML.Utils.Message) Not a string provided (%s)!' % type(name)
assert (name is not None and isinstance(name, str)), \
'(PyNestML.Utils.Message) Not a string provided (%s)!' % type(name)
- message = 'model \'%s\' redeclared!' % name
+ message = 'Model \'%s\' redeclared!' % name
return MessageCode.MODEL_REDECLARED, message
@classmethod
diff --git a/pynestml/utils/model_parser.py b/pynestml/utils/model_parser.py
index 7fabf361e..62a8669bb 100644
--- a/pynestml/utils/model_parser.py
+++ b/pynestml/utils/model_parser.py
@@ -24,6 +24,7 @@
from antlr4 import CommonTokenStream, FileStream, InputStream
from antlr4.error.ErrorStrategy import BailErrorStrategy, DefaultErrorStrategy
from antlr4.error.ErrorListener import ConsoleErrorListener
+from pynestml.cocos.co_cos_manager import CoCosManager
from pynestml.generated.PyNestMLLexer import PyNestMLLexer
from pynestml.generated.PyNestMLParser import PyNestMLParser
@@ -65,6 +66,7 @@
from pynestml.meta_model.ast_variable import ASTVariable
from pynestml.meta_model.ast_while_stmt import ASTWhileStmt
from pynestml.symbol_table.symbol_table import SymbolTable
+from pynestml.transformers.assign_implicit_conversion_factors_transformer import AssignImplicitConversionFactorsTransformer
from pynestml.utils.ast_source_location import ASTSourceLocation
from pynestml.utils.error_listener import NestMLErrorListener
from pynestml.utils.logger import Logger, LoggingLevel
@@ -142,10 +144,14 @@ def parse_file(cls, file_path=None):
for model in ast.get_model_list():
model.accept(ASTSymbolTableVisitor())
SymbolTable.add_model_scope(model.get_name(), model.get_scope())
+ Logger.set_current_node(model)
+ AssignImplicitConversionFactorsTransformer().transform(model)
+ Logger.set_current_node(None)
# store source paths
for model in ast.get_model_list():
model.file_path = file_path
+
ast.file_path = file_path
return ast
diff --git a/pynestml/utils/type_caster.py b/pynestml/utils/type_caster.py
index 34e4e6ccc..4ce2624dd 100644
--- a/pynestml/utils/type_caster.py
+++ b/pynestml/utils/type_caster.py
@@ -28,12 +28,11 @@ class TypeCaster:
@staticmethod
def do_magnitude_conversion_rhs_to_lhs(_rhs_type_symbol, _lhs_type_symbol, _containing_expression):
"""
- determine conversion factor from rhs to lhs, register it with the relevant expression
+ Determine conversion factor from rhs to lhs, register it with the relevant expression
"""
_containing_expression.set_implicit_conversion_factor(
- UnitTypeSymbol.get_conversion_factor(_lhs_type_symbol.astropy_unit,
- _rhs_type_symbol.astropy_unit))
- _containing_expression.type = _lhs_type_symbol
+ UnitTypeSymbol.get_conversion_factor(_rhs_type_symbol.astropy_unit,
+ _lhs_type_symbol.astropy_unit))
code, message = Messages.get_implicit_magnitude_conversion(_lhs_type_symbol, _rhs_type_symbol,
_containing_expression.get_implicit_conversion_factor())
Logger.log_message(code=code, message=message,
@@ -45,18 +44,26 @@ def try_to_recover_or_error(_lhs_type_symbol, _rhs_type_symbol, _containing_expr
if _rhs_type_symbol.is_castable_to(_lhs_type_symbol):
if isinstance(_lhs_type_symbol, UnitTypeSymbol) \
and isinstance(_rhs_type_symbol, UnitTypeSymbol):
- conversion_factor = UnitTypeSymbol.get_conversion_factor(
- _lhs_type_symbol.astropy_unit, _rhs_type_symbol.astropy_unit)
+ conversion_factor = UnitTypeSymbol.get_conversion_factor(_rhs_type_symbol.astropy_unit, _lhs_type_symbol.astropy_unit)
+
+ if conversion_factor is None:
+ # error during conversion
+ code, message = Messages.get_type_different_from_expected(_lhs_type_symbol, _rhs_type_symbol)
+ Logger.log_message(error_position=_containing_expression.get_source_position(),
+ code=code, message=message, log_level=LoggingLevel.ERROR)
+ return
+
if not conversion_factor == 1.:
# the units are mutually convertible, but require a factor unequal to 1 (e.g. mV and A*Ohm)
- TypeCaster.do_magnitude_conversion_rhs_to_lhs(
- _rhs_type_symbol, _lhs_type_symbol, _containing_expression)
+ TypeCaster.do_magnitude_conversion_rhs_to_lhs(_rhs_type_symbol, _lhs_type_symbol, _containing_expression)
+
# the units are mutually convertible (e.g. V and A*Ohm)
code, message = Messages.get_implicit_cast_rhs_to_lhs(_rhs_type_symbol.print_symbol(),
_lhs_type_symbol.print_symbol())
Logger.log_message(error_position=_containing_expression.get_source_position(),
code=code, message=message, log_level=LoggingLevel.INFO)
- else:
- code, message = Messages.get_type_different_from_expected(_lhs_type_symbol, _rhs_type_symbol)
- Logger.log_message(error_position=_containing_expression.get_source_position(),
- code=code, message=message, log_level=LoggingLevel.ERROR)
+ return
+
+ code, message = Messages.get_type_different_from_expected(_lhs_type_symbol, _rhs_type_symbol)
+ Logger.log_message(error_position=_containing_expression.get_source_position(),
+ code=code, message=message, log_level=LoggingLevel.ERROR)
diff --git a/pynestml/visitors/ast_builder_visitor.py b/pynestml/visitors/ast_builder_visitor.py
index 0e766d530..ff4b66fb9 100644
--- a/pynestml/visitors/ast_builder_visitor.py
+++ b/pynestml/visitors/ast_builder_visitor.py
@@ -52,16 +52,17 @@ def visitNestMLCompilationUnit(self, ctx):
models = list()
for child in ctx.model():
models.append(self.visit(child))
+
# extract the name of the artifact from the context
if hasattr(ctx.start.source[1], 'fileName'):
artifact_name = ntpath.basename(ctx.start.source[1].fileName)
else:
artifact_name = 'parsed_from_string'
+
compilation_unit = ASTNodeFactory.create_ast_nestml_compilation_unit(list_of_models=models,
source_position=create_source_pos(ctx),
artifact_name=artifact_name)
- # first ensure certain properties of the model
- CoCosManager.check_model_names_unique(compilation_unit)
+
return compilation_unit
# Visit a parse tree produced by PyNESTMLParser#datatype.
diff --git a/pynestml/visitors/ast_symbol_table_visitor.py b/pynestml/visitors/ast_symbol_table_visitor.py
index 011182543..bc85d4cdd 100644
--- a/pynestml/visitors/ast_symbol_table_visitor.py
+++ b/pynestml/visitors/ast_symbol_table_visitor.py
@@ -19,7 +19,6 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
-from pynestml.cocos.co_cos_manager import CoCosManager
from pynestml.meta_model.ast_model import ASTModel
from pynestml.meta_model.ast_model_body import ASTModelBody
from pynestml.meta_model.ast_namespace_decorator import ASTNamespaceDecorator
@@ -53,7 +52,6 @@ def __init__(self):
self.symbol_stack = Stack()
self.scope_stack = Stack()
self.block_type_stack = Stack()
- self.after_ast_rewrite_ = False
def visit_model(self, node: ASTModel) -> None:
"""
@@ -79,10 +77,6 @@ def visit_model(self, node: ASTModel) -> None:
node.get_scope().add_symbol(types[symbol])
def endvisit_model(self, node: ASTModel):
- # before following checks occur, we need to ensure several simple properties
- CoCosManager.post_symbol_table_builder_checks(
- node, after_ast_rewrite=self.after_ast_rewrite_)
-
# update the equations
for equation_block in node.get_equations_blocks():
ASTUtils.assign_ode_to_variables(equation_block)
@@ -287,8 +281,7 @@ def visit_declaration(self, node: ASTDeclaration) -> None:
namespace_decorators = {}
for d in node.get_decorators():
if isinstance(d, ASTNamespaceDecorator):
- namespace_decorators[str(d.get_namespace())] = str(
- d.get_name())
+ namespace_decorators[str(d.get_namespace())] = str(d.get_name())
else:
decorators.append(d)
@@ -296,6 +289,7 @@ def visit_declaration(self, node: ASTDeclaration) -> None:
block_type = None
if not self.block_type_stack.is_empty():
block_type = self.block_type_stack.top()
+
for var in node.get_variables(): # for all variables declared create a new symbol
var.update_scope(node.get_scope())
@@ -324,11 +318,14 @@ def visit_declaration(self, node: ASTDeclaration) -> None:
symbol.set_comment(node.get_comment())
node.get_scope().add_symbol(symbol)
var.set_type_symbol(type_symbol)
+
# the data type
node.get_data_type().update_scope(node.get_scope())
+
# the rhs update
if node.has_expression():
node.get_expression().update_scope(node.get_scope())
+
# the invariant update
if node.has_invariant():
node.get_invariant().update_scope(node.get_scope())
diff --git a/tests/cocos_test.py b/tests/cocos_test.py
deleted file mode 100644
index f557faaf0..000000000
--- a/tests/cocos_test.py
+++ /dev/null
@@ -1,698 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# cocos_test.py
-#
-# This file is part of NEST.
-#
-# Copyright (C) 2004 The NEST Initiative
-#
-# NEST is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 2 of the License, or
-# (at your option) any later version.
-#
-# NEST is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with NEST. If not, see .
-
-from __future__ import print_function
-
-import os
-import unittest
-
-from pynestml.utils.ast_source_location import ASTSourceLocation
-from pynestml.symbol_table.symbol_table import SymbolTable
-from pynestml.symbols.predefined_functions import PredefinedFunctions
-from pynestml.symbols.predefined_types import PredefinedTypes
-from pynestml.symbols.predefined_units import PredefinedUnits
-from pynestml.symbols.predefined_variables import PredefinedVariables
-from pynestml.utils.logger import LoggingLevel, Logger
-from pynestml.utils.model_parser import ModelParser
-
-
-class CoCosTest(unittest.TestCase):
-
- def setUp(self):
- Logger.init_logger(LoggingLevel.INFO)
- SymbolTable.initialize_symbol_table(
- ASTSourceLocation(
- start_line=0,
- start_column=0,
- end_line=0,
- end_column=0))
- PredefinedUnits.register_units()
- PredefinedTypes.register_types()
- PredefinedVariables.register_variables()
- PredefinedFunctions.register_functions()
-
- def test_invalid_element_defined_after_usage(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVariableDefinedAfterUsage.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_element_defined_after_usage(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVariableDefinedAfterUsage.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_element_in_same_line(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoElementInSameLine.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_element_in_same_line(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoElementInSameLine.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_integrate_odes_called_if_equations_defined(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_integrate_odes_called_if_equations_defined(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_element_not_defined_in_scope(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVariableNotDefined.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)), 5)
-
- def test_valid_element_not_defined_in_scope(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVariableNotDefined.nestml'))
- self.assertEqual(
- len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)),
- 0)
-
- def test_variable_with_same_name_as_unit(self):
- Logger.set_logging_level(LoggingLevel.NO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVariableWithSameNameAsUnit.nestml'))
- self.assertEqual(
- len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)),
- 3)
-
- def test_invalid_variable_redeclaration(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVariableRedeclared.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_variable_redeclaration(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVariableRedeclared.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_each_block_unique(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoEachBlockUnique.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2)
-
- def test_valid_each_block_unique(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoEachBlockUnique.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_function_unique_and_defined(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoFunctionNotUnique.nestml'))
- self.assertEqual(
- len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 5)
-
- def test_valid_function_unique_and_defined(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoFunctionNotUnique.nestml'))
- self.assertEqual(
- len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_inline_expressions_have_rhs(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoInlineExpressionHasNoRhs.nestml'))
- assert model is None
-
- def test_valid_inline_expressions_have_rhs(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoInlineExpressionHasNoRhs.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_inline_expression_has_several_lhs(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoInlineExpressionWithSeveralLhs.nestml'))
- assert model is None
-
- def test_valid_inline_expression_has_several_lhs(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoInlineExpressionWithSeveralLhs.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_no_values_assigned_to_input_ports(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoValueAssignedToInputPort.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_no_values_assigned_to_input_ports(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoValueAssignedToInputPort.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_order_of_equations_correct(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoNoOrderOfEquations.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2)
-
- def test_valid_order_of_equations_correct(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoNoOrderOfEquations.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_numerator_of_unit_one(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoUnitNumeratorNotOne.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)), 2)
-
- def test_valid_numerator_of_unit_one(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoUnitNumeratorNotOne.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_names_of_neurons_unique(self):
- Logger.init_logger(LoggingLevel.INFO)
- ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoMultipleNeuronsWithEqualName.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(None, LoggingLevel.ERROR)), 1)
-
- def test_valid_names_of_neurons_unique(self):
- Logger.init_logger(LoggingLevel.INFO)
- ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoMultipleNeuronsWithEqualName.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(None, LoggingLevel.ERROR)), 0)
-
- def test_invalid_no_nest_collision(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoNestNamespaceCollision.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_no_nest_collision(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoNestNamespaceCollision.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_redundant_input_port_keywords_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoInputPortWithRedundantTypes.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_redundant_input_port_keywords_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoInputPortWithRedundantTypes.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_parameters_assigned_only_in_parameters_block(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoParameterAssignedOutsideBlock.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_parameters_assigned_only_in_parameters_block(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoParameterAssignedOutsideBlock.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_inline_expressions_assigned_only_in_declaration(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoAssignmentToInlineExpression.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_invalid_internals_assigned_only_in_internals_block(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoInternalAssignedOutsideBlock.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_internals_assigned_only_in_internals_block(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoInternalAssignedOutsideBlock.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_function_with_wrong_arg_number_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_function_with_wrong_arg_number_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_init_values_have_rhs_and_ode(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoInitValuesWithoutOde.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), 2)
-
- def test_valid_init_values_have_rhs_and_ode(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoInitValuesWithoutOde.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), 2)
-
- def test_invalid_incorrect_return_stmt_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoIncorrectReturnStatement.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 4)
-
- def test_valid_incorrect_return_stmt_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoIncorrectReturnStatement.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_ode_vars_outside_init_block_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoOdeVarNotInInitialValues.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_ode_vars_outside_init_block_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoOdeVarNotInInitialValues.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_convolve_correctly_defined(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoConvolveNotCorrectlyProvided.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)), 3)
-
- def test_valid_convolve_correctly_defined(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoConvolveNotCorrectlyProvided.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_vector_in_non_vector_declaration_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVectorInNonVectorDeclaration.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_vector_in_non_vector_declaration_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVectorInNonVectorDeclaration.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_vector_parameter_declaration(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVectorParameterDeclaration.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_vector_parameter_declaration(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVectorParameterDeclaration.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_vector_parameter_type(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVectorParameterType.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_vector_parameter_type(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVectorParameterType.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_vector_parameter_size(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVectorDeclarationSize.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2)
-
- def test_valid_vector_parameter_size(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVectorDeclarationSize.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_convolve_correctly_parameterized(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoConvolveNotCorrectlyParametrized.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2)
-
- def test_valid_convolve_correctly_parameterized(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoConvolveNotCorrectlyParametrized.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)), 0)
-
- def test_invalid_invariant_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoInvariantNotBool.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_invariant_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoInvariantNotBool.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_expression_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoIllegalExpression.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)), 6)
-
- def test_valid_expression_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoIllegalExpression.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_compound_expression_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CompoundOperatorWithDifferentButCompatibleUnits.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 5)
-
- def test_valid_compound_expression_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CompoundOperatorWithDifferentButCompatibleUnits.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_ode_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoOdeIncorrectlyTyped.nestml'))
- self.assertTrue(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)) > 0)
-
- def test_valid_ode_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoOdeCorrectlyTyped.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_output_block_defined_if_emit_call(self):
- """test that an error is raised when the emit_spike() function is called by the neuron, but an output block is not defined"""
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoOutputPortDefinedIfEmitCall.nestml'))
- self.assertTrue(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)) > 0)
-
- def test_invalid_output_port_defined_if_emit_call(self):
- """test that an error is raised when the emit_spike() function is called by the neuron, but a spiking output port is not defined"""
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoOutputPortDefinedIfEmitCall-2.nestml'))
- self.assertTrue(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)) > 0)
-
- def test_valid_output_port_defined_if_emit_call(self):
- """test that no error is raised when the output block is missing, but not emit_spike() functions are called"""
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoOutputPortDefinedIfEmitCall.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_valid_coco_kernel_type(self):
- """
- Test the functionality of CoCoKernelType.
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoKernelType.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_coco_kernel_type(self):
- """
- Test the functionality of CoCoKernelType.
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoKernelType.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_invalid_coco_kernel_type_initial_values(self):
- """
- Test the functionality of CoCoKernelType.
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoKernelTypeInitialValues.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 4)
-
- def test_valid_coco_state_variables_initialized(self):
- """
- Test that the CoCo condition is applicable for all the variables in the state block initialized with a value
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoStateVariablesInitialized.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_coco_state_variables_initialized(self):
- """
- Test that the CoCo condition is applicable for all the variables in the state block not initialized
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoStateVariablesInitialized.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2)
-
- def test_invalid_co_co_priorities_correctly_specified(self):
- """
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoPrioritiesCorrectlySpecified.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_co_co_priorities_correctly_specified(self):
- """
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoPrioritiesCorrectlySpecified.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_co_co_resolution_legally_used(self):
- """
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoResolutionLegallyUsed.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2)
-
- def test_valid_co_co_resolution_legally_used(self):
- """
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoResolutionLegallyUsed.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_valid_co_co_vector_input_port(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVectorInputPortSizeAndType.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_co_co_vector_input_port(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVectorInputPortSizeAndType.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
diff --git a/tests/function_parameter_templating_test.py b/tests/function_parameter_templating_test.py
deleted file mode 100644
index e3cb89e41..000000000
--- a/tests/function_parameter_templating_test.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# function_parameter_templating_test.py
-#
-# This file is part of NEST.
-#
-# Copyright (C) 2004 The NEST Initiative
-#
-# NEST is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 2 of the License, or
-# (at your option) any later version.
-#
-# NEST is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with NEST. If not, see .
-
-import os
-import unittest
-
-from pynestml.symbol_table.symbol_table import SymbolTable
-from pynestml.symbols.predefined_functions import PredefinedFunctions
-from pynestml.symbols.predefined_types import PredefinedTypes
-from pynestml.symbols.predefined_units import PredefinedUnits
-from pynestml.symbols.predefined_variables import PredefinedVariables
-from pynestml.utils.ast_source_location import ASTSourceLocation
-from pynestml.utils.logger import Logger, LoggingLevel
-from pynestml.utils.model_parser import ModelParser
-
-# minor setup steps required
-SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0))
-PredefinedUnits.register_units()
-PredefinedTypes.register_types()
-PredefinedVariables.register_variables()
-PredefinedFunctions.register_functions()
-
-
-class FunctionParameterTemplatingTest(unittest.TestCase):
- """
- This test is used to test the correct derivation of types when functions use templated type parameters.
- """
-
- def test(self):
- Logger.init_logger(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__),
- "resources", "FunctionParameterTemplatingTest.nestml"))))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)), 7)
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/tests/nest_tests/nest_delay_based_variables_test.py b/tests/nest_tests/nest_delay_based_variables_test.py
index 51f863e19..a11c280f2 100644
--- a/tests/nest_tests/nest_delay_based_variables_test.py
+++ b/tests/nest_tests/nest_delay_based_variables_test.py
@@ -19,13 +19,12 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
+from typing import List
+
import numpy as np
import os
-from typing import List
import pytest
-import nest
-
try:
import matplotlib
import matplotlib.pyplot as plt
@@ -34,15 +33,12 @@
except BaseException:
TEST_PLOTS = False
+import nest
+
from pynestml.codegeneration.nest_tools import NESTTools
from pynestml.frontend.pynestml_frontend import generate_nest_target
-target_path = "target_delay"
-logging_level = "DEBUG"
-suffix = "_nestml"
-
-
def plot_fig(times, recordable_events_delay: dict, recordable_events: dict, filename: str):
fig, axes = plt.subplots(len(recordable_events), 1, figsize=(7, 9), sharex=True)
for i, recordable_name in enumerate(recordable_events_delay.keys()):
@@ -86,6 +82,9 @@ def run_simulation(neuron_model_name: str, module_name: str, recordables: List[s
("DelayDifferentialEquationsWithNumericSolver.nestml", "dde_numeric_nestml", ["x", "z"]),
("DelayDifferentialEquationsWithMixedSolver.nestml", "dde_mixed_nestml", ["x", "z"])])
def test_dde_with_analytic_solver(file_name: str, neuron_model_name: str, recordables: List[str]):
+ target_path = "target_delay"
+ logging_level = "DEBUG"
+ suffix = "_nestml"
input_path = os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), "resources", file_name)))
module_name = neuron_model_name + "_module"
print("Module name: ", module_name)
@@ -112,16 +111,3 @@ def test_dde_with_analytic_solver(file_name: str, neuron_model_name: str, record
if neuron_model_name == "dde_analytic_nestml":
np.testing.assert_allclose(recordable_events_delay[recordables[1]][int(delay):],
recordable_events[recordables[1]][:-int(delay)])
-
- @pytest.fixture(scope="function", autouse=True)
- def cleanup(self):
- # Run the test
- yield
-
- # clean up
- import shutil
- if self.target_path:
- try:
- shutil.rmtree(self.target_path)
- except Exception:
- pass
diff --git a/tests/nest_tests/resources/integrate_odes_test_params.nestml b/tests/nest_tests/resources/integrate_odes_test_params.nestml
index d07fe8fd4..d6430e537 100644
--- a/tests/nest_tests/resources/integrate_odes_test_params.nestml
+++ b/tests/nest_tests/resources/integrate_odes_test_params.nestml
@@ -8,7 +8,6 @@ model integrate_odes_test:
update:
integrate_odes(2 * test_1)
- integrate_odes(test_3)
integrate_odes(100 ms)
integrate_odes(test_1)
integrate_odes(test_2)
diff --git a/tests/nest_tests/resources/integrate_odes_test_params2.nestml b/tests/nest_tests/resources/integrate_odes_test_params2.nestml
new file mode 100644
index 000000000..616401e48
--- /dev/null
+++ b/tests/nest_tests/resources/integrate_odes_test_params2.nestml
@@ -0,0 +1,10 @@
+"""
+Model for testing the integrate_odes() function.
+"""
+model integrate_odes_test:
+ state:
+ test_1 real = 0.
+ test_2 real = 0.
+
+ update:
+ integrate_odes(test_3)
diff --git a/tests/nest_tests/test_integrate_odes.py b/tests/nest_tests/test_integrate_odes.py
index 99b94c6ca..6ddb699b4 100644
--- a/tests/nest_tests/test_integrate_odes.py
+++ b/tests/nest_tests/test_integrate_odes.py
@@ -27,16 +27,9 @@
import nest
-from pynestml.utils.ast_source_location import ASTSourceLocation
-from pynestml.symbol_table.symbol_table import SymbolTable
-from pynestml.symbols.predefined_functions import PredefinedFunctions
-from pynestml.symbols.predefined_types import PredefinedTypes
-from pynestml.symbols.predefined_units import PredefinedUnits
-from pynestml.symbols.predefined_variables import PredefinedVariables
from pynestml.codegeneration.nest_tools import NESTTools
-from pynestml.frontend.pynestml_frontend import generate_nest_target
+from pynestml.frontend.pynestml_frontend import generate_nest_target, generate_target
from pynestml.utils.logger import LoggingLevel, Logger
-from pynestml.utils.model_parser import ModelParser
try:
import matplotlib
@@ -227,12 +220,15 @@ def test_integrate_odes_nonlinear(self):
def test_integrate_odes_params(self):
r"""Test the integrate_odes() function, in particular with respect to the parameter types."""
- Logger.init_logger(LoggingLevel.INFO)
- SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0))
- PredefinedUnits.register_units()
- PredefinedTypes.register_types()
- PredefinedVariables.register_variables()
- PredefinedFunctions.register_functions()
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml"))))
- assert len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)) == 6
+ fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml")))
+ generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")
+
+ assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2
+
+ def test_integrate_odes_params2(self):
+ r"""Test the integrate_odes() function, in particular with respect to non-existent parameter variables."""
+
+ fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params2.nestml")))
+ generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")
+
+ assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2
diff --git a/tests/test_cocos.py b/tests/test_cocos.py
new file mode 100644
index 000000000..c60d778cf
--- /dev/null
+++ b/tests/test_cocos.py
@@ -0,0 +1,405 @@
+# -*- coding: utf-8 -*-
+#
+# test_cocos.py
+#
+# This file is part of NEST.
+#
+# Copyright (C) 2004 The NEST Initiative
+#
+# NEST is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# NEST is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+
+from __future__ import print_function
+
+from typing import Optional
+
+import os
+import pytest
+
+from pynestml.meta_model.ast_model import ASTModel
+from pynestml.symbol_table.symbol_table import SymbolTable
+from pynestml.symbols.predefined_functions import PredefinedFunctions
+from pynestml.symbols.predefined_types import PredefinedTypes
+from pynestml.symbols.predefined_units import PredefinedUnits
+from pynestml.symbols.predefined_variables import PredefinedVariables
+from pynestml.utils.ast_source_location import ASTSourceLocation
+from pynestml.utils.logger import LoggingLevel, Logger
+from pynestml.utils.model_parser import ModelParser
+from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
+from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
+
+
+class TestCoCos:
+
+ @pytest.fixture(scope="module", autouse=True)
+ def setUp(self):
+ SymbolTable.initialize_symbol_table(
+ ASTSourceLocation(
+ start_line=0,
+ start_column=0,
+ end_line=0,
+ end_column=0))
+ PredefinedUnits.register_units()
+ PredefinedTypes.register_types()
+ PredefinedVariables.register_variables()
+ PredefinedFunctions.register_functions()
+
+ def test_invalid_element_defined_after_usage(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVariableDefinedAfterUsage.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_element_defined_after_usage(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableDefinedAfterUsage.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_element_in_same_line(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoElementInSameLine.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_element_in_same_line(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoElementInSameLine.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_integrate_odes_called_if_equations_defined(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_integrate_odes_called_if_equations_defined(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_element_not_defined_in_scope(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVariableNotDefined.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 6
+
+ def test_valid_element_not_defined_in_scope(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableNotDefined.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_variable_with_same_name_as_unit(self):
+ Logger.set_logging_level(LoggingLevel.NO)
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableWithSameNameAsUnit.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.WARNING)) == 3
+
+ def test_invalid_variable_redeclaration(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVariableRedeclared.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_variable_redeclaration(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableRedeclared.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_each_block_unique(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoEachBlockUnique.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_each_block_unique(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoEachBlockUnique.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_function_unique_and_defined(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoFunctionNotUnique.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 8
+
+ def test_valid_function_unique_and_defined(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoFunctionNotUnique.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_inline_expressions_have_rhs(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInlineExpressionHasNoRhs.nestml'))
+ assert model is None
+
+ def test_valid_inline_expressions_have_rhs(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInlineExpressionHasNoRhs.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_inline_expression_has_several_lhs(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInlineExpressionWithSeveralLhs.nestml'))
+ assert model is None
+
+ def test_valid_inline_expression_has_several_lhs(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInlineExpressionWithSeveralLhs.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_no_values_assigned_to_input_ports(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoValueAssignedToInputPort.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_no_values_assigned_to_input_ports(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoValueAssignedToInputPort.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_order_of_equations_correct(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoNoOrderOfEquations.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_order_of_equations_correct(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoNoOrderOfEquations.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_numerator_of_unit_one(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoUnitNumeratorNotOne.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_numerator_of_unit_one(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoUnitNumeratorNotOne.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_names_of_neurons_unique(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoMultipleNeuronsWithEqualName.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 3
+
+ def test_valid_names_of_neurons_unique(self):
+ self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoMultipleNeuronsWithEqualName.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(None, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_no_nest_collision(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoNestNamespaceCollision.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_no_nest_collision(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoNestNamespaceCollision.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_redundant_input_port_keywords_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInputPortWithRedundantTypes.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_redundant_input_port_keywords_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInputPortWithRedundantTypes.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_parameters_assigned_only_in_parameters_block(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoParameterAssignedOutsideBlock.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_parameters_assigned_only_in_parameters_block(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoParameterAssignedOutsideBlock.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_inline_expressions_assigned_only_in_declaration(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoAssignmentToInlineExpression.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_invalid_internals_assigned_only_in_internals_block(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInternalAssignedOutsideBlock.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_internals_assigned_only_in_internals_block(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInternalAssignedOutsideBlock.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_function_with_wrong_arg_number_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_function_with_wrong_arg_number_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_init_values_have_rhs_and_ode(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInitValuesWithoutOde.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.WARNING)) == 2
+
+ def test_valid_init_values_have_rhs_and_ode(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInitValuesWithoutOde.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.WARNING)) == 2
+
+ def test_invalid_incorrect_return_stmt_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoIncorrectReturnStatement.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 8
+
+ def test_valid_incorrect_return_stmt_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoIncorrectReturnStatement.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_ode_vars_outside_init_block_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOdeVarNotInInitialValues.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_ode_vars_outside_init_block_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoOdeVarNotInInitialValues.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_convolve_correctly_defined(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoConvolveNotCorrectlyProvided.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_convolve_correctly_defined(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoConvolveNotCorrectlyProvided.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_vector_in_non_vector_declaration_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorInNonVectorDeclaration.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_vector_in_non_vector_declaration_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorInNonVectorDeclaration.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_vector_parameter_declaration(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorParameterDeclaration.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_vector_parameter_declaration(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorParameterDeclaration.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_vector_parameter_type(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorParameterType.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_vector_parameter_type(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorParameterType.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_vector_parameter_size(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorDeclarationSize.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_vector_parameter_size(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorDeclarationSize.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_convolve_correctly_parameterized(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoConvolveNotCorrectlyParametrized.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_convolve_correctly_parameterized(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoConvolveNotCorrectlyParametrized.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_invariant_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInvariantNotBool.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_invariant_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInvariantNotBool.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_expression_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoIllegalExpression.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_expression_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoIllegalExpression.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_compound_expression_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CompoundOperatorWithDifferentButCompatibleUnits.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 10
+
+ def test_valid_compound_expression_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CompoundOperatorWithDifferentButCompatibleUnits.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_ode_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOdeIncorrectlyTyped.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) > 0
+
+ def test_valid_ode_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoOdeCorrectlyTyped.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_output_block_defined_if_emit_call(self):
+ """test that an error is raised when the emit_spike() function is called by the neuron, but an output block is not defined"""
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOutputPortDefinedIfEmitCall.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) > 0
+
+ def test_invalid_output_port_defined_if_emit_call(self):
+ """test that an error is raised when the emit_spike() function is called by the neuron, but a spiking output port is not defined"""
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOutputPortDefinedIfEmitCall-2.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) > 0
+
+ def test_valid_output_port_defined_if_emit_call(self):
+ """test that no error is raised when the output block is missing, but not emit_spike() functions are called"""
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoOutputPortDefinedIfEmitCall.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_valid_coco_kernel_type(self):
+ """
+ Test the functionality of CoCoKernelType.
+ """
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoKernelType.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_coco_kernel_type(self):
+ """
+ Test the functionality of CoCoKernelType.
+ """
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoKernelType.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_invalid_coco_kernel_type_initial_values(self):
+ """
+ Test the functionality of CoCoKernelType.
+ """
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoKernelTypeInitialValues.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 4
+
+ def test_valid_coco_state_variables_initialized(self):
+ """
+ Test that the CoCo condition is applicable for all the variables in the state block initialized with a value
+ """
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoStateVariablesInitialized.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_coco_state_variables_initialized(self):
+ """
+ Test that the CoCo condition is applicable for all the variables in the state block not initialized
+ """
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoStateVariablesInitialized.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_invalid_co_co_priorities_correctly_specified(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoPrioritiesCorrectlySpecified.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_co_co_priorities_correctly_specified(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoPrioritiesCorrectlySpecified.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_co_co_resolution_legally_used(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoResolutionLegallyUsed.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_co_co_resolution_legally_used(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoResolutionLegallyUsed.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_valid_co_co_vector_input_port(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorInputPortSizeAndType.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_co_co_vector_input_port(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorInputPortSizeAndType.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def _parse_and_validate_model(self, fname: str) -> Optional[str]:
+ from pynestml.frontend.pynestml_frontend import generate_target
+
+ Logger.init_logger(LoggingLevel.DEBUG)
+
+ try:
+ generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")
+ except BaseException:
+ return None
+
+ ast_compilation_unit = ModelParser.parse_file(fname)
+ if ast_compilation_unit is None or len(ast_compilation_unit.get_model_list()) == 0:
+ return None
+
+ model: ASTModel = ast_compilation_unit.get_model_list()[0]
+ model_name = model.get_name()
+
+ return model_name
diff --git a/tests/test_function_parameter_templating.py b/tests/test_function_parameter_templating.py
new file mode 100644
index 000000000..b93e06780
--- /dev/null
+++ b/tests/test_function_parameter_templating.py
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+#
+# test_function_parameter_templating.py
+#
+# This file is part of NEST.
+#
+# Copyright (C) 2004 The NEST Initiative
+#
+# NEST is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# NEST is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+
+import os
+
+from pynestml.utils.logger import Logger, LoggingLevel
+from pynestml.frontend.pynestml_frontend import generate_target
+
+
+class TestFunctionParameterTemplating:
+ """
+ This test is used to test the correct derivation of types when functions use templated type parameters.
+ """
+
+ def test(self):
+ fname = os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), "resources", "FunctionParameterTemplatingTest.nestml")))
+ generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")
+ assert len(Logger.get_all_messages_of_level_and_or_node("templated_function_parameters_type_test", LoggingLevel.ERROR)) == 5
diff --git a/tests/test_unit_system.py b/tests/test_unit_system.py
new file mode 100644
index 000000000..2cad0b98d
--- /dev/null
+++ b/tests/test_unit_system.py
@@ -0,0 +1,164 @@
+# -*- coding: utf-8 -*-
+#
+# test_unit_system.py
+#
+# This file is part of NEST.
+#
+# Copyright (C) 2004 The NEST Initiative
+#
+# NEST is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# NEST is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+
+import os
+import pytest
+
+from pynestml.codegeneration.printers.constant_printer import ConstantPrinter
+from pynestml.codegeneration.printers.cpp_expression_printer import CppExpressionPrinter
+from pynestml.codegeneration.printers.cpp_simple_expression_printer import CppSimpleExpressionPrinter
+from pynestml.codegeneration.printers.cpp_type_symbol_printer import CppTypeSymbolPrinter
+from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter
+from pynestml.codegeneration.printers.nest_cpp_function_call_printer import NESTCppFunctionCallPrinter
+from pynestml.codegeneration.printers.nestml_variable_printer import NestMLVariablePrinter
+from pynestml.frontend.pynestml_frontend import generate_target
+from pynestml.symbol_table.symbol_table import SymbolTable
+from pynestml.symbols.predefined_functions import PredefinedFunctions
+from pynestml.symbols.predefined_types import PredefinedTypes
+from pynestml.symbols.predefined_units import PredefinedUnits
+from pynestml.symbols.predefined_variables import PredefinedVariables
+from pynestml.utils.ast_source_location import ASTSourceLocation
+from pynestml.utils.logger import Logger, LoggingLevel
+from pynestml.utils.model_parser import ModelParser
+
+
+class TestUnitSystem:
+ r"""
+ Test class for units system.
+ """
+
+ @pytest.fixture(scope="class", autouse=True)
+ def setUp(self, request):
+ Logger.set_logging_level(LoggingLevel.INFO)
+
+ SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0))
+
+ PredefinedUnits.register_units()
+ PredefinedTypes.register_types()
+ PredefinedVariables.register_variables()
+ PredefinedFunctions.register_functions()
+
+ Logger.init_logger(LoggingLevel.INFO)
+
+ variable_printer = NestMLVariablePrinter(None)
+ function_call_printer = NESTCppFunctionCallPrinter(None)
+ cpp_variable_printer = CppVariablePrinter(None)
+ self.printer = CppExpressionPrinter(CppSimpleExpressionPrinter(cpp_variable_printer,
+ ConstantPrinter(),
+ function_call_printer))
+ cpp_variable_printer._expression_printer = self.printer
+ variable_printer._expression_printer = self.printer
+ function_call_printer._expression_printer = self.printer
+
+ request.cls.printer = self.printer
+
+ def get_first_statement_in_update_block(self, model):
+ if model.get_model_list()[0].get_update_blocks()[0]:
+ return model.get_model_list()[0].get_update_blocks()[0].get_block().get_stmts()[0]
+
+ return None
+
+ def get_first_declaration_in_state_block(self, model):
+ assert len(model.get_model_list()[0].get_state_blocks()) == 1
+
+ return model.get_model_list()[0].get_state_blocks()[0].get_declarations()[0]
+
+ def get_first_declared_function(self, model):
+ return model.get_model_list()[0].get_functions()[0]
+
+ def print_rhs_of_first_assignment_in_update_block(self, model):
+ assignment = self.get_first_statement_in_update_block(model).small_stmt.get_assignment()
+ expression = assignment.get_expression()
+
+ return self.printer.print(expression)
+
+ def print_first_function_call_in_update_block(self, model):
+ function_call = self.get_first_statement_in_update_block(model).small_stmt.get_function_call()
+
+ return self.printer.print(function_call)
+
+ def print_rhs_of_first_declaration_in_state_block(self, model):
+ declaration = self.get_first_declaration_in_state_block(model)
+ expression = declaration.get_expression()
+
+ return self.printer.print(expression)
+
+ def print_first_return_statement_in_first_declared_function(self, model):
+ func = self.get_first_declared_function(model)
+ return_expression = func.get_block().get_stmts()[0].small_stmt.get_return_stmt().get_expression()
+ return self.printer.print(return_expression)
+
+ def test_expression_after_magnitude_conversion_in_direct_assignment(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DirectAssignmentWithDifferentButCompatibleUnits.nestml'))
+ printed_rhs_expression = self.print_rhs_of_first_assignment_in_update_block(model)
+
+ assert printed_rhs_expression == '(1000.0 * (10 * V))'
+
+ def test_expression_after_nested_magnitude_conversion_in_direct_assignment(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DirectAssignmentWithDifferentButCompatibleNestedUnits.nestml'))
+ printed_rhs_expression = self.print_rhs_of_first_assignment_in_update_block(model)
+
+ assert printed_rhs_expression == '(1000.0 * (10 * V + (0.001 * (5 * mV)) + 20 * V + (1000.0 * (1 * kV))))'
+
+ def test_expression_after_magnitude_conversion_in_compound_assignment(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'CompoundAssignmentWithDifferentButCompatibleUnits.nestml'))
+ printed_rhs_expression = self.print_rhs_of_first_assignment_in_update_block(model)
+
+ assert printed_rhs_expression == '(0.001 * (1200 * mV))'
+
+ def test_expression_after_magnitude_conversion_in_declaration(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DeclarationWithDifferentButCompatibleUnitMagnitude.nestml'))
+ printed_rhs_expression = self.print_rhs_of_first_declaration_in_state_block(model)
+
+ assert printed_rhs_expression == '(1000.0 * (10 * V))'
+
+ def test_expression_after_type_conversion_in_declaration(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DeclarationWithDifferentButCompatibleUnits.nestml'))
+ declaration = self.get_first_declaration_in_state_block(model)
+ from astropy import units as u
+
+ assert declaration.get_expression().type.unit.unit == u.mV
+
+ def test_declaration_with_same_variable_name_as_unit(self):
+ Logger.init_logger(LoggingLevel.DEBUG)
+
+ generate_target(input_path=os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DeclarationWithSameVariableNameAsUnit.nestml'), target_platform="NONE", logging_level="DEBUG")
+
+ assert len(Logger.get_all_messages_of_level_and_or_node("BlockTest", LoggingLevel.ERROR)) == 0
+ assert len(Logger.get_all_messages_of_level_and_or_node("BlockTest", LoggingLevel.WARNING)) == 3
+
+ def test_expression_after_magnitude_conversion_in_standalone_function_call(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'FunctionCallWithDifferentButCompatibleUnits.nestml'))
+ printed_function_call = self.print_first_function_call_in_update_block(model)
+
+ assert printed_function_call == 'foo((1000.0 * (10 * V)))'
+
+ def test_expression_after_magnitude_conversion_in_rhs_function_call(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'RhsFunctionCallWithDifferentButCompatibleUnits.nestml'))
+ printed_function_call = self.print_rhs_of_first_assignment_in_update_block(model)
+
+ assert printed_function_call == 'foo((1000.0 * (10 * V)))'
+
+ def test_return_stmt_after_magnitude_conversion_in_function_body(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'FunctionBodyReturnStatementWithDifferentButCompatibleUnits.nestml'))
+ printed_return_stmt = self.print_first_return_statement_in_first_declared_function(model)
+
+ assert printed_return_stmt == '(0.001 * (bar))'
diff --git a/tests/unit_system_test.py b/tests/unit_system_test.py
deleted file mode 100644
index 1f7817b91..000000000
--- a/tests/unit_system_test.py
+++ /dev/null
@@ -1,177 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# unit_system_test.py
-#
-# This file is part of NEST.
-#
-# Copyright (C) 2004 The NEST Initiative
-#
-# NEST is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 2 of the License, or
-# (at your option) any later version.
-#
-# NEST is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with NEST. If not, see .
-
-import os
-import unittest
-from pynestml.codegeneration.printers.constant_printer import ConstantPrinter
-
-from pynestml.codegeneration.printers.cpp_expression_printer import CppExpressionPrinter
-from pynestml.codegeneration.printers.cpp_simple_expression_printer import CppSimpleExpressionPrinter
-from pynestml.codegeneration.printers.cpp_type_symbol_printer import CppTypeSymbolPrinter
-from pynestml.codegeneration.printers.nestml_variable_printer import NestMLVariablePrinter
-from pynestml.symbol_table.symbol_table import SymbolTable
-from pynestml.symbols.predefined_functions import PredefinedFunctions
-from pynestml.symbols.predefined_types import PredefinedTypes
-from pynestml.symbols.predefined_units import PredefinedUnits
-from pynestml.symbols.predefined_variables import PredefinedVariables
-from pynestml.utils.ast_source_location import ASTSourceLocation
-from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter
-from pynestml.codegeneration.printers.nest_cpp_function_call_printer import NESTCppFunctionCallPrinter
-from pynestml.codegeneration.printers.cpp_function_call_printer import CppFunctionCallPrinter
-from pynestml.utils.logger import Logger, LoggingLevel
-from pynestml.utils.model_parser import ModelParser
-
-
-SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0))
-
-PredefinedUnits.register_units()
-PredefinedTypes.register_types()
-PredefinedVariables.register_variables()
-PredefinedFunctions.register_functions()
-
-Logger.init_logger(LoggingLevel.INFO)
-
-type_symbol_printer = CppTypeSymbolPrinter()
-variable_printer = NestMLVariablePrinter(None)
-function_call_printer = NESTCppFunctionCallPrinter(None)
-cpp_variable_printer = CppVariablePrinter(None)
-printer = CppExpressionPrinter(CppSimpleExpressionPrinter(cpp_variable_printer,
- ConstantPrinter(),
- function_call_printer))
-cpp_variable_printer._expression_printer = printer
-variable_printer._expression_printer = printer
-function_call_printer._expression_printer = printer
-
-
-def get_first_statement_in_update_block(model):
- if model.get_model_list()[0].get_update_blocks()[0]:
- return model.get_model_list()[0].get_update_blocks()[0].get_block().get_stmts()[0]
- return None
-
-
-def get_first_declaration_in_state_block(model):
- assert len(model.get_model_list()[0].get_state_blocks()) == 1
- return model.get_model_list()[0].get_state_blocks()[0].get_declarations()[0]
-
-
-def get_first_declared_function(model):
- return model.get_model_list()[0].get_functions()[0]
-
-
-def print_rhs_of_first_assignment_in_update_block(model):
- assignment = get_first_statement_in_update_block(model).small_stmt.get_assignment()
- expression = assignment.get_expression()
- return printer.print(expression)
-
-
-def print_first_function_call_in_update_block(model):
- function_call = get_first_statement_in_update_block(model).small_stmt.get_function_call()
- return printer.print(function_call)
-
-
-def print_rhs_of_first_declaration_in_state_block(model):
- declaration = get_first_declaration_in_state_block(model)
- expression = declaration.get_expression()
- return printer.print(expression)
-
-
-def print_first_return_statement_in_first_declared_function(model):
- func = get_first_declared_function(model)
- return_expression = func.get_block().get_stmts()[0].small_stmt.get_return_stmt().get_expression()
- return printer.print(return_expression)
-
-
-class UnitSystemTest(unittest.TestCase):
- """
- Test class for everything Unit related.
- """
-
- def setUp(self):
- Logger.set_logging_level(LoggingLevel.INFO)
-
- def test_expression_after_magnitude_conversion_in_direct_assignment(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'DirectAssignmentWithDifferentButCompatibleUnits.nestml'))
- printed_rhs_expression = print_rhs_of_first_assignment_in_update_block(model)
-
- self.assertEqual(printed_rhs_expression, '(1000.0 * (10 * V))')
-
- def test_expression_after_nested_magnitude_conversion_in_direct_assignment(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'DirectAssignmentWithDifferentButCompatibleNestedUnits.nestml'))
- printed_rhs_expression = print_rhs_of_first_assignment_in_update_block(model)
-
- self.assertEqual(printed_rhs_expression, '(1000.0 * (10 * V + (0.001 * (5 * mV)) + 20 * V + (1000.0 * (1 * kV))))')
-
- def test_expression_after_magnitude_conversion_in_compound_assignment(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'CompoundAssignmentWithDifferentButCompatibleUnits.nestml'))
- printed_rhs_expression = print_rhs_of_first_assignment_in_update_block(model)
- self.assertEqual(printed_rhs_expression, '(0.001 * (1200 * mV))')
-
- def test_expression_after_magnitude_conversion_in_declaration(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'DeclarationWithDifferentButCompatibleUnitMagnitude.nestml'))
- printed_rhs_expression = print_rhs_of_first_declaration_in_state_block(model)
- self.assertEqual(printed_rhs_expression, '(1000.0 * (10 * V))')
-
- def test_expression_after_type_conversion_in_declaration(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'DeclarationWithDifferentButCompatibleUnits.nestml'))
- declaration = get_first_declaration_in_state_block(model)
- from astropy import units as u
- self.assertTrue(declaration.get_expression().type.unit.unit == u.mV)
-
- def test_declaration_with_same_variable_name_as_unit(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'DeclarationWithSameVariableNameAsUnit.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), 3)
-
- def test_expression_after_magnitude_conversion_in_standalone_function_call(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'FunctionCallWithDifferentButCompatibleUnits.nestml'))
- printed_function_call = print_first_function_call_in_update_block(model)
- self.assertEqual(printed_function_call, 'foo((1000.0 * (10 * V)))')
-
- def test_expression_after_magnitude_conversion_in_rhs_function_call(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'RhsFunctionCallWithDifferentButCompatibleUnits.nestml'))
- printed_function_call = print_rhs_of_first_assignment_in_update_block(model)
- self.assertEqual(printed_function_call, 'foo((1000.0 * (10 * V)))')
-
- def test_return_stmt_after_magnitude_conversion_in_function_body(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'FunctionBodyReturnStatementWithDifferentButCompatibleUnits.nestml'))
- printed_return_stmt = print_first_return_statement_in_first_declared_function(model)
- self.assertEqual(printed_return_stmt, '(0.001 * (bar))')