Skip to content

Commit

Permalink
Merge branch 'vraymond/m12_documentation' of https://github.com/ml4ai…
Browse files Browse the repository at this point in the history
…/skema into vraymond/m12_documentation
  • Loading branch information
vincentraymond-ua committed Mar 26, 2024
2 parents 3d7195d + 2aff436 commit b5e351d
Show file tree
Hide file tree
Showing 17 changed files with 1,025 additions and 57 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies=[
"dill==0.3.7",
"networkx==2.8.8",
"PyYAML",
"tree-sitter",
"tree-sitter==0.20.4",
"neo4j==5.14.1",
"requests",
"beautifulsoup4", # used to remove comments etc from pMML before sending to MORAE
Expand All @@ -43,7 +43,7 @@ dynamic = ["readme"]
# Pygraphviz is often tricky to install, so we reserve it for the dev extras
# list.
# - six: Required by auto-generated Swagger models
dev = ["pytest", "pytest-cov", "pytest-xdist", "pytest-asyncio", "black", "mypy", "coverage", "pygraphviz", "six"]
dev = ["pytest", "pytest-cov", "pytest-xdist", "pytest-asyncio", "pytest-mock", "black", "mypy", "coverage", "pygraphviz", "six"]

demo = ["notebook"]

Expand Down
2 changes: 2 additions & 0 deletions skema/program_analysis/CAST/python/node_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
",",
"=",
"==",
"[",
"]",
"(",
")",
":",
Expand Down
116 changes: 112 additions & 4 deletions skema/program_analysis/CAST/python/ts2cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def visit(self, node: Node):
return self.visit_dict_comprehension(node)
elif node.type == "lambda":
return self.visit_lambda(node)
elif node.type == "subscript":
return self.visit_subscript(node)
elif node.type == "slice":
return self.visit_slice(node)
elif node.type == "pair":
return self.visit_pair(node)
elif node.type == "while_statement":
Expand All @@ -170,6 +174,8 @@ def visit(self, node: Node):
return self.visit_import(node)
elif node.type == "import_from_statement":
return self.visit_import_from(node)
elif node.type == "class_definition":
return self.visit_class_definition(node)
elif node.type == "yield":
return self.visit_yield(node)
elif node.type == "assert_statement":
Expand Down Expand Up @@ -259,7 +265,7 @@ def visit_return(self, node: Node) -> ModelReturn:
ret_val = node.children[1]
ret_cast = self.visit(ret_val)

return ModelReturn(value=ret_cast, source_refs=[ref])
return ModelReturn(value=get_operand_node(ret_cast), source_refs=[ref])

def visit_call(self, node: Node) -> Call:
ref = self.node_helper.get_source_ref(node)
Expand Down Expand Up @@ -432,8 +438,8 @@ def visit_binary_op(self, node: Node) -> Operator:
op = get_op(self.node_helper.get_operator(node.children[1]))
left, _, right = node.children

left_cast = get_name_node(self.visit(left))
right_cast = get_name_node(self.visit(right))
left_cast = get_operand_node(self.visit(left))
right_cast = get_operand_node(self.visit(right))

return Operator(
op=op,
Expand Down Expand Up @@ -592,6 +598,61 @@ def visit_attribute(self, node: Node):

return Attribute(value= get_name_node(obj_cast), attr=get_name_node(attr_cast), source_refs=ref)

def visit_subscript(self, node: Node):
ref = self.node_helper.get_source_ref(node)
values = get_non_control_children(node)
name_cast = self.visit(values[0])
subscript_list = values[1:]
subscript_casts = []
for subscript in subscript_list:
cast = self.visit(subscript)
if isinstance(cast, list):
for elem in cast:
subscript_casts.append(get_func_name_node(cast))
else:
subscript_casts.append(get_func_name_node(cast))

get_func = self.get_gromet_function_node("_get")

get_call = Call(
func = get_func,
arguments = [get_func_name_node(name_cast)] + subscript_casts,
source_refs=ref
)

return get_call

def visit_slice(self, node: Node):
ref = self.node_helper.get_source_ref(node)
indices = get_non_control_children(node)
index_cast = []

for index in indices:
cast = self.visit(index)
if isinstance(cast ,list):
index_cast.extend(cast)
else:
index_cast.append(cast)

start = index_cast[0]
end = index_cast[1]
if len(index_cast) == 3:
step = index_cast[2]
else:
step = CASTLiteralValue(
value_type=ScalarType.INTEGER,
value="1",
source_code_data_type=["Python", "3.8", "Float"],
source_refs=ref,
)

return CASTLiteralValue(value_type=StructureType.LIST,
value=[start,end,step],
source_code_data_type=["Python", "3.8", "List"],
source_refs=ref
)


def handle_for_clause(self, node: Node):
# Given the "for x in seq" clause of a list comprehension
# we translate it to a CAST for loop, leaving the actual
Expand Down Expand Up @@ -973,6 +1034,43 @@ def visit_for(self, node: Node) -> Loop:
source_refs = ref
)

def retrieve_init_func(self, functions: List[FunctionDef]):
# Given a list of CAST function defs, we
# attempt to retrieve the CAST function that corresponds to
# "__init__"
for func in functions:
if func.name.name == "__init__":
return func
return None

def retrieve_class_attrs(self, init_func: FunctionDef):
attrs = []
for stmt in init_func.body:
if isinstance(stmt, Assignment):
if isinstance(stmt.left, Attribute):
if stmt.left.value.name == "self":
attrs.append(stmt.left.attr)

return attrs

def visit_class_definition(self, node):
class_name_node = get_first_child_by_type(node, "identifier")
class_cast = self.visit(class_name_node)

function_defs = get_children_by_types(get_children_by_types(node, "block")[0], "function_definition")
func_defs_cast = []
for func in function_defs:
func_cast = self.visit(func)
if isinstance(func_cast, List):
func_defs_cast.extend(func_cast)
else:
func_defs_cast.append(func_cast)

init_func = self.retrieve_init_func(func_defs_cast)
attributes = self.retrieve_class_attrs(init_func)

return RecordDef(name=get_name_node(class_cast).name, bases=[], funcs=func_defs_cast, fields=attributes)


def visit_name(self, node):
# First, we will check if this name is already defined, and if it is return the name node generated previously
Expand Down Expand Up @@ -1039,7 +1137,7 @@ def get_name_node(node):
if isinstance(cur_node, Var):
return cur_node.val
else:
return node
return cur_node

def get_func_name_node(node):
# Given a CAST node, we attempt to extract the appropriate name element
Expand All @@ -1049,3 +1147,13 @@ def get_func_name_node(node):
return cur_node.val
else:
return cur_node

def get_operand_node(node):
# Given a CAST/list node, we extract the appropriate operand for the operator from it
cur_node = node
if isinstance(node, list):
cur_node = node[0]
if isinstance(cur_node, Var):
return cur_node.val
else:
return cur_node
8 changes: 7 additions & 1 deletion skema/program_analysis/module_locate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
def identify_source_type(source: str):
if not source:
return "Unknown"
if "github" in source:
elif source == "https://github.com/python/cpython":
return "Compiled"
elif "github" in source:
return "Repository"
elif source.startswith("http"):
return "Url"
Expand Down Expand Up @@ -53,6 +55,10 @@ def module_locate(module_name: str) -> str:
:return: The module's file path, GitHub URL, or tarball URL.
"""

# Check if module is compiled into Python
if module_name in sys.builtin_module_names:
return "https://github.com/python/cpython"

# Attempt to find the module in the local environment
try:
module_obj = importlib.import_module(module_name)
Expand Down
6 changes: 3 additions & 3 deletions skema/program_analysis/tests/test_function_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,10 @@ def test_fun4():

func_def_body = func_def_node.body[2]
assert isinstance(func_def_body, ModelReturn)
assert isinstance(func_def_body.value, Var)
assert isinstance(func_def_body.value, Name)

assert func_def_body.value.val.name == "z"
assert func_def_body.value.val.id == 4
assert func_def_body.value.name == "z"
assert func_def_body.value.id == 4

#######################################################
func_asg_node = fun_cast.nodes[0].body[1]
Expand Down
Loading

0 comments on commit b5e351d

Please sign in to comment.