Skip to content

Commit

Permalink
Add typing.
Browse files Browse the repository at this point in the history
  • Loading branch information
jacqueswww committed Apr 15, 2019
1 parent 8d2f75f commit 9c33553
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions vyper/ast_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import ast as python_ast
from typing import (
Generator,
)

import vyper.ast as vyper_ast
from vyper.exceptions import (
Expand All @@ -16,12 +19,11 @@
iterable_cast,
)


DICT_AST_SKIPLIST = ('source_code', )


@iterable_cast(list)
def _build_vyper_ast_list(source_code, node):
def _build_vyper_ast_list(source_code: str, node: list) -> Generator:
for n in node:
yield parse_python_ast(
source_code=source_code,
Expand All @@ -30,14 +32,19 @@ def _build_vyper_ast_list(source_code, node):


@iterable_cast(dict)
def _build_vyper_ast_init_kwargs(source_code, node, vyper_class, class_name):
def _build_vyper_ast_init_kwargs(
source_code: str,
node: python_ast.AST,
vyper_class: vyper_ast.VyperNode,
class_name: str
) -> Generator:
yield ('col_offset', getattr(node, 'col_offset', None))
yield ('lineno', getattr(node, 'lineno', None))
yield ('node_id', node.node_id)
yield ('node_id', node.node_id) # type: ignore
yield ('source_code', source_code)

if isinstance(node, python_ast.ClassDef):
yield ('class_type', node.class_type)
yield ('class_type', node.class_type) # type: ignore

for field_name in node._fields:
val = getattr(node, field_name)
Expand All @@ -60,7 +67,7 @@ def _build_vyper_ast_init_kwargs(source_code, node, vyper_class, class_name):
)


def parse_python_ast(source_code, node):
def parse_python_ast(source_code: str, node: python_ast.Module) -> vyper_ast.Module:
if isinstance(node, list):
return _build_vyper_ast_list(source_code, node)
elif isinstance(node, python_ast.AST):
Expand All @@ -79,7 +86,7 @@ def parse_python_ast(source_code, node):
return node


def parse_to_ast(source_code):
def parse_to_ast(source_code: str) -> list:
if '\x00' in source_code:
raise ParserException('No null bytes (\\x00) allowed in the source code.')
class_types, reformatted_code = pre_parse(source_code)
Expand All @@ -90,17 +97,17 @@ def parse_to_ast(source_code):
source_code=source_code,
node=py_ast,
)
return vyper_ast.body
return vyper_ast.body # type: ignore


@iterable_cast(list)
def _ast_to_list(node):
def _ast_to_list(node: list) -> Generator:
for x in node:
yield ast_to_dict(x)


@iterable_cast(dict)
def _ast_to_dict(node):
def _ast_to_dict(node: vyper_ast.VyperNode) -> Generator:
for f in node.get_slots():
if f not in DICT_AST_SKIPLIST:
yield (f, ast_to_dict(getattr(node, f, None)))
Expand All @@ -118,7 +125,7 @@ def ast_to_dict(node: vyper_ast.VyperNode) -> dict:
raise CompilerPanic('Unknown vyper AST node provided.')


def dict_to_ast(ast_struct):
def dict_to_ast(ast_struct: dict) -> vyper_ast.VyperNode:
if isinstance(ast_struct, dict) and 'ast_type' in ast_struct:
vyper_class = getattr(vyper_ast, ast_struct['ast_type'])
klass = vyper_class(**{
Expand All @@ -138,7 +145,7 @@ def dict_to_ast(ast_struct):
raise CompilerPanic('Unknown ast_struct provided.')


def to_python_ast(vyper_ast_node):
def to_python_ast(vyper_ast_node: vyper_ast.VyperNode) -> python_ast.AST:
if isinstance(vyper_ast_node, list):
return [
to_python_ast(n)
Expand All @@ -154,11 +161,13 @@ def to_python_ast(vyper_ast_node):
)
for k in vyper_ast_node.get_slots()
})
else:
raise CompilerPanic(f'Unknown vyper AST class "{class_name}" provided.')
else:
return vyper_ast_node


def ast_to_string(vyper_ast_node):
def ast_to_string(vyper_ast_node: vyper_ast.VyperNode) -> str:
py_ast_node = to_python_ast(vyper_ast_node)
return python_ast.dump(
python_ast.Module(
Expand Down

0 comments on commit 9c33553

Please sign in to comment.