From 156e19311493ab76766c96ae0a056c41bb6f1a3b Mon Sep 17 00:00:00 2001 From: Vincent Raymond Date: Mon, 4 Mar 2024 14:55:30 -0500 Subject: [PATCH] [fortran] Fix for function declarations in derived types (#834) ## Summary of Changes There was a bug discovered in the CTSM source PhotosynthesisMod.F90 where a function is declared in a derived type but defined in the outer module. This PR updates the variable context to support this case. Additionally resolves a bug where function declarations and definitions can have different case sensitivity. ### Related issues Resolves ??? --------- Co-authored-by: titomeister --- .../program_analysis/CAST/fortran/ts2cast.py | 53 ++++++++++--------- .../CAST/fortran/variable_context.py | 35 +++++++++++- 2 files changed, 61 insertions(+), 27 deletions(-) diff --git a/skema/program_analysis/CAST/fortran/ts2cast.py b/skema/program_analysis/CAST/fortran/ts2cast.py index 37629bfe0aa..95b6687ead3 100644 --- a/skema/program_analysis/CAST/fortran/ts2cast.py +++ b/skema/program_analysis/CAST/fortran/ts2cast.py @@ -105,13 +105,7 @@ def run(self, root) -> List[Module]: # TODO: Research the above outer_body_nodes = get_children_by_types(root, ["function", "subroutine"]) if len(outer_body_nodes) > 0: - body = [] - for body_node in outer_body_nodes: - child_cast = self.visit(body_node) - if isinstance(child_cast, List): - body.extend(child_cast) - elif isinstance(child_cast, AstNode): - body.append(child_cast) + body = self.generate_cast_body(outer_body_nodes) modules.append( Module( name=None, @@ -179,14 +173,8 @@ def visit_module(self, node: Node) -> Module: """Visitor for program and module statement. Returns a Module object""" self.variable_context.push_context() - program_body = [] - for child in node.children[1:-1]: # Ignore the start and end program statement - child_cast = self.visit(child) - if isinstance(child_cast, List): - program_body.extend(child_cast) - elif isinstance(child_cast, AstNode): - program_body.append(child_cast) - + program_body = self.generate_cast_body(node.children[1:-1]) + self.variable_context.pop_context() return Module( @@ -231,7 +219,7 @@ def visit_function_def(self, node): # (function_result) - Optional # (identifier) # (body_node) ... - + # Create a new variable context self.variable_context.push_context() @@ -316,6 +304,19 @@ def visit_function_def(self, node): # Pop variable context off of stack before leaving this scope self.variable_context.pop_context() + + # If this is a class function, we need to associate the function def with the class + # We should also return None here so we don't duplicate the function def + if self.variable_context.is_class_function(name.name): + self.variable_context.copy_class_function(name.name, + FunctionDef( + name=name, + func_args=func_args, + body=body, + source_refs=[self.node_helper.get_source_ref(node)], + )) + return None + return FunctionDef( name=name, func_args=func_args, @@ -1020,20 +1021,17 @@ def visit_derived_type(self, node: Node) -> RecordDef: # If we tell the variable context we are in a record definition, it will append the type name as a prefix to all defined variables. self.variable_context.enter_record_definition(record_name) - # Note: + # Note: In derived type declarations, functions are only declared. The actual definition will be in the outer module. funcs = [] - derived_type_procedures_node = get_first_child_by_type( + if derived_type_procedures_node := get_first_child_by_type( node, "derived_type_procedures" - ) - if derived_type_procedures_node: + ): for procedure_node in get_children_by_types( derived_type_procedures_node, ["procedure_statement"] ): - funcs.append( - self.visit_name( - get_first_child_by_type(procedure_node, "method_name") - ) - ) + function_name = self.node_helper.get_identifier(get_first_child_by_type(procedure_node, "method_name", recurse=True)) + funcs.append(self.variable_context.register_module_function(function_name)) + # A derived type can only have variable declarations in its body. fields = [] @@ -1261,12 +1259,14 @@ def get_gromet_function_node(self, func_name: str) -> Name: def generate_cast_body(self, body_nodes: List): body = [] + for node in body_nodes: cast = self.visit(node) + if isinstance(cast, AstNode): body.append(cast) elif isinstance(cast, List): - body.extend(cast) + body.extend([element for element in cast if element is not None]) # Gromet doesn't support empty bodies, so we should create a no_op instead if len(body) == 0: @@ -1274,3 +1274,4 @@ def generate_cast_body(self, body_nodes: List): # TODO: How to add more support for source references return body + diff --git a/skema/program_analysis/CAST/fortran/variable_context.py b/skema/program_analysis/CAST/fortran/variable_context.py index 074f7213cba..8263dbf7e63 100644 --- a/skema/program_analysis/CAST/fortran/variable_context.py +++ b/skema/program_analysis/CAST/fortran/variable_context.py @@ -2,6 +2,7 @@ from skema.program_analysis.CAST2FN.model.cast import ( Var, Name, + FunctionDef ) class VariableContext(object): @@ -27,6 +28,8 @@ def __init__(self): self.stop_condition_id = 0 self.function_name_id = 0 + self.class_functions = {"_class": {"function": FunctionDef()}} + def push_context(self): """Create a new variable context and add it to the stack""" @@ -146,4 +149,34 @@ def set_internal(self): self.internal = True def unset_internal(self): - self.internal = False \ No newline at end of file + self.internal = False + + def register_module_function(self, function: str): + # Fortran variables are case INSENSITIVE so we should lower it first + function = function.lower() + function_def = FunctionDef( + name=Name( + name="", + id=-1, + source_refs=[] + ), + func_args=[], + body=[], + source_refs=[] + ) + self.class_functions[function] = function_def + + return function_def + + + def is_class_function(self, function: str): + function = function.lower() + return function in self.class_functions + + def copy_class_function(self, function: str, function_def: FunctionDef ): + function = function.lower() + self.class_functions[function].name = function_def.name + self.class_functions[function].func_args = function_def.func_args + self.class_functions[function].body = function_def.body + self.class_functions[function].source_refs = function_def.source_refs +