Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[f2dace/dev, fortran] Constant propagation in the array subscripts. #1878

Draft
wants to merge 13 commits into
base: f2dace/dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 214 additions & 47 deletions dace/frontend/fortran/ast_desugaring.py

Large diffs are not rendered by default.

158 changes: 50 additions & 108 deletions dace/frontend/fortran/ast_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,133 +768,79 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
line_number=node.line_number)


class ArgumentExtractorNodeLister(NodeVisitor):
"""
Finds all arguments in function calls in the AST node and its children that have to be extracted into independent expressions
"""

def __init__(self):
self.nodes: List[ast_internal_classes.Call_Expr_Node] = []

def visit_For_Stmt_Node(self, node: ast_internal_classes.For_Stmt_Node):
return

def visit_If_Then_Stmt_Node(self, node: ast_internal_classes.If_Stmt_Node):
return

def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
stop = False
# if hasattr(node, "subroutine"):
# if node.subroutine is True:
# stop = True

from dace.frontend.fortran.intrinsics import FortranIntrinsics
if not stop and node.name.name not in [
"malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()
]:
for i in node.args:
if isinstance(i, (ast_internal_classes.Name_Node, ast_internal_classes.Literal,
ast_internal_classes.Array_Subscript_Node, ast_internal_classes.Data_Ref_Node,
ast_internal_classes.Actual_Arg_Spec_Node)):
continue
else:
self.nodes.append(i)
return self.generic_visit(node)

def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node):
return


class ArgumentExtractor(NodeTransformer):
"""
Uses the ArgumentExtractorNodeLister to find all function calls
in the AST node and its children that have to be extracted into independent expressions
It then creates a new temporary variable for each of them and replaces the call with the variable.
"""

def __init__(self, program, count=0):
self.count = count
self.program = program

def __init__(self, program):
self._count = 0
ParentScopeAssigner().visit(program)
self.scope_vars = ScopeVarsDeclarations(program)
self.scope_vars.visit(program)
# For a nesting of execution parts (rare, but in case it happens), after visiting each direct child of it,
# `self.execution_preludes[-1]` will contain all the temporary variable assignments necessary for that node.
self.execution_preludes: List[List[ast_internal_classes.BinOp_Node]] = []

def _get_tempvar_name(self):
tmpname, self._count = f"tmp_arg_{self._count}", self._count + 1
return tmpname

def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
DIRECTLY_REFERNCEABLE = (ast_internal_classes.Name_Node, ast_internal_classes.Literal,
ast_internal_classes.Array_Subscript_Node, ast_internal_classes.Data_Ref_Node)

from dace.frontend.fortran.intrinsics import FortranIntrinsics
if node.name.name in ["malloc", "pow", "cbrt", "__dace_epsilon",
*FortranIntrinsics.call_extraction_exemptions()]:
return self.generic_visit(node)
# if node.subroutine:
# return self.generic_visit(node)
if not hasattr(self, "count"):
self.count = 0
tmp = self.count
result = ast_internal_classes.Call_Expr_Node(type=node.type, subroutine=node.subroutine,
name=node.name, args=[], line_number=node.line_number,
parent=node.parent)
result = ast_internal_classes.Call_Expr_Node(
name=node.name, args=[], line_number=node.line_number,
type=node.type, subroutine=node.subroutine, parent=node.parent)

for i, arg in enumerate(node.args):
# Ensure we allow to extract function calls from arguments
if isinstance(arg, (ast_internal_classes.Name_Node, ast_internal_classes.Literal,
ast_internal_classes.Array_Subscript_Node, ast_internal_classes.Data_Ref_Node,
ast_internal_classes.Actual_Arg_Spec_Node)):
if (isinstance(arg, DIRECTLY_REFERNCEABLE)
or (isinstance(arg, ast_internal_classes.Actual_Arg_Spec_Node)
and isinstance(arg.arg, DIRECTLY_REFERNCEABLE))):
# If it is a node type that's allowed to be directly referenced in a (possibly keyworded) function
# argument, then we keep the node as is.
result.args.append(arg)
continue

# These needs to be extracted, so register a temporary variable.
tmpname = self._get_tempvar_name()
decl = ast_internal_classes.Decl_Stmt_Node(
vardecl=[ast_internal_classes.Var_Decl_Node(name=tmpname, type='VOID', sizes=None, init=None)])
node.parent.specification_part.specifications.append(decl)

if isinstance(arg, ast_internal_classes.Actual_Arg_Spec_Node):
self.generic_visit(arg.arg)
result.args.append(ast_internal_classes.Actual_Arg_Spec_Node(
arg_name=arg.arg_name, arg=ast_internal_classes.Name_Node(name=tmpname, type=arg.arg.type)))
asgn = ast_internal_classes.BinOp_Node(
op="=", lval=ast_internal_classes.Name_Node(name=tmpname, type=arg.arg.type),
rval=arg.arg, line_number=node.line_number, parent=node.parent)
else:
result.args.append(ast_internal_classes.Name_Node(name="tmp_arg_" + str(tmp), type='VOID'))
tmp = tmp + 1
self.count = tmp
self.generic_visit(arg)
result.args.append(ast_internal_classes.Name_Node(name=tmpname, type=arg.type))
asgn = ast_internal_classes.BinOp_Node(
op="=", lval=ast_internal_classes.Name_Node(name=tmpname, type=arg.type),
rval=arg, line_number=node.line_number, parent=node.parent)

self.execution_preludes[-1].append(asgn)
return result

def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node):
newbody = []

for child in node.execution:
lister = ArgumentExtractorNodeLister()
lister.visit(child)
res = lister.nodes
for i in res:
if i == child:
res.pop(res.index(i))

if res is not None:

# Variables are counted from 0...end, starting from main node, to all calls nested
# in main node arguments.
# However, we need to define nested ones first.
# We go in reverse order, counting from end-1 to 0.
temp = self.count + len(res) - 1
for i in reversed(range(0, len(res))):

if isinstance(res[i], ast_internal_classes.Data_Ref_Node):
struct_def, cur_var, _ = self.program.structures.find_definition(self.scope_vars, res[i])

var_type = cur_var.type
else:
var_type = res[i].type

node.parent.specification_part.specifications.append(
ast_internal_classes.Decl_Stmt_Node(vardecl=[
ast_internal_classes.Var_Decl_Node(
name="tmp_arg_" + str(temp),
type='VOID',
sizes=None,
init=None,
)
])
)
newbody.append(
ast_internal_classes.BinOp_Node(op="=",
lval=ast_internal_classes.Name_Node(name="tmp_arg_" +
str(temp),
type=res[i].type),
rval=res[i],
line_number=child.line_number, parent=child.parent))
temp = temp - 1

newbody.append(self.visit(child))

return ast_internal_classes.Execution_Part_Node(execution=newbody)
self.execution_preludes.append([])
for ex in node.execution:
ex = self.visit(ex)
newbody.extend(reversed(self.execution_preludes[-1]))
newbody.append(ex)
self.execution_preludes[-1].clear()
self.execution_preludes.pop()
return ast_internal_classes.Execution_Part_Node(execution = newbody)


class FunctionCallTransformer(NodeTransformer):
Expand Down Expand Up @@ -2816,10 +2762,6 @@ def visit_Data_Ref_Node(self, node: ast_internal_classes.Data_Ref_Node):
return node

def visit_Actual_Arg_Spec_Node(self, node: ast_internal_classes.Actual_Arg_Spec_Node):

if node.type != 'VOID':
return node

node.arg = self.visit(node.arg)

func_arg_name_type = self._get_type(node.arg)
Expand Down
Loading