Skip to content

Commit

Permalink
STASH
Browse files Browse the repository at this point in the history
  • Loading branch information
pratyai committed Jan 18, 2025
1 parent ef76b30 commit d04d39b
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 6 deletions.
35 changes: 32 additions & 3 deletions dace/frontend/fortran/fortran_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass
from itertools import chain
from pathlib import Path
from typing import List, Optional, Set, Dict, Tuple, Union
from typing import List, Optional, Set, Dict, Tuple, Union, Any

import networkx as nx
from fparser.common.readfortran import FortranFileReader as ffr, FortranStringReader
Expand Down Expand Up @@ -37,7 +37,7 @@
make_practically_constant_arguments_constants, make_practically_constant_global_vars_constants, \
exploit_locally_constant_variables, assign_globally_unique_variable_names, assign_globally_unique_subprogram_names, \
create_global_initializers, convert_data_statements_into_assignments, make_argument_mapping_explicit
from dace.frontend.fortran.ast_internal_classes import FNode, Main_Program_Node
from dace.frontend.fortran.ast_internal_classes import FNode, Main_Program_Node, Name_Node
from dace.frontend.fortran.ast_utils import children_of_type
from dace.frontend.fortran.intrinsics import IntrinsicSDFGTransformation, NeedsTypeInferenceException
from dace.properties import CodeBlock
Expand Down Expand Up @@ -965,8 +965,12 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node,

# First we need to check if the parameters are literals or variables
for arg_i, variable in enumerate(variables_in_call):
# if 'nlev_var_541' in sdfg.symbols:
# breakpoint()
if isinstance(variable, ast_internal_classes.Actual_Arg_Spec_Node):
keyword, variable = variable.arg_name, variable.arg
if keyword == 'ng_var_339':
breakpoint()

if isinstance(variable, ast_internal_classes.Name_Node):
varname = variable.name
Expand All @@ -986,10 +990,20 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node,
continue

par2.append(parameters[arg_i])
# if varname == 'nlev_var_541':
# breakpoint()
# if 'nlev_var_541' in sdfg.symbols:
# breakpoint()
var2.append(variable)
assert varname not in sdfg.symbols
# assert 'nlev_var_541' not in sdfg.symbols

# This handles the case where the function is called with literals
variables_in_call = var2
# for variable_in_call in variables_in_call:
# if not isinstance(variable_in_call, Name_Node):
# breakpoint()
# assert variable_in_call.name not in sdfg.symbols
parameters = par2
assigns = []
self.local_not_transient_because_assign[my_name_sdfg] = []
Expand All @@ -1004,6 +1018,9 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node,
rval=litval,
op="=",
line_number=node.line_number))
# for variable_in_call in variables_in_call:
# assert variable_in_call.name not in sdfg.symbols
# assert 'nlev_var_541' not in sdfg.symbols
sym_dict = {}
# This handles the case where the function is called with symbols
for parameter, symbol in symbol_arguments:
Expand All @@ -1021,10 +1038,13 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node,
op="=",
line_number=node.line_number))

# assert 'nlev_var_541' not in sdfg.symbols
# This handles the case where the function is called with variables starting with the case that the variable
# is local to the calling SDFG.
# is local to the calling SDFG
needs_replacement = {}
for variable_in_call in variables_in_call:
# assert variable_in_call.name not in sdfg.symbols
# assert 'nlev_var_541' not in sdfg.symbols
local_name = parameters[variables_in_call.index(variable_in_call)]
self.name_mapping[new_sdfg][local_name.name] = new_sdfg._find_new_name(local_name.name)
self.all_array_names.append(self.name_mapping[new_sdfg][local_name.name])
Expand Down Expand Up @@ -1187,6 +1207,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node,
if not self.multiple_sdfgs:
# print("Adding nested sdfg", new_sdfg.name, "to", sdfg.name)
# print(sym_dict)
if 'ng_var_339' in new_sdfg.symbols:
breakpoint()
if node.execution_part is not None:
if node.specification_part is not None and node.specification_part.uses is not None:
for j in node.specification_part.uses:
Expand Down Expand Up @@ -1231,6 +1253,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node,
sym_dict[i] = ast_utils.get_name(var)
memlet_skip.append(ast_utils.get_name(var))

if 'ng_var_339' in new_sdfg.symbols:
breakpoint()
for i in assigns:
self.translate(i, new_sdfg, new_sdfg)
if i.lval.name in new_sdfg.symbols:
Expand Down Expand Up @@ -1431,6 +1455,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node,
# tmp_sdfg=copy.deepcopy(new_sdfg)
new_sdfg.simplify()
new_sdfg.validate()
sdfg.save('/Users/pmz/Downloads/bleh.sdfg')
sdfg.validate()

if self.multiple_sdfgs == True:
Expand Down Expand Up @@ -1529,6 +1554,8 @@ def add_full_object(self, new_sdfg: SDFG,sdfg:SDFG, array: dat.Array, local_name
self.actual_offsets_per_sdfg[sdfg])

else:
if array is None:
breakpoint()
shape= array.shape
offset = array.offset
strides=array.strides
Expand Down Expand Up @@ -1629,6 +1656,8 @@ def process_variable_call(
f"Variable `{var_name}` not found in SDFG {sdfg} or globalSDFG {self.globalsdfg}"
array = self.globalsdfg.arrays[globalsdfg_name]
else:
if sdfg_name not in sdfg.arrays:
breakpoint()
array = sdfg.arrays[sdfg_name]
self.names_of_object_in_parent_sdfg[new_sdfg][local_name.name] = sdfg_name

Expand Down
6 changes: 3 additions & 3 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,9 +750,9 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context
dst = utils.get_global_memlet_path_dst(sdfg, state, edge)
if isinstance(dst, AccessNode):
outputs.add(dst.data)
if len(inputs - outputs) > 0:
raise ValueError(f"Inout connector {conn} is connected to different input ({inputs}) and "
f"output ({outputs}) arrays")
# if len(inputs - outputs) > 0:
# raise ValueError(f"Inout connector {conn} is connected to different input ({inputs}) and "
# f"output ({outputs}) arrays")

# Validate undefined symbols
if self.sdfg:
Expand Down
1 change: 1 addition & 0 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,7 @@ def add_nested_sdfg(
s.symbol_mapping = symbol_mapping
missing_symbols = [s for s in symbols if s not in symbol_mapping]
if missing_symbols:
breakpoint()
raise ValueError('Missing symbols on nested SDFG "%s": %s' % (name, missing_symbols))

# Add new global symbols to nested SDFG
Expand Down
2 changes: 2 additions & 0 deletions dace/transformation/passes/scalar_to_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ def _get_requested_range(self, node: ast.Subscript, memlet_subset: subsets.Subse
j = 0
for i, (start, end, _) in enumerate(memlet_subset.ndrange()):
if start != end:
if i >= min(len(new_tasklet_slice), len(tasklet_slice)):
breakpoint()
new_tasklet_slice[i] = tasklet_slice[j]
j += 1

Expand Down

0 comments on commit d04d39b

Please sign in to comment.