Skip to content

Commit

Permalink
refactor, cover more cases and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pratyai committed Jan 18, 2025
1 parent 14cd6d0 commit ef76b30
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 12 deletions.
93 changes: 81 additions & 12 deletions dace/frontend/fortran/ast_desugaring.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,12 +1681,7 @@ def _poison_variable(v: Base):
assert fnspec
fnspec = ident_spec(alias_map[fnspec])
kv = _make_keyword_mapping_of_call(fcall, alias_map)
if kv is None:
# Cannot do much with intrinsic functions, so everything is poisoned.
args = args.children if args else tuple()
for a in args:
_poison_variable(a)
continue
assert kv is not None

for k, v in kv.items():
aspec = fnspec + (k,)
Expand Down Expand Up @@ -1743,9 +1738,6 @@ def _make_keyword_mapping_of_call(fcall: Union[Function_Reference, Call_Stmt], a
fnspec = search_real_local_alias_spec(fn, alias_map)
assert fnspec
fnstmt = alias_map[fnspec]
if isinstance(fnstmt.parent.parent, Program):
# Global subroutines cannot be keyworded.
return None
fnargs = atmost_one(children_of_type(fnstmt, Dummy_Arg_List))
fnargs = fnargs.children if fnargs else tuple()
assert len(args) <= len(fnargs), f"Cannot pass more arguments({len(args)}) than defined ({len(fnargs)})"
Expand Down Expand Up @@ -1812,7 +1804,6 @@ def make_practically_constant_arguments_constants(ast: Program, keepers: List[SP
fnargs_undecidables: Set[SPEC] = set()
fnargs_optional_presence: Dict[SPEC, Set[bool]] = {}
for fcall in walk(ast, (Function_Reference, Call_Stmt)):
kv = _make_keyword_mapping_of_call(fcall, alias_map)
fn, args = fcall.children
if isinstance(fn, Intrinsic_Name):
# Cannot do anything with intrinsic functions.
Expand All @@ -1826,6 +1817,8 @@ def make_practically_constant_arguments_constants(ast: Program, keepers: List[SP
continue
fnargs = atmost_one(children_of_type(fnstmt, Dummy_Arg_List))
fnargs = fnargs.children if fnargs else tuple()
kv = _make_keyword_mapping_of_call(fcall, alias_map)
assert kv is not None
assert len(kv) <= len(fnargs), f"Cannot pass more arguments({len(kv)}) than defined ({len(fnargs)})"
for a in fnargs:
aspec = search_real_local_alias_spec(a, alias_map)
Expand Down Expand Up @@ -2691,8 +2684,7 @@ def const_eval_nodes(ast: Program) -> Program:
LITERAL_CLASSES, Expr, Add_Operand, Or_Operand, Mult_Operand, Level_2_Expr, Level_3_Expr, Level_4_Expr,
Level_5_Expr, Intrinsic_Function_Reference)
NON_EXPRESSION_CLASSES = (
Explicit_Shape_Spec, Loop_Control, Call_Stmt, Function_Reference, Initialization, Component_Initialization,
Section_Subscript_List)
Explicit_Shape_Spec, Loop_Control, Initialization, Component_Initialization, Section_Subscript_List)

alias_map = alias_specs(ast)

Expand Down Expand Up @@ -2725,6 +2717,35 @@ def _const_eval_node(n: Base) -> bool:
for nm in reversed(walk(node, Name)):
_const_eval_node(nm)

for fcall in reversed(walk(ast, (Call_Stmt, Function_Reference))):
fn, args = fcall.children
if isinstance(fn, Intrinsic_Name):
# Cannot do anything with intrinsic functions.
continue
kv = _make_keyword_mapping_of_call(fcall, alias_map)
assert kv is not None
fnspec = search_real_local_alias_spec(fn, alias_map)
assert fnspec
fnstmt = alias_map[fnspec]
fnargs = atmost_one(children_of_type(fnstmt, Dummy_Arg_List))
fnargs = fnargs.children if fnargs else tuple()
assert len(kv) <= len(fnargs), f"Cannot pass more arguments({len(kv)}) than defined ({len(fnargs)})"
for a in fnargs:
aspec = search_real_local_alias_spec(a, alias_map)
assert aspec
adecl = alias_map[aspec]
atype = find_type_of_entity(adecl, alias_map)
assert atype
if aspec[-1] not in kv:
continue
v = kv[aspec[-1]]
if atype.out and isinstance(v, (Name, Part_Ref, Data_Ref)):
# TODO: This should not happen in the first place. But after some pruning the problematic calls go away.
# So we ignore it for now, but should be revisted later.
# We're passing a writeable object.
continue
_const_eval_node(v)

return ast


Expand Down Expand Up @@ -2833,6 +2854,49 @@ def inject_const_evals(ast: Program,
inject_consts = inject_consts or []
alias_map = alias_specs(ast)

def _can_inject_in_function_argument(v: Base) -> bool:
"""
Determine whether we are in a function argument, and is it allowed to inject in there.
"""
var = v
while var.parent and not isinstance(var.parent, (Actual_Arg_Spec, Actual_Arg_Spec_List)):
var = var.parent
if not isinstance(var.parent, (Actual_Arg_Spec, Actual_Arg_Spec_List)):
# Not a function argument anyway.
return True
if not isinstance(var, (Name, Data_Ref)):
# Not a writeable object anyway.
return True
fcall = var.parent
while fcall and not isinstance(fcall, (Call_Stmt, Function_Reference, Intrinsic_Function_Reference)):
fcall = fcall.parent
if not fcall:
# Not a function argument anyway.
return True
fn, args = fcall.children
if isinstance(fn, Intrinsic_Name):
# Cannot do anything with intrinsic functions.
return False
fnspec = search_real_local_alias_spec(fn, alias_map)
assert fnspec
fnspec = ident_spec(alias_map[fnspec])
kv = _make_keyword_mapping_of_call(fcall, alias_map)
assert kv is not None

for k, v in kv.items():
if v is not var:
continue
aspec = fnspec + (k,)
assert aspec in alias_map
adecl = alias_map[aspec]
atype = find_type_of_entity(adecl, alias_map)
assert atype

# TODO: This should not happen in the first place. But after some pruning the problematic calls go
# away. So we ignore it for now, but should be revisted later.
return not atype.out
return True

TOPLEVEL_SPEC = ('*',)

items_by_scopes = {}
Expand Down Expand Up @@ -2891,6 +2955,8 @@ def inject_const_evals(ast: Program,
lv, _, _ = dr.parent.children
if lv is dr:
continue
if not _can_inject_in_function_argument(dr):
continue
item = _find_matching_item(items, dr, alias_map)
if not item:
continue
Expand All @@ -2902,6 +2968,9 @@ def inject_const_evals(ast: Program,
# We don't want to replace the values in their declarations or imports, but only where their
# values are being used.
continue
if not _can_inject_in_function_argument(nm):
continue

loc = search_real_local_alias_spec(nm, alias_map)
if not loc or not isinstance(alias_map[loc], Entity_Decl):
continue
Expand Down
24 changes: 24 additions & 0 deletions tests/fortran/ast_desugaring_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,11 @@ def test_config_injection_type():
type(config), intent(inout) :: this
this%b = 5.1
end subroutine fun
subroutine update(x)
implicit none
real, intent(inout) :: x
x = 1.1
end subroutine update
end module lib
""").add_file("""
subroutine main(cfg)
Expand All @@ -1539,6 +1544,7 @@ def test_config_injection_type():
type(big_config), intent(in) :: cfg
real :: a = 1
a = cfg%big%b + a * globalo%a
call update(globalo%b)
end subroutine main
""").check_with_gfortran().get()
ast = parse_and_improve(sources)
Expand Down Expand Up @@ -1566,13 +1572,19 @@ def test_config_injection_type():
TYPE(config), INTENT(INOUT) :: this
this % b = 5.1
END SUBROUTINE fun
SUBROUTINE update(x)
IMPLICIT NONE
REAL, INTENT(INOUT) :: x
x = 1.1
END SUBROUTINE update
END MODULE lib
SUBROUTINE main(cfg)
USE lib
IMPLICIT NONE
TYPE(big_config), INTENT(IN) :: cfg
REAL :: a = 1
a = 10000.0 + a * 42
CALL update(globalo % b)
END SUBROUTINE main
""".strip()
assert got == want
Expand Down Expand Up @@ -1833,13 +1845,19 @@ def test_practically_constant_global_vars_constants():
logical, intent(out) :: what
what = .true.
end subroutine update
subroutine noop(what)
implicit none
logical, intent(in) :: what
end subroutine noop
end module lib
""").add_file("""
subroutine main
use lib
implicit none
real :: a = 1.0
call update(movable_cond)
call noop(fixed_cond)
call noop(what = fixed_cond)
movable_cond = .not. movable_cond
if (fixed_cond .and. movable_cond) a = 7.1
end subroutine main
Expand All @@ -1859,12 +1877,18 @@ def test_practically_constant_global_vars_constants():
LOGICAL, INTENT(OUT) :: what
what = .TRUE.
END SUBROUTINE update
SUBROUTINE noop(what)
IMPLICIT NONE
LOGICAL, INTENT(IN) :: what
END SUBROUTINE noop
END MODULE lib
SUBROUTINE main
USE lib
IMPLICIT NONE
REAL :: a = 1.0
CALL update(movable_cond)
CALL noop(fixed_cond)
CALL noop(what = fixed_cond)
movable_cond = .NOT. movable_cond
IF (fixed_cond .AND. movable_cond) a = 7.1
END SUBROUTINE main
Expand Down

0 comments on commit ef76b30

Please sign in to comment.