Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up NESTML grammar definition #1144

Merged
merged 6 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pynestml/cocos/co_co_user_defined_function_correctly_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def check_co_co(cls, _node=None):
symbol = userDefinedFunction.get_scope().resolve_to_symbol(userDefinedFunction.get_name(),
SymbolKind.FUNCTION)
# first ensure that the block contains at least one statement
if symbol is not None and len(userDefinedFunction.get_block().get_stmts()) > 0:
if symbol is not None and len(userDefinedFunction.get_stmts_body().get_stmts()) > 0:
# now check that the last statement is a return
cls.__check_return_recursively(symbol.get_return_type(),
userDefinedFunction.get_block().get_stmts(), False)
userDefinedFunction.get_stmts_body().get_stmts(), False)
# now if it does not have a statement, but uses a return type, it is an error
elif symbol is not None and userDefinedFunction.has_return_type() and \
not symbol.get_return_type().equals(PredefinedTypes.get_void_type()):
Expand Down Expand Up @@ -135,19 +135,19 @@ def __check_return_recursively(cls, type_symbol=None, stmts=None, ret_defined=Fa
# otherwise it is a compound stmt, thus check recursively
if stmt.is_if_stmt():
cls.__check_return_recursively(type_symbol,
stmt.get_if_stmt().get_if_clause().get_block().get_stmts(),
stmt.get_if_stmt().get_if_clause().get_stmts_body().get_stmts(),
ret_defined)
for else_ifs in stmt.get_if_stmt().get_elif_clauses():
cls.__check_return_recursively(type_symbol, else_ifs.get_block().get_stmts(), ret_defined)
cls.__check_return_recursively(type_symbol, else_ifs.get_stmts_body().get_stmts(), ret_defined)
if stmt.get_if_stmt().has_else_clause():
cls.__check_return_recursively(type_symbol,
stmt.get_if_stmt().get_else_clause().get_block().get_stmts(),
stmt.get_if_stmt().get_else_clause().get_stmts_body().get_stmts(),
ret_defined)
elif stmt.is_while_stmt():
cls.__check_return_recursively(type_symbol, stmt.get_while_stmt().get_block().get_stmts(),
cls.__check_return_recursively(type_symbol, stmt.get_while_stmt().get_stmts_body().get_stmts(),
ret_defined)
elif stmt.is_for_stmt():
cls.__check_return_recursively(type_symbol, stmt.get_for_stmt().get_block().get_stmts(),
cls.__check_return_recursively(type_symbol, stmt.get_for_stmt().get_stmts_body().get_stmts(),
ret_defined)
# now, if a return statement has not been defined in the corresponding higher level block, we have
# to ensure that it is defined here
Expand Down
8 changes: 4 additions & 4 deletions pynestml/codegeneration/printers/model_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pynestml.meta_model.ast_arithmetic_operator import ASTArithmeticOperator
from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.meta_model.ast_bit_operator import ASTBitOperator
from pynestml.meta_model.ast_block import ASTBlock
from pynestml.meta_model.ast_stmts_body import ASTStmtsBody
from pynestml.meta_model.ast_block_with_variables import ASTBlockWithVariables
from pynestml.meta_model.ast_compound_stmt import ASTCompoundStmt
from pynestml.meta_model.ast_data_type import ASTDataType
Expand Down Expand Up @@ -78,7 +78,7 @@ def print_assignment(self, node: ASTAssignment) -> str:
def print_bit_operator(self, node: ASTBitOperator) -> str:
raise Exception("Printer does not support printing this node type")

def print_block(self, node: ASTBlock) -> str:
def print_stmts_body(self, node: ASTStmtsBody) -> str:
raise Exception("Printer does not support printing this node type")

def print_block_with_variables(self, node: ASTBlockWithVariables) -> str:
Expand Down Expand Up @@ -196,8 +196,8 @@ def print(self, node: ASTNode) -> str:
if isinstance(node, ASTBitOperator):
return self.print_bit_operator(node)

if isinstance(node, ASTBlock):
return self.print_block(node)
if isinstance(node, ASTStmtsBody):
return self.print_stmts_body(node)

if isinstance(node, ASTBlockWithVariables):
return self.print_block_with_variables(node)
Expand Down
89 changes: 62 additions & 27 deletions pynestml/codegeneration/printers/nestml_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pynestml.meta_model.ast_arithmetic_operator import ASTArithmeticOperator
from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.meta_model.ast_bit_operator import ASTBitOperator
from pynestml.meta_model.ast_block import ASTBlock
from pynestml.meta_model.ast_stmts_body import ASTStmtsBody
from pynestml.meta_model.ast_block_with_variables import ASTBlockWithVariables
from pynestml.meta_model.ast_comparison_operator import ASTComparisonOperator
from pynestml.meta_model.ast_compound_stmt import ASTCompoundStmt
Expand Down Expand Up @@ -76,10 +76,9 @@ def __init__(self):

def print_model(self, node: ASTModel) -> str:
ret = print_ml_comments(node.pre_comments, self.indent, False)
self.inc_indent()
ret += "model " + node.get_name() + ":" + print_sl_comment(node.in_comment)
ret += "\n" + self.print(node.get_body())
self.dec_indent()

return ret

def print_arithmetic_operator(celf, node: ASTArithmeticOperator) -> str:
Expand Down Expand Up @@ -138,39 +137,44 @@ def print_bit_operator(self, node: ASTBitOperator) -> str:

raise RuntimeError("Unknown bit operator")

def print_block(self, node: ASTBlock) -> str:
def print_stmts_body(self, node: ASTStmtsBody) -> str:
ret = ""
self.inc_indent()
for stmt in node.stmts:
ret += self.print(stmt)

self.dec_indent()

return ret

def print_block_with_variables(self, node: ASTBlockWithVariables) -> str:
temp_indent = self.indent
self.inc_indent()
ret = print_ml_comments(node.pre_comments, temp_indent, False)
ret += print_n_spaces(temp_indent)
ret = print_ml_comments(node.pre_comments, self.indent, False)
ret += print_n_spaces(self.indent)

if node.is_state:
ret += "state"
elif node.is_parameters:
ret += "parameters"
else:
assert node.is_internals
ret += "internals"

ret += ":" + print_sl_comment(node.in_comment) + "\n"

if node.get_declarations() is not None:
self.inc_indent()
for decl in node.get_declarations():
ret += self.print(decl)
self.dec_indent()

self.dec_indent()

return ret

def print_model_body(self, node: ASTModelBody) -> str:
self.inc_indent()
ret = ""
for elem in node.body_elements:
ret += self.print(elem)

self.dec_indent()

return ret

def print_comparison_operator(self, node: ASTComparisonOperator) -> str:
Expand Down Expand Up @@ -257,11 +261,20 @@ def print_declaration(self, node: ASTDeclaration) -> str:
return ret

def print_elif_clause(self, node: ASTElifClause) -> str:
return (print_n_spaces(self.indent) + "elif " + self.print(node.get_condition())
+ ":\n" + self.print(node.get_block()))
ret = print_n_spaces(self.indent) + "elif " + self.print(node.get_condition()) + ":\n"
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def print_else_clause(self, node: ASTElseClause) -> str:
return print_n_spaces(self.indent) + "else:\n" + self.print(node.get_block())
ret = print_n_spaces(self.indent) + "else:\n"
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def print_equations_block(self, node: ASTEquationsBlock) -> str:
temp_indent = self.indent
Expand Down Expand Up @@ -301,20 +314,25 @@ def print_for_stmt(self, node: ASTForStmt) -> str:
ret += ("for " + node.get_variable() + " in " + self.print(node.get_start_from()) + "..."
+ self.print(node.get_end_at()) + " step "
+ str(node.get_step()) + ":" + print_sl_comment(node.in_comment) + "\n")
ret += self.print(node.get_block())
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()
return ret

def print_function(self, node: ASTFunction) -> str:
ret = print_ml_comments(node.pre_comments, self.indent)
ret += "function " + node.get_name() + "("
ret += print_n_spaces(self.indent) + "function " + node.get_name() + "("
if node.has_parameters():
for par in node.get_parameters():
ret += self.print(par)
ret += ")"
if node.has_return_type():
ret += " " + self.print(node.get_return_type())
ret += ":" + print_sl_comment(node.in_comment) + "\n"
ret += self.print(node.get_block()) + "\n"
self.inc_indent()
ret += self.print(node.get_stmts_body()) + "\n"
self.dec_indent()

return ret

def print_function_call(self, node: ASTFunctionCall) -> str:
Expand All @@ -323,14 +341,19 @@ def print_function_call(self, node: ASTFunctionCall) -> str:
ret += self.print(node.get_args()[i])
if i < len(node.get_args()) - 1: # in the case that it is not the last arg, print also a comma
ret += ","

ret += ")"

return ret

def print_if_clause(self, node: ASTIfClause) -> str:
ret = print_ml_comments(node.pre_comments, self.indent)
ret += print_n_spaces(self.indent) + "if " + self.print(node.get_condition()) + ":"
ret += print_sl_comment(node.in_comment) + "\n"
ret += self.print(node.get_block())
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def print_if_stmt(self, node: ASTIfStmt) -> str:
Expand All @@ -341,6 +364,7 @@ def print_if_stmt(self, node: ASTIfStmt) -> str:
if node.get_else_clause() is not None:
ret += self.print(node.get_else_clause())
ret += print_n_spaces(self.indent) + "\n"

return ret

def print_input_block(self, node: ASTInputBlock) -> str:
Expand Down Expand Up @@ -519,31 +543,39 @@ def print_unit_type(self, node: ASTUnitType) -> str:
return self.print(node.base) + "**" + str(node.exponent)

if node.is_arithmetic_expression():
t_lhs = (
self.print(node.get_lhs()) if isinstance(node.get_lhs(), ASTUnitType) else str(node.get_lhs()))
t_lhs = self.print(node.get_lhs()) if isinstance(node.get_lhs(), ASTUnitType) else str(node.get_lhs())
if node.is_times:
return t_lhs + "*" + self.print(node.get_rhs())
else:
return t_lhs + "/" + self.print(node.get_rhs())

return t_lhs + "/" + self.print(node.get_rhs())

return node.unit

def print_on_receive_block(self, node: ASTOnReceiveBlock) -> str:
ret = print_ml_comments(node.pre_comments, self.indent, False)
ret += print_n_spaces(self.indent) + "onReceive(" + node.port_name + "):" + print_sl_comment(node.in_comment) + "\n"
ret += self.print(node.get_block())
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def print_on_condition_block(self, node: ASTOnConditionBlock) -> str:
ret = print_ml_comments(node.pre_comments, self.indent, False)
ret += print_n_spaces(self.indent) + "onCondition(" + self.print(node.get_cond_expr()) + "):" + print_sl_comment(node.in_comment) + "\n"
ret += self.print(node.get_block())
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def print_update_block(self, node: ASTUpdateBlock):
ret = print_ml_comments(node.pre_comments, self.indent, False)
ret += print_n_spaces(self.indent) + "update:" + print_sl_comment(node.in_comment) + "\n"
ret += self.print(node.get_block())
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def print_variable(self, node: ASTVariable):
Expand All @@ -561,7 +593,10 @@ def print_while_stmt(self, node: ASTWhileStmt) -> str:
ret = print_ml_comments(node.pre_comments, self.indent, False)
ret += (print_n_spaces(self.indent) + "while " + self.print(node.get_condition())
+ ":" + print_sl_comment(node.in_comment) + "\n")
ret += self.print(node.get_block())
self.inc_indent()
ret += self.print(node.get_stmts_body())
self.dec_indent()

return ret

def inc_indent(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,8 @@ void {{ neuronName }}::pre_run_hook()
{{ function_declaration.FunctionDeclaration(function, neuronName + "::") }}
{
{%- filter indent(2,True) %}
{%- with ast = function.get_block() %}
{%- include "directives_cpp/Block.jinja2" %}
{%- with ast = function.get_stmts_body() %}
{%- include "directives_cpp/StmtsBody.jinja2" %}
{%- endwith %}
{%- endfilter %}
}
Expand Down Expand Up @@ -782,13 +782,13 @@ void {{ neuronName }}::update(nest::Time const & origin, const long from, const
{% if neuron.get_update_blocks() %}
{%- filter indent(2) %}
{%- for block in neuron.get_update_blocks() %}
{%- set ast = block.get_block() %}
{%- set ast = block.get_stmts_body() %}
{%- if ast.print_comment('*')|length > 1 %}
/*
{{ast.print_comment('*')}}
*/
{%- endif %}
{%- include "directives_cpp/Block.jinja2" %}
{%- include "directives_cpp/StmtsBody.jinja2" %}
{%- endfor %}
{%- endfilter %}
{%- endif %}
Expand Down Expand Up @@ -853,14 +853,14 @@ void {{ neuronName }}::update(nest::Time const & origin, const long from, const
{%- for block in neuron.get_on_condition_blocks() %}
if ({{ printer.print(block.get_cond_expr()) }})
{
{%- set ast = block.get_block() %}
{%- set ast = block.get_stmts_body() %}
{%- if ast.print_comment('*') | length > 1 %}
/*
{{ast.print_comment('*')}}
*/
{%- endif %}
{%- filter indent(6) %}
{%- include "directives_cpp/Block.jinja2" %}
{%- include "directives_cpp/StmtsBody.jinja2" %}
{%- endfilter %}
}
{%- endfor %}
Expand Down Expand Up @@ -1141,14 +1141,14 @@ void {{ neuronName }}::handle(nest::CurrentEvent& e)
// -------------------------------------------------------------------------

{%- for blk in neuron.get_on_receive_blocks() %}
{%- set ast = blk.get_block() %}
{%- set ast = blk.get_stmts_body() %}
void
{{ neuronName }}::on_receive_block_{{ blk.get_port_name() }}()
{
const double __timestep = nest::Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function

{%- filter indent(2, True) -%}
{%- include "directives_cpp/Block.jinja2" %}
{%- include "directives_cpp/StmtsBody.jinja2" %}
{%- endfilter %}
}

Expand Down
Loading
Loading