diff --git a/bin/vyper b/bin/vyper index 0f38732562..66cc660cf1 100755 --- a/bin/vyper +++ b/bin/vyper @@ -23,6 +23,7 @@ format_options_help = """Format to print, one or more of: bytecode_runtime - Bytecode at runtime abi - ABI in JSON format abi_python - ABI in python format + ast - AST in JSON format source_map - Vyper source map method_identifiers - Dictionary of method signature to method identifier. combined_json - All of the above format options combined as single JSON output. @@ -105,7 +106,8 @@ if __name__ == '__main__': else: # Normal output. translate_map = { 'abi_python': 'abi', - 'json': 'abi' + 'json': 'abi', + 'ast': 'ast_dict' } formats = [] orig_args = uniq(args.format.split(',')) @@ -156,7 +158,7 @@ if __name__ == '__main__': for out in out_list: for f in orig_args: o = out[translate_map.get(f, f)] - if f in ('abi', 'json'): + if f in ('abi', 'json', 'ast'): print(json.dumps(o)) elif f == 'abi_python': print(o) diff --git a/tests/examples/market_maker/test_on_chain_market_maker.py b/tests/examples/market_maker/test_on_chain_market_maker.py index 376c256907..36f1d1d3ca 100644 --- a/tests/examples/market_maker/test_on_chain_market_maker.py +++ b/tests/examples/market_maker/test_on_chain_market_maker.py @@ -25,7 +25,7 @@ def erc20(get_contract): ) -def test_initial_statet(market_maker): +def test_initial_state(market_maker): assert market_maker.totalEthQty() == 0 assert market_maker.totalTokenQty() == 0 assert market_maker.invariant() == 0 diff --git a/tests/parser/ast_utils/test_ast.py b/tests/parser/ast_utils/test_ast.py new file mode 100644 index 0000000000..416fde41fa --- /dev/null +++ b/tests/parser/ast_utils/test_ast.py @@ -0,0 +1,37 @@ +from vyper.ast_utils import ( + parse_to_ast, +) + + +def test_ast_equal(): + code = """ +@public +def test() -> int128: + a: uint256 = 100 + return 123 + """ + + ast1 = parse_to_ast(code) + ast2 = parse_to_ast("\n \n" + code + "\n\n") + + assert ast1 == ast2 + + +def test_ast_unequal(): + code1 = """ +@public +def test() -> int128: + a: uint256 = 100 + return 123 + """ + code2 = """ +@public +def test() -> int128: + a: uint256 = 100 + return 121 + """ + + ast1 = parse_to_ast(code1) + ast2 = parse_to_ast(code2) + + assert ast1 != ast2 diff --git a/tests/parser/ast_utils/test_ast_dict.py b/tests/parser/ast_utils/test_ast_dict.py new file mode 100644 index 0000000000..5a7746f34f --- /dev/null +++ b/tests/parser/ast_utils/test_ast_dict.py @@ -0,0 +1,84 @@ +from vyper import ( + compiler, +) +from vyper.ast_utils import ( + ast_to_dict, + dict_to_ast, + parse_to_ast, +) + + +def get_node_ids(ast_struct, ids=None): + if ids is None: + ids = [] + + for k, v in ast_struct.items(): + if isinstance(v, dict): + ids = get_node_ids(v, ids) + elif isinstance(v, list): + for x in v: + ids = get_node_ids(x, ids) + elif k == 'node_id': + ids.append(v) + elif v is None or isinstance(v, (str, int)): + continue + else: + raise Exception('Unknown ast_struct provided.') + return ids + + +def test_ast_to_dict_node_id(): + code = """ +@public +def test() -> int128: + a: uint256 = 100 + return 123 + """ + dict_out = compiler.compile_code(code, ['ast_dict']) + node_ids = get_node_ids(dict_out) + + assert len(node_ids) == len(set(node_ids)) + + +def test_basic_ast(): + code = """ +a: int128 + """ + dict_out = compiler.compile_code(code, ['ast_dict']) + assert dict_out['ast_dict']['ast'][0] == { + 'annotation': { + 'ast_type': 'Name', + 'col_offset': 3, + 'id': 'int128', + 'lineno': 2, + 'node_id': 4 + }, + 'ast_type': 'AnnAssign', + 'col_offset': 0, + 'lineno': 2, + 'node_id': 1, + 'simple': 1, + 'target': { + 'ast_type': 'Name', + 'col_offset': 0, + 'id': 'a', + 'lineno': 2, + 'node_id': 2 + }, + 'value': None + } + + +def test_dict_to_ast(): + code = """ +@public +def test() -> int128: + a: uint256 = 100 + return 123 + """ + + original_ast = parse_to_ast(code) + out_dict = ast_to_dict(original_ast) + new_ast = dict_to_ast(out_dict) + + assert new_ast == original_ast diff --git a/tests/parser/ast_utils/test_ast_to_string.py b/tests/parser/ast_utils/test_ast_to_string.py new file mode 100644 index 0000000000..2ea5d4d79b --- /dev/null +++ b/tests/parser/ast_utils/test_ast_to_string.py @@ -0,0 +1,18 @@ +from vyper.ast_utils import ( + ast_to_string, + parse_to_ast, +) + + +def test_ast_to_string(): + code = """ +@public +def testme(a: int128) -> int128: + return a + """ + vyper_ast = parse_to_ast(code) + assert ast_to_string(vyper_ast) == ( + "Module(body=[FunctionDef(name='testme', args=arguments(args=[arg(arg='a'," + " annotation=Name(id='int128'))], defaults=[]), body=[Return(value=Name(id='a'))]," + " decorator_list=[Name(id='public')], returns=Name(id='int128'))])" + ) diff --git a/tests/parser/exceptions/test_invalid_literal_exception.py b/tests/parser/exceptions/test_invalid_literal_exception.py index 78bc366f25..49bdac0d27 100644 --- a/tests/parser/exceptions/test_invalid_literal_exception.py +++ b/tests/parser/exceptions/test_invalid_literal_exception.py @@ -89,11 +89,6 @@ def foo(): """, """ @public -def foo(): - x = convert(-1, uint256) - """, - """ -@public def foo(): x = convert(-(-(-1)), uint256) """, diff --git a/tests/parser/exceptions/test_invalid_type_exception.py b/tests/parser/exceptions/test_invalid_type_exception.py index 39bd9eb09f..edebf42f25 100644 --- a/tests/parser/exceptions/test_invalid_type_exception.py +++ b/tests/parser/exceptions/test_invalid_type_exception.py @@ -43,9 +43,6 @@ def foo(x): pass b: map((int128, decimal), int128) """, """ -b: int128[int128: address] - """, - """ x: wei(wei) """, """ @@ -61,18 +58,12 @@ def foo(x): pass x: int128(wei ** -1) """, """ -x: int128(wei >> 3) - """, - """ x: bytes <= wei """, """ x: string <= 33 """, """ -x: bytes[1:3] - """, - """ x: bytes[33.3] """, """ diff --git a/tests/parser/exceptions/test_structure_exception.py b/tests/parser/exceptions/test_structure_exception.py index 4073f26ca7..3f24b05157 100644 --- a/tests/parser/exceptions/test_structure_exception.py +++ b/tests/parser/exceptions/test_structure_exception.py @@ -27,24 +27,6 @@ def foo(): pass send(0x1234567890123456789012345678901234567890, 5) """, """ -x: int128[5] -@public -def foo(): - self.x[2:4] = 3 - """, - """ -x: int128[5] -@public -def foo(): - z = self.x[2:4] - """, - """ -@public -def foo(): - x: int128[5] - z = x[2:4] - """, - """ @public def foo(): x: int128 = 5 @@ -147,11 +129,6 @@ def foo() -> int128: """, """ @public -def foo(): - x: address = ~self - """, - """ -@public def foo(): x = concat(b"") """, diff --git a/tests/parser/exceptions/test_syntax_exception.py b/tests/parser/exceptions/test_syntax_exception.py new file mode 100644 index 0000000000..05f9cfb3f1 --- /dev/null +++ b/tests/parser/exceptions/test_syntax_exception.py @@ -0,0 +1,55 @@ +import pytest +from pytest import ( + raises, +) + +from vyper import ( + compiler, +) +from vyper.exceptions import ( + SyntaxException, +) + +fail_list = [ + """ +x: bytes[1:3] + """, + """ +b: int128[int128: address] + """, + """ +x: int128[5] +@public +def foo(): + self.x[2:4] = 3 + """, + """ +@public +def foo(): + x: address = ~self + """, + """ +x: int128[5] +@public +def foo(): + z = self.x[2:4] + """, + """ +@public +def foo(): + x: int128[5] + z = x[2:4] + """, + """ +x: int128(wei >> 3) + """, + """ +Transfer: event({_&rom: indexed(address)}) + """, +] + + +@pytest.mark.parametrize('bad_code', fail_list) +def test_syntax_exception(bad_code): + with raises(SyntaxException): + compiler.compile_code(bad_code) diff --git a/tests/parser/parser_utils/test_annotate_and_optimize_ast.py b/tests/parser/parser_utils/test_annotate_and_optimize_ast.py index 4a51d2e32e..10528f836b 100644 --- a/tests/parser/parser_utils/test_annotate_and_optimize_ast.py +++ b/tests/parser/parser_utils/test_annotate_and_optimize_ast.py @@ -1,14 +1,14 @@ -import ast +import ast as python_ast from vyper.parser.parser_utils import ( - annotate_and_optimize_ast, + annotate_ast, ) from vyper.parser.pre_parser import ( pre_parse, ) -class AssertionVisitor(ast.NodeVisitor): +class AssertionVisitor(python_ast.NodeVisitor): def assert_about_node(self, node): assert False @@ -34,11 +34,11 @@ def foo() -> int128: def get_contract_info(source_code): class_types, reformatted_code = pre_parse(source_code) - parsed_ast = ast.parse(reformatted_code) + py_ast = python_ast.parse(reformatted_code) - annotate_and_optimize_ast(parsed_ast, reformatted_code, class_types) + annotate_ast(py_ast, reformatted_code, class_types) - return parsed_ast, reformatted_code + return py_ast, reformatted_code def test_it_annotates_ast_with_source_code(): @@ -67,5 +67,5 @@ def test_it_rewrites_unary_subtractions(): function_def = contract_ast.body[2] return_stmt = function_def.body[0] - assert isinstance(return_stmt.value, ast.Num) + assert isinstance(return_stmt.value, python_ast.Num) assert return_stmt.value.n == -1 diff --git a/tests/parser/syntax/test_no_none.py b/tests/parser/syntax/test_no_none.py index e224963fa5..fc42658254 100644 --- a/tests/parser/syntax/test_no_none.py +++ b/tests/parser/syntax/test_no_none.py @@ -1,5 +1,6 @@ from vyper.exceptions import ( InvalidLiteralException, + SyntaxException, ) @@ -123,7 +124,7 @@ def foo(): for contract in contracts: assert_compile_failed( lambda: get_contract_with_gas_estimation(contract), - InvalidLiteralException + SyntaxException ) diff --git a/tests/parser/syntax/utils/test_event_names.py b/tests/parser/syntax/utils/test_event_names.py index 834901547e..0b1d9d9b1d 100644 --- a/tests/parser/syntax/utils/test_event_names.py +++ b/tests/parser/syntax/utils/test_event_names.py @@ -61,9 +61,6 @@ def foo(i: int128) -> int128: Transfer: eve.t({_from: indexed(address)}) """, InvalidTypeException), """ -Transfer: event({_&rom: indexed(address)}) - """, - """ Transfer: event({_from: i.dexed(address), _to: indexed(address),lue: uint256}) """ ] diff --git a/vyper/ast.py b/vyper/ast.py new file mode 100644 index 0000000000..9522af29ac --- /dev/null +++ b/vyper/ast.py @@ -0,0 +1,277 @@ +from itertools import ( + chain, +) +import typing + +from vyper.exceptions import ( + CompilerPanic, +) + +BASE_NODE_ATTRIBUTES = ('node_id', 'source_code', 'col_offset', 'lineno') + + +class VyperNode: + __slots__ = BASE_NODE_ATTRIBUTES + ignored_fields: typing.Tuple = ('ctx', ) + only_empty_fields: typing.Tuple = () + + @classmethod + def get_slots(cls): + return set(chain.from_iterable( + getattr(klass, '__slots__', []) + for klass in cls.__class__.mro(cls) + )) + + def __init__(self, **kwargs): + + for field_name, value in kwargs.items(): + if field_name in self.get_slots(): + setattr(self, field_name, value) + elif value: + raise CompilerPanic( + f'Unsupported non-empty value field_name: {field_name}, ' + f' class: {type(self)} value: {value}' + ) + + def __eq__(self, other): + if isinstance(other, type(self)): + for field_name in self.get_slots(): + if field_name not in ('node_id', 'source_code', 'col_offset', 'lineno'): + if getattr(self, field_name, None) != getattr(other, field_name, None): + return False + return True + else: + return False + + +class Module(VyperNode): + __slots__ = ('body', ) + + +class Name(VyperNode): + __slots__ = ('id', ) + + +class Subscript(VyperNode): + __slots__ = ('slice', 'value') + + +class Index(VyperNode): + __slots__ = ('value', ) + + +class arg(VyperNode): + __slots__ = ('arg', 'annotation') + + +class Tuple(VyperNode): + __slots__ = ('elts', ) + + +class FunctionDef(VyperNode): + __slots__ = ('args', 'body', 'returns', 'name', 'decorator_list', 'pos') + + +class arguments(VyperNode): + __slots__ = ('args', 'defaults', 'default') + only_empty_fields = ('vararg', 'kwonlyargs', 'kwarg', 'kw_defaults') + + +class Import(VyperNode): + __slots__ = ('names', ) + + +class Call(VyperNode): + __slots__ = ('func', 'args', 'keywords', 'keyword') + + +class keyword(VyperNode): + __slots__ = ('arg', 'value') + + +class Str(VyperNode): + __slots__ = ('s', ) + + +class Compare(VyperNode): + __slots__ = ('comparators', 'ops', 'left', 'right') + + +class Num(VyperNode): + __slots__ = ('n', ) + + +class NameConstant(VyperNode): + __slots__ = ('value', ) + + +class Attribute(VyperNode): + __slots__ = ('attr', 'value',) + + +class Op(VyperNode): + __slots__ = ('op', 'left', 'right') + + +class BoolOp(Op): + __slots__ = ('values', ) + + +class BinOp(Op): + __slots__ = () + + +class UnaryOp(Op): + __slots__ = ('operand', ) + + +class List(VyperNode): + __slots__ = ('elts', ) + + +class Dict(VyperNode): + __slots__ = ('keys', 'values') + + +class Bytes(VyperNode): + __slots__ = ('s', ) + + +class Add(VyperNode): + __slots__ = () + + +class Sub(VyperNode): + __slots__ = () + + +class Mult(VyperNode): + __slots__ = () + + +class Div(VyperNode): + __slots__ = () + + +class Mod(VyperNode): + __slots__ = () + + +class Pow(VyperNode): + __slots__ = () + + +class In(VyperNode): + __slots__ = () + + +class Gt(VyperNode): + __slots__ = () + + +class GtE(VyperNode): + __slots__ = () + + +class LtE(VyperNode): + __slots__ = () + + +class Lt(VyperNode): + __slots__ = () + + +class Eq(VyperNode): + __slots__ = () + + +class NotEq(VyperNode): + __slots__ = () + + +class And(VyperNode): + __slots__ = () + + +class Or(VyperNode): + __slots__ = () + + +class Not(VyperNode): + __slots__ = () + + +class USub(VyperNode): + __slots__ = () + + +class Expr(VyperNode): + __slots__ = ('value', ) + + +class Pass(VyperNode): + __slots__ = () + + +class AnnAssign(VyperNode): + __slots__ = ('target', 'annotation', 'value', 'simple') + + +class Assign(VyperNode): + __slots__ = ('targets', 'value') + + +class If(VyperNode): + __slots__ = ('test', 'body', 'orelse') + + +class Assert(VyperNode): + __slots__ = ('test', 'msg') + + +class For(VyperNode): + __slots__ = ('iter', 'target', 'orelse', 'body') + + +class AugAssign(VyperNode): + __slots__ = ('op', 'target', 'value') + + +class Break(VyperNode): + __slots__ = () + + +class Continue(VyperNode): + __slots__ = () + + +class Return(VyperNode): + __slots__ = ('value', ) + + +class Delete(VyperNode): + __slots__ = ('targets', ) + + +class stmt(VyperNode): + __slots__ = () + + +class ClassDef(VyperNode): + __slots__ = ('class_type', 'name', 'body') + + +class Raise(VyperNode): + __slots__ = ('exc', ) + + +class Slice(VyperNode): + only_empty_fields = ('lower', ) + + +class alias(VyperNode): + __slots__ = ('name', 'asname') + + +class ImportFrom(VyperNode): + __slots__ = ('module', 'names') diff --git a/vyper/ast_utils.py b/vyper/ast_utils.py new file mode 100644 index 0000000000..329a73dc96 --- /dev/null +++ b/vyper/ast_utils.py @@ -0,0 +1,176 @@ +import ast as python_ast +from typing import ( + Generator, +) + +import vyper.ast as vyper_ast +from vyper.exceptions import ( + CompilerPanic, + ParserException, + SyntaxException, +) +from vyper.parser.parser_utils import ( + annotate_ast, +) +from vyper.parser.pre_parser import ( + pre_parse, +) +from vyper.utils import ( + iterable_cast, +) + +DICT_AST_SKIPLIST = ('source_code', ) + + +@iterable_cast(list) +def _build_vyper_ast_list(source_code: str, node: list) -> Generator: + for n in node: + yield parse_python_ast( + source_code=source_code, + node=n, + ) + + +@iterable_cast(dict) +def _build_vyper_ast_init_kwargs( + source_code: str, + node: python_ast.AST, + vyper_class: vyper_ast.VyperNode, + class_name: str +) -> Generator: + yield ('col_offset', getattr(node, 'col_offset', None)) + yield ('lineno', getattr(node, 'lineno', None)) + yield ('node_id', node.node_id) # type: ignore + yield ('source_code', source_code) + + if isinstance(node, python_ast.ClassDef): + yield ('class_type', node.class_type) # type: ignore + + for field_name in node._fields: + val = getattr(node, field_name) + if field_name in vyper_class.ignored_fields: + continue + elif val and field_name in vyper_class.only_empty_fields: + raise SyntaxException( + 'Invalid Vyper Syntax. ' + f'"{field_name}" is an unsupported attribute field ' + f'on Python AST "{class_name}" class.', + val + ) + else: + yield ( + field_name, + parse_python_ast( + source_code=source_code, + node=val, + ) + ) + + +def parse_python_ast(source_code: str, node: python_ast.AST) -> vyper_ast.VyperNode: + if isinstance(node, list): + return _build_vyper_ast_list(source_code, node) + elif isinstance(node, python_ast.AST): + class_name = node.__class__.__name__ + if hasattr(vyper_ast, class_name): + vyper_class = getattr(vyper_ast, class_name) + init_kwargs = _build_vyper_ast_init_kwargs( + source_code, node, vyper_class, class_name + ) + return vyper_class(**init_kwargs) + else: + raise SyntaxException( + f'Invalid syntax (unsupported "{class_name}" Python AST node).', node + ) + else: + return node + + +def parse_to_ast(source_code: str) -> list: + if '\x00' in source_code: + raise ParserException('No null bytes (\\x00) allowed in the source code.') + class_types, reformatted_code = pre_parse(source_code) + py_ast = python_ast.parse(reformatted_code) + annotate_ast(py_ast, source_code, class_types) + # Convert to Vyper AST. + vyper_ast = parse_python_ast( + source_code=source_code, + node=py_ast, + ) + return vyper_ast.body # type: ignore + + +@iterable_cast(list) +def _ast_to_list(node: list) -> Generator: + for x in node: + yield ast_to_dict(x) + + +@iterable_cast(dict) +def _ast_to_dict(node: vyper_ast.VyperNode) -> Generator: + for f in node.get_slots(): + if f not in DICT_AST_SKIPLIST: + yield (f, ast_to_dict(getattr(node, f, None))) + yield ('ast_type', node.__class__.__name__) + + +def ast_to_dict(node: vyper_ast.VyperNode) -> dict: + if isinstance(node, vyper_ast.VyperNode): + return _ast_to_dict(node) + elif isinstance(node, list): + return _ast_to_list(node) + elif node is None or isinstance(node, (str, int)): + return node + else: + raise CompilerPanic('Unknown vyper AST node provided.') + + +def dict_to_ast(ast_struct: dict) -> vyper_ast.VyperNode: + if isinstance(ast_struct, dict) and 'ast_type' in ast_struct: + vyper_class = getattr(vyper_ast, ast_struct['ast_type']) + klass = vyper_class(**{ + k: dict_to_ast(v) + for k, v in ast_struct.items() + if k in vyper_class.get_slots() + }) + return klass + elif isinstance(ast_struct, list): + return [ + dict_to_ast(x) + for x in ast_struct + ] + elif ast_struct is None or isinstance(ast_struct, (str, int)): + return ast_struct + else: + raise CompilerPanic('Unknown ast_struct provided.') + + +def to_python_ast(vyper_ast_node: vyper_ast.VyperNode) -> python_ast.AST: + if isinstance(vyper_ast_node, list): + return [ + to_python_ast(n) + for n in vyper_ast_node + ] + elif isinstance(vyper_ast_node, vyper_ast.VyperNode): + class_name = vyper_ast_node.__class__.__name__ + if hasattr(python_ast, class_name): + py_klass = getattr(python_ast, class_name) + return py_klass(**{ + k: to_python_ast( + getattr(vyper_ast_node, k, None) + ) + for k in vyper_ast_node.get_slots() + }) + else: + raise CompilerPanic(f'Unknown vyper AST class "{class_name}" provided.') + else: + return vyper_ast_node + + +def ast_to_string(vyper_ast_node: vyper_ast.VyperNode) -> str: + py_ast_node = to_python_ast(vyper_ast_node) + return python_ast.dump( + python_ast.Module( + body=py_ast_node + ) + ) diff --git a/vyper/compiler.py b/vyper/compiler.py index fbc6ab3013..45ec10a9cf 100644 --- a/vyper/compiler.py +++ b/vyper/compiler.py @@ -7,6 +7,9 @@ compile_lll, optimizer, ) +from vyper.ast_utils import ( + ast_to_dict, +) from vyper.opcodes import ( opcodes, ) @@ -186,8 +189,17 @@ def _mk_opcodes_runtime(code, contract_name, interface_codes): return get_opcodes(code, contract_name, bytecodes_runtime=True, interface_codes=interface_codes) +def _mk_ast_dict(code, contract_name, interface_codes): + o = { + 'contract_name': contract_name, + 'ast': ast_to_dict(parser.parse_to_ast(code)) + } + return o + + output_formats_map = { 'abi': _mk_abi_output, + 'ast_dict': _mk_ast_dict, 'bytecode': _mk_bytecode_output, 'bytecode_runtime': _mk_bytecode_runtime_output, 'ir': _mk_ir_output, diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 0e23a018a8..5acff7ca83 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -84,6 +84,10 @@ class VersionException(ParserException): pass +class SyntaxException(ParserException): + pass + + class CompilerPanic(Exception): def __init__(self, message): diff --git a/vyper/functions/functions.py b/vyper/functions/functions.py index f468f9077e..ab3751ef8e 100644 --- a/vyper/functions/functions.py +++ b/vyper/functions/functions.py @@ -1,6 +1,6 @@ -import ast import hashlib +from vyper import ast from vyper.exceptions import ( ConstancyViolationException, InvalidLiteralException, diff --git a/vyper/functions/signatures.py b/vyper/functions/signatures.py index 60751cadd9..223cdb27eb 100644 --- a/vyper/functions/signatures.py +++ b/vyper/functions/signatures.py @@ -1,6 +1,9 @@ -import ast import functools +from vyper import ast +from vyper.ast_utils import ( + parse_to_ast, +) from vyper.exceptions import ( InvalidLiteralException, StructureException, @@ -74,7 +77,7 @@ def process_arg(index, arg, expected_arg_typelist, function_name, context): else: # Does not work for unit-endowed types inside compound types, e.g. timestamp[2] parsed_expected_type = context.parse_type( - ast.parse(expected_arg).body[0].value, + parse_to_ast(expected_arg)[0].value, 'memory', ) if isinstance(parsed_expected_type, BaseType): diff --git a/vyper/parser/constants.py b/vyper/parser/constants.py index 9a7e1cd542..c38a5327d2 100644 --- a/vyper/parser/constants.py +++ b/vyper/parser/constants.py @@ -1,6 +1,6 @@ -import ast import copy +from vyper import ast from vyper.exceptions import ( StructureException, TypeMismatchException, diff --git a/vyper/parser/events.py b/vyper/parser/events.py index d884805384..8b25358786 100644 --- a/vyper/parser/events.py +++ b/vyper/parser/events.py @@ -1,5 +1,4 @@ -import ast - +from vyper import ast from vyper.exceptions import ( InvalidLiteralException, TypeMismatchException, diff --git a/vyper/parser/expr.py b/vyper/parser/expr.py index e6b3bef0f5..201f909186 100644 --- a/vyper/parser/expr.py +++ b/vyper/parser/expr.py @@ -1,6 +1,6 @@ -import ast import warnings +from vyper import ast from vyper.exceptions import ( InvalidLiteralException, NonPayableViolationException, @@ -78,7 +78,7 @@ def __init__(self, expr, context): if expr_type in self.expr_table: self.lll_node = self.expr_table[expr_type]() else: - raise Exception("Unsupported operator: %r" % ast.dump(self.expr)) + raise Exception("Unsupported operator.", self.expr) def get_expr(self): return self.expr @@ -398,7 +398,7 @@ def attribute(self): def subscript(self): sub = Expr.parse_variable_location(self.expr.value, self.context) if isinstance(sub.typ, (MappingType, ListType)): - if 'value' not in vars(self.expr.slice): + if not isinstance(self.expr.slice, ast.Index): raise StructureException( "Array access must access a single element, not a slice", self.expr, @@ -467,7 +467,7 @@ def arithmetic(self): self.expr, ) - num = ast.Num(val) + num = ast.Num(n=val) num.source_code = self.expr.source_code num.lineno = self.expr.lineno num.col_offset = self.expr.col_offset @@ -964,7 +964,7 @@ def unary_operations(self): ) if operand.typ.is_literal and 'int' in operand.typ.typ: - num = ast.Num(0 - operand.value) + num = ast.Num(n=0 - operand.value) num.source_code = self.expr.source_code num.lineno = self.expr.lineno num.col_offset = self.expr.col_offset @@ -1069,7 +1069,7 @@ def dict_fail(self): " favor of named structs, see VIP300", DeprecationWarning ) - raise InvalidLiteralException("Invalid literal: %r" % ast.dump(self.expr), self.expr) + raise InvalidLiteralException("Invalid literal.", self.expr) @staticmethod def struct_literals(expr, name, context): @@ -1078,7 +1078,7 @@ def struct_literals(expr, name, context): for key, value in zip(expr.keys, expr.values): if not isinstance(key, ast.Name): raise TypeMismatchException( - "Invalid member variable for struct: %r" % vars(key).get('id', key), + "Invalid member variable for struct: %r" % getattr(key, 'id', ''), key, ) check_valid_varname( diff --git a/vyper/parser/external_call.py b/vyper/parser/external_call.py index bc2980b307..a894c7e9bf 100644 --- a/vyper/parser/external_call.py +++ b/vyper/parser/external_call.py @@ -1,5 +1,4 @@ -import ast - +from vyper import ast from vyper.exceptions import ( FunctionDeclarationException, StructureException, @@ -183,4 +182,4 @@ def make_external_call(stmt_expr, context): ) else: - raise StructureException("Unsupported operator: %r" % ast.dump(stmt_expr), stmt_expr) + raise StructureException("Unsupported operator.", stmt_expr) diff --git a/vyper/parser/global_context.py b/vyper/parser/global_context.py index 8e1e86af1f..f873f4207d 100644 --- a/vyper/parser/global_context.py +++ b/vyper/parser/global_context.py @@ -1,5 +1,7 @@ -import ast - +from vyper import ast +from vyper.ast_utils import ( + parse_to_ast, +) from vyper.exceptions import ( EventDeclarationException, FunctionDeclarationException, @@ -12,7 +14,6 @@ Constants, ) from vyper.parser.parser_utils import ( - annotate_and_optimize_ast, getpos, ) from vyper.signatures.function_signature import ( @@ -243,9 +244,7 @@ def mk_getter(cls, varname, typ): # Parser for a single line @staticmethod def parse_line(code): - parsed_ast = ast.parse(code).body[0] - annotate_and_optimize_ast(parsed_ast, code) - + parsed_ast = parse_to_ast(code)[0] return parsed_ast # A struct is a list of members diff --git a/vyper/parser/parser.py b/vyper/parser/parser.py index 18f942d3f2..2e327f0093 100644 --- a/vyper/parser/parser.py +++ b/vyper/parser/parser.py @@ -1,14 +1,15 @@ -import ast from typing import ( Any, List, - cast, ) +from vyper import ast +from vyper.ast_utils import ( + parse_to_ast, +) from vyper.exceptions import ( EventDeclarationException, FunctionDeclarationException, - ParserException, StructureException, ) from vyper.parser.function_definitions import ( @@ -22,12 +23,6 @@ from vyper.parser.lll_node import ( LLLnode, ) -from vyper.parser.parser_utils import ( - annotate_and_optimize_ast, -) -from vyper.parser.pre_parser import ( - pre_parse, -) from vyper.signatures import ( sig_utils, ) @@ -69,28 +64,6 @@ INITIALIZER_LLL = LLLnode.from_list(INITIALIZER_LIST, typ=None) -def parse_to_ast(source_code: str) -> List[ast.stmt]: - """ - Parses the given vyper source code and returns a list of python AST objects - for all statements in the source. Performs pre-processing of source code - before parsing as well as post-processing of the resulting AST. - - :param source_code: The vyper source code to be parsed. - :return: The post-processed list of python AST objects for each statement in - ``source_code``. - """ - class_types, reformatted_code = pre_parse(source_code) - - if '\x00' in reformatted_code: - raise ParserException('No null bytes (\\x00) allowed in the source code.') - - # The return type depends on the parse mode which is why we need to cast here - parsed_ast = cast(ast.Module, ast.parse(reformatted_code)) - annotate_and_optimize_ast(parsed_ast, reformatted_code, class_types) - - return parsed_ast.body - - def parse_events(sigs, global_ctx): for event in global_ctx._events: sigs[event.target.id] = EventSignature.from_declaration(event, global_ctx) diff --git a/vyper/parser/parser_utils.py b/vyper/parser/parser_utils.py index 3dd5b7c24b..cace2392df 100644 --- a/vyper/parser/parser_utils.py +++ b/vyper/parser/parser_utils.py @@ -1,4 +1,4 @@ -import ast +import ast as python_ast from typing import ( Any, List, @@ -6,6 +6,7 @@ Union, ) +from vyper import ast from vyper.exceptions import ( InvalidLiteralException, StructureException, @@ -712,16 +713,16 @@ def make_setter(left, right, location, pos, in_function_call=False): raise Exception("Invalid type for setters") -def is_return_from_function(node: Union[ast.AST, List[Any]]) -> bool: +def is_return_from_function(node: Union[python_ast.AST, List[Any]]) -> bool: is_selfdestruct = ( - isinstance(node, ast.Expr) - and isinstance(node.value, ast.Call) - and isinstance(node.value.func, ast.Name) + isinstance(node, python_ast.Expr) + and isinstance(node.value, python_ast.Call) + and isinstance(node.value.func, python_ast.Name) and node.value.func.id == 'selfdestruct' ) - if isinstance(node, ast.Return): + if isinstance(node, python_ast.Return): return True - elif isinstance(node, ast.Raise): + elif isinstance(node, python_ast.Raise): return True elif is_selfdestruct: return True @@ -729,12 +730,13 @@ def is_return_from_function(node: Union[ast.AST, List[Any]]) -> bool: return False -class AnnotatingVisitor(ast.NodeTransformer): +class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _class_types: ClassTypes def __init__(self, source_code: str, class_types: Optional[ClassTypes] = None): - self._source_code = source_code + self._source_code: str = source_code + self.counter: int = 0 if class_types is not None: self._class_types = class_types else: @@ -744,6 +746,8 @@ def generic_visit(self, node): # Decorate every node in the AST with the original source code. This is # necessary to facilitate error pretty-printing. node.source_code = self._source_code + node.node_id = self.counter + self.counter += 1 return super().generic_visit(node) @@ -756,30 +760,29 @@ def visit_ClassDef(self, node): return node -class RewriteUnarySubVisitor(ast.NodeTransformer): +class RewriteUnarySubVisitor(python_ast.NodeTransformer): def visit_UnaryOp(self, node): self.generic_visit(node) - - if isinstance(node.op, ast.USub) and isinstance(node.operand, ast.Num): + if isinstance(node.op, python_ast.USub) and isinstance(node.operand, python_ast.Num): node.operand.n = 0 - node.operand.n return node.operand else: return node -class EnsureSingleExitChecker(ast.NodeVisitor): +class EnsureSingleExitChecker(python_ast.NodeVisitor): - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + def visit_FunctionDef(self, node: python_ast.FunctionDef) -> None: self.generic_visit(node) self.check_return_body(node, node.body) - def visit_If(self, node: ast.If) -> None: + def visit_If(self, node: python_ast.If) -> None: self.generic_visit(node) self.check_return_body(node, node.body) if node.orelse: self.check_return_body(node, node.orelse) - def check_return_body(self, node: ast.AST, node_list: List[Any]) -> None: + def check_return_body(self, node: python_ast.AST, node_list: List[Any]) -> None: return_count = len([n for n in node_list if is_return_from_function(n)]) if return_count > 1: raise StructureException( @@ -797,17 +800,17 @@ def check_return_body(self, node: ast.AST, node_list: List[Any]) -> None: ) -class UnmatchedReturnChecker(ast.NodeVisitor): +class UnmatchedReturnChecker(python_ast.NodeVisitor): """ Make sure all return statement are balanced (both branches of if statement should have returns statements). """ - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + def visit_FunctionDef(self, node: python_ast.FunctionDef) -> None: self.generic_visit(node) self.handle_primary_function_def(node) - def handle_primary_function_def(self, node: ast.FunctionDef) -> None: + def handle_primary_function_def(self, node: python_ast.FunctionDef) -> None: if node.returns and not self.return_check(node.body): raise StructureException( f'Missing or Unmatched return statements in function "{node.name}". ' @@ -815,12 +818,12 @@ def handle_primary_function_def(self, node: ast.FunctionDef) -> None: node ) - def return_check(self, node: Union[ast.AST, List[Any]]) -> bool: + def return_check(self, node: Union[python_ast.AST, List[Any]]) -> bool: if is_return_from_function(node): return True elif isinstance(node, list): return any(self.return_check(stmt) for stmt in node) - elif isinstance(node, ast.If): + elif isinstance(node, python_ast.If): if_body_check = self.return_check(node.body) else_body_check = self.return_check(node.orelse) if if_body_check and else_body_check: # both side need to match. @@ -830,8 +833,8 @@ def return_check(self, node: Union[ast.AST, List[Any]]) -> bool: return False -def annotate_and_optimize_ast( - parsed_ast: ast.Module, +def annotate_ast( + parsed_ast: Union[python_ast.AST, python_ast.Module], source_code: str, class_types: Optional[ClassTypes] = None, ) -> None: diff --git a/vyper/parser/stmt.py b/vyper/parser/stmt.py index f34bfd2db7..2125ce2be3 100644 --- a/vyper/parser/stmt.py +++ b/vyper/parser/stmt.py @@ -1,6 +1,9 @@ -import ast import re +from vyper import ast +from vyper.ast_utils import ( + ast_to_dict, +) from vyper.exceptions import ( ConstancyViolationException, EventDeclarationException, @@ -348,7 +351,7 @@ def parse_if(self): def _clear(self): # Create zero node - none = ast.NameConstant(None) + none = ast.NameConstant(value=None) none.lineno = self.stmt.lineno none.col_offset = self.stmt.col_offset zero = Expr(none, self.context).lll_node @@ -549,15 +552,15 @@ def parse_for(self): arg1, ) - if ast.dump(arg0) != ast.dump(arg1.left): + if arg0 != arg1.left: raise StructureException( ( "Two-arg for statements of the form `for i in " "range(x, x + y): ...` must have x identical in both " "places: %r %r" ) % ( - ast.dump(arg0), - ast.dump(arg1.left) + ast_to_dict(arg0), + ast_to_dict(arg1.left) ), self.stmt.iter, ) diff --git a/vyper/signatures/event_signature.py b/vyper/signatures/event_signature.py index 99e64a4f9c..bc289ae792 100644 --- a/vyper/signatures/event_signature.py +++ b/vyper/signatures/event_signature.py @@ -1,5 +1,4 @@ -import ast - +from vyper import ast from vyper.exceptions import ( EventDeclarationException, InvalidTypeException, diff --git a/vyper/signatures/function_signature.py b/vyper/signatures/function_signature.py index 10706ae07a..e4ff1c3197 100644 --- a/vyper/signatures/function_signature.py +++ b/vyper/signatures/function_signature.py @@ -1,8 +1,11 @@ -import ast from collections import ( Counter, ) +from vyper import ast +from vyper.ast_utils import ( + to_python_ast, +) from vyper.exceptions import ( FunctionDeclarationException, InvalidTypeException, @@ -415,5 +418,5 @@ def is_initializer(self): def validate_return_statement_balance(self): # Run balanced return statement check. - UnmatchedReturnChecker().visit(self.func_ast_code) - EnsureSingleExitChecker().visit(self.func_ast_code) + UnmatchedReturnChecker().visit(to_python_ast(self.func_ast_code)) + EnsureSingleExitChecker().visit(to_python_ast(self.func_ast_code)) diff --git a/vyper/signatures/interface.py b/vyper/signatures/interface.py index ee4fa511a4..da22cff875 100644 --- a/vyper/signatures/interface.py +++ b/vyper/signatures/interface.py @@ -1,9 +1,9 @@ -import ast import copy import importlib import os import pkgutil +from vyper import ast from vyper.exceptions import ( ParserException, StructureException, @@ -48,17 +48,17 @@ def render_return(sig): def abi_type_to_ast(atype): if atype in ('int128', 'uint256', 'bool', 'address', 'bytes32'): - return ast.Name(atype, None) + return ast.Name(id=atype) elif atype == 'decimal': - return ast.Name('int128', None) + return ast.Name(id='int128') elif atype == 'bytes': return ast.Subscript( - value=ast.Name('bytes', None), + value=ast.Name(id='bytes'), slice=ast.Index(256) ) elif atype == 'string': return ast.Subscript( - value=ast.Name('string', None), + value=ast.Name(id='string'), slice=ast.Index(256) ) else: @@ -91,11 +91,11 @@ def mk_full_signature_from_json(abi): ] ) - decorator_list = [ast.Name('public', None)] + decorator_list = [ast.Name(id='public')] if func['constant']: - decorator_list.append(ast.Name('constant', None)) + decorator_list.append(ast.Name(id='constant')) if func['payable']: - decorator_list.append(ast.Name('payable', None)) + decorator_list.append(ast.Name(id='payable')) sig = FunctionSignature.from_definition( code=ast.FunctionDef( diff --git a/vyper/types/convert.py b/vyper/types/convert.py index 55afe83094..989e2714bb 100644 --- a/vyper/types/convert.py +++ b/vyper/types/convert.py @@ -1,7 +1,7 @@ -import ast import math import warnings +from vyper import ast from vyper.exceptions import ( InvalidLiteralException, ParserException, diff --git a/vyper/types/types.py b/vyper/types/types.py index 3767ab8a8e..aca95241b5 100644 --- a/vyper/types/types.py +++ b/vyper/types/types.py @@ -1,5 +1,4 @@ import abc -import ast from collections import ( OrderedDict, ) @@ -9,6 +8,7 @@ ) import warnings +from vyper import ast from vyper.exceptions import ( InvalidTypeException, ) @@ -429,8 +429,7 @@ def parse_type(item, location, sigs=None, custom_units=None, custom_structs=None return BaseType(base_type, unit, positional) # Subscripts elif isinstance(item, ast.Subscript): - - if 'value' not in vars(item.slice): + if isinstance(item.slice, ast.Slice): raise InvalidTypeException( "Array / ByteArray access must access a single element, not a slice", item, @@ -478,7 +477,7 @@ def parse_type(item, location, sigs=None, custom_units=None, custom_structs=None " favor of named structs, see VIP300", DeprecationWarning ) - raise InvalidTypeException("Invalid type: %r" % ast.dump(item), item) + raise InvalidTypeException("Invalid type", item) elif isinstance(item, ast.Tuple): members = [ parse_type( @@ -491,7 +490,7 @@ def parse_type(item, location, sigs=None, custom_units=None, custom_structs=None ] return TupleType(members) else: - raise InvalidTypeException("Invalid type: %r" % ast.dump(item), item) + raise InvalidTypeException("Invalid type", item) # Gets the number of memory or storage keys needed to represent a given type diff --git a/vyper/utils.py b/vyper/utils.py index 55ed4c16b2..4d1136227a 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -2,6 +2,7 @@ from collections import ( OrderedDict, ) +import functools import re from vyper.exceptions import ( @@ -258,3 +259,12 @@ def check_valid_varname(varname, def is_instances(instances, instance_type): return all([isinstance(inst, instance_type) for inst in instances]) + + +def iterable_cast(cast_type): + def yf(func): + @functools.wraps(func) + def f(*args, **kwargs): + return cast_type(func(*args, **kwargs)) + return f + return yf