Skip to content

Commit

Permalink
clean up NESTML grammar definition
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Nov 28, 2024
1 parent 31576af commit c1fb633
Show file tree
Hide file tree
Showing 42 changed files with 1,309 additions and 1,143 deletions.
14 changes: 7 additions & 7 deletions pynestml/cocos/co_co_user_defined_function_correctly_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def check_co_co(cls, _node=None):
symbol = userDefinedFunction.get_scope().resolve_to_symbol(userDefinedFunction.get_name(),
SymbolKind.FUNCTION)
# first ensure that the block contains at least one statement
if symbol is not None and len(userDefinedFunction.get_block().get_stmts()) > 0:
if symbol is not None and len(userDefinedFunction.get_stmts_body().get_stmts()) > 0:
# now check that the last statement is a return
cls.__check_return_recursively(symbol.get_return_type(),
userDefinedFunction.get_block().get_stmts(), False)
userDefinedFunction.get_stmts_body().get_stmts(), False)
# now if it does not have a statement, but uses a return type, it is an error
elif symbol is not None and userDefinedFunction.has_return_type() and \
not symbol.get_return_type().equals(PredefinedTypes.get_void_type()):
Expand Down Expand Up @@ -135,19 +135,19 @@ def __check_return_recursively(cls, type_symbol=None, stmts=None, ret_defined=Fa
# otherwise it is a compound stmt, thus check recursively
if stmt.is_if_stmt():
cls.__check_return_recursively(type_symbol,
stmt.get_if_stmt().get_if_clause().get_block().get_stmts(),
stmt.get_if_stmt().get_if_clause().get_stmts_body().get_stmts(),
ret_defined)
for else_ifs in stmt.get_if_stmt().get_elif_clauses():
cls.__check_return_recursively(type_symbol, else_ifs.get_block().get_stmts(), ret_defined)
cls.__check_return_recursively(type_symbol, else_ifs.get_stmts_body().get_stmts(), ret_defined)
if stmt.get_if_stmt().has_else_clause():
cls.__check_return_recursively(type_symbol,
stmt.get_if_stmt().get_else_clause().get_block().get_stmts(),
stmt.get_if_stmt().get_else_clause().get_stmts_body().get_stmts(),
ret_defined)
elif stmt.is_while_stmt():
cls.__check_return_recursively(type_symbol, stmt.get_while_stmt().get_block().get_stmts(),
cls.__check_return_recursively(type_symbol, stmt.get_while_stmt().get_stmts_body().get_stmts(),
ret_defined)
elif stmt.is_for_stmt():
cls.__check_return_recursively(type_symbol, stmt.get_for_stmt().get_block().get_stmts(),
cls.__check_return_recursively(type_symbol, stmt.get_for_stmt().get_stmts_body().get_stmts(),
ret_defined)
# now, if a return statement has not been defined in the corresponding higher level block, we have
# to ensure that it is defined here
Expand Down
8 changes: 4 additions & 4 deletions pynestml/codegeneration/printers/model_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pynestml.meta_model.ast_arithmetic_operator import ASTArithmeticOperator
from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.meta_model.ast_bit_operator import ASTBitOperator
from pynestml.meta_model.ast_block import ASTBlock
from pynestml.meta_model.ast_stmts_body import ASTStmtsBody
from pynestml.meta_model.ast_block_with_variables import ASTBlockWithVariables
from pynestml.meta_model.ast_compound_stmt import ASTCompoundStmt
from pynestml.meta_model.ast_data_type import ASTDataType
Expand Down Expand Up @@ -78,7 +78,7 @@ def print_assignment(self, node: ASTAssignment) -> str:
def print_bit_operator(self, node: ASTBitOperator) -> str:
raise Exception("Printer does not support printing this node type")

def print_block(self, node: ASTBlock) -> str:
def print_stmts_body(self, node: ASTStmtsBody) -> str:
raise Exception("Printer does not support printing this node type")

def print_block_with_variables(self, node: ASTBlockWithVariables) -> str:
Expand Down Expand Up @@ -196,8 +196,8 @@ def print(self, node: ASTNode) -> str:
if isinstance(node, ASTBitOperator):
return self.print_bit_operator(node)

if isinstance(node, ASTBlock):
return self.print_block(node)
if isinstance(node, ASTStmtsBody):
return self.print_stmts_body(node)

if isinstance(node, ASTBlockWithVariables):
return self.print_block_with_variables(node)
Expand Down
89 changes: 62 additions & 27 deletions pynestml/codegeneration/printers/nestml_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pynestml.meta_model.ast_arithmetic_operator import ASTArithmeticOperator
from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.meta_model.ast_bit_operator import ASTBitOperator
from pynestml.meta_model.ast_block import ASTBlock
from pynestml.meta_model.ast_stmts_body import ASTStmtsBody
from pynestml.meta_model.ast_block_with_variables import ASTBlockWithVariables
from pynestml.meta_model.ast_comparison_operator import ASTComparisonOperator
from pynestml.meta_model.ast_compound_stmt import ASTCompoundStmt
Expand Down Expand Up @@ -76,10 +76,9 @@ def __init__(self):

def print_model(self, node: ASTModel) -> str:
ret = print_ml_comments(node.pre_comments, self.indent, False)
self.inc_indent()
ret += "model " + node.get_name() + ":" + print_sl_comment(node.in_comment)
ret += "\n" + self.print(node.get_body())
self.dec_indent()

return ret

def print_arithmetic_operator(celf, node: ASTArithmeticOperator) -> str:
Expand Down Expand Up @@ -138,39 +137,44 @@ def print_bit_operator(self, node: ASTBitOperator) -> str:

raise RuntimeError("Unknown bit operator")

def print_block(self, node: ASTBlock) -> str:
def print_stmts_body(self, node: ASTStmtsBody) -> str:
ret = ""
self.inc_indent()
for stmt in node.stmts:
ret += self.print(stmt)

self.dec_indent()

return ret

def print_block_with_variables(self, node: ASTBlockWithVariables) -> str:
temp_indent = self.indent
self.inc_indent()
ret = print_ml_comments(node.pre_comments, temp_indent, False)
ret += print_n_spaces(temp_indent)
ret = print_ml_comments(node.pre_comments, self.indent, False)
ret += print_n_spaces(self.indent)

if node.is_state:
ret += "state"
elif node.is_parameters:
ret += "parameters"
else:
assert node.is_internals
ret += "internals"

ret += ":" + print_sl_comment(node.in_comment) + "\n"

if node.get_declarations() is not None:
self.inc_indent()
for decl in node.get_declarations():
ret += self.print(decl)
self.dec_indent()

self.dec_indent()

return ret

def print_model_body(self, node: ASTModelBody) -> str:
self.inc_indent()
ret = ""
for elem in node.body_elements:
ret += self.print(elem)

self.dec_indent()

return ret

def print_comparison_operator(self, node: ASTComparisonOperator) -> str:
Expand Down Expand Up @@ -257,11 +261,20 @@ def print_declaration(self, node: ASTDeclaration) -> str:
return ret

def print_elif_clause(self, node: ASTElifClause) -> str:
return (print_n_spaces(self.indent) + "elif " + self.print(node.get_condition())
+ ":\n" + self.print(node.get_block()))
ret = print_n_spaces(self.indent) + "elif " + self.print(node.get_condition()) + ":\n"
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def print_else_clause(self, node: ASTElseClause) -> str:
return print_n_spaces(self.indent) + "else:\n" + self.print(node.get_block())
ret = print_n_spaces(self.indent) + "else:\n"
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def print_equations_block(self, node: ASTEquationsBlock) -> str:
temp_indent = self.indent
Expand Down Expand Up @@ -301,20 +314,25 @@ def print_for_stmt(self, node: ASTForStmt) -> str:
ret += ("for " + node.get_variable() + " in " + self.print(node.get_start_from()) + "..."
+ self.print(node.get_end_at()) + " step "
+ str(node.get_step()) + ":" + print_sl_comment(node.in_comment) + "\n")
ret += self.print(node.get_block())
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()
return ret

def print_function(self, node: ASTFunction) -> str:
ret = print_ml_comments(node.pre_comments, self.indent)
ret += "function " + node.get_name() + "("
ret += print_n_spaces(self.indent) + "function " + node.get_name() + "("
if node.has_parameters():
for par in node.get_parameters():
ret += self.print(par)
ret += ")"
if node.has_return_type():
ret += " " + self.print(node.get_return_type())
ret += ":" + print_sl_comment(node.in_comment) + "\n"
ret += self.print(node.get_block()) + "\n"
self.inc_indent()
ret += self.print(node.get_stmts_body()) + "\n"
self.dec_indent()

return ret

def print_function_call(self, node: ASTFunctionCall) -> str:
Expand All @@ -323,14 +341,19 @@ def print_function_call(self, node: ASTFunctionCall) -> str:
ret += self.print(node.get_args()[i])
if i < len(node.get_args()) - 1: # in the case that it is not the last arg, print also a comma
ret += ","

ret += ")"

return ret

def print_if_clause(self, node: ASTIfClause) -> str:
ret = print_ml_comments(node.pre_comments, self.indent)
ret += print_n_spaces(self.indent) + "if " + self.print(node.get_condition()) + ":"
ret += print_sl_comment(node.in_comment) + "\n"
ret += self.print(node.get_block())
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def print_if_stmt(self, node: ASTIfStmt) -> str:
Expand All @@ -341,6 +364,7 @@ def print_if_stmt(self, node: ASTIfStmt) -> str:
if node.get_else_clause() is not None:
ret += self.print(node.get_else_clause())
ret += print_n_spaces(self.indent) + "\n"

return ret

def print_input_block(self, node: ASTInputBlock) -> str:
Expand Down Expand Up @@ -519,31 +543,39 @@ def print_unit_type(self, node: ASTUnitType) -> str:
return self.print(node.base) + "**" + str(node.exponent)

if node.is_arithmetic_expression():
t_lhs = (
self.print(node.get_lhs()) if isinstance(node.get_lhs(), ASTUnitType) else str(node.get_lhs()))
t_lhs = self.print(node.get_lhs()) if isinstance(node.get_lhs(), ASTUnitType) else str(node.get_lhs())
if node.is_times:
return t_lhs + "*" + self.print(node.get_rhs())
else:
return t_lhs + "/" + self.print(node.get_rhs())

return t_lhs + "/" + self.print(node.get_rhs())

return node.unit

def print_on_receive_block(self, node: ASTOnReceiveBlock) -> str:
ret = print_ml_comments(node.pre_comments, self.indent, False)
ret += print_n_spaces(self.indent) + "onReceive(" + node.port_name + "):" + print_sl_comment(node.in_comment) + "\n"
ret += self.print(node.get_block())
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def print_on_condition_block(self, node: ASTOnConditionBlock) -> str:
ret = print_ml_comments(node.pre_comments, self.indent, False)
ret += print_n_spaces(self.indent) + "onCondition(" + self.print(node.get_cond_expr()) + "):" + print_sl_comment(node.in_comment) + "\n"
ret += self.print(node.get_block())
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def print_update_block(self, node: ASTUpdateBlock):
ret = print_ml_comments(node.pre_comments, self.indent, False)
ret += print_n_spaces(self.indent) + "update:" + print_sl_comment(node.in_comment) + "\n"
ret += self.print(node.get_block())
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def print_variable(self, node: ASTVariable):
Expand All @@ -561,7 +593,10 @@ def print_while_stmt(self, node: ASTWhileStmt) -> str:
ret = print_ml_comments(node.pre_comments, self.indent, False)
ret += (print_n_spaces(self.indent) + "while " + self.print(node.get_condition())
+ ":" + print_sl_comment(node.in_comment) + "\n")
ret += self.print(node.get_block())
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def inc_indent(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,8 @@ void {{ neuronName }}::pre_run_hook()
{{ function_declaration.FunctionDeclaration(function, neuronName + "::") }}
{
{%- filter indent(2,True) %}
{%- with ast = function.get_block() %}
{%- include "directives_cpp/Block.jinja2" %}
{%- with ast = function.get_stmts_body() %}
{%- include "directives_cpp/StmtsBody.jinja2" %}
{%- endwith %}
{%- endfilter %}
}
Expand Down Expand Up @@ -782,13 +782,13 @@ void {{ neuronName }}::update(nest::Time const & origin, const long from, const
{% if neuron.get_update_blocks() %}
{%- filter indent(2) %}
{%- for block in neuron.get_update_blocks() %}
{%- set ast = block.get_block() %}
{%- set ast = block.get_stmts_body() %}
{%- if ast.print_comment('*')|length > 1 %}
/*
{{ast.print_comment('*')}}
*/
{%- endif %}
{%- include "directives_cpp/Block.jinja2" %}
{%- include "directives_cpp/StmtsBody.jinja2" %}
{%- endfor %}
{%- endfilter %}
{%- endif %}
Expand Down Expand Up @@ -853,14 +853,14 @@ void {{ neuronName }}::update(nest::Time const & origin, const long from, const
{%- for block in neuron.get_on_condition_blocks() %}
if ({{ printer.print(block.get_cond_expr()) }})
{
{%- set ast = block.get_block() %}
{%- set ast = block.get_stmts_body() %}
{%- if ast.print_comment('*') | length > 1 %}
/*
{{ast.print_comment('*')}}
*/
{%- endif %}
{%- filter indent(6) %}
{%- include "directives_cpp/Block.jinja2" %}
{%- include "directives_cpp/StmtsBody.jinja2" %}
{%- endfilter %}
}
{%- endfor %}
Expand Down Expand Up @@ -1141,14 +1141,14 @@ void {{ neuronName }}::handle(nest::CurrentEvent& e)
// -------------------------------------------------------------------------

{%- for blk in neuron.get_on_receive_blocks() %}
{%- set ast = blk.get_block() %}
{%- set ast = blk.get_stmts_body() %}
void
{{ neuronName }}::on_receive_block_{{ blk.get_port_name() }}()
{
const double __timestep = nest::Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function

{%- filter indent(2, True) -%}
{%- include "directives_cpp/Block.jinja2" %}
{%- include "directives_cpp/StmtsBody.jinja2" %}
{%- endfilter %}
}

Expand Down
Loading

0 comments on commit c1fb633

Please sign in to comment.