diff --git a/pyproject.toml b/pyproject.toml index 20028ab88d8..4e3ef3ce8ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dependencies=[ "dill==0.3.7", "networkx==2.8.8", "PyYAML", - "tree-sitter", + "tree-sitter==0.20.4", "neo4j==5.14.1", "requests", "beautifulsoup4", # used to remove comments etc from pMML before sending to MORAE @@ -43,7 +43,7 @@ dynamic = ["readme"] # Pygraphviz is often tricky to install, so we reserve it for the dev extras # list. # - six: Required by auto-generated Swagger models -dev = ["pytest", "pytest-cov", "pytest-xdist", "pytest-asyncio", "black", "mypy", "coverage", "pygraphviz", "six"] +dev = ["pytest", "pytest-cov", "pytest-xdist", "pytest-asyncio", "pytest-mock", "black", "mypy", "coverage", "pygraphviz", "six"] demo = ["notebook"] diff --git a/skema/program_analysis/CAST/python/node_helper.py b/skema/program_analysis/CAST/python/node_helper.py index 4f6dcefd3ac..035fc2e0ccb 100644 --- a/skema/program_analysis/CAST/python/node_helper.py +++ b/skema/program_analysis/CAST/python/node_helper.py @@ -8,6 +8,8 @@ ",", "=", "==", + "[", + "]", "(", ")", ":", diff --git a/skema/program_analysis/CAST/python/ts2cast.py b/skema/program_analysis/CAST/python/ts2cast.py index cfe01f34b98..bc65d3b7142 100644 --- a/skema/program_analysis/CAST/python/ts2cast.py +++ b/skema/program_analysis/CAST/python/ts2cast.py @@ -160,6 +160,10 @@ def visit(self, node: Node): return self.visit_dict_comprehension(node) elif node.type == "lambda": return self.visit_lambda(node) + elif node.type == "subscript": + return self.visit_subscript(node) + elif node.type == "slice": + return self.visit_slice(node) elif node.type == "pair": return self.visit_pair(node) elif node.type == "while_statement": @@ -170,6 +174,8 @@ def visit(self, node: Node): return self.visit_import(node) elif node.type == "import_from_statement": return self.visit_import_from(node) + elif node.type == "class_definition": + return self.visit_class_definition(node) elif node.type == "yield": return self.visit_yield(node) elif node.type == "assert_statement": @@ -259,7 +265,7 @@ def visit_return(self, node: Node) -> ModelReturn: ret_val = node.children[1] ret_cast = self.visit(ret_val) - return ModelReturn(value=ret_cast, source_refs=[ref]) + return ModelReturn(value=get_operand_node(ret_cast), source_refs=[ref]) def visit_call(self, node: Node) -> Call: ref = self.node_helper.get_source_ref(node) @@ -432,8 +438,8 @@ def visit_binary_op(self, node: Node) -> Operator: op = get_op(self.node_helper.get_operator(node.children[1])) left, _, right = node.children - left_cast = get_name_node(self.visit(left)) - right_cast = get_name_node(self.visit(right)) + left_cast = get_operand_node(self.visit(left)) + right_cast = get_operand_node(self.visit(right)) return Operator( op=op, @@ -592,6 +598,61 @@ def visit_attribute(self, node: Node): return Attribute(value= get_name_node(obj_cast), attr=get_name_node(attr_cast), source_refs=ref) + def visit_subscript(self, node: Node): + ref = self.node_helper.get_source_ref(node) + values = get_non_control_children(node) + name_cast = self.visit(values[0]) + subscript_list = values[1:] + subscript_casts = [] + for subscript in subscript_list: + cast = self.visit(subscript) + if isinstance(cast, list): + for elem in cast: + subscript_casts.append(get_func_name_node(cast)) + else: + subscript_casts.append(get_func_name_node(cast)) + + get_func = self.get_gromet_function_node("_get") + + get_call = Call( + func = get_func, + arguments = [get_func_name_node(name_cast)] + subscript_casts, + source_refs=ref + ) + + return get_call + + def visit_slice(self, node: Node): + ref = self.node_helper.get_source_ref(node) + indices = get_non_control_children(node) + index_cast = [] + + for index in indices: + cast = self.visit(index) + if isinstance(cast ,list): + index_cast.extend(cast) + else: + index_cast.append(cast) + + start = index_cast[0] + end = index_cast[1] + if len(index_cast) == 3: + step = index_cast[2] + else: + step = CASTLiteralValue( + value_type=ScalarType.INTEGER, + value="1", + source_code_data_type=["Python", "3.8", "Float"], + source_refs=ref, + ) + + return CASTLiteralValue(value_type=StructureType.LIST, + value=[start,end,step], + source_code_data_type=["Python", "3.8", "List"], + source_refs=ref + ) + + def handle_for_clause(self, node: Node): # Given the "for x in seq" clause of a list comprehension # we translate it to a CAST for loop, leaving the actual @@ -973,6 +1034,43 @@ def visit_for(self, node: Node) -> Loop: source_refs = ref ) + def retrieve_init_func(self, functions: List[FunctionDef]): + # Given a list of CAST function defs, we + # attempt to retrieve the CAST function that corresponds to + # "__init__" + for func in functions: + if func.name.name == "__init__": + return func + return None + + def retrieve_class_attrs(self, init_func: FunctionDef): + attrs = [] + for stmt in init_func.body: + if isinstance(stmt, Assignment): + if isinstance(stmt.left, Attribute): + if stmt.left.value.name == "self": + attrs.append(stmt.left.attr) + + return attrs + + def visit_class_definition(self, node): + class_name_node = get_first_child_by_type(node, "identifier") + class_cast = self.visit(class_name_node) + + function_defs = get_children_by_types(get_children_by_types(node, "block")[0], "function_definition") + func_defs_cast = [] + for func in function_defs: + func_cast = self.visit(func) + if isinstance(func_cast, List): + func_defs_cast.extend(func_cast) + else: + func_defs_cast.append(func_cast) + + init_func = self.retrieve_init_func(func_defs_cast) + attributes = self.retrieve_class_attrs(init_func) + + return RecordDef(name=get_name_node(class_cast).name, bases=[], funcs=func_defs_cast, fields=attributes) + def visit_name(self, node): # First, we will check if this name is already defined, and if it is return the name node generated previously @@ -1039,7 +1137,7 @@ def get_name_node(node): if isinstance(cur_node, Var): return cur_node.val else: - return node + return cur_node def get_func_name_node(node): # Given a CAST node, we attempt to extract the appropriate name element @@ -1049,3 +1147,13 @@ def get_func_name_node(node): return cur_node.val else: return cur_node + +def get_operand_node(node): + # Given a CAST/list node, we extract the appropriate operand for the operator from it + cur_node = node + if isinstance(node, list): + cur_node = node[0] + if isinstance(cur_node, Var): + return cur_node.val + else: + return cur_node diff --git a/skema/program_analysis/module_locate.py b/skema/program_analysis/module_locate.py index c9a43e9919d..2a2201a121c 100644 --- a/skema/program_analysis/module_locate.py +++ b/skema/program_analysis/module_locate.py @@ -16,7 +16,9 @@ def identify_source_type(source: str): if not source: return "Unknown" - if "github" in source: + elif source == "https://github.com/python/cpython": + return "Compiled" + elif "github" in source: return "Repository" elif source.startswith("http"): return "Url" @@ -53,6 +55,10 @@ def module_locate(module_name: str) -> str: :return: The module's file path, GitHub URL, or tarball URL. """ + # Check if module is compiled into Python + if module_name in sys.builtin_module_names: + return "https://github.com/python/cpython" + # Attempt to find the module in the local environment try: module_obj = importlib.import_module(module_name) diff --git a/skema/program_analysis/tests/test_function_cast.py b/skema/program_analysis/tests/test_function_cast.py index 1b53f4bd57f..a5a5db78cd2 100644 --- a/skema/program_analysis/tests/test_function_cast.py +++ b/skema/program_analysis/tests/test_function_cast.py @@ -266,10 +266,10 @@ def test_fun4(): func_def_body = func_def_node.body[2] assert isinstance(func_def_body, ModelReturn) - assert isinstance(func_def_body.value, Var) + assert isinstance(func_def_body.value, Name) - assert func_def_body.value.val.name == "z" - assert func_def_body.value.val.id == 4 + assert func_def_body.value.name == "z" + assert func_def_body.value.id == 4 ####################################################### func_asg_node = fun_cast.nodes[0].body[1] diff --git a/skema/program_analysis/tests/test_list_proc_cast.py b/skema/program_analysis/tests/test_list_proc_cast.py new file mode 100644 index 00000000000..e5754158ddb --- /dev/null +++ b/skema/program_analysis/tests/test_list_proc_cast.py @@ -0,0 +1,274 @@ +# import json NOTE: json and Path aren't used right now, +# from pathlib import Path but will be used in the future +from skema.program_analysis.CAST.python.ts2cast import TS2CAST +from skema.program_analysis.CAST2FN.model.cast import ( + Assignment, + FunctionDef, + ModelReturn, + Var, + Call, + Name, + CASTLiteralValue, +) + +def list1(): + return """ +x = [1,2,3] +y = x[0] + """ + +def list2(): + return """ +x = [1,[2,3]] +y = x[1][0] + """ + +def list3(): + return """ +x = [1,2,3,4,5] +y = x[0:3] +z = x[0:3:2] + """ + +def list4(): + return """ +def foo(): + return 2 + +x = [1,2,3,4,5] +y = x[0:foo()] + """ + + +def generate_cast(test_file_string): + # use Python to CAST + out_cast = TS2CAST(test_file_string, from_file=False).out_cast + + return out_cast + +def test_list1(): + cast = generate_cast(list1()) + + asg_node = cast.nodes[0].body[0] + index_node = cast.nodes[0].body[1] + + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert isinstance(asg_node.left.val, Name) + assert asg_node.left.val.name == "x" + assert asg_node.left.val.id == 0 + + assert isinstance(asg_node.right, CASTLiteralValue) + assert asg_node.right.value_type == "List" + + assert isinstance(index_node, Assignment) + assert isinstance(index_node.left, Var) + assert isinstance(index_node.left.val, Name) + assert index_node.left.val.name == "y" + assert index_node.left.val.id == 2 + + index_call = index_node.right + assert isinstance(index_call, Call) + assert isinstance(index_call.func, Name) + assert index_call.func.name == "_get" + assert index_call.func.id == 1 + assert len(index_call.arguments) == 2 + + assert isinstance(index_call.arguments[0], Name), f"is type{type(index_call.arguments[0])}" + assert index_call.arguments[0].name == "x" + assert index_call.arguments[0].id == 0 + + assert isinstance(index_call.arguments[1], CASTLiteralValue) + assert index_call.arguments[1].value_type == "Integer" + assert index_call.arguments[1].value == "0" + + +def test_list2(): + cast = generate_cast(list2()) + + asg_node = cast.nodes[0].body[0] + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert isinstance(asg_node.left.val, Name) + assert asg_node.left.val.name == "x" + assert asg_node.left.val.id == 0 + + assert isinstance(asg_node.right, CASTLiteralValue) + assert asg_node.right.value_type == "List" + second_list = asg_node.right.value[1] + + assert isinstance(second_list, CASTLiteralValue) + assert second_list.value_type == "List" + assert isinstance(second_list.value[0], CASTLiteralValue) + assert second_list.value[0].value == "2" + + assert isinstance(second_list.value[1], CASTLiteralValue) + assert second_list.value[1].value == "3" + + index_node = cast.nodes[0].body[1] + assert isinstance(index_node, Assignment) + assert isinstance(index_node.left, Var) + assert isinstance(index_node.left.val, Name) + assert index_node.left.val.name == "y" + assert index_node.left.val.id == 2 + + index_call = index_node.right + assert isinstance(index_call, Call) + assert isinstance(index_call.func, Name) + assert index_call.func.name == "_get" + assert index_call.func.id == 1 + assert len(index_call.arguments) == 2 + + arg_call = index_call.arguments[0] + assert isinstance(arg_call, Call), f"is type{type(index_call.arguments[0])}" + assert arg_call.func.name == "_get" + assert arg_call.func.id == 1 + + assert len(arg_call.arguments) == 2 + assert isinstance(arg_call.arguments[0], Name) + assert arg_call.arguments[0].name == "x" + assert arg_call.arguments[0].id == 0 + + assert isinstance(arg_call.arguments[1], CASTLiteralValue) + assert arg_call.arguments[1].value == "1" + + assert isinstance(index_call.arguments[1], CASTLiteralValue) + assert index_call.arguments[1].value_type == "Integer" + assert index_call.arguments[1].value == "0" + + +def test_list3(): + cast = generate_cast(list3()) + + asg_node = cast.nodes[0].body[0] + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert isinstance(asg_node.left.val, Name) + assert asg_node.left.val.name == "x" + assert asg_node.left.val.id == 0 + + assert isinstance(asg_node.right, CASTLiteralValue) + assert asg_node.right.value_type == "List" + assert len(asg_node.right.value) == 5 + + index_node = cast.nodes[0].body[1] + assert isinstance(index_node, Assignment) + assert isinstance(index_node.left, Var) + assert isinstance(index_node.left.val, Name) + assert index_node.left.val.name == "y" + assert index_node.left.val.id == 2 + + index_call = index_node.right + assert isinstance(index_call, Call) + assert isinstance(index_call.func, Name) + assert index_call.func.name == "_get" + assert index_call.func.id == 1 + assert len(index_call.arguments) == 2 + + slice1 = index_call.arguments[0] + assert isinstance(slice1, Name) + assert slice1.name == "x" + assert slice1.id == 0 + + slice2 = index_call.arguments[1] + assert isinstance(slice2, CASTLiteralValue) + assert slice2.value_type == "List" + assert len(slice2.value) == 3 + + assert isinstance(slice2.value[0], CASTLiteralValue) + assert slice2.value[0].value == "0" + + assert isinstance(slice2.value[1], CASTLiteralValue) + assert slice2.value[1].value == "3" + + assert isinstance(slice2.value[2], CASTLiteralValue) + assert slice2.value[2].value == "1" + + second_idx = cast.nodes[0].body[2] + assert isinstance(second_idx, Assignment) + assert isinstance(second_idx.left, Var) + assert second_idx.left.val.name == "z" + assert second_idx.left.val.id == 3 + + second_call = second_idx.right + assert isinstance(second_call, Call) + assert isinstance(second_call.func, Name) + assert second_call.func.name == "_get" + assert second_call.func.id == 1 + + second_args = second_call.arguments + assert len(second_args) == 2 + assert isinstance(second_args[0], Name) + assert second_args[0].name == "x" + assert second_args[0].id == 0 + + idx_args = second_args[1] + assert isinstance(idx_args, CASTLiteralValue) + assert idx_args.value_type == "List" + assert len(idx_args.value) == 3 + + assert isinstance(idx_args.value[0], CASTLiteralValue) + assert idx_args.value[0].value == "0" + + assert isinstance(idx_args.value[1], CASTLiteralValue) + assert idx_args.value[1].value == "3" + + assert isinstance(idx_args.value[2], CASTLiteralValue) + assert idx_args.value[2].value == "2" + + +def test_list4(): + cast = generate_cast(list4()) + + func_def_node = cast.nodes[0].body[0] + assert isinstance(func_def_node, FunctionDef) + assert func_def_node.name.name == "foo" + assert func_def_node.name.id == 0 + assert isinstance(func_def_node.body[0], ModelReturn) + assert isinstance(func_def_node.body[0].value, CASTLiteralValue) + assert func_def_node.body[0].value.value == "2" + + asg_node = cast.nodes[0].body[1] + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert isinstance(asg_node.left.val, Name) + assert asg_node.left.val.name == "x" + assert asg_node.left.val.id == 1 + + assert isinstance(asg_node.right, CASTLiteralValue) + assert asg_node.right.value_type == "List" + assert len(asg_node.right.value) == 5 + + index_node = cast.nodes[0].body[2] + assert isinstance(index_node, Assignment) + assert isinstance(index_node.left, Var) + assert isinstance(index_node.left.val, Name) + assert index_node.left.val.name == "y" + assert index_node.left.val.id == 3 + + index_call = index_node.right + assert isinstance(index_call, Call) + assert isinstance(index_call.func, Name) + assert index_call.func.name == "_get" + assert index_call.func.id == 2 + assert len(index_call.arguments) == 2 + + slice1 = index_call.arguments[0] + assert isinstance(slice1, Name) + assert slice1.name == "x" + assert slice1.id == 1 + + slice2 = index_call.arguments[1] + assert isinstance(slice2, CASTLiteralValue) + assert slice2.value_type == "List" + assert len(slice2.value) == 3 + + assert isinstance(slice2.value[0], CASTLiteralValue) + assert slice2.value[0].value == "0" + + assert isinstance(slice2.value[1], Call) + assert slice2.value[1].func.name == "foo" + assert slice2.value[1].func.id == 0 + + assert isinstance(slice2.value[2], CASTLiteralValue) + assert slice2.value[2].value == "1" diff --git a/skema/program_analysis/tests/test_module_locate.py b/skema/program_analysis/tests/test_module_locate.py new file mode 100644 index 00000000000..c95b4034ecc --- /dev/null +++ b/skema/program_analysis/tests/test_module_locate.py @@ -0,0 +1,41 @@ +import pytest +from unittest.mock import patch +from skema.program_analysis.module_locate import identify_source_type, module_locate + +# Testing identify_source_type +@pytest.mark.parametrize("source,expected_type", [ + ("https://github.com/python/cpython", "Compiled"), + ("https://github.com/other/repository", "Repository"), + ("http://example.com", "Url"), + ("local/path/to/module", "Local"), + ("", "Unknown"), +]) +def test_identify_source_type(source, expected_type): + assert identify_source_type(source) == expected_type + +# Mocking requests.get to test module_locate without actual HTTP requests +@pytest.fixture +def mock_requests_get(mocker): + mock = mocker.patch('skema.program_analysis.module_locate.requests.get') + return mock + +def test_module_locate_builtin_module(): + assert module_locate("sys") == "https://github.com/python/cpython" + +def test_module_locate_from_pypi_with_github_source(mock_requests_get): + mock_requests_get.return_value.json.return_value = { + 'info': {'version': '1.0.0', 'project_urls': {'Source': 'https://github.com/example/project'}}, + 'releases': {'1.0.0': [{'filename': 'example-1.0.0.tar.gz', 'url': 'https://example.com/example-1.0.0.tar.gz'}]} + } + assert module_locate("example") == "https://github.com/example/project" + +def test_module_locate_from_pypi_with_tarball_url(mock_requests_get): + mock_requests_get.return_value.json.return_value = { + 'info': {'version': '1.2.3'}, + 'releases': {'1.2.3': [{'filename': 'package-1.2.3.tar.gz', 'url': 'https://pypi.org/package-1.2.3.tar.gz'}]} + } + assert module_locate("package") == "https://pypi.org/package-1.2.3.tar.gz" + +def test_module_locate_not_found(mock_requests_get): + mock_requests_get.side_effect = Exception("Module not found") + assert module_locate("nonexistent") is None diff --git a/skema/program_analysis/tests/test_record_cast.py b/skema/program_analysis/tests/test_record_cast.py new file mode 100644 index 00000000000..f8680910658 --- /dev/null +++ b/skema/program_analysis/tests/test_record_cast.py @@ -0,0 +1,481 @@ +# import json NOTE: json and Path aren't used right now, +# from pathlib import Path but will be used in the future +from skema.program_analysis.CAST.python.ts2cast import TS2CAST +from skema.program_analysis.CAST2FN.model.cast import ( + Assignment, + Attribute, + Var, + Name, + FunctionDef, + Call, + CASTLiteralValue, + RecordDef, + Operator, + ModelReturn +) + +def class1(): + return """ +class MyClass: + def __init__(self, a: int, b: int): + self.a = a + self.b = b + self.c = a + b + + def get_c(self): + return self.c + +mc = MyClass(2, 3) +x = mc.get_c() + """ + +def class2(): + return """ +class MyClass1: + def __init__(self, a: int): + self.a = a + + def get_a(self): + return self.a + +class MyClass2(MyClass1): + def __init__(self, b: int): + self.b = b + super().__init__(b + 1) + + def get_b(self): + return self.b + + +mc = MyClass2(2) +x = mc.get_a() +y = mc.get_b() + + """ + +def class3(): + return """ +class Foo: + def __init__(self, a): + self.a = a + + def add1(self): + self.a = self.a + 1 + + def sub1(self): + self.a = self.a - 1 + +foo = Foo() + +foo.add1().add1().sub1() + """ + +def generate_cast(test_file_string): + # use Python to CAST + out_cast = TS2CAST(test_file_string, from_file=False).out_cast + + return out_cast + +def test_class1(): + fun_cast = generate_cast(class1()) + record_def_node = fun_cast.nodes[0].body[0] + assert isinstance(record_def_node, RecordDef) + assert record_def_node.name == "MyClass" + assert len(record_def_node.bases) == 0 + assert len(record_def_node.fields) == 3 + fields = record_def_node.fields + assert isinstance(fields[0], Name) + assert fields[0].name == "a" + assert fields[0].id == 3 + + assert isinstance(fields[1], Name) + assert fields[1].name == "b" + assert fields[1].id == 4 + + assert isinstance(fields[2], Name) + assert fields[2].name == "c" + assert fields[2].id == 5 + + assert len(record_def_node.funcs) == 2 + init_func = record_def_node.funcs[0] + assert isinstance(init_func, FunctionDef) + assert isinstance(init_func.name, Name) + assert init_func.name.name == "__init__" + assert init_func.name.id == 1 + + func_args = init_func.func_args + assert len(func_args) == 3 + assert isinstance(func_args[0], Var) + assert func_args[0].val.name == "self" + assert func_args[0].val.id == 2 + + assert isinstance(func_args[1], Var) + assert func_args[1].val.name == "a" + assert func_args[1].val.id == 3 + + assert isinstance(func_args[2], Var) + assert func_args[2].val.name == "b" + assert func_args[2].val.id == 4 + + func_body = init_func.body + assert len(func_body) == 3 + assert isinstance(func_body[0], Assignment) + assert isinstance(func_body[0].left, Attribute) + assert func_body[0].left.value.name == "self" + assert func_body[0].left.attr.name == "a" + + assert isinstance(func_body[0].right, Name) + assert func_body[0].right.name == "a" + + assert isinstance(func_body[1], Assignment) + assert isinstance(func_body[1].left, Attribute) + assert func_body[1].left.value.name == "self" + assert func_body[1].left.attr.name == "b" + + assert isinstance(func_body[1].right, Name) + assert func_body[1].right.name == "b" + + assert isinstance(func_body[2], Assignment) + assert isinstance(func_body[2].left, Attribute) + assert func_body[2].left.value.name == "self" + assert func_body[2].left.attr.name == "c" + + assert isinstance(func_body[2].right, Operator) + assert len(func_body[2].right.operands) == 2 + assert func_body[2].right.op == "ast.Add" + assert func_body[2].right.operands[0].name == "a" + assert func_body[2].right.operands[1].name == "b" + + get_func = record_def_node.funcs[1] + assert isinstance(get_func, FunctionDef) + assert get_func.name.name == "get_c" + func_args = get_func.func_args + + assert len(func_args) == 1 + assert func_args[0].val.name == "self" + + func_body = get_func.body + assert len(func_body) == 1 + assert isinstance(func_body[0], ModelReturn) + assert isinstance(func_body[0].value, Attribute) + assert func_body[0].value.value.name == "self" + assert func_body[0].value.attr.name == "c" + + ####################################################### + asg_node = fun_cast.nodes[0].body[1] + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert asg_node.left.val.name == "mc" + assert asg_node.left.val.id == 9 + + func_call_node = asg_node.right + assert isinstance(func_call_node, Call) + assert func_call_node.func.name == "MyClass" + + assert len(func_call_node.arguments) == 2 + assert isinstance(func_call_node.arguments[0], CASTLiteralValue) + assert func_call_node.arguments[0].value_type == "Integer" + assert func_call_node.arguments[0].value == "2" + + assert isinstance(func_call_node.arguments[1], CASTLiteralValue) + assert func_call_node.arguments[1].value_type == "Integer" + assert func_call_node.arguments[1].value == "3" + + ####################################################### + asg_node = fun_cast.nodes[0].body[2] + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert asg_node.left.val.name == "x" + assert asg_node.left.val.id == 10 + + assert isinstance(asg_node.right, Call) + assert isinstance(asg_node.right.func, Attribute) + assert asg_node.right.func.value.name == "mc" + assert asg_node.right.func.attr.name == "get_c" + +def test_class2(): + fun_cast = generate_cast(class2()) + record_def_node = fun_cast.nodes[0].body[0] + assert isinstance(record_def_node, RecordDef) + assert record_def_node.name == "MyClass1" + assert len(record_def_node.bases) == 0 + assert len(record_def_node.fields) == 1 + fields = record_def_node.fields + assert isinstance(fields[0], Name) + assert fields[0].name == "a" + assert fields[0].id == 3 + + assert len(record_def_node.funcs) == 2 + init_func = record_def_node.funcs[0] + assert isinstance(init_func, FunctionDef) + assert isinstance(init_func.name, Name) + assert init_func.name.name == "__init__" + assert init_func.name.id == 1 + + func_args = init_func.func_args + assert len(func_args) == 2 + assert isinstance(func_args[0], Var) + assert func_args[0].val.name == "self" + assert func_args[0].val.id == 2 + + assert isinstance(func_args[1], Var) + assert func_args[1].val.name == "a" + assert func_args[1].val.id == 3 + + func_body = init_func.body + assert len(func_body) == 1 + assert isinstance(func_body[0], Assignment) + assert isinstance(func_body[0].left, Attribute) + assert func_body[0].left.value.name == "self" + assert func_body[0].left.attr.name == "a" + + assert isinstance(func_body[0].right, Name) + assert func_body[0].right.name == "a" + + get_func = record_def_node.funcs[1] + assert isinstance(get_func, FunctionDef) + assert get_func.name.name == "get_a" + func_args = get_func.func_args + + assert len(func_args) == 1 + assert func_args[0].val.name == "self" + + func_body = get_func.body + assert len(func_body) == 1 + assert isinstance(func_body[0], ModelReturn) + assert isinstance(func_body[0].value, Attribute) + assert func_body[0].value.value.name == "self" + assert func_body[0].value.attr.name == "a" + + ####################################################### + record_def_node = fun_cast.nodes[0].body[1] + assert isinstance(record_def_node, RecordDef) + assert record_def_node.name == "MyClass2" + assert len(record_def_node.bases) == 0 + assert len(record_def_node.fields) == 1 + fields = record_def_node.fields + assert isinstance(fields[0], Name) + assert fields[0].name == "b" + assert fields[0].id == 9 + + assert len(record_def_node.funcs) == 2 + init_func = record_def_node.funcs[0] + assert isinstance(init_func, FunctionDef) + assert isinstance(init_func.name, Name) + assert init_func.name.name == "__init__" + assert init_func.name.id == 1 + + func_args = init_func.func_args + assert len(func_args) == 2 + assert isinstance(func_args[0], Var) + assert func_args[0].val.name == "self" + assert func_args[0].val.id == 8 + + assert isinstance(func_args[1], Var) + assert func_args[1].val.name == "b" + assert func_args[1].val.id == 9 + + func_body = init_func.body + assert len(func_body) == 2 + assert isinstance(func_body[0], Assignment) + assert isinstance(func_body[0].left, Attribute) + assert func_body[0].left.value.name == "self" + assert func_body[0].left.attr.name == "b" + + assert isinstance(func_body[0].right, Name) + assert func_body[0].right.name == "b" + + assert isinstance(func_body[1], Call) + assert isinstance(func_body[1].func, Attribute) + assert isinstance(func_body[1].func.value, Call) + assert func_body[1].func.value.func.name == "super" + assert func_body[1].func.value.func.id == 10 + assert func_body[1].func.attr.name == "__init__" + + func_args = func_body[1].arguments + assert isinstance(func_args[0], Operator) + assert func_args[0].op == "ast.Add" + + assert len(func_args[0].operands) == 2 + assert func_args[0].operands[0].name == "b" + + assert isinstance(func_args[0].operands[1], CASTLiteralValue) + assert func_args[0].operands[1].value == "1" + + get_func = record_def_node.funcs[1] + assert isinstance(get_func, FunctionDef) + assert get_func.name.name == "get_b" + func_args = get_func.func_args + + assert len(func_args) == 1 + assert func_args[0].val.name == "self" + + func_body = get_func.body + assert len(func_body) == 1 + assert isinstance(func_body[0], ModelReturn) + assert isinstance(func_body[0].value, Attribute) + assert func_body[0].value.value.name == "self" + assert func_body[0].value.attr.name == "b" + + # ####################################################### + asg_node = fun_cast.nodes[0].body[2] + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert asg_node.left.val.name == "mc" + assert asg_node.left.val.id == 14 + + assert isinstance(asg_node.right, Call) + init_func = asg_node.right + assert init_func.func.name == "MyClass2" + assert len(init_func.arguments) == 1 + assert isinstance(init_func.arguments[0], CASTLiteralValue) + + asg_node = fun_cast.nodes[0].body[3] + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert asg_node.left.val.name == "x" + + assert isinstance(asg_node.right, Call) + assert isinstance(asg_node.right.func, Attribute) + assert asg_node.right.func.attr.name == "get_a" + assert asg_node.right.func.attr.id == 4 + assert asg_node.right.func.value.name == "mc" + assert asg_node.right.func.value.id == 14 + + + asg_node = fun_cast.nodes[0].body[4] + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert asg_node.left.val.name == "y" + assert asg_node.left.val.id == 16 + + assert isinstance(asg_node.right, Call) + assert isinstance(asg_node.right.func, Attribute) + assert asg_node.right.func.attr.name == "get_b" + assert asg_node.right.func.attr.id == 11 + assert asg_node.right.func.value.name == "mc" + assert asg_node.right.func.value.id == 14 + +def test_class3(): + class_cast = generate_cast(class3()) + record_def_node = class_cast.nodes[0].body[0] + assert isinstance(record_def_node, RecordDef) + assert record_def_node.name == "Foo" + assert len(record_def_node.bases) == 0 + assert len(record_def_node.fields) == 1 + fields = record_def_node.fields + assert isinstance(fields[0], Name) + assert fields[0].name == "a" + assert fields[0].id == 3 + + assert len(record_def_node.funcs) == 3 + init_func = record_def_node.funcs[0] + assert isinstance(init_func, FunctionDef) + assert isinstance(init_func.name, Name) + assert init_func.name.name == "__init__" + assert init_func.name.id == 1 + + func_args = init_func.func_args + assert len(func_args) == 2 + assert isinstance(func_args[0], Var) + assert func_args[0].val.name == "self" + assert func_args[0].val.id == 2 + + assert isinstance(func_args[1], Var) + assert func_args[1].val.name == "a" + assert func_args[1].val.id == 3 + + func_body = init_func.body + assert len(func_body) == 1 + assert isinstance(func_body[0], Assignment) + assert isinstance(func_body[0].left, Attribute) + assert func_body[0].left.value.name == "self" + assert func_body[0].left.attr.name == "a" + + assert isinstance(func_body[0].right, Name) + assert func_body[0].right.name == "a" + + add_func = record_def_node.funcs[1] + assert isinstance(add_func, FunctionDef) + assert add_func.name.name == "add1" + assert add_func.name.id == 4 + func_args = add_func.func_args + + assert len(func_args) == 1 + assert func_args[0].val.name == "self" + + func_body = add_func.body + asg_stmt = func_body[0] + assert isinstance(asg_stmt, Assignment) + asg_left = asg_stmt.left + assert isinstance(asg_left, Attribute) + assert isinstance(asg_left.attr, Name) + assert asg_left.attr.name == "a" + assert isinstance(asg_left.value, Name) + assert asg_left.value.name == "self" + + asg_right = asg_stmt.right + assert isinstance(asg_right, Operator) + assert asg_right.op == "ast.Add" + assert isinstance(asg_right.operands[0], Attribute) + assert isinstance(asg_right.operands[0].attr, Name) + assert asg_right.operands[0].attr.name == "a" + assert isinstance(asg_right.operands[0].value, Name) + assert asg_right.operands[0].value.name == "self" + + assert isinstance(asg_right.operands[1], CASTLiteralValue) + + sub_func = record_def_node.funcs[2] + assert isinstance(sub_func, FunctionDef) + assert sub_func.name.name == "sub1" + assert sub_func.name.id == 7 + func_args = sub_func.func_args + + assert len(func_args) == 1 + assert func_args[0].val.name == "self" + + func_body = sub_func.body + asg_stmt = func_body[0] + assert isinstance(asg_stmt, Assignment) + asg_left = asg_stmt.left + assert isinstance(asg_left, Attribute) + assert isinstance(asg_left.attr, Name) + assert asg_left.attr.name == "a" + assert isinstance(asg_left.value, Name) + assert asg_left.value.name == "self" + + asg_right = asg_stmt.right + assert isinstance(asg_right, Operator) + assert asg_right.op == "ast.Sub" + assert isinstance(asg_right.operands[0], Attribute) + assert isinstance(asg_right.operands[0].attr, Name) + assert asg_right.operands[0].attr.name == "a" + assert isinstance(asg_right.operands[0].value, Name) + assert asg_right.operands[0].value.name == "self" + + assert isinstance(asg_right.operands[1], CASTLiteralValue) + + assignment_node = class_cast.nodes[0].body[1] + assert isinstance(assignment_node, Assignment) + assert isinstance(assignment_node.left, Var) + assert isinstance(assignment_node.left.val, Name) + assert assignment_node.left.val.name == "foo" + assert assignment_node.left.val.id == 10 + + assert isinstance(assignment_node.right, Call) + assert assignment_node.right.func.name == "Foo" + + call_node = class_cast.nodes[0].body[2] + assert isinstance(call_node, Call) + assert isinstance(call_node.func, Attribute) + assert isinstance(call_node.func.value, Call) + assert call_node.func.attr.name == "sub1" + assert call_node.func.attr.id == 7 + assert isinstance(call_node.func.value.func, Attribute) + assert isinstance(call_node.func.value.func.value, Call) + assert call_node.func.value.func.attr.name == "add1" + assert call_node.func.value.func.attr.id == 4 + assert isinstance(call_node.func.value.func.value.func, Attribute) + assert call_node.func.value.func.value.func.value.name == "foo" + assert call_node.func.value.func.value.func.attr.name == "add1" + assert call_node.func.value.func.value.func.attr.id == 4 diff --git a/skema/rest/integrated_text_reading_proxy.py b/skema/rest/integrated_text_reading_proxy.py index 78b1076d7da..8ed11de8d15 100644 --- a/skema/rest/integrated_text_reading_proxy.py +++ b/skema/rest/integrated_text_reading_proxy.py @@ -408,7 +408,7 @@ def integrated_extractions( ) async def integrated_text_extractions( response: Response, - texts: TextReadingInputDocuments, + inputs: TextReadingInputDocuments, annotate_skema: bool = True, annotate_mit: bool = True, ) -> TextReadingAnnotationsOutput: @@ -428,10 +428,12 @@ async def integrated_text_extractions( ``` """ # Get the input plain texts - texts = texts.texts + texts = inputs.texts + + amrs = inputs.amrs # Run the text extractors - return integrated_extractions( + extractions = integrated_extractions( response, annotate_text_with_skema, texts, @@ -440,6 +442,31 @@ async def integrated_text_extractions( annotate_mit ) + # Do the alignment + aligned_amrs = list() + if len(amrs) > 0: + # Build an UploadFile instance from the extractions + json_extractions = extractions.model_dump_json() + extractions_ufile = UploadFile(file=io.BytesIO(json_extractions.encode('utf-8'))) + for amr in amrs: + # amr = json.loads(amr) + amr_ufile = UploadFile(file=io.BytesIO(amr.encode('utf-8'))) + try: + aligned_amr = metal_proxy.link_amr( + amr_file=amr_ufile, + text_extractions_file=extractions_ufile) + aligned_amrs.append(aligned_amr) + except Exception as e: + error = TextReadingError(pipeline="AMR Linker", message=f"Error annotating {amr.filename}: {e}") + if extractions.generalized_errors is None: + extractions.generalized_errors = [error] + else: + extractions.generalized_errors.append(error) + + extractions.aligned_amrs = aligned_amrs + + return extractions + @router.post( "/integrated-pdf-extractions", diff --git a/skema/rest/schema.py b/skema/rest/schema.py index 56ffdbcf8ab..71e9b92d749 100644 --- a/skema/rest/schema.py +++ b/skema/rest/schema.py @@ -5,6 +5,7 @@ from typing import List, Optional, Dict, Any from askem_extractions.data_model import AttributeCollection +from fastapi import UploadFile from pydantic import BaseModel, Field # see https://github.com/pydantic/pydantic/issues/5821#issuecomment-1559196859 @@ -168,6 +169,10 @@ class TextReadingInputDocuments(BaseModel): description="List of input plain texts to be annotated by the text reading pipelines", examples=[["x = 0", "y = 1", "I: Infected population"]], ) + amrs: List[str] = Field( + description="List of optional AMR files to align with the extractions", + examples=[[]] + ) class TextReadingError(BaseModel): diff --git a/skema/rest/tests/test_integrated_text_reading_proxy.py b/skema/rest/tests/test_integrated_text_reading_proxy.py index d4e726e2440..900954beabc 100644 --- a/skema/rest/tests/test_integrated_text_reading_proxy.py +++ b/skema/rest/tests/test_integrated_text_reading_proxy.py @@ -23,7 +23,8 @@ def test_text_integrated_extractions(): "x = 0", "y = 1", "I: Infected population" - ] + ], + "amrs": [] } response = client.post(f"/integrated-text-extractions", params=params, json=payload) diff --git a/skema/rest/workflows.py b/skema/rest/workflows.py index 0f03c7242a5..62d0ae36b65 100644 --- a/skema/rest/workflows.py +++ b/skema/rest/workflows.py @@ -87,7 +87,7 @@ async def equation_to_amrs(data: schema.EquationsToAMRs, client: httpx.AsyncClie res_new = await client.put(f"{SKEMA_RS_ADDESS}/mathml/g-amr", json=eqns) if res_new.status_code != 200: return JSONResponse( - status_code=402, + status_code=422, content={ "error": f"Attempted creation of {data.model} AMR, which failed. Then tried creation of Generalized AMR, which also failed with the following error {res_new.text}. Please check equations, seen as pMathML below.", "payload": eqns, @@ -126,7 +126,7 @@ async def equation_to_amrs(data: schema.EquationsToAMRs, client: httpx.AsyncClie ) else: return JSONResponse( - status_code=401, + status_code=422, content={ "error": f"{data.model} is not a supported model type", "payload": eqns, diff --git a/skema/skema-rs/mathml/src/acset.rs b/skema/skema-rs/mathml/src/acset.rs index 3fa1435bbd1..a6fac061029 100644 --- a/skema/skema-rs/mathml/src/acset.rs +++ b/skema/skema-rs/mathml/src/acset.rs @@ -202,7 +202,7 @@ pub struct Initial { } impl Default for Initial { - fn default() -> Self { Initial { target: "temp".to_string(), expression: "0".to_string(), expression_mathml: "<\\math>".to_string() } } + fn default() -> Self { Initial { target: "temp".to_string(), expression: "0".to_string(), expression_mathml: "".to_string() } } } #[derive( @@ -223,12 +223,12 @@ pub struct RegTransition { /// Note: source is a required field in the schema, but we make it optional since we want to /// reuse this schema for partial extractions as well. #[serde(skip_serializing_if = "Option::is_none")] - pub source: Option>, + pub source: Option, /// Note: target is a required field in the schema, but we make it optional since we want to /// reuse this schema for partial extractions as well. #[serde(skip_serializing_if = "Option::is_none")] - pub target: Option>, + pub target: Option, #[serde(skip_serializing_if = "Option::is_none")] pub sign: Option, @@ -890,8 +890,8 @@ impl From> for RegNet { } // This adds the intial values from the state variables into the parameters vec let parameters = Parameter { - id: state.clone(), - name: Some(state.clone()), + id: r_state.initial.clone().unwrap(), + name: r_state.initial.clone(), description: Some(format!( "The total {} population at timestep 0", state.clone() @@ -929,7 +929,9 @@ impl From> for RegNet { unpaired_terms.remove(*i); } - for (i, t) in transition_pair.iter().enumerate() { + let mut trans_num = 0; + + for (_i, t) in transition_pair.iter().enumerate() { if t.0.exp_states.len() == 1 { // construct transtions for simple transtions let prop = Properties { @@ -938,15 +940,27 @@ impl From> for RegNet { rate_constant: None, }; let trans = RegTransition { - id: format!("t{}", i.clone()), - source: Some([t.1.dyn_state.clone()].to_vec()), - target: Some([t.0.dyn_state.clone()].to_vec()), + id: format!("t{}", trans_num.clone()), + source: Some(t.1.dyn_state.clone()), + target: Some(t.0.dyn_state.clone()), sign: Some(true), grounding: None, properties: Some(prop.clone()), }; + trans_num = trans_num + 1; + transitions_vec.insert(trans.clone()); + let trans = RegTransition { + id: format!("t{}", trans_num.clone()), + source: Some(t.0.dyn_state.clone()), + target: Some(t.1.dyn_state.clone()), + sign: Some(false), + grounding: None, + properties: Some(prop.clone()), + }; + trans_num = trans_num + 1; transitions_vec.insert(trans.clone()); } else { + // construct transitions for complicated transitions // mainly need to construct the output specially, // run by clay @@ -963,22 +977,25 @@ impl From> for RegNet { name: t.0.parameters[0].clone(), rate_constant: None, }; - let trans = RegTransition { - id: format!("t{}", i.clone()), - source: Some(t.1.exp_states.clone()), - target: Some(output.clone()), - sign: Some(true), - grounding: None, - properties: Some(prop.clone()), - }; - transitions_vec.insert(trans.clone()); + for (j, _out) in output.iter().enumerate() { + let trans = RegTransition { + id: format!("t{}", trans_num.clone()), + source: Some(t.1.exp_states[j].clone()), + target: Some(output[j].clone()), + sign: Some(true), + grounding: None, + properties: Some(prop.clone()), + }; + transitions_vec.insert(trans.clone()); + trans_num = trans_num + 1; + } } } for (i, term) in unpaired_terms.iter().enumerate() { println!("Term: {:?}", term.clone()); if term.exp_states.len() > 1 { - let mut output = [term.dyn_state.clone()].to_vec(); + let mut output = term.dyn_state.clone(); let mut input = term.exp_states.clone(); let param_len = term.parameters.len(); @@ -991,28 +1008,28 @@ impl From> for RegNet { input.sort(); input.dedup(); - output.sort(); - output.dedup(); if input.clone().len() > 1 { let old_input = input.clone(); input = [].to_vec(); for term in old_input.clone().iter() { - if *term != output[0] { + if *term != output { input.push(term.clone()); } } } - - let trans = RegTransition { - id: format!("s{}", i.clone()), - source: Some(input.clone()), - target: Some(output.clone()), - sign: Some(true), - grounding: None, - properties: Some(prop.clone()), - }; - transitions_vec.insert(trans.clone()); + for (j, _trm) in input.iter().enumerate() { + let trans = RegTransition { + id: format!("s{}", trans_num.clone()), + source: Some(input[j].clone()), + target: Some(output.clone()), + sign: Some(term.polarity), + grounding: None, + properties: Some(prop.clone()), + }; + transitions_vec.insert(trans.clone()); + trans_num = trans_num + 1; + } } } diff --git a/skema/skema-rs/mathml/src/parsers/decapodes_serialization.rs b/skema/skema-rs/mathml/src/parsers/decapodes_serialization.rs index a82667b9d03..046fe501fd4 100644 --- a/skema/skema-rs/mathml/src/parsers/decapodes_serialization.rs +++ b/skema/skema-rs/mathml/src/parsers/decapodes_serialization.rs @@ -297,9 +297,9 @@ pub fn to_decapodes_serialization( let tgt_idx = table_counts.variable_count; let mut derivative_str = String::new(); if *notation == DerivativeNotation::LeibnizTotal { - derivative_str.push_str(&*format!("D({},{})", order, bound_var)); + derivative_str.push_str(&format!("D({},{})", order, bound_var)); } else if *notation == DerivativeNotation::LeibnizPartialStandard { - derivative_str.push_str(&*format!("PD({},{})", order, bound_var)); + derivative_str.push_str(&format!("PD({},{})", order, bound_var)); } let unary = UnaryOperator { src: to_decapodes_serialization(&rest[0], tables, table_counts), diff --git a/skema/skema-rs/mathml/src/parsers/first_order_ode.rs b/skema/skema-rs/mathml/src/parsers/first_order_ode.rs index 71fbaf350a1..a4ef19c366d 100644 --- a/skema/skema-rs/mathml/src/parsers/first_order_ode.rs +++ b/skema/skema-rs/mathml/src/parsers/first_order_ode.rs @@ -954,6 +954,7 @@ pub fn get_terms_mult(sys_states: Vec, eq: Vec) -> P let mut rhs_vec = Vec::::new(); for arg in arg_terms.iter() { + println!("arg_term: {:?}", arg.clone()); if arg.0 == 0 { lhs_vec.push(arg.1.clone()); } else { diff --git a/skema/skema-rs/skema/src/bin/morae.rs b/skema/skema-rs/skema/src/bin/morae.rs index db73afd9d71..c7f532b938c 100644 --- a/skema/skema-rs/skema/src/bin/morae.rs +++ b/skema/skema-rs/skema/src/bin/morae.rs @@ -1,13 +1,14 @@ use clap::Parser; pub use mathml::mml2pn::{ACSet, Term}; +use mathml::parsers::math_expression_tree::MathExpressionTree; use std::fs; // new imports -use mathml::acset::{GeneralizedAMR, PetriNet}; +use mathml::acset::GeneralizedAMR; use neo4rs::{query, Node}; use schemars::schema_for; use skema::config::Config; -use skema::model_extraction::module_id2mathml_MET_ast; + use std::env; use std::sync::Arc; @@ -66,9 +67,13 @@ async fn main() { env::var("SKEMA_GRAPH_DB_HOST").unwrap_or("graphdb-bolt.askem.lum.ai".to_string()); let db_port = env::var("SKEMA_GRAPH_DB_PORT").unwrap_or("443".to_string()); - let schema = schema_for!(GeneralizedAMR); - let data = format!("{}", serde_json::to_string_pretty(&schema).unwrap()); - fs::write("./schema.txt", data).expect("Unable to write file"); + let schema_met = schema_for!(MathExpressionTree); + let data_met = serde_json::to_string_pretty(&schema_met).unwrap().to_string(); + fs::write("./met_schema.txt", data_met).expect("Unable to write file"); + + let schema_gamr = schema_for!(GeneralizedAMR); + let data_gamr = serde_json::to_string_pretty(&schema_gamr).unwrap().to_string(); + fs::write("./gamr_schema.txt", data_gamr).expect("Unable to write file"); /* let config = Config { db_protocol: db_protocol.clone(), diff --git a/skema/skema-rs/skema/src/services/mathml.rs b/skema/skema-rs/skema/src/services/mathml.rs index 41bf4ef041c..356dd9e0084 100644 --- a/skema/skema-rs/skema/src/services/mathml.rs +++ b/skema/skema-rs/skema/src/services/mathml.rs @@ -14,7 +14,7 @@ use mathml::{ parsers::first_order_ode::{first_order_ode, FirstOrderODE}, }; use petgraph::dot::{Config, Dot}; -use serde_json::from_str; + use utoipa; /// Parse MathML and return a DOT representation of the abstract syntax tree (AST) @@ -151,7 +151,7 @@ request_body = Vec, responses( ( status = 200, -body = Vec +body = Vec ) ) )]