diff --git a/pynestml/codegeneration/code_generator.py b/pynestml/codegeneration/code_generator.py index abd289601..a0ed82c60 100644 --- a/pynestml/codegeneration/code_generator.py +++ b/pynestml/codegeneration/code_generator.py @@ -117,7 +117,6 @@ def _setup_template_env(self, template_files: List[str], templates_root_dir: str # Environment for neuron templates env = Environment(loader=FileSystemLoader(_template_dirs)) env.globals["raise"] = self.raise_helper - env.globals["is_delta_kernel"] = ASTUtils.is_delta_kernel # Load all the templates _templates = list() diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py index 8fdb04566..333a39fe6 100644 --- a/pynestml/codegeneration/nest_code_generator.py +++ b/pynestml/codegeneration/nest_code_generator.py @@ -53,7 +53,6 @@ from pynestml.frontend.frontend_configuration import FrontendConfiguration from pynestml.meta_model.ast_assignment import ASTAssignment from pynestml.meta_model.ast_input_port import ASTInputPort -from pynestml.meta_model.ast_kernel import ASTKernel from pynestml.meta_model.ast_model import ASTModel from pynestml.meta_model.ast_node_factory import ASTNodeFactory from pynestml.meta_model.ast_ode_equation import ASTOdeEquation @@ -228,6 +227,10 @@ def run_nest_target_specific_cocos(self, neurons: Sequence[ASTModel], synapses: raise Exception("Error(s) occurred during code generation") def generate_code(self, models: Sequence[ASTModel]) -> None: + for model in models: + for equations_block in model.get_equations_blocks(): + assert len(equations_block.get_kernels()) == 0, "Kernels and convolutions should have been removed by ConvolutionsTransformer" + neurons, synapses = CodeGeneratorUtils.get_model_types_from_names(models, neuron_models=self.get_option("neuron_models"), synapse_models=self.get_option("synapse_models")) self.run_nest_target_specific_cocos(neurons, synapses) @@ -264,9 +267,7 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None: for neuron in neurons: code, message = Messages.get_analysing_transforming_model(neuron.get_name()) Logger.log_message(None, code, message, None, LoggingLevel.INFO) - spike_updates, post_spike_updates, equations_with_delay_vars, equations_with_vector_vars = self.analyse_neuron(neuron) - neuron.spike_updates = spike_updates - neuron.post_spike_updates = post_spike_updates + equations_with_delay_vars, equations_with_vector_vars = self.analyse_neuron(neuron) neuron.equations_with_delay_vars = equations_with_delay_vars neuron.equations_with_vector_vars = equations_with_vector_vars @@ -277,14 +278,13 @@ def analyse_transform_synapses(self, synapses: List[ASTModel]) -> None: """ for synapse in synapses: Logger.log_message(None, None, "Analysing/transforming synapse {}.".format(synapse.get_name()), None, LoggingLevel.INFO) - synapse.spike_updates = self.analyse_synapse(synapse) + self.analyse_synapse(synapse) def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment], List[ASTOdeEquation], List[ASTOdeEquation]]: """ Analyse and transform a single neuron. :param neuron: a single neuron. :return: see documentation for get_spike_update_expressions() for more information. - :return: post_spike_updates: list of post-synaptic spike update expressions :return: equations_with_delay_vars: list of equations containing delay variables :return: equations_with_vector_vars: list of equations containing delay variables """ @@ -298,18 +298,15 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di ASTUtils.all_variables_defined_in_block(neuron.get_state_blocks())) ASTUtils.add_timestep_symbol(neuron) - return {}, {}, [], [] + return [], [] if len(neuron.get_equations_blocks()) > 1: raise Exception("Only one equations block per model supported for now") equations_block = neuron.get_equations_blocks()[0] - kernel_buffers = ASTUtils.generate_kernel_buffers(neuron, equations_block) ASTUtils.make_inline_expressions_self_contained(equations_block.get_inline_expressions()) ASTUtils.replace_inline_expressions_through_defining_expressions(equations_block.get_ode_equations(), equations_block.get_inline_expressions()) - delta_factors = ASTUtils.get_delta_factors_(neuron, equations_block) - ASTUtils.replace_convolve_calls_with_buffers_(neuron, equations_block) # Collect all equations with delay variables and replace ASTFunctionCall to ASTVariable wherever necessary equations_with_delay_vars_visitor = ASTEquationsWithDelayVarsVisitor() @@ -321,7 +318,7 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di neuron.accept(eqns_with_vector_vars_visitor) equations_with_vector_vars = eqns_with_vector_vars_visitor.equations - analytic_solver, numeric_solver = self.ode_toolbox_analysis(neuron, kernel_buffers) + analytic_solver, numeric_solver = self.ode_toolbox_analysis(neuron) self.analytic_solver[neuron.get_name()] = analytic_solver self.numeric_solver[neuron.get_name()] = numeric_solver @@ -335,23 +332,14 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di if ode_eq.get_lhs().get_name() == var.get_name(): used_in_eq = True break - for kern in equations_block.get_kernels(): - for kern_var in kern.get_variables(): - if kern_var.get_name() == var.get_name(): - used_in_eq = True - break if not used_in_eq: self.non_equations_state_variables[neuron.get_name()].append(var) - ASTUtils.remove_initial_values_for_kernels(neuron) - kernels = ASTUtils.remove_kernel_definitions_from_equations_block(neuron) ASTUtils.update_initial_values_for_odes(neuron, [analytic_solver, numeric_solver]) ASTUtils.remove_ode_definitions_from_equations_block(neuron) - ASTUtils.create_initial_values_for_kernels(neuron, [analytic_solver, numeric_solver], kernels) ASTUtils.create_integrate_odes_combinations(neuron) ASTUtils.replace_variable_names_in_expressions(neuron, [analytic_solver, numeric_solver]) - ASTUtils.replace_convolution_aliasing_inlines(neuron) ASTUtils.add_timestep_symbol(neuron) if self.analytic_solver[neuron.get_name()] is not None: @@ -364,9 +352,7 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di # Update the delay parameter parameters after symbol table update ASTUtils.update_delay_parameter_in_state_vars(neuron, state_vars_before_update) - spike_updates, post_spike_updates = self.get_spike_update_expressions(neuron, kernel_buffers, [analytic_solver, numeric_solver], delta_factors) - - return spike_updates, post_spike_updates, equations_with_delay_vars, equations_with_vector_vars + return equations_with_delay_vars, equations_with_vector_vars def analyse_synapse(self, synapse: ASTModel) -> Dict[str, ASTAssignment]: """ @@ -376,34 +362,26 @@ def analyse_synapse(self, synapse: ASTModel) -> Dict[str, ASTAssignment]: code, message = Messages.get_start_processing_model(synapse.get_name()) Logger.log_message(synapse, code, message, synapse.get_source_position(), LoggingLevel.INFO) - spike_updates = {} if synapse.get_equations_blocks(): if len(synapse.get_equations_blocks()) > 1: raise Exception("Only one equations block per model supported for now") equations_block = synapse.get_equations_blocks()[0] - kernel_buffers = ASTUtils.generate_kernel_buffers(synapse, equations_block) ASTUtils.make_inline_expressions_self_contained(equations_block.get_inline_expressions()) ASTUtils.replace_inline_expressions_through_defining_expressions( equations_block.get_ode_equations(), equations_block.get_inline_expressions()) - delta_factors = ASTUtils.get_delta_factors_(synapse, equations_block) - ASTUtils.replace_convolve_calls_with_buffers_(synapse, equations_block) - analytic_solver, numeric_solver = self.ode_toolbox_analysis(synapse, kernel_buffers) + analytic_solver, numeric_solver = self.ode_toolbox_analysis(synapse) self.analytic_solver[synapse.get_name()] = analytic_solver self.numeric_solver[synapse.get_name()] = numeric_solver - ASTUtils.remove_initial_values_for_kernels(synapse) - kernels = ASTUtils.remove_kernel_definitions_from_equations_block(synapse) ASTUtils.update_initial_values_for_odes(synapse, [analytic_solver, numeric_solver]) ASTUtils.remove_ode_definitions_from_equations_block(synapse) - ASTUtils.create_initial_values_for_kernels(synapse, [analytic_solver, numeric_solver], kernels) ASTUtils.create_integrate_odes_combinations(synapse) ASTUtils.replace_variable_names_in_expressions(synapse, [analytic_solver, numeric_solver]) ASTUtils.add_timestep_symbol(synapse) self.update_symbol_table(synapse) - spike_updates, _ = self.get_spike_update_expressions(synapse, kernel_buffers, [analytic_solver, numeric_solver], delta_factors) if not self.analytic_solver[synapse.get_name()] is None: synapse = ASTUtils.add_declarations_to_internals( @@ -416,8 +394,6 @@ def analyse_synapse(self, synapse: ASTModel) -> Dict[str, ASTAssignment]: ASTUtils.update_blocktype_for_common_parameters(synapse) - return spike_updates - def _get_model_namespace(self, astnode: ASTModel) -> Dict: namespace = {} @@ -567,8 +543,6 @@ def _get_synapse_model_namespace(self, synapse: ASTModel) -> Dict: expr_ast.accept(ASTSymbolTableVisitor()) namespace["numeric_update_expressions"][sym] = expr_ast - namespace["spike_updates"] = synapse.spike_updates - return namespace def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: @@ -582,7 +556,6 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: if "paired_synapse" in dir(neuron): namespace["extra_on_emit_spike_stmts_from_synapse"] = neuron.extra_on_emit_spike_stmts_from_synapse namespace["paired_synapse"] = neuron.paired_synapse.get_name() - namespace["post_spike_updates"] = neuron.post_spike_updates namespace["transferred_variables"] = neuron._transferred_variables namespace["transferred_variables_syms"] = {var_name: neuron.scope.resolve_to_symbol( var_name, SymbolKind.VARIABLE) for var_name in namespace["transferred_variables"]} @@ -736,8 +709,6 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: namespace["numerical_state_symbols"] = numeric_state_variable_names ASTUtils.assign_numeric_non_numeric_state_variables(neuron, numeric_state_variable_names, namespace["numeric_update_expressions"] if "numeric_update_expressions" in namespace.keys() else None, namespace["update_expressions"] if "update_expressions" in namespace.keys() else None) - namespace["spike_updates"] = neuron.spike_updates - namespace["recordable_state_variables"] = [] for state_block in neuron.get_state_blocks(): for decl in state_block.get_declarations(): @@ -745,7 +716,6 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: sym = var.get_scope().resolve_to_symbol(var.get_complete_name(), SymbolKind.VARIABLE) if isinstance(sym.get_type_symbol(), (UnitTypeSymbol, RealTypeSymbol)) \ - and not ASTUtils.is_delta_kernel(neuron.get_kernel_by_name(sym.name)) \ and sym.is_recordable: namespace["recordable_state_variables"].append(var) @@ -755,7 +725,7 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: for var in decl.get_variables(): sym = var.get_scope().resolve_to_symbol(var.get_complete_name(), SymbolKind.VARIABLE) - if sym.has_declaring_expression() and (not neuron.get_kernel_by_name(sym.name)): + if sym.has_declaring_expression(): namespace["parameter_vars_with_iv"].append(var) namespace["recordable_inline_expressions"] = [sym for sym in neuron.get_inline_expression_symbols() @@ -774,7 +744,7 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: return namespace - def ode_toolbox_analysis(self, neuron: ASTModel, kernel_buffers: Mapping[ASTKernel, ASTInputPort]): + def ode_toolbox_analysis(self, neuron: ASTModel): """ Prepare data for ODE-toolbox input format, invoke ODE-toolbox analysis via its API, and return the output. """ @@ -783,11 +753,11 @@ def ode_toolbox_analysis(self, neuron: ASTModel, kernel_buffers: Mapping[ASTKern equations_block = neuron.get_equations_blocks()[0] - if len(equations_block.get_kernels()) == 0 and len(equations_block.get_ode_equations()) == 0: + if len(equations_block.get_ode_equations()) == 0: # no equations defined -> no changes to the neuron return None, None - odetoolbox_indict = ASTUtils.transform_ode_and_kernels_to_json(neuron, neuron.get_parameters_blocks(), kernel_buffers, printer=self._ode_toolbox_printer) + odetoolbox_indict = ASTUtils.transform_odes_to_json(neuron, neuron.get_parameters_blocks(), printer=self._ode_toolbox_printer) odetoolbox_indict["options"] = {} odetoolbox_indict["options"]["output_timestep_symbol"] = "__h" disable_analytic_solver = self.get_option("solver") != "analytic" @@ -831,107 +801,3 @@ def update_symbol_table(self, neuron) -> None: symbol_table_visitor.after_ast_rewrite_ = True neuron.accept(symbol_table_visitor) SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope()) - - def get_spike_update_expressions(self, neuron: ASTModel, kernel_buffers, solver_dicts, delta_factors) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment]]: - r""" - Generate the equations that update the dynamical variables when incoming spikes arrive. To be invoked after - ode-toolbox. - - For example, a resulting `assignment_str` could be "I_kernel_in += (inh_spikes/nS) * 1". The values are taken from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from user specification in the model. - from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from - user specification in the model. - - Note that for kernels, `initial_values` actually contains the increment upon spike arrival, rather than the - initial value of the corresponding ODE dimension. - ``spike_updates`` is a mapping from input port name (as a string) to update expressions. - - ``post_spike_updates`` is a mapping from kernel name (as a string) to update expressions. - """ - spike_updates = {} - post_spike_updates = {} - - for kernel, spike_input_port in kernel_buffers: - if ASTUtils.is_delta_kernel(kernel): - continue - - spike_input_port_name = spike_input_port.get_variable().get_name() - - if not spike_input_port_name in spike_updates.keys(): - spike_updates[str(spike_input_port)] = [] - - if "_is_post_port" in dir(spike_input_port.get_variable()) \ - and spike_input_port.get_variable()._is_post_port: - # it's a port in the neuron ??? that receives post spikes ??? - orig_port_name = spike_input_port_name[:spike_input_port_name.index("__for_")] - buffer_type = neuron.paired_synapse.get_scope().resolve_to_symbol(orig_port_name, SymbolKind.VARIABLE).get_type_symbol() - else: - buffer_type = neuron.get_scope().resolve_to_symbol(spike_input_port_name, SymbolKind.VARIABLE).get_type_symbol() - - assert not buffer_type is None - - for kernel_var in kernel.get_variables(): - for var_order in range(ASTUtils.get_kernel_var_order_from_ode_toolbox_result(kernel_var.get_name(), solver_dicts)): - kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name(kernel_var.get_name(), spike_input_port, var_order) - expr = ASTUtils.get_initial_value_from_ode_toolbox_result(kernel_spike_buf_name, solver_dicts) - assert expr is not None, "Initial value not found for kernel " + kernel_var - expr = str(expr) - if expr in ["0", "0.", "0.0"]: - continue # skip adding the statement if we are only adding zero - - assignment_str = kernel_spike_buf_name + " += " - if "_is_post_port" in dir(spike_input_port.get_variable()) \ - and spike_input_port.get_variable()._is_post_port: - assignment_str += "1." - else: - assignment_str += "(" + str(spike_input_port) + ")" - if not expr in ["1.", "1.0", "1"]: - assignment_str += " * (" + expr + ")" - - if not buffer_type.print_nestml_type() in ["1.", "1.0", "1", "real", "integer"]: - assignment_str += " / (" + buffer_type.print_nestml_type() + ")" - - ast_assignment = ModelParser.parse_assignment(assignment_str) - ast_assignment.update_scope(neuron.get_scope()) - ast_assignment.accept(ASTSymbolTableVisitor()) - - if neuron.get_scope().resolve_to_symbol(spike_input_port_name, SymbolKind.VARIABLE) is None: - # this case covers variables that were moved from synapse to the neuron - post_spike_updates[kernel_var.get_name()] = ast_assignment - elif "_is_post_port" in dir(spike_input_port.get_variable()) and spike_input_port.get_variable()._is_post_port: - Logger.log_message(None, None, "Adding post assignment string: " + str(ast_assignment), None, LoggingLevel.INFO) - spike_updates[str(spike_input_port)].append(ast_assignment) - else: - spike_updates[str(spike_input_port)].append(ast_assignment) - - for k, factor in delta_factors.items(): - var = k[0] - inport = k[1] - assignment_str = var.get_name() + "'" * (var.get_differential_order() - 1) + " += " - if not factor in ["1.", "1.0", "1"]: - factor_expr = ModelParser.parse_expression(factor) - factor_expr.update_scope(neuron.get_scope()) - factor_expr.accept(ASTSymbolTableVisitor()) - assignment_str += "(" + self._printer_no_origin.print(factor_expr) + ") * " - - if "_is_post_port" in dir(inport) and inport._is_post_port: - orig_port_name = inport[:inport.index("__for_")] - buffer_type = neuron.paired_synapse.get_scope().resolve_to_symbol(orig_port_name, SymbolKind.VARIABLE).get_type_symbol() - else: - buffer_type = neuron.get_scope().resolve_to_symbol(inport.get_name(), SymbolKind.VARIABLE).get_type_symbol() - - assignment_str += str(inport) - if not buffer_type.print_nestml_type() in ["1.", "1.0", "1"]: - assignment_str += " / (" + buffer_type.print_nestml_type() + ")" - ast_assignment = ModelParser.parse_assignment(assignment_str) - ast_assignment.update_scope(neuron.get_scope()) - ast_assignment.accept(ASTSymbolTableVisitor()) - - inport_name = inport.get_name() - if inport.has_vector_parameter(): - inport_name += "_" + str(ASTUtils.get_numeric_vector_size(inport)) - if not inport_name in spike_updates.keys(): - spike_updates[inport_name] = [] - - spike_updates[inport_name].append(ast_assignment) - - return spike_updates, post_spike_updates diff --git a/pynestml/codegeneration/nest_compartmental_code_generator.py b/pynestml/codegeneration/nest_compartmental_code_generator.py index 8dc48958d..cf3757481 100644 --- a/pynestml/codegeneration/nest_compartmental_code_generator.py +++ b/pynestml/codegeneration/nest_compartmental_code_generator.py @@ -280,22 +280,16 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None: def create_ode_indict(self, neuron: ASTModel, - parameters_block: ASTBlockWithVariables, - kernel_buffers: Mapping[ASTKernel, - ASTInputPort]): - odetoolbox_indict = self.transform_ode_and_kernels_to_json( - neuron, parameters_block, kernel_buffers) + parameters_block: ASTBlockWithVariables): + odetoolbox_indict = self.transform_ode_and_kernels_to_json(neuron, parameters_block) odetoolbox_indict["options"] = {} odetoolbox_indict["options"]["output_timestep_symbol"] = "__h" return odetoolbox_indict def ode_solve_analytically(self, neuron: ASTModel, - parameters_block: ASTBlockWithVariables, - kernel_buffers: Mapping[ASTKernel, - ASTInputPort]): - odetoolbox_indict = self.create_ode_indict( - neuron, parameters_block, kernel_buffers) + parameters_block: ASTBlockWithVariables): + odetoolbox_indict = self.create_ode_indict(neuron, parameters_block) full_solver_result = analysis( odetoolbox_indict, @@ -314,8 +308,7 @@ def ode_solve_analytically(self, return full_solver_result, analytic_solver - def ode_toolbox_analysis(self, neuron: ASTModel, - kernel_buffers: Mapping[ASTKernel, ASTInputPort]): + def ode_toolbox_analysis(self, neuron: ASTModel): """ Prepare data for ODE-toolbox input format, invoke ODE-toolbox analysis via its API, and return the output. """ @@ -324,15 +317,13 @@ def ode_toolbox_analysis(self, neuron: ASTModel, equations_block = neuron.get_equations_blocks()[0] - if len(equations_block.get_kernels()) == 0 and len( - equations_block.get_ode_equations()) == 0: + if len(equations_block.get_ode_equations()) == 0: # no equations defined -> no changes to the neuron return None, None parameters_block = neuron.get_parameters_blocks()[0] - solver_result, analytic_solver = self.ode_solve_analytically( - neuron, parameters_block, kernel_buffers) + solver_result, analytic_solver = self.ode_solve_analytically(neuron, parameters_block) # if numeric solver is required, generate a stepping function that # includes each state variable @@ -341,8 +332,7 @@ def ode_toolbox_analysis(self, neuron: ASTModel, x for x in solver_result if x["solver"].startswith("numeric")] if numeric_solvers: - odetoolbox_indict = self.create_ode_indict( - neuron, parameters_block, kernel_buffers) + odetoolbox_indict = self.create_ode_indict(neuron, parameters_block) solver_result = analysis( odetoolbox_indict, disable_stiffness_check=True, @@ -417,24 +407,6 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]: return [] - # goes through all convolve() inside ode's from equations block - # if they have delta kernels, use sympy to expand the expression, then - # find the convolve calls and replace them with constant value 1 - # then return every subexpression that had that convolve() replaced - delta_factors = ASTUtils.get_delta_factors_(neuron, equations_block) - - # goes through all convolve() inside equations block - # extracts what kernel is paired with what spike buffer - # returns pairs (kernel, spike_buffer) - kernel_buffers = ASTUtils.generate_kernel_buffers( - neuron, equations_block) - - # replace convolve(g_E, spikes_exc) with g_E__X__spikes_exc[__d] - # done by searching for every ASTSimpleExpression inside equations_block - # which is a convolve call and substituting that call with - # newly created ASTVariable kernel__X__spike_buffer - ASTUtils.replace_convolve_calls_with_buffers_(neuron, equations_block) - # substitute inline expressions with each other # such that no inline expression references another inline expression ASTUtils.make_inline_expressions_self_contained( @@ -450,14 +422,13 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]: # "update_expressions" key in those solvers contains a mapping # {expression1: update_expression1, expression2: update_expression2} - analytic_solver, numeric_solver = self.ode_toolbox_analysis( - neuron, kernel_buffers) + analytic_solver, numeric_solver = self.ode_toolbox_analysis(neuron) """ # separate analytic solutions by kernel # this is is needed for the synaptic case self.kernel_name_to_analytic_solver[neuron.get_name( - )] = self.ode_toolbox_anaysis_cm_syns(neuron, kernel_buffers) + )] = self.ode_toolbox_anaysis_cm_syns(neuron) """ self.analytic_solver[neuron.get_name()] = analytic_solver @@ -472,12 +443,6 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]: # by odetoolbox, higher order variables don't get deleted here ASTUtils.remove_initial_values_for_kernels(neuron) - # delete all kernels as they are all converted into buffers - # and corresponding update formulas calculated by odetoolbox - # Remember them in a variable though - kernels = ASTUtils.remove_kernel_definitions_from_equations_block( - neuron) - # Every ODE variable (a variable of order > 0) is renamed according to ODE-toolbox conventions # their initial values are replaced by expressions suggested by ODE-toolbox. # Differential order can now be set to 0, becase they can directly represent the value of the derivative now. @@ -491,22 +456,11 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]: # corresponding updates ASTUtils.remove_ode_definitions_from_equations_block(neuron) - # restore state variables that were referenced by kernels - # and set their initial values by those suggested by ODE-toolbox - ASTUtils.create_initial_values_for_kernels( - neuron, [analytic_solver, numeric_solver], kernels) - # Inside all remaining expressions, translate all remaining variable names # according to the naming conventions of ODE-toolbox. ASTUtils.replace_variable_names_in_expressions( neuron, [analytic_solver, numeric_solver]) - # find all inline kernels defined as ASTSimpleExpression - # that have a single kernel convolution aliasing variable ('__X__') - # translate all remaining variable names according to the naming - # conventions of ODE-toolbox - ASTUtils.replace_convolution_aliasing_inlines(neuron) - # add variable __h to internals block ASTUtils.add_timestep_symbol(neuron) @@ -677,13 +631,9 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: expr_ast.accept(ASTSymbolTableVisitor()) namespace["numeric_update_expressions"][sym] = expr_ast - namespace["spike_updates"] = neuron.spike_updates - namespace["recordable_state_variables"] = [ sym for sym in neuron.get_state_symbols() if namespace["declarations"].get_domain_from_type( - sym.get_type_symbol()) == "double" and sym.is_recordable and not ASTUtils.is_delta_kernel( - neuron.get_kernel_by_name( - sym.name))] + sym.get_type_symbol()) == "double" and sym.is_recordable] namespace["recordable_inline_expressions"] = [ sym for sym in neuron.get_inline_expression_symbols() if namespace["declarations"].get_domain_from_type( sym.get_type_symbol()) == "double" and sym.is_recordable] @@ -807,7 +757,7 @@ def get_spike_update_expressions( for var_order in range( ASTUtils.get_kernel_var_order_from_ode_toolbox_result( kernel_var.get_name(), solver_dicts)): - kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name( + kernel_spike_buf_name = ASTUtils.construct_kernel_spike_buf_name( kernel_var.get_name(), spike_input_port, var_order) expr = ASTUtils.get_initial_value_from_ode_toolbox_result( kernel_spike_buf_name, solver_dicts) @@ -849,18 +799,9 @@ def get_spike_update_expressions( def transform_ode_and_kernels_to_json( self, neuron: ASTModel, - parameters_block, - kernel_buffers): + parameters_block): """ Converts AST node to a JSON representation suitable for passing to ode-toolbox. - - Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements - - convolve(G, ex_spikes) - convolve(G, in_spikes) - - then `kernel_buffers` will contain the pairs `(G, ex_spikes)` and `(G, in_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__ex_spikes` and `G__X__in_spikes`. - :param parameters_block: ASTBlockWithVariables :return: Dict """ @@ -890,43 +831,6 @@ def transform_ode_and_kernels_to_json( iv_symbol_name)] = expr odetoolbox_indict["dynamics"].append(entry) - # write a copy for each (kernel, spike buffer) combination - for kernel, spike_input_port in kernel_buffers: - - if ASTUtils.is_delta_kernel(kernel): - # delta function -- skip passing this to ode-toolbox - continue - - for kernel_var in kernel.get_variables(): - expr = ASTUtils.get_expr_from_kernel_var( - kernel, kernel_var.get_complete_name()) - kernel_order = kernel_var.get_differential_order() - kernel_X_spike_buf_name_ticks = ASTUtils.construct_kernel_X_spike_buf_name( - kernel_var.get_name(), spike_input_port, kernel_order, diff_order_symbol="'") - - ASTUtils.replace_rhs_variables(expr, kernel_buffers) - - entry = {} - entry["expression"] = kernel_X_spike_buf_name_ticks + " = " + str(expr) - - # initial values need to be declared for order 1 up to kernel - # order (e.g. none for kernel function f(t) = ...; 1 for kernel - # ODE f'(t) = ...; 2 for f''(t) = ... and so on) - entry["initial_values"] = {} - for order in range(kernel_order): - iv_sym_name_ode_toolbox = ASTUtils.construct_kernel_X_spike_buf_name( - kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'") - symbol_name_ = kernel_var.get_name() + "'" * order - symbol = equations_block.get_scope().resolve_to_symbol( - symbol_name_, SymbolKind.VARIABLE) - assert symbol is not None, "Could not find initial value for variable " + symbol_name_ - initial_value_expr = symbol.get_declaring_expression() - assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_ - entry["initial_values"][iv_sym_name_ode_toolbox] = self._ode_toolbox_printer.print( - initial_value_expr) - - odetoolbox_indict["dynamics"].append(entry) - odetoolbox_indict["parameters"] = {} if parameters_block is not None: for decl in parameters_block.get_declarations(): diff --git a/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2 index 5832d90f3..fedb8a96e 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronClass.jinja2 @@ -262,10 +262,8 @@ std::vector< std::tuple< int, int > > {{neuronName}}::rport_to_nestml_buffer_idx // copy state struct S_ {%- for init in neuron.get_state_symbols() %} -{%- if not is_delta_kernel(neuron.get_kernel_by_name(init.name)) %} {%- set node = utils.get_state_variable_by_name(astnode, init.get_symbol_name()) %} {{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }} = __n.{{ nest_codegen_utils.print_symbol_origin(init, node) % printer_no_origin.print(node) }}; -{%- endif %} {%- endfor %} // copy internals V_ @@ -786,14 +784,6 @@ const {{ type_symbol_printer.print(var_symbol.type_symbol) }} {{variable_name}}_ {%- endfor %} {%- endif %} - - /** - * spike updates due to convolutions - **/ -{% filter indent(4) %} -{%- include "directives_cpp/ApplySpikesFromBuffers.jinja2" %} -{%- endfilter %} - /** * Begin NESTML generated code for the onCondition block(s) **/ @@ -1149,13 +1139,9 @@ void {%- endfor %} /** - * print updates due to convolutions + * push back spike history **/ -{%- for _, spike_update in post_spike_updates.items() %} - {{ printer.print(utils.get_variable_by_name(astnode, spike_update.get_variable().get_complete_name())) }} += 1.; -{%- endfor %} - last_spike_ = t_sp_ms; history_.push_back( histentry__{{neuronName}}( last_spike_ {%- for var in purely_numeric_state_variables_moved|sort %} diff --git a/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2 index b8aa79c41..345372e0a 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/common/NeuronHeader.jinja2 @@ -346,13 +346,11 @@ public: // Getters/setters for state block // ------------------------------------------------------------------------- -{% filter indent(2, True) -%} +{% filter indent(2, True) -%} {%- for variable_symbol in neuron.get_state_symbols() %} -{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %} -{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %} -{% endif %} -{% endfor %} +{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} +{%- include "directives_cpp/MemberVariableGetterSetter.jinja2" %} +{% endfor %} {%- endfilter %} {%- endif %} @@ -962,22 +960,20 @@ inline nest_port_t {{neuronName}}::handles_test_event(nest::DataLoggingRequest& inline void {{neuronName}}::get_status(DictionaryDatum &__d) const { // parameters -{%- for variable_symbol in neuron.get_parameter_symbols() %} -{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- filter indent(2) %} +{%- filter indent(2) %} +{%- for variable_symbol in neuron.get_parameter_symbols() %} +{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} {%- include "directives_cpp/WriteInDictionary.jinja2" %} -{%- endfilter %} -{%- endfor %} +{%- endfor %} +{%- endfilter %} // initial values for state variables in ODE or kernel -{%- for variable_symbol in neuron.get_state_symbols() %} -{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %} -{%- filter indent(2) %} -{%- include "directives_cpp/WriteInDictionary.jinja2" %} -{%- endfilter %} -{%- endif -%} -{%- endfor %} +{%- filter indent(2) %} +{%- for variable_symbol in neuron.get_state_symbols() %} +{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} +{%- include "directives_cpp/WriteInDictionary.jinja2" %} +{%- endfor %} +{%- endfilter %} {{neuron_parent_class}}::get_status( __d ); @@ -1023,11 +1019,9 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d) // initial values for state variables in ODE or kernel {%- for variable_symbol in neuron.get_state_symbols() %} {%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %} -{%- filter indent(2) %} -{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %} -{%- endfilter %} -{%- endif %} +{%- filter indent(2) %} +{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %} +{%- endfilter %} {%- endfor %} // We now know that (ptmp, stmp) are consistent. We do not @@ -1046,11 +1040,9 @@ inline void {{neuronName}}::set_status(const DictionaryDatum &__d) {%- for variable_symbol in neuron.get_state_symbols() -%} {%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.name)) %} -{%- filter indent(2) %} -{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %} -{%- endfilter %} -{%- endif %} +{%- filter indent(2) %} +{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %} +{%- endfilter %} {%- endfor %} {% for invariant in neuron.get_parameter_invariants() %} diff --git a/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 index 3d727ea23..38b6d6ce4 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 @@ -814,17 +814,6 @@ public: {%- endfilter %} } - /** - * update all convolutions with pre spikes - **/ - -{%- for spike_updates_for_port in spike_updates.values() %} -{%- for spike_update in spike_updates_for_port %} - {{ printer.print(spike_update.get_variable()) }} += 1.; // XXX: TODO: increment with initial value instead of 1 -{%- endfor %} -{%- endfor %} - - /** * in case pre and post spike time coincide and pre update takes priority **/ @@ -989,14 +978,12 @@ void {%- filter indent(2,True) %} {%- for variable_symbol in synapse.get_state_symbols() %} {%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- if not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.get_symbol_name())) %} -{%- include "directives_cpp/WriteInDictionary.jinja2" %} -{%- if variable_symbol.get_namespace_decorator("nest")|length > 0 %} +{%- include "directives_cpp/WriteInDictionary.jinja2" %} +{%- if variable_symbol.get_namespace_decorator("nest")|length > 0 %} // special treatment for variable marked with @nest::name decorator -{%- set nest_namespace_name = variable_symbol.get_namespace_decorator("nest") %} -{%- if not variable_symbol.is_internals() %} +{%- set nest_namespace_name = variable_symbol.get_namespace_decorator("nest") %} +{%- if not variable_symbol.is_internals() %} def<{{declarations.print_variable_type(variable_symbol)}}>(__d, names::{{nest_namespace_name}}, get_{{printer_no_origin.print(variable)}}()); -{%- endif %} {%- endif %} {%- endif %} {%- endfor %} @@ -1024,24 +1011,22 @@ void {%- filter indent(2,True) %} {%- for variable_symbol in synapse.get_state_symbols() %} {%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- if not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.name)) %} -{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %} -{%- if variable_symbol.get_namespace_decorator("nest")|length > 0 %} +{%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %} +{%- if variable_symbol.get_namespace_decorator("nest")|length > 0 %} // special treatment for variables marked with @nest::name decorator -{%- set nest_namespace_name = variable_symbol.get_namespace_decorator("nest") %} +{%- set nest_namespace_name = variable_symbol.get_namespace_decorator("nest") %} {#- -------- XXX: TODO: this is almost the content of directives_cpp/ReadFromDictionaryToTmp.jinja2 verbatim, refactor this ---------- #} -{%- if not variable_symbol.is_inline_expression and not variable_symbol.is_state() %} +{%- if not variable_symbol.is_inline_expression and not variable_symbol.is_state() %} tmp_{{ printer_no_origin.print(variable) }} = get_{{ printer_no_origin.print(variable) }}(); updateValue<{{ declarations.print_variable_type(variable_symbol) }}>(__d, "{{ nest_namespace_name }}", tmp_{{ printer_no_origin.print(variable) }}); -{%- elif not variable_symbol.is_inline_expression and variable_symbol.is_state() %} +{%- elif not variable_symbol.is_inline_expression and variable_symbol.is_state() %} tmp_{{ printer_no_origin.print(variable) }} = get_{{ printer_no_origin.print(variable) }}(); updateValue<{{ declarations.print_variable_type(variable_symbol) }}>(__d, "{{ nest_namespace_name }}", tmp_{{ printer_no_origin.print(variable) }}); -{%- else %} +{%- else %} // ignores '{{ printer_no_origin.print(variable) }}' {{ declarations.print_variable_type(variable_symbol) }}' since it is an function and setter isn't defined -{%- endif %} -{#- -------------------------------------------------------------------------------------------------------------------------------- #} {%- endif %} -{%- endif %} +{#- -------------------------------------------------------------------------------------------------------------------------------- #} +{%- endif %} {%- endfor %} {%- endfilter %} @@ -1069,11 +1054,9 @@ updateValue<{{ declarations.print_variable_type(variable_symbol) }}>(__d, "{{ ne // set state {%- for variable_symbol in synapse.get_state_symbols() %} {%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- if not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.name)) %} -{%- filter indent(2,True) %} -{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %} -{%- endfilter %} -{%- endif %} +{%- filter indent(2,True) %} +{%- include "directives_cpp/AssignTmpDictionaryValue.jinja2" %} +{%- endfilter %} {%- endfor %} // check invariants diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/ApplySpikesFromBuffers.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/ApplySpikesFromBuffers.jinja2 deleted file mode 100644 index 881257451..000000000 --- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/ApplySpikesFromBuffers.jinja2 +++ /dev/null @@ -1,6 +0,0 @@ -{% if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} -{%- for spike_updates_for_port in spike_updates.values() %} -{%- for ast in spike_updates_for_port -%} -{%- include "directives_cpp/Assignment.jinja2" %} -{%- endfor %} -{%- endfor %} diff --git a/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 b/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 index 5fc3ae589..7022e06ff 100644 --- a/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 +++ b/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 @@ -220,10 +220,8 @@ class Neuron_{{neuronName}}(Neuron): # ------------------------------------------------------------------------- {% filter indent(2, True) -%} {%- for variable_symbol in neuron.get_state_symbols() %} -{%- if not is_delta_kernel(neuron.get_kernel_by_name(variable_symbol.get_symbol_name())) %} {%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} {%- include "directives_py/MemberVariableGetterSetter.jinja2" %} -{%- endif %} {%- endfor %} {%- endfilter %} @@ -314,13 +312,6 @@ class Neuron_{{neuronName}}(Neuron): {%- include "directives_py/AnalyticIntegrationStep_end.jinja2" %} {%- endwith %} - # ------------------------------------------------------------------------- - # process spikes from buffers - # ------------------------------------------------------------------------- -{%- filter indent(4, True) -%} -{%- include "directives_py/ApplySpikesFromBuffers.jinja2" %} -{%- endfilter %} - # ------------------------------------------------------------------------- # begin NESTML generated code for the onReceive block(s) # ------------------------------------------------------------------------- diff --git a/pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/ApplySpikesFromBuffers.jinja2 b/pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/ApplySpikesFromBuffers.jinja2 deleted file mode 100644 index c0952b2f5..000000000 --- a/pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/ApplySpikesFromBuffers.jinja2 +++ /dev/null @@ -1,6 +0,0 @@ -{%- if tracing %}# generated by {{self._TemplateReference__context.name}}{% endif %} -{%- for spike_updates_for_port in spike_updates.values() %} -{%- for ast in spike_updates_for_port -%} -{%- include "directives_py/Assignment.jinja2" %} -{%- endfor %} -{%- endfor %} diff --git a/pynestml/frontend/pynestml_frontend.py b/pynestml/frontend/pynestml_frontend.py index 53a8cbdd6..5aef28052 100644 --- a/pynestml/frontend/pynestml_frontend.py +++ b/pynestml/frontend/pynestml_frontend.py @@ -37,6 +37,7 @@ from pynestml.symbols.predefined_types import PredefinedTypes from pynestml.symbols.predefined_units import PredefinedUnits from pynestml.symbols.predefined_variables import PredefinedVariables +from pynestml.transformers.convolutions_transformer import ConvolutionsTransformer from pynestml.transformers.transformer import Transformer from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.messages import Messages @@ -59,6 +60,9 @@ def transformers_from_target_name(target_name: str, options: Optional[Mapping[st if options is None: options = {} + # for all targets, add the convolutions transformer + transformers.append(ConvolutionsTransformer()) + if target_name.upper() in ["NEST", "SPINNAKER"]: from pynestml.transformers.illegal_variable_name_transformer import IllegalVariableNameTransformer diff --git a/pynestml/meta_model/ast_model.py b/pynestml/meta_model/ast_model.py index 834e56897..b38629275 100644 --- a/pynestml/meta_model/ast_model.py +++ b/pynestml/meta_model/ast_model.py @@ -670,8 +670,10 @@ def get_on_receive_block(self, port_name: str) -> Optional[ASTOnReceiveBlock]: return self.get_body().get_on_receive_block(port_name) def get_on_condition_blocks(self) -> List[ASTOnConditionBlock]: + r"""See ASTModelBody.get_on_condition_blocks() for the documentation for this function.""" if not self.get_body(): return [] + return self.get_body().get_on_condition_blocks() def get_on_condition_block(self, port_name: str) -> Optional[ASTOnConditionBlock]: diff --git a/pynestml/meta_model/ast_model_body.py b/pynestml/meta_model/ast_model_body.py index 6e32561ce..1908bb88d 100644 --- a/pynestml/meta_model/ast_model_body.py +++ b/pynestml/meta_model/ast_model_body.py @@ -182,10 +182,14 @@ def get_on_condition_block(self, port_name) -> Optional[ASTOnConditionBlock]: return None def get_on_condition_blocks(self) -> List[ASTOnConditionBlock]: + r""" + XXX: TODO: sorting based on priority + """ on_condition_blocks = [] for elem in self.get_body_elements(): if isinstance(elem, ASTOnConditionBlock): on_condition_blocks.append(elem) + return on_condition_blocks def get_equations_blocks(self) -> List[ASTEquationsBlock]: diff --git a/pynestml/transformers/convolutions_transformer.py b/pynestml/transformers/convolutions_transformer.py new file mode 100644 index 000000000..cb38f8f04 --- /dev/null +++ b/pynestml/transformers/convolutions_transformer.py @@ -0,0 +1,501 @@ +# -*- coding: utf-8 -*- +# +# convolutions_transformer.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from __future__ import annotations + +from typing import Any, Dict, List, Sequence, Mapping, Optional, Tuple, Union + +import re + +import odetoolbox + +from pynestml.codegeneration.printers.ast_printer import ASTPrinter +from pynestml.codegeneration.printers.constant_printer import ConstantPrinter +from pynestml.codegeneration.printers.ode_toolbox_expression_printer import ODEToolboxExpressionPrinter +from pynestml.codegeneration.printers.ode_toolbox_function_call_printer import ODEToolboxFunctionCallPrinter +from pynestml.codegeneration.printers.ode_toolbox_variable_printer import ODEToolboxVariablePrinter +from pynestml.codegeneration.printers.unitless_sympy_simple_expression_printer import UnitlessSympySimpleExpressionPrinter +from pynestml.frontend.frontend_configuration import FrontendConfiguration +from pynestml.meta_model.ast_assignment import ASTAssignment +from pynestml.meta_model.ast_declaration import ASTDeclaration +from pynestml.meta_model.ast_equations_block import ASTEquationsBlock +from pynestml.meta_model.ast_expression import ASTExpression +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression +from pynestml.meta_model.ast_input_port import ASTInputPort +from pynestml.meta_model.ast_kernel import ASTKernel +from pynestml.meta_model.ast_model import ASTModel +from pynestml.meta_model.ast_node import ASTNode +from pynestml.meta_model.ast_node_factory import ASTNodeFactory +from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression +from pynestml.meta_model.ast_variable import ASTVariable +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.symbol import SymbolKind +from pynestml.symbols.variable_symbol import BlockType +from pynestml.transformers.transformer import Transformer +from pynestml.utils.ast_source_location import ASTSourceLocation +from pynestml.utils.ast_utils import ASTUtils +from pynestml.utils.logger import Logger +from pynestml.utils.logger import LoggingLevel +from pynestml.utils.model_parser import ModelParser +from pynestml.utils.string_utils import removesuffix +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor +from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor +from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor +from pynestml.visitors.ast_visitor import ASTVisitor + + +class ConvolutionsTransformer(Transformer): + r"""For each convolution that occurs in the model, allocate one or more needed state variables and replace the convolution() calls by these variable names.""" + + _default_options = { + "convolution_separator": "__conv__", + "diff_order_symbol": "__d", + "simplify_expression": "sympy.logcombine(sympy.powsimp(sympy.expand(expr)))" + } + + def __init__(self, options: Optional[Mapping[str, Any]] = None): + super(Transformer, self).__init__(options) + + # ODE-toolbox printers + self._constant_printer = ConstantPrinter() + self._ode_toolbox_variable_printer = ODEToolboxVariablePrinter(None) + self._ode_toolbox_function_call_printer = ODEToolboxFunctionCallPrinter(None) + self._ode_toolbox_printer = ODEToolboxExpressionPrinter(simple_expression_printer=UnitlessSympySimpleExpressionPrinter(variable_printer=self._ode_toolbox_variable_printer, + constant_printer=self._constant_printer, + function_call_printer=self._ode_toolbox_function_call_printer)) + self._ode_toolbox_variable_printer._expression_printer = self._ode_toolbox_printer + self._ode_toolbox_function_call_printer._expression_printer = self._ode_toolbox_printer + + def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]: + r"""Transform a model or a list of models. Return an updated model or list of models.""" + for model in models: + print("-------- MODEL BEFORE TRANSFORM ------------") + print(model) + kernel_buffers = self.generate_kernel_buffers(model) + odetoolbox_indict = self.transform_kernels_to_json(model, kernel_buffers) + print("odetoolbox indict: " + str(odetoolbox_indict)) + solvers_json, shape_sys, shapes = odetoolbox._analysis(odetoolbox_indict, + disable_stiffness_check=True, + disable_analytic_solver=True, + preserve_expressions=True, + simplify_expression=self.get_option("simplify_expression"), + log_level=FrontendConfiguration.logging_level) + print("odetoolbox outdict: " + str(solvers_json)) + + self.remove_initial_values_for_kernels(model) + self.create_initial_values_for_kernels(model, solvers_json, kernel_buffers) + self.create_spike_update_event_handlers(model, solvers_json, kernel_buffers) + self.replace_convolve_calls_with_buffers_(model) + self.remove_kernel_definitions_from_equations_blocks(model) + print("-------- MODEL AFTER TRANSFORM ------------") + print(model) + print("-------------------------------------------") + + return models + + def construct_kernel_spike_buf_name(self, kernel_var_name: str, spike_input_port: ASTInputPort, order: int, diff_order_symbol: Optional[str] = None): + """ + Construct a kernel-buffer name as ``KERNEL_NAME__conv__INPUT_PORT_NAME`` + + For example, if the kernel is + .. code-block:: + kernel I_kernel = exp(-t / tau_x) + + and the input port is + .. code-block:: + pre_spikes nS <- spike + + then the constructed variable will be ``I_kernel__conv__pre_pikes`` + """ + assert type(kernel_var_name) is str + assert type(order) is int + + if isinstance(spike_input_port, ASTSimpleExpression): + spike_input_port = spike_input_port.get_variable() + + if not isinstance(spike_input_port, str): + spike_input_port_name = spike_input_port.get_name() + else: + spike_input_port_name = spike_input_port + + if isinstance(spike_input_port, ASTVariable): + if spike_input_port.has_vector_parameter(): + spike_input_port_name += "_" + str(self.get_numeric_vector_size(spike_input_port)) + + if not diff_order_symbol: + diff_order_symbol = self.get_option("diff_order_symbol") + + return kernel_var_name.replace("$", "__DOLLAR") + self.get_option("convolution_separator") + spike_input_port_name + diff_order_symbol * order + + def replace_rhs_variable(self, expr: ASTExpression, variable_name_to_replace: str, kernel_var: ASTVariable, + spike_buf: ASTInputPort): + """ + Replace variable names in definitions of kernel dynamics + :param expr: expression in which to replace the variables + :param variable_name_to_replace: variable name to replace in the expression + :param kernel_var: kernel variable instance + :param spike_buf: input port instance + :return: + """ + def replace_kernel_var(node): + if type(node) is ASTSimpleExpression \ + and node.is_variable() \ + and node.get_variable().get_name() == variable_name_to_replace: + var_order = node.get_variable().get_differential_order() + new_variable_name = cls.construct_kernel_X_spike_buf_name( + kernel_var.get_name(), spike_buf, var_order - 1, diff_order_symbol="'") + new_variable = ASTVariable(new_variable_name, var_order) + new_variable.set_source_position(node.get_variable().get_source_position()) + node.set_variable(new_variable) + + expr.accept(ASTHigherOrderVisitor(visit_funcs=replace_kernel_var)) + + def replace_rhs_variables(self, expr: ASTExpression, kernel_buffers: Mapping[ASTKernel, ASTInputPort]): + """ + Replace variable names in definitions of kernel dynamics. + + Say that the kernel is + + .. code-block:: + + G = -G / tau + + Its variable symbol might be replaced by "G__conv__spikesEx": + + .. code-block:: + + G__conv__spikesEx = -G / tau + + This function updates the right-hand side of `expr` so that it would also read (in this example): + + .. code-block:: + + G__conv__spikesEx = -G__conv__spikesEx / tau + + These equations will later on be fed to ode-toolbox, so we use the symbol "'" to indicate differential order. + + Note that for kernels/systems of ODE of dimension > 1, all variable orders and all variables for this kernel will already be present in `kernel_buffers`. + """ + for kernel, spike_buf in kernel_buffers: + for kernel_var in kernel.get_variables(): + variable_name_to_replace = kernel_var.get_name() + self.replace_rhs_variable(expr, variable_name_to_replace=variable_name_to_replace, + kernel_var=kernel_var, spike_buf=spike_buf) + + @classmethod + def remove_initial_values_for_kernels(cls, model: ASTModel) -> None: + r""" + Remove initial values for original declarations (e.g. g_in, g_in', V_m); these will be replaced with the initial value expressions returned from ODE-toolbox. + """ + symbols_to_remove = set() + for equations_block in model.get_equations_blocks(): + for kernel in equations_block.get_kernels(): + for kernel_var in kernel.get_variables(): + kernel_var_order = kernel_var.get_differential_order() + for order in range(kernel_var_order): + symbol_name = kernel_var.get_name() + "'" * order + symbols_to_remove.add(symbol_name) + + decl_to_remove = set() + for symbol_name in symbols_to_remove: + for state_block in model.get_state_blocks(): + for decl in state_block.get_declarations(): + if len(decl.get_variables()) == 1: + if decl.get_variables()[0].get_name() == symbol_name: + decl_to_remove.add(decl) + else: + for var in decl.get_variables(): + if var.get_name() == symbol_name: + decl.variables.remove(var) + + for decl in decl_to_remove: + for state_block in model.get_state_blocks(): + if decl in state_block.get_declarations(): + state_block.get_declarations().remove(decl) + + def create_initial_values_for_kernels(self, model: ASTModel, solver_dicts: List[Dict], kernels: List[ASTKernel]) -> None: + r""" + Add the variables used in kernels from the ode-toolbox result dictionary as ODEs in NESTML AST + """ + for solver_dict in solver_dicts: + if solver_dict is None: + continue + + for var_name, expr in solver_dict["initial_values"].items(): + spike_in_port_name = var_name.split(self.get_option("convolution_separator"))[1] + spike_in_port_name = spike_in_port_name.split("__d")[0] + spike_in_port = ASTUtils.get_input_port_by_name(model.get_input_blocks(), spike_in_port_name) + type_str = "real" + if spike_in_port: + differential_order: int = len(re.findall("__d", var_name)) + if differential_order: + type_str = "(s**-" + str(differential_order) + ")" + + expr = "0 " + type_str # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is 0 (property of the convolution) + if not ASTUtils.declaration_in_state_block(model, var_name): + ASTUtils.add_declaration_to_state_block(model, var_name, expr, type_str) + + def is_delta_kernel(self, kernel: ASTKernel) -> bool: + """ + Catches definition of kernel, or reference (function call or variable name) of a delta kernel function. + """ + if not isinstance(kernel, ASTKernel): + return False + + if len(kernel.get_variables()) != 1: + # delta kernel not allowed if more than one variable is defined in this kernel + return False + + expr = kernel.get_expressions()[0] + + rhs_is_delta_kernel = type(expr) is ASTSimpleExpression \ + and expr.is_function_call() \ + and expr.get_function_call().get_scope().resolve_to_symbol(expr.get_function_call().get_name(), SymbolKind.FUNCTION).equals(PredefinedFunctions.name2function["delta"]) + + rhs_is_multiplied_delta_kernel = type(expr) is ASTExpression \ + and type(expr.get_rhs()) is ASTSimpleExpression \ + and expr.get_rhs().is_function_call() \ + and expr.get_rhs().get_function_call().get_scope().resolve_to_symbol(expr.get_rhs().get_function_call().get_name(), SymbolKind.FUNCTION).equals(PredefinedFunctions.name2function["delta"]) + + return rhs_is_delta_kernel or rhs_is_multiplied_delta_kernel + + def replace_convolve_calls_with_buffers_(self, model: ASTModel) -> None: + r""" + Replace all occurrences of `convolve(kernel[']^n, spike_input_port)` with the corresponding buffer variable, e.g. `g_E__X__spikes_exc[__d]^n` for a kernel named `g_E` and a spike input port named `spikes_exc`. + """ + + def replace_function_call_through_var(_expr=None): + if _expr.is_function_call() and _expr.get_function_call().get_name() == "convolve": + convolve = _expr.get_function_call() + el = (convolve.get_args()[0], convolve.get_args()[1]) + sym = convolve.get_args()[0].get_scope().resolve_to_symbol( + convolve.get_args()[0].get_variable().name, SymbolKind.VARIABLE) + if sym.block_type == BlockType.INPUT: + # swap elements + el = (el[1], el[0]) + var = el[0].get_variable() + spike_input_port = el[1].get_variable() + kernel = model.get_kernel_by_name(var.get_name()) + + _expr.set_function_call(None) + buffer_var = self.construct_kernel_spike_buf_name( + var.get_name(), spike_input_port, var.get_differential_order() - 1) + if self.is_delta_kernel(kernel): + # delta kernels are treated separately, and should be kept out of the dynamics (computing derivates etc.) --> set to zero + _expr.set_variable(None) + _expr.set_numeric_literal(0) + else: + ast_variable = ASTVariable(buffer_var) + ast_variable.set_source_position(_expr.get_source_position()) + _expr.set_variable(ast_variable) + + def func(x): + return replace_function_call_through_var(x) if isinstance(x, ASTSimpleExpression) else True + + for equations_block in model.get_equations_blocks(): + equations_block.accept(ASTHigherOrderVisitor(func)) + + @classmethod + def replace_convolution_aliasing_inlines(cls, neuron: ASTModel) -> None: + """ + Replace all occurrences of kernel names (e.g. ``I_dend`` and ``I_dend'`` for a definition involving a second-order kernel ``inline kernel I_dend = convolve(kern_name, spike_buf)``) with the ODE-toolbox generated variable ``kern_name__X__spike_buf``. + """ + def replace_var(_expr, replace_var_name: str, replace_with_var_name: str): + if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): + var = _expr.get_variable() + if var.get_name() == replace_var_name: + ast_variable = ASTVariable(replace_with_var_name + '__d' * var.get_differential_order(), + differential_order=0) + ast_variable.set_source_position(var.get_source_position()) + _expr.set_variable(ast_variable) + + elif isinstance(_expr, ASTVariable): + var = _expr + if var.get_name() == replace_var_name: + var.set_name(replace_with_var_name + '__d' * var.get_differential_order()) + var.set_differential_order(0) + + for equation_block in neuron.get_equations_blocks(): + for decl in equation_block.get_declarations(): + if isinstance(decl, ASTInlineExpression): + expr = decl.get_expression() + if isinstance(expr, ASTExpression): + expr = expr.get_lhs() + + if isinstance(expr, ASTSimpleExpression) \ + and '__X__' in str(expr) \ + and expr.get_variable(): + replace_with_var_name = expr.get_variable().get_name() + neuron.accept(ASTHigherOrderVisitor(lambda x: replace_var( + x, decl.get_variable_name(), replace_with_var_name))) + + def generate_kernel_buffers(self, model: ASTModel) -> Mapping[ASTKernel, ASTInputPort]: + r""" + For every occurrence of a convolution of the form `convolve(var, spike_buf)`: add the element `(kernel, spike_buf)` to the set, with `kernel` being the kernel that contains variable `var`. + """ + kernel_buffers = set() + for equations_block in model.get_equations_blocks(): + convolve_calls = ASTUtils.get_convolve_function_calls(equations_block) + for convolve in convolve_calls: + el = (convolve.get_args()[0], convolve.get_args()[1]) + sym = convolve.get_args()[0].get_scope().resolve_to_symbol(convolve.get_args()[0].get_variable().name, SymbolKind.VARIABLE) + if sym is None: + raise Exception("No initial value(s) defined for kernel with variable \"" + + convolve.get_args()[0].get_variable().get_complete_name() + "\"") + if sym.block_type == BlockType.INPUT: + # swap the order + el = (el[1], el[0]) + + # find the corresponding kernel object + var = el[0].get_variable() + assert var is not None + kernel = model.get_kernel_by_name(var.get_name()) + assert kernel is not None, "In convolution \"convolve(" + str(var.name) + ", " + str( + el[1]) + ")\": no kernel by name \"" + var.get_name() + "\" found in model." + + el = (kernel, el[1]) + kernel_buffers.add(el) + + return kernel_buffers + + def remove_kernel_definitions_from_equations_blocks(self, model: ASTModel) -> ASTDeclaration: + r""" + Removes all kernels in equations blocks. + """ + for equations_block in model.get_equations_blocks(): + decl_to_remove = set() + for decl in equations_block.get_declarations(): + if type(decl) is ASTKernel: + decl_to_remove.add(decl) + + for decl in decl_to_remove: + equations_block.get_declarations().remove(decl) + + def transform_kernels_to_json(self, model: ASTModel, kernel_buffers: List[Tuple[ASTKernel, ASTInputPort]]) -> Dict: + """ + Converts AST node to a JSON representation suitable for passing to ode-toolbox. + + Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements + + .. code-block:: + + convolve(G, exc_spikes) + convolve(G, inh_spikes) + + then `kernel_buffers` will contain the pairs `(G, exc_spikes)` and `(G, inh_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__exc_spikes` and `G__X__inh_spikes`. + """ + odetoolbox_indict = {} + odetoolbox_indict["dynamics"] = [] + + for kernel, spike_input_port in kernel_buffers: + + if self.is_delta_kernel(kernel): + # delta function -- skip passing this to ode-toolbox + continue + + for kernel_var in kernel.get_variables(): + expr = ASTUtils.get_expr_from_kernel_var(kernel, kernel_var.get_complete_name()) + kernel_order = kernel_var.get_differential_order() + kernel_X_spike_buf_name_ticks = self.construct_kernel_spike_buf_name(kernel_var.get_name(), spike_input_port, kernel_order, diff_order_symbol="'") + + self.replace_rhs_variables(expr, kernel_buffers) + + entry = {"expression": kernel_X_spike_buf_name_ticks + " = " + str(expr), "initial_values": {}} + + # initial values need to be declared for order 1 up to kernel order (e.g. none for kernel function + # f(t) = ...; 1 for kernel ODE f'(t) = ...; 2 for f''(t) = ... and so on) + for order in range(kernel_order): + iv_sym_name_ode_toolbox = self.construct_kernel_spike_buf_name(kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'") + symbol_name_ = kernel_var.get_name() + "'" * order + symbol = model.get_scope().resolve_to_symbol(symbol_name_, SymbolKind.VARIABLE) + assert symbol is not None, "Could not find initial value for variable " + symbol_name_ + initial_value_expr = symbol.get_declaring_expression() + assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_ + entry["initial_values"][iv_sym_name_ode_toolbox] = self._ode_toolbox_printer.print(initial_value_expr) + + odetoolbox_indict["dynamics"].append(entry) + + odetoolbox_indict["parameters"] = {} + for parameters_block in model.get_parameters_blocks(): + for decl in parameters_block.get_declarations(): + for var in decl.variables: + odetoolbox_indict["parameters"][var.get_complete_name()] = self._ode_toolbox_printer.print(decl.get_expression()) + + return odetoolbox_indict + + def create_spike_update_event_handlers(self, model: ASTModel, solver_dicts, kernel_buffers: List[Tuple[ASTKernel, ASTInputPort]]) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment]]: + r""" + Generate the equations that update the dynamical variables when incoming spikes arrive. To be invoked after + ode-toolbox. + + For example, a resulting `assignment_str` could be "I_kernel_in += (inh_spikes/nS) * 1". The values are taken from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from user specification in the model. + from the initial values for each corresponding dynamical variable, either from ode-toolbox or directly from + user specification in the model. + + Note that for kernels, `initial_values` actually contains the increment upon spike arrival, rather than the + initial value of the corresponding ODE dimension. + """ + + spike_in_port_to_stmts = {} + for solver_dict in solver_dicts: + for var, expr in solver_dict["initial_values"].items(): + expr = str(expr) + if expr in ["0", "0.", "0.0"]: + continue # skip adding the statement if we are only adding zero + + spike_in_port_name = var.split(self.get_option("convolution_separator"))[1] + spike_in_port_name = spike_in_port_name.split("__d")[0] + spike_in_port = ASTUtils.get_input_port_by_name(model.get_input_blocks(), spike_in_port_name) + type_str = "real" + + assert spike_in_port + differential_order: int = len(re.findall("__d", var)) + if differential_order: + type_str = "(s**-" + str(differential_order) + ")" + + assignment_str = var + " += " + assignment_str += "(" + str(spike_in_port_name) + ")" + if not expr in ["1.", "1.0", "1"]: + assignment_str += " * (" + expr + ")" + + ast_assignment = ModelParser.parse_assignment(assignment_str) + ast_assignment.update_scope(model.get_scope()) + ast_assignment.accept(ASTSymbolTableVisitor()) + + ast_small_stmt = ASTNodeFactory.create_ast_small_stmt(assignment=ast_assignment) + ast_stmt = ASTNodeFactory.create_ast_stmt(small_stmt=ast_small_stmt) + + if not spike_in_port_name in spike_in_port_to_stmts.keys(): + spike_in_port_to_stmts[spike_in_port_name] = [] + + spike_in_port_to_stmts[spike_in_port_name].append(ast_stmt) + + # for every input port, add an onreceive block with its update statements + for in_port, stmts in spike_in_port_to_stmts.items(): + stmts_block = ASTNodeFactory.create_ast_block(stmts, ASTSourceLocation.get_added_source_position()) + on_receive_block = ASTNodeFactory.create_ast_on_receive_block(stmts_block, + in_port, + const_parameters=None, # XXX: TODO: add priority here! + source_position=ASTSourceLocation.get_added_source_position()) + + model.get_body().get_body_elements().append(on_receive_block) + + model.accept(ASTParentVisitor()) diff --git a/pynestml/transformers/synapse_post_neuron_transformer.py b/pynestml/transformers/synapse_post_neuron_transformer.py index 5dd4aa3e0..ccaefd743 100644 --- a/pynestml/transformers/synapse_post_neuron_transformer.py +++ b/pynestml/transformers/synapse_post_neuron_transformer.py @@ -24,13 +24,8 @@ from typing import Any, Sequence, Mapping, Optional, Union from pynestml.frontend.frontend_configuration import FrontendConfiguration -from pynestml.meta_model.ast_assignment import ASTAssignment -from pynestml.meta_model.ast_equations_block import ASTEquationsBlock -from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.meta_model.ast_model import ASTModel from pynestml.meta_model.ast_node import ASTNode -from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression -from pynestml.meta_model.ast_variable import ASTVariable from pynestml.symbols.symbol import SymbolKind from pynestml.symbols.variable_symbol import BlockType from pynestml.transformers.transformer import Transformer @@ -40,7 +35,6 @@ from pynestml.utils.string_utils import removesuffix from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor -from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor from pynestml.visitors.ast_visitor import ASTVisitor @@ -149,51 +143,6 @@ def get_neuron_var_name_from_syn_port_name(self, port_name: str, neuron_name: st return None - def get_convolve_with_not_post_vars(self, nodes: Union[ASTEquationsBlock, Sequence[ASTEquationsBlock]], neuron_name: str, synapse_name: str, parent_node: ASTNode): - class ASTVariablesUsedInConvolutionVisitor(ASTVisitor): - _variables = [] - - def __init__(self, node: ASTNode, parent_node: ASTNode, codegen_class): - super(ASTVariablesUsedInConvolutionVisitor, self).__init__() - self.node = node - self.parent_node = parent_node - self.codegen_class = codegen_class - - def visit_function_call(self, node): - func_name = node.get_name() - if func_name == "convolve": - symbol_buffer = node.get_scope().resolve_to_symbol(str(node.get_args()[1]), - SymbolKind.VARIABLE) - input_port = ASTUtils.get_input_port_by_name( - self.parent_node.get_input_blocks(), symbol_buffer.name) - if input_port and not self.codegen_class.is_post_port(input_port.name, neuron_name, synapse_name): - kernel_name = node.get_args()[0].get_variable().name - self._variables.append(kernel_name) - - found_parent_assignment = False - node_ = node - while not found_parent_assignment: - node_ = node_.get_parent() - # XXX TODO also needs to accept normal ASTExpression, ASTAssignment? - if isinstance(node_, ASTInlineExpression): - found_parent_assignment = True - var_name = node_.get_variable_name() - self._variables.append(var_name) - - if not nodes: - return [] - - if isinstance(nodes, ASTNode): - nodes = [nodes] - - variables = [] - for node in nodes: - visitor = ASTVariablesUsedInConvolutionVisitor(node, parent_node, self) - node.accept(visitor) - variables.extend(visitor._variables) - - return variables - def get_all_variables_assigned_to(self, node): r"""Return a list of all variables that are assigned to in ``node``.""" class ASTAssignedToVariablesFinderVisitor(ASTVisitor): @@ -257,13 +206,6 @@ def transform_neuron_synapse_pair_(self, neuron, synapse): all_state_vars = [var.get_complete_name() for var in all_state_vars] - # add names of convolutions - all_state_vars += ASTUtils.get_all_variables_used_in_convolutions(synapse.get_equations_blocks(), synapse) - - # add names of kernels - kernel_buffers = ASTUtils.generate_kernel_buffers(synapse, synapse.get_equations_blocks()) - all_state_vars += [var.name for k in kernel_buffers for var in k[0].variables] - # exclude certain variables from being moved: # exclude any variable assigned to in any block that is not connected to a postsynaptic port strictly_synaptic_vars = ["t"] # "seed" this with the predefined variable t @@ -276,28 +218,25 @@ def transform_neuron_synapse_pair_(self, neuron, synapse): for update_block in synapse.get_update_blocks(): strictly_synaptic_vars += self.get_all_variables_assigned_to(update_block) - # exclude convolutions if they are not with a postsynaptic variable - convolve_with_not_post_vars = self.get_convolve_with_not_post_vars(synapse.get_equations_blocks(), neuron.name, synapse.name, synapse) - # exclude all variables that depend on the ones that are not to be moved strictly_synaptic_vars_dependent = ASTUtils.recursive_dependent_variables_search(strictly_synaptic_vars, synapse) # do set subtraction - syn_to_neuron_state_vars = list(set(all_state_vars) - (set(strictly_synaptic_vars) | set(convolve_with_not_post_vars) | set(strictly_synaptic_vars_dependent))) + syn_to_neuron_state_vars = list(set(all_state_vars) - (set(strictly_synaptic_vars) | set(strictly_synaptic_vars_dependent))) # - # collect all the variable/parameter/kernel/function/etc. names used in defining expressions of `syn_to_neuron_state_vars` + # collect all the variable/parameter/function/etc. names used in defining expressions of `syn_to_neuron_state_vars` # recursive_vars_used = ASTUtils.recursive_necessary_variables_search(syn_to_neuron_state_vars, synapse) new_neuron.recursive_vars_used = recursive_vars_used new_neuron._transferred_variables = [neuron_state_var + var_name_suffix - for neuron_state_var in syn_to_neuron_state_vars if new_synapse.get_kernel_by_name(neuron_state_var) is None] + for neuron_state_var in syn_to_neuron_state_vars] # all state variables that will be moved from synapse to neuron syn_to_neuron_state_vars = [] for var_name in recursive_vars_used: - if ASTUtils.get_state_variable_by_name(synapse, var_name) or ASTUtils.get_inline_expression_by_name(synapse, var_name) or ASTUtils.get_kernel_by_name(synapse, var_name): + if ASTUtils.get_state_variable_by_name(synapse, var_name) or ASTUtils.get_inline_expression_by_name(synapse, var_name): syn_to_neuron_state_vars.append(var_name) Logger.log_message(None, -1, "State variables that will be moved from synapse to neuron: " + str(syn_to_neuron_state_vars), @@ -393,33 +332,6 @@ def transform_neuron_synapse_pair_(self, neuron, synapse): block_type=BlockType.STATE, mode="move") - # - # mark variables in the neuron pertaining to synapse postsynaptic ports - # - # convolutions with them ultimately yield variable updates when post neuron calls emit_spike() - # - - def mark_post_ports(neuron, synapse, mark_node): - post_ports = [] - - def mark_post_port(_expr=None): - var = None - if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): - var = _expr.get_variable() - elif isinstance(_expr, ASTVariable): - var = _expr - - if var: - var_base_name = var.name[:-len(var_name_suffix)] # prune the suffix - if self.is_post_port(var_base_name, neuron.name, synapse.name): - post_ports.append(var) - var._is_post_port = True - - mark_node.accept(ASTHigherOrderVisitor(lambda x: mark_post_port(x))) - return post_ports - - mark_post_ports(new_neuron, new_synapse, new_neuron) - # # move statements in post receive block from synapse to new_neuron # @@ -561,6 +473,13 @@ def mark_post_port(_expr=None): return new_neuron, new_synapse def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]: + + # check that there are no convolutions or kernels in the model (these should have been transformed out by the ConvolutionsTransformer) + for model in models: + for equations_block in model.get_equations_blocks(): + assert len(equations_block.get_kernels()) == 0, "Kernels and convolutions should have been removed by ConvolutionsTransformer" + + # transform each (neuron, synapse) pair for neuron_synapse_pair in self.get_option("neuron_synapse_pairs"): neuron_name = neuron_synapse_pair["neuron"] synapse_name = neuron_synapse_pair["synapse"] diff --git a/pynestml/transformers/transformer.py b/pynestml/transformers/transformer.py index 7144a9bec..06dbb37c5 100644 --- a/pynestml/transformers/transformer.py +++ b/pynestml/transformers/transformer.py @@ -39,4 +39,5 @@ def __init__(self, options: Optional[Mapping[str, Any]]=None): @abstractmethod def transform(self, model: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]: + r"""Transform a model or a list of models. Return an updated model or list of models.""" assert False diff --git a/pynestml/utils/ast_synapse_information_collector.py b/pynestml/utils/ast_synapse_information_collector.py index f5a6763bc..e44a539a1 100644 --- a/pynestml/utils/ast_synapse_information_collector.py +++ b/pynestml/utils/ast_synapse_information_collector.py @@ -207,7 +207,7 @@ def get_basic_kernel_variable_names(self, synapse_inline): kernel_name = kernel_var.get_name() spike_input_port = self.input_port_name_to_input_port[spike_var.get_name( )] - kernel_variable_name = self.construct_kernel_X_spike_buf_name( + kernel_variable_name = self.construct_kernel_spike_buf_name( kernel_name, spike_input_port, order) results.append(kernel_variable_name) @@ -338,12 +338,3 @@ def visit_expression(self, node): def endvisit_expression(self, node): self.inside_expression = False self.current_expression = None - - # this method was copied over from ast_transformer - # in order to avoid a circular dependency - @staticmethod - def construct_kernel_X_spike_buf_name(kernel_var_name: str, spike_input_port, order: int, diff_order_symbol="__d"): - assert type(kernel_var_name) is str - assert type(order) is int - assert type(diff_order_symbol) is str - return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + str(spike_input_port) + diff_order_symbol * order diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index 9a67054a1..a8755bb89 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -458,6 +458,17 @@ def create_equations_block(cls, model: ASTModel) -> ASTModel: model.get_body().get_body_elements().append(block) return model + @classmethod + def create_on_receive_block(cls, model: ASTModel, block: ASTBlock, input_port_name: str) -> ASTModel: + """ + Creates a single onReceive block in the handed over model. + :param model: a single model + :return: the modified model + """ + # local import since otherwise circular dependency + + return model + @classmethod def contains_convolve_call(cls, variable: VariableSymbol) -> bool: """ @@ -1262,7 +1273,7 @@ def to_ode_toolbox_name(cls, name: str) -> str: @classmethod def get_expr_from_kernel_var(cls, kernel: ASTKernel, var_name: str) -> Union[ASTExpression, ASTSimpleExpression]: - """ + r""" Get the expression using the kernel variable """ assert isinstance(var_name, str) @@ -1277,122 +1288,9 @@ def all_convolution_variable_names(cls, model: ASTModel) -> List[str]: var_names = [var.get_complete_name() for var in vars if "__X__" in var.get_complete_name()] return var_names - @classmethod - def construct_kernel_X_spike_buf_name(cls, kernel_var_name: str, spike_input_port: ASTInputPort, order: int, - diff_order_symbol="__d"): - """ - Construct a kernel-buffer name as - - For example, if the kernel is - .. code-block:: - kernel I_kernel = exp(-t / tau_x) - - and the input port is - .. code-block:: - pre_spikes nS <- spike - - then the constructed variable will be 'I_kernel__X__pre_pikes' - """ - assert type(kernel_var_name) is str - assert type(order) is int - assert type(diff_order_symbol) is str - - if isinstance(spike_input_port, ASTSimpleExpression): - spike_input_port = spike_input_port.get_variable() - - if not isinstance(spike_input_port, str): - spike_input_port_name = spike_input_port.get_name() - else: - spike_input_port_name = spike_input_port - - if isinstance(spike_input_port, ASTVariable): - if spike_input_port.has_vector_parameter(): - spike_input_port_name += "_" + str(cls.get_numeric_vector_size(spike_input_port)) - - return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + spike_input_port_name + diff_order_symbol * order - - @classmethod - def replace_rhs_variable(cls, expr: ASTExpression, variable_name_to_replace: str, kernel_var: ASTVariable, - spike_buf: ASTInputPort): - """ - Replace variable names in definitions of kernel dynamics - :param expr: expression in which to replace the variables - :param variable_name_to_replace: variable name to replace in the expression - :param kernel_var: kernel variable instance - :param spike_buf: input port instance - :return: - """ - def replace_kernel_var(node): - if type(node) is ASTSimpleExpression \ - and node.is_variable() \ - and node.get_variable().get_name() == variable_name_to_replace: - var_order = node.get_variable().get_differential_order() - new_variable_name = cls.construct_kernel_X_spike_buf_name( - kernel_var.get_name(), spike_buf, var_order - 1, diff_order_symbol="'") - new_variable = ASTVariable(new_variable_name, var_order) - new_variable.set_source_position(node.get_variable().get_source_position()) - node.set_variable(new_variable) - - expr.accept(ASTHigherOrderVisitor(visit_funcs=replace_kernel_var)) - - @classmethod - def replace_rhs_variables(cls, expr: ASTExpression, kernel_buffers: Mapping[ASTKernel, ASTInputPort]): - """ - Replace variable names in definitions of kernel dynamics. - - Say that the kernel is - - .. code-block:: - - G = -G / tau - - Its variable symbol might be replaced by "G__X__spikesEx": - - .. code-block:: - - G__X__spikesEx = -G / tau - - This function updates the right-hand side of `expr` so that it would also read (in this example): - - .. code-block:: - - G__X__spikesEx = -G__X__spikesEx / tau - - These equations will later on be fed to ode-toolbox, so we use the symbol "'" to indicate differential order. - - Note that for kernels/systems of ODE of dimension > 1, all variable orders and all variables for this kernel will already be present in `kernel_buffers`. - """ - for kernel, spike_buf in kernel_buffers: - for kernel_var in kernel.get_variables(): - variable_name_to_replace = kernel_var.get_name() - cls.replace_rhs_variable(expr, variable_name_to_replace=variable_name_to_replace, - kernel_var=kernel_var, spike_buf=spike_buf) - - @classmethod - def is_delta_kernel(cls, kernel: ASTKernel) -> bool: - """ - Catches definition of kernel, or reference (function call or variable name) of a delta kernel function. - """ - if type(kernel) is ASTKernel: - if not len(kernel.get_variables()) == 1: - # delta kernel not allowed if more than one variable is defined in this kernel - return False - expr = kernel.get_expressions()[0] - else: - expr = kernel - - rhs_is_delta_kernel = type(expr) is ASTSimpleExpression \ - and expr.is_function_call() \ - and expr.get_function_call().get_scope().resolve_to_symbol(expr.get_function_call().get_name(), SymbolKind.FUNCTION).equals(PredefinedFunctions.name2function["delta"]) - rhs_is_multiplied_delta_kernel = type(expr) is ASTExpression \ - and type(expr.get_rhs()) is ASTSimpleExpression \ - and expr.get_rhs().is_function_call() \ - and expr.get_rhs().get_function_call().get_scope().resolve_to_symbol(expr.get_rhs().get_function_call().get_name(), SymbolKind.FUNCTION).equals(PredefinedFunctions.name2function["delta"]) - return rhs_is_delta_kernel or rhs_is_multiplied_delta_kernel - @classmethod def get_input_port_by_name(cls, input_blocks: List[ASTInputBlock], port_name: str) -> ASTInputPort: - """ + r""" Get the input port given the port name :param input_block: block to be searched :param port_name: name of the input port @@ -1407,8 +1305,10 @@ def get_input_port_by_name(cls, input_blocks: List[ASTInputBlock], port_name: st port_name, port_index = port_name.split("_") assert int(port_index) > 0 assert int(port_index) <= size_parameter + if input_port.name == port_name: return input_port + return None @classmethod @@ -1424,6 +1324,7 @@ def get_parameter_by_name(cls, node: ASTModel, var_name: str) -> ASTDeclaration: for var in decl.get_variables(): if var.get_name() == var_name: return decl + return None @classmethod @@ -1629,37 +1530,6 @@ def recursive_necessary_variables_search(cls, vars: List[str], model: ASTModel) return list(set(vars_used)) - @classmethod - def remove_initial_values_for_kernels(cls, model: ASTModel) -> None: - """ - Remove initial values for original declarations (e.g. g_in, g_in', V_m); these might conflict with the initial value expressions returned from ODE-toolbox. - """ - symbols_to_remove = set() - for equations_block in model.get_equations_blocks(): - for kernel in equations_block.get_kernels(): - for kernel_var in kernel.get_variables(): - kernel_var_order = kernel_var.get_differential_order() - for order in range(kernel_var_order): - symbol_name = kernel_var.get_name() + "'" * order - symbols_to_remove.add(symbol_name) - - decl_to_remove = set() - for symbol_name in symbols_to_remove: - for state_block in model.get_state_blocks(): - for decl in state_block.get_declarations(): - if len(decl.get_variables()) == 1: - if decl.get_variables()[0].get_name() == symbol_name: - decl_to_remove.add(decl) - else: - for var in decl.get_variables(): - if var.get_name() == symbol_name: - decl.variables.remove(var) - - for decl in decl_to_remove: - for state_block in model.get_state_blocks(): - if decl in state_block.get_declarations(): - state_block.get_declarations().remove(decl) - @classmethod def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict]) -> None: """ @@ -1888,53 +1758,12 @@ def _visit(self, node): return visitor.calls @classmethod - def create_initial_values_for_kernels(cls, model: ASTModel, solver_dicts: List[Dict], kernels: List[ASTKernel]) -> None: - r""" - Add the variables used in kernels from the ode-toolbox result dictionary as ODEs in NESTML AST - """ - for solver_dict in solver_dicts: - if solver_dict is None: - continue - - for var_name in solver_dict["initial_values"].keys(): - if cls.variable_in_kernels(var_name, kernels): - # original initial value expressions should have been removed to make place for ode-toolbox results - assert not cls.declaration_in_state_block(model, var_name) - - for solver_dict in solver_dicts: - if solver_dict is None: - continue - - for var_name, expr in solver_dict["initial_values"].items(): - # overwrite is allowed because initial values might be repeated between numeric and analytic solver - if cls.variable_in_kernels(var_name, kernels): - spike_in_port_name = var_name.split("__X__")[1] - spike_in_port_name = spike_in_port_name.split("__d")[0] - spike_in_port = ASTUtils.get_input_port_by_name(model.get_input_blocks(), spike_in_port_name) - type_str = "real" - if spike_in_port: - differential_order: int = len(re.findall("__d", var_name)) - if differential_order: - type_str = "(s**-" + str(differential_order) + ")" - - expr = "0 " + type_str # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is 0 (property of the convolution) - if not cls.declaration_in_state_block(model, var_name): - cls.add_declaration_to_state_block(model, var_name, expr, type_str) - - @classmethod - def transform_ode_and_kernels_to_json(cls, model: ASTModel, parameters_blocks: Sequence[ASTBlockWithVariables], - kernel_buffers: Mapping[ASTKernel, ASTInputPort], printer: ASTPrinter) -> Dict: + def transform_odes_to_json(cls, + model: ASTModel, + parameters_blocks: Sequence[ASTBlockWithVariables], + printer: ASTPrinter) -> Dict: """ - Converts AST node to a JSON representation suitable for passing to ode-toolbox. - - Each kernel has to be generated for each spike buffer convolve in which it occurs, e.g. if the NESTML model code contains the statements - - .. code-block:: - - convolve(G, exc_spikes) - convolve(G, inh_spikes) - - then `kernel_buffers` will contain the pairs `(G, exc_spikes)` and `(G, inh_spikes)`, from which two ODEs will be generated, with dynamical state (variable) names `G__X__exc_spikes` and `G__X__inh_spikes`. + Converts to a JSON representation suitable for passing to ode-toolbox. """ odetoolbox_indict = {} @@ -1959,37 +1788,6 @@ def transform_ode_and_kernels_to_json(cls, model: ASTModel, parameters_blocks: S odetoolbox_indict["dynamics"].append(entry) - # write a copy for each (kernel, spike buffer) combination - for kernel, spike_input_port in kernel_buffers: - - if cls.is_delta_kernel(kernel): - # delta function -- skip passing this to ode-toolbox - continue - - for kernel_var in kernel.get_variables(): - expr = cls.get_expr_from_kernel_var(kernel, kernel_var.get_complete_name()) - kernel_order = kernel_var.get_differential_order() - kernel_X_spike_buf_name_ticks = cls.construct_kernel_X_spike_buf_name( - kernel_var.get_name(), spike_input_port, kernel_order, diff_order_symbol="'") - - cls.replace_rhs_variables(expr, kernel_buffers) - - entry = {"expression": kernel_X_spike_buf_name_ticks + " = " + str(expr), "initial_values": {}} - - # initial values need to be declared for order 1 up to kernel order (e.g. none for kernel function - # f(t) = ...; 1 for kernel ODE f'(t) = ...; 2 for f''(t) = ... and so on) - for order in range(kernel_order): - iv_sym_name_ode_toolbox = cls.construct_kernel_X_spike_buf_name( - kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'") - symbol_name_ = kernel_var.get_name() + "'" * order - symbol = equations_block.get_scope().resolve_to_symbol(symbol_name_, SymbolKind.VARIABLE) - assert symbol is not None, "Could not find initial value for variable " + symbol_name_ - initial_value_expr = symbol.get_declaring_expression() - assert initial_value_expr is not None, "No initial value found for variable name " + symbol_name_ - entry["initial_values"][iv_sym_name_ode_toolbox] = printer.print(initial_value_expr) - - odetoolbox_indict["dynamics"].append(entry) - odetoolbox_indict["parameters"] = {} for parameters_block in parameters_blocks: for decl in parameters_block.get_declarations(): @@ -2079,52 +1877,6 @@ def log_set_source_position(node): return definitions - @classmethod - def get_delta_factors_(cls, neuron: ASTModel, equations_block: ASTEquationsBlock) -> dict: - r""" - For every occurrence of a convolution of the form `x^(n) = a * convolve(kernel, inport) + ...` where `kernel` is a delta function, add the element `(x^(n), inport) --> a` to the set. - """ - delta_factors = {} - - for ode_eq in equations_block.get_ode_equations(): - var = ode_eq.get_lhs() - expr = ode_eq.get_rhs() - conv_calls = ASTUtils.get_convolve_function_calls(expr) - for conv_call in conv_calls: - assert len( - conv_call.args) == 2, "convolve() function call should have precisely two arguments: kernel and spike input port" - kernel = conv_call.args[0] - if cls.is_delta_kernel(neuron.get_kernel_by_name(kernel.get_variable().get_name())): - inport = conv_call.args[1].get_variable() - expr_str = str(expr) - sympy_expr = sympy.parsing.sympy_parser.parse_expr(expr_str, global_dict=odetoolbox.Shape._sympy_globals) - sympy_expr = sympy.expand(sympy_expr) - sympy_conv_expr = sympy.parsing.sympy_parser.parse_expr(str(conv_call), global_dict=odetoolbox.Shape._sympy_globals) - factor_str = [] - for term in sympy.Add.make_args(sympy_expr): - if term.find(sympy_conv_expr): - factor_str.append(str(term.replace(sympy_conv_expr, 1))) - factor_str = " + ".join(factor_str) - delta_factors[(var, inport)] = factor_str - - return delta_factors - - @classmethod - def remove_kernel_definitions_from_equations_block(cls, model: ASTModel) -> ASTDeclaration: - r""" - Removes all kernels in equations blocks. - """ - for equations_block in model.get_equations_blocks(): - decl_to_remove = set() - for decl in equations_block.get_declarations(): - if type(decl) is ASTKernel: - decl_to_remove.add(decl) - - for decl in decl_to_remove: - equations_block.get_declarations().remove(decl) - - return decl_to_remove - @classmethod def add_timestep_symbol(cls, model: ASTModel) -> None: """ @@ -2137,70 +1889,6 @@ def add_timestep_symbol(cls, model: ASTModel) -> None: )], "\"__h\" is a reserved name, please do not use variables by this name in your NESTML file" model.add_to_internals_block(ModelParser.parse_declaration('__h ms = resolution()'), index=0) - @classmethod - def generate_kernel_buffers(cls, model: ASTModel, equations_block: Union[ASTEquationsBlock, List[ASTEquationsBlock]]) -> Mapping[ASTKernel, ASTInputPort]: - """ - For every occurrence of a convolution of the form `convolve(var, spike_buf)`: add the element `(kernel, spike_buf)` to the set, with `kernel` being the kernel that contains variable `var`. - """ - - kernel_buffers = set() - convolve_calls = ASTUtils.get_convolve_function_calls(equations_block) - for convolve in convolve_calls: - el = (convolve.get_args()[0], convolve.get_args()[1]) - sym = convolve.get_args()[0].get_scope().resolve_to_symbol(convolve.get_args()[0].get_variable().name, SymbolKind.VARIABLE) - if sym is None: - raise Exception("No initial value(s) defined for kernel with variable \"" - + convolve.get_args()[0].get_variable().get_complete_name() + "\"") - if sym.block_type == BlockType.INPUT: - # swap the order - el = (el[1], el[0]) - - # find the corresponding kernel object - var = el[0].get_variable() - assert var is not None - kernel = model.get_kernel_by_name(var.get_name()) - assert kernel is not None, "In convolution \"convolve(" + str(var.name) + ", " + str( - el[1]) + ")\": no kernel by name \"" + var.get_name() + "\" found in model." - - el = (kernel, el[1]) - kernel_buffers.add(el) - - return kernel_buffers - - @classmethod - def replace_convolution_aliasing_inlines(cls, neuron: ASTModel) -> None: - """ - Replace all occurrences of kernel names (e.g. ``I_dend`` and ``I_dend'`` for a definition involving a second-order kernel ``inline kernel I_dend = convolve(kern_name, spike_buf)``) with the ODE-toolbox generated variable ``kern_name__X__spike_buf``. - """ - def replace_var(_expr, replace_var_name: str, replace_with_var_name: str): - if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable(): - var = _expr.get_variable() - if var.get_name() == replace_var_name: - ast_variable = ASTVariable(replace_with_var_name + '__d' * var.get_differential_order(), - differential_order=0) - ast_variable.set_source_position(var.get_source_position()) - _expr.set_variable(ast_variable) - - elif isinstance(_expr, ASTVariable): - var = _expr - if var.get_name() == replace_var_name: - var.set_name(replace_with_var_name + '__d' * var.get_differential_order()) - var.set_differential_order(0) - - for equation_block in neuron.get_equations_blocks(): - for decl in equation_block.get_declarations(): - if isinstance(decl, ASTInlineExpression): - expr = decl.get_expression() - if isinstance(expr, ASTExpression): - expr = expr.get_lhs() - - if isinstance(expr, ASTSimpleExpression) \ - and '__X__' in str(expr) \ - and expr.get_variable(): - replace_with_var_name = expr.get_variable().get_name() - neuron.accept(ASTHigherOrderVisitor(lambda x: replace_var( - x, decl.get_variable_name(), replace_with_var_name))) - @classmethod def replace_variable_names_in_expressions(cls, model: ASTModel, solver_dicts: List[dict]) -> None: """ @@ -2229,42 +1917,6 @@ def func(x): model.accept(ASTHigherOrderVisitor(func)) - @classmethod - def replace_convolve_calls_with_buffers_(cls, model: ASTModel, equations_block: ASTEquationsBlock) -> None: - r""" - Replace all occurrences of `convolve(kernel[']^n, spike_input_port)` with the corresponding buffer variable, e.g. `g_E__X__spikes_exc[__d]^n` for a kernel named `g_E` and a spike input port named `spikes_exc`. - """ - - def replace_function_call_through_var(_expr=None): - if _expr.is_function_call() and _expr.get_function_call().get_name() == "convolve": - convolve = _expr.get_function_call() - el = (convolve.get_args()[0], convolve.get_args()[1]) - sym = convolve.get_args()[0].get_scope().resolve_to_symbol( - convolve.get_args()[0].get_variable().name, SymbolKind.VARIABLE) - if sym.block_type == BlockType.INPUT: - # swap elements - el = (el[1], el[0]) - var = el[0].get_variable() - spike_input_port = el[1].get_variable() - kernel = model.get_kernel_by_name(var.get_name()) - - _expr.set_function_call(None) - buffer_var = cls.construct_kernel_X_spike_buf_name( - var.get_name(), spike_input_port, var.get_differential_order() - 1) - if cls.is_delta_kernel(kernel): - # delta kernels are treated separately, and should be kept out of the dynamics (computing derivates etc.) --> set to zero - _expr.set_variable(None) - _expr.set_numeric_literal(0) - else: - ast_variable = ASTVariable(buffer_var) - ast_variable.set_source_position(_expr.get_source_position()) - _expr.set_variable(ast_variable) - - def func(x): - return replace_function_call_through_var(x) if isinstance(x, ASTSimpleExpression) else True - - equations_block.accept(ASTHigherOrderVisitor(func)) - @classmethod def update_blocktype_for_common_parameters(cls, node): r"""Change the BlockType for all homogeneous parameters to BlockType.COMMON_PARAMETER""" @@ -2313,7 +1965,7 @@ def find_model_by_name(cls, model_name: str, models: Iterable[ASTModel]) -> Opti @classmethod def get_convolve_function_calls(cls, nodes: Union[ASTNode, List[ASTNode]]): """ - Returns all sum function calls in the handed over meta_model node or one of its children. + Returns all convolve function calls in the handed over node. :param nodes: a single or list of AST nodes. """ if isinstance(nodes, ASTNode): @@ -2326,13 +1978,12 @@ def get_convolve_function_calls(cls, nodes: Union[ASTNode, List[ASTNode]]): return function_calls @classmethod - def contains_convolve_function_call(cls, ast: ASTNode) -> bool: + def contains_convolve_function_call(cls, node: ASTNode) -> bool: """ - Indicates whether _ast or one of its child nodes contains a sum call. - :param ast: a single meta_model + Indicates whether the node contains any convolve function call. :return: True if sum is contained, otherwise False. """ - return len(cls.get_function_calls(ast, PredefinedFunctions.CONVOLVE)) > 0 + return len(cls.get_function_calls(node, PredefinedFunctions.CONVOLVE)) > 0 @classmethod def get_function_calls(cls, ast_node: ASTNode, function_list: List[str]) -> List[ASTFunctionCall]: @@ -2457,13 +2108,6 @@ def visit_variable(self, node): for expr in numeric_update_expressions.values(): expr.accept(visitor) - for update_expr_list in neuron.spike_updates.values(): - for update_expr in update_expr_list: - update_expr.accept(visitor) - - for update_expr in neuron.post_spike_updates.values(): - update_expr.accept(visitor) - for node in neuron.equations_with_delay_vars + neuron.equations_with_vector_vars: node.accept(visitor) diff --git a/pynestml/utils/synapse_processing.py b/pynestml/utils/synapse_processing.py index 464abd269..298286dd9 100644 --- a/pynestml/utils/synapse_processing.py +++ b/pynestml/utils/synapse_processing.py @@ -73,7 +73,7 @@ def collect_additional_base_infos(cls, neuron, syns_info): for kernel_var, spikes_var in kernel_arg_pairs: kernel_name = kernel_var.get_name() spikes_name = spikes_var.get_name() - convolution_name = info_collector.construct_kernel_X_spike_buf_name( + convolution_name = info_collector.construct_kernel_spike_buf_name( kernel_name, spikes_name, 0) syns_info[synapse_name]["convolutions"][convolution_name] = { "kernel": { @@ -196,7 +196,7 @@ def transform_ode_and_kernels_to_json( expr = ASTUtils.get_expr_from_kernel_var( kernel, kernel_var.get_complete_name()) kernel_order = kernel_var.get_differential_order() - kernel_X_spike_buf_name_ticks = ASTUtils.construct_kernel_X_spike_buf_name( + kernel_X_spike_buf_name_ticks = ASTUtils.construct_kernel_spike_buf_name( kernel_var.get_name(), spike_input_port.get_name(), kernel_order, diff_order_symbol="'") ASTUtils.replace_rhs_variables(expr, kernel_buffers) @@ -207,7 +207,7 @@ def transform_ode_and_kernels_to_json( # order (e.g. none for kernel function f(t) = ...; 1 for kernel # ODE f'(t) = ...; 2 for f''(t) = ... and so on) for order in range(kernel_order): - iv_sym_name_ode_toolbox = ASTUtils.construct_kernel_X_spike_buf_name( + iv_sym_name_ode_toolbox = ASTUtils.construct_kernel_spike_buf_name( kernel_var.get_name(), spike_input_port, order, diff_order_symbol="'") symbol_name_ = kernel_var.get_name() + "'" * order symbol = equations_block.get_scope().resolve_to_symbol( diff --git a/tests/nest_tests/stdp_nn_restr_symm_test.py b/tests/nest_tests/stdp_nn_restr_symm_test.py index a89ad4c3e..b1d13ce10 100644 --- a/tests/nest_tests/stdp_nn_restr_symm_test.py +++ b/tests/nest_tests/stdp_nn_restr_symm_test.py @@ -44,17 +44,17 @@ class NestSTDPNNRestrSymmSynapseTest(unittest.TestCase): - neuron_model_name = "iaf_psc_exp_neuron_nestml__with_stdp_nn_restr_symm_synapse_nestml" - ref_neuron_model_name = "iaf_psc_exp_neuron_nestml_non_jit" + neuron_model_name = "iaf_psc_alpha_neuron_nestml__with_stdp_nn_restr_symm_synapse_nestml" + ref_neuron_model_name = "iaf_psc_alpha_neuron_nestml_non_jit" - synapse_model_name = "stdp_nn_restr_symm_synapse_nestml__with_iaf_psc_exp_neuron_nestml" + synapse_model_name = "stdp_nn_restr_symm_synapse_nestml__with_iaf_psc_alpha_neuron_nestml" ref_synapse_model_name = "stdp_nn_restr_synapse" def setUp(self): r"""Generate the neuron model code""" # generate the "jit" model (co-generated neuron and synapse), that does not rely on ArchivingNode - files = [os.path.join("models", "neurons", "iaf_psc_exp_neuron.nestml"), + files = [os.path.join("models", "neurons", "iaf_psc_alpha_neuron.nestml"), os.path.join("models", "synapses", "stdp_nn_restr_symm_synapse.nestml")] input_path = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join( os.pardir, os.pardir, s))) for s in files] @@ -65,13 +65,13 @@ def setUp(self): suffix="_nestml", codegen_opts={"neuron_parent_class": "StructuralPlasticityNode", "neuron_parent_class_include": "structural_plasticity_node.h", - "neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron", + "neuron_synapse_pairs": [{"neuron": "iaf_psc_alpha_neuron", "synapse": "stdp_nn_restr_symm_synapse", "post_ports": ["post_spikes"]}]}) # generate the "non-jit" model, that relies on ArchivingNode generate_nest_target(input_path=os.path.realpath(os.path.join(os.path.dirname(__file__), - os.path.join(os.pardir, os.pardir, "models", "neurons", "iaf_psc_exp_neuron.nestml"))), + os.path.join(os.pardir, os.pardir, "models", "neurons", "iaf_psc_alpha_neuron.nestml"))), target_path="/tmp/nestml-non-jit", logging_level="INFO", module_name="nestml_non_jit_module",