Skip to content

Commit

Permalink
fix issue with indirect type loading in using * directives
Browse files Browse the repository at this point in the history
  • Loading branch information
t81lal committed Aug 1, 2024
1 parent d83403c commit 62a07a7
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 41 deletions.
23 changes: 9 additions & 14 deletions src/solidity_parser/ast/ast2builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,11 +750,16 @@ def scopes_for_type(self, node: solnodes1.AST1Node, ttype: solnodes2.Types, use_
scope = ttype.value.x.scope
assert isinstance(scope, symtab.FileScope)
return [scope]

user_scopes = node.scope.find_user_type_scope(ttype.value.x.name.text)

filtered_scopes = []
try:
user_scopes = node.scope.find_user_type_scope(ttype.value.x.name.text)
except symtab.TypeNotFound:
# Weird situation where an object of type T is used in a contract during an intermediate computation but
# T isn't imported. E.g. (x.y).z = f() where x.y is of type T and f returns an object of type S but
# T isn't imported in the contract. This is the AddressSlot case in ERC1967Upgrade
return [ttype.scope]

filtered_scopes = []
for s in user_scopes:
if s is not None and s.value != ttype.scope.value:
if s.res_syms_single() != ttype.scope.res_syms_single():
Expand All @@ -763,23 +768,13 @@ def scopes_for_type(self, node: solnodes1.AST1Node, ttype: solnodes2.Types, use_
continue
filtered_scopes.append(s)

user_scopes = filtered_scopes

if not user_scopes:
# Weird situation where an object of type T is used in a contract during an intermediate computation but
# T isn't imported. E.g. (x.y).z = f() where x.y is of type T and f returns an object of type S but
# T isn't imported in the contract. This is the AddressSlot case in ERC1967Upgrade
scopes = [ttype.scope]
else:
scopes = user_scopes
# "Prior to version 0.5.0, Solidity allowed address members to be accessed by a contract instance, for
# example this.balance. This is now forbidden and an explicit conversion to address must be done:
# address(this).balance"
# TODO: add versioncheck
# if ttype.value.x.is_contract():
# scopes.append(scope.find_type(solnodes1.AddressType(False)))

return scopes
return filtered_scopes
elif isinstance(ttype, soltypes.BuiltinType):
scope = node.scope.find_single(ttype.name)
elif isinstance(ttype, soltypes.MetaTypeType):
Expand Down
9 changes: 6 additions & 3 deletions src/solidity_parser/ast/solnodes2.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def __str__(self):
def __repr__(self):
return self.__str__()

def type_key(self, *args, **kwargs):
return self.value.x.name.text

def code_str(self):
return self.value.x.name.text

def is_builtin(self) -> bool:
return False

Expand All @@ -81,9 +87,6 @@ def can_implicitly_cast_from(self, actual_type: soltypes.Type) -> bool:
def get_types_for_declared_type(self) -> list['TopLevelUnit']:
return [self.value.x] + self.value.x.get_subtypes()

def code_str(self):
return self.value.x.name.text


@nodebase.NodeDataclass
class SuperType(soltypes.Type):
Expand Down
33 changes: 27 additions & 6 deletions src/solidity_parser/ast/symtab.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def _add_to_results(possible_symbols: Collection, results: list, found_already:
results.append(s)


class TypeNotFound(Exception):
pass

class Scopeable:
"""Element that can be added as a child of a Scope"""

Expand Down Expand Up @@ -393,6 +396,7 @@ def find_user_type_scope(self, name, find_base_symbol: bool = False, default=Non
:param predicate: Optional function to filter during the search
:return: A single scope if find_base_symbol is True, or a list of scopes if find_base_symbol is False
:raises TypeNotFound: if no type matching the name was found
"""

if not default and not find_base_symbol:
Expand Down Expand Up @@ -463,6 +467,9 @@ def do_search(using_mode):
if not result:
result = default

if not result:
raise TypeNotFound(f"Could not find scope for type: '{name}'")

return result.res_syms_single() if find_base_symbol else result

# These don't work because Pycharm is broken: See PY-42137
Expand Down Expand Up @@ -640,12 +647,17 @@ def create_builtin_scope(key, value=None, values=None, functions=None):
return scope


def type_key(ttype) -> str:
return f'<type:{ttype.type_key()}>'
def simple_name_resolver(scope, name):
type_scope = scope.find_user_type_scope(name, find_base_symbol=True)
return type_scope.aliases[0]


def type_key(ttype, name_resolver=simple_name_resolver) -> str:
return f'<type:{ttype.type_key(name_resolver)}>'

def meta_type_key(ttype) -> str:
return f'<metatype:{ttype.type_key()}>'

def meta_type_key(ttype, name_resolver=simple_name_resolver) -> str:
return f'<metatype:{ttype.type_key(name_resolver)}>'


class RootScope(Scope):
Expand Down Expand Up @@ -720,7 +732,7 @@ def __init__(self, parser_version: version_util.Version):
def address_object(payable):
# key is <type: address> or <type: address payable>
t = soltypes.AddressType(payable)
scope = BuiltinObject(type_key(t), t)
scope = BuiltinObject(type_key(t, None), t)
scope.add(BuiltinValue('balance', uint()))
scope.add(BuiltinValue('code', bytes()))
scope.add(BuiltinValue('codehash', bytes32()))
Expand Down Expand Up @@ -1258,6 +1270,7 @@ def scope_name(self, base_name, node):
return f'<{base_name}>@{node.location}'

def find_using_target_scope_and_name(self, current_scope, target_type: soltypes.Type):

# TODO: Not sure if this is possible and I don't want to handle it(yet), just want to handle Types
# for target_type
if isinstance(target_type, solnodes.Ident):
Expand Down Expand Up @@ -1399,7 +1412,15 @@ def process_using_any_type(self, context: Context, node: solnodes.UsingDirective
if not func_def.parameters:
continue
target_type = func_def.parameters[0].var_type
target_type_scope, target_scope_name = self.find_using_target_scope_and_name(cur_scope, target_type)
try:
target_type_scope, target_scope_name = self.find_using_target_scope_and_name(cur_scope, target_type)
except TypeNotFound:
# when a contract has a 'using L for *' directive, it doesn't need to import any of the types that will
# be bound to by the first parameter of the library's functions => the target type lookup in the current
# scope might fail(if it wasnt imported). In this case, finding the base scope of the type by looking in
# the library scope(don't take any proxy scopes here as indirect using directives aren't
# inherited/imported).
target_type_scope, target_scope_name = self.find_using_target_scope_and_name(library_scope, target_type)

scope_to_add_to = self.get_proxy_scope_for_type(cur_scope, target_type, target_scope_name, target_type_scope, library_scope, check_lib=False)

Expand Down
75 changes: 58 additions & 17 deletions src/solidity_parser/ast/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def is_void(self) -> bool:
"""
return False

def type_key(self):
def type_key(self, *args, **kwargs):
""" Returns a unique key for the type that can be used to cache types in the symbol table """
return self.code_str()

Expand Down Expand Up @@ -122,6 +122,9 @@ def can_implicitly_cast_from(self, actual_type: 'Type') -> bool:
return True
return actual_type.is_float()

def type_key(self, *args, **kwargs):
raise ValueError('Float types do not have a type key')


@NodeDataclass
class VoidType(Type):
Expand All @@ -137,6 +140,9 @@ def code_str(self):
def __str__(self):
return '<void>'

def type_key(self, *args, **kwargs):
raise ValueError('Void types do not have a type key')


@NodeDataclass
class ArrayType(Type):
Expand All @@ -146,6 +152,9 @@ class ArrayType(Type):
def __str__(self): return f"{self.base_type}[]"

def code_str(self): return f'{self.base_type.code_str()}[]'

def type_key(self, name_resolver=None, *args, **kwargs):
return f"{self.base_type.type_key(name_resolver, *args, **kwargs)}[]"

def is_builtin(self) -> bool:
# e.g. byte[] is builtin, string[] is builtin, MyContract[] is not
Expand Down Expand Up @@ -183,6 +192,12 @@ class FixedLengthArrayType(ArrayType):

def __str__(self): return f"{self.base_type}[{self.size}]"

def code_str(self):
return f'{self.base_type.code_str()}[{str(self.size)}]'

def type_key(self, name_resolver=None, *args, **kwargs):
return f"{self.base_type.type_key(name_resolver, *args, **kwargs)}[{str(self.size)}]"

def is_fixed_size(self) -> bool:
return True

Expand All @@ -200,9 +215,6 @@ def can_implicitly_cast_from(self, actual_type: 'Type') -> bool:

return False

def code_str(self):
return f'{self.base_type.code_str()}[{str(self.size)}]'


@NodeDataclass
class VariableLengthArrayType(ArrayType):
Expand All @@ -214,6 +226,10 @@ def __str__(self): return f"{self.base_type}[{self.size}]"
def code_str(self):
return f'{self.base_type.code_str()}[{self.size.code_str()}]'

def type_key(self, name_resolver=None, *args, **kwargs):
# the size bit is a bit tricky as it might not be a literal, just stringify it for now
return f"{self.base_type.type_key(name_resolver, *args, **kwargs)}[{self.size.code_str()}]"


@NodeDataclass
class AddressType(Type):
Expand All @@ -222,6 +238,9 @@ class AddressType(Type):

def __str__(self): return f"address{' payable' if self.is_payable else ''}"

def code_str(self):
return 'address' + (' payable' if self.is_payable else '')

def can_implicitly_cast_from(self, actual_type: Type) -> bool:
# address_payable(actual_type) can be cast to address implicitly
if actual_type.is_address():
Expand Down Expand Up @@ -257,9 +276,6 @@ def is_builtin(self) -> bool:
def is_address(self) -> bool:
return True

def code_str(self):
return 'address' + (' payable' if self.is_payable else '')


@NodeDataclass
class ByteType(Type):
Expand Down Expand Up @@ -309,11 +325,14 @@ def can_implicitly_cast_from(self, actual_type: 'Type') -> bool:
return super().can_implicitly_cast_from(actual_type)

def __str__(self):
return self.code_str()
return 'bytes'

def code_str(self):
return 'bytes'

def type_key(self, *args, **kwargs):
return 'bytes'


@NodeDataclass
class IntType(Type):
Expand Down Expand Up @@ -384,15 +403,18 @@ class StringType(ArrayType):

def __str__(self): return "string"

def code_str(self):
return 'string'

def type_key(self, *args, **kwargs):
return 'string'

def is_builtin(self) -> bool:
return True

def is_string(self) -> bool:
return True

def code_str(self):
return 'string'


@NodeDataclass
class PreciseStringType(StringType):
Expand Down Expand Up @@ -432,6 +454,12 @@ def _name(ident):
return (' ' + str(ident)) if ident else ''
return f"({self.src}{_name(self.src_name)} => {self.dst}{_name(self.dst_name)})"

def code_str(self):
return str(self)

def type_key(self, name_resolver=None, *args, **kwargs):
return f"({self.src.type_key(name_resolver, *args, **kwargs)} => {self.dst.type_key(name_resolver, *args, **kwargs)})"

def is_mapping(self) -> bool:
return True

Expand All @@ -450,9 +478,6 @@ def flatten(self) -> list[Type]:
next_link = None
return result

def code_str(self):
return str(self)


@NodeDataclass
class UserType(Type):
Expand All @@ -465,6 +490,12 @@ class UserType(Type):

def __str__(self): return str(self.name)

def type_key(self, name_resolver=None, *args, **kwargs):
if name_resolver is None:
raise ValueError(f'Cannot resolve {self.name} without a name resolver')
else:
return name_resolver(self.scope, self.name.text)


@NodeDataclass
class BuiltinType(Type):
Expand Down Expand Up @@ -530,10 +561,10 @@ def code_str(self):
def __str__(self):
return self.code_str()

def type_key(self):
def type_key(self, name_resolver=None, *args, **kwargs):
# doesn't include modifiers for now
input_params = ', '.join([p.type_key() for p in self.inputs])
output_params = ', '.join([p.type_key() for p in self.outputs])
input_params = ', '.join([p.type_key(name_resolver, *args, **kwargs) for p in self.inputs])
output_params = ', '.join([p.type_key(name_resolver, *args, **kwargs) for p in self.outputs])
return f'function ({input_params}) returns ({output_params})'


Expand All @@ -557,6 +588,9 @@ def code_str(self):
def __str__(self):
return f'({", ".join(str(t) for t in self.ttypes)})'

def type_key(self, name_resolver=None, *args, **kwargs):
return f'({", ".join(t.type_key(name_resolver, *args, **kwargs) for t in self.ttypes)})'


@NodeDataclass
class MetaTypeType(Type):
Expand All @@ -576,6 +610,8 @@ def code_str(self):
def __str__(self):
return f'type({self.ttype})'

# TODO: metatype typekey


@NodeDataclass
class VarType(Type):
Expand All @@ -590,6 +626,9 @@ class VarType(Type):

def __str__(self): return "var"

def type_key(self, name_resolver=None, *args, **kwargs):
raise ValueError('Var types do not have a type key')


@NodeDataclass
class AnyType(Type):
Expand All @@ -601,3 +640,5 @@ class AnyType(Type):

def __str__(self): return "*"

def type_key(self, name_resolver=None, *args, **kwargs):
raise ValueError('Any types do not have a type key')
6 changes: 6 additions & 0 deletions test/solidity_parser/ast/snapshots/snap_test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

snapshots = Snapshot()

snapshots['TestASTJSONCases::test_debug 1'] = [
GenericRepr("FileDefinition(source_unit_name='using_for_directive.sol', name=Ident(text='using_for_directive.sol'), parts=[FunctionDefinition(name=Ident(text='f'), inputs=[Parameter(var=Var(name=Ident(text=None), ttype=IntType(is_signed=False, size=256), location=None))], outputs=[], modifiers=[], code=Block(stmts=[], is_unchecked=False), markers=[])])"),
GenericRepr("LibraryDefinition(source_unit_name='using_for_directive.sol', name=Ident(text='L'), parts=[], type_overrides=[])"),
GenericRepr("ContractDefinition(source_unit_name='using_for_directive.sol', name=Ident(text='C'), is_abstract=False, inherits=[], parts=[], type_overrides=[])")
]

snapshots['TestASTJSONCases::test_success_path_abstract_contract_sol 1'] = [
GenericRepr("ContractDefinition(source_unit_name='abstract_contract.sol', name=Ident(text='C'), is_abstract=True, inherits=[], parts=[FunctionDefinition(name=Ident(text='constructor'), inputs=[], outputs=[], modifiers=[], code=Block(stmts=[], is_unchecked=False), markers=[<FunctionMarker.CONSTRUCTOR: 1>])], type_overrides=[])")
]
Expand Down
3 changes: 2 additions & 1 deletion test/solidity_parser/ast/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,13 @@ def test_ast_internal_function_different_ids_export(self):
self._load_separated_file('ast_internal_function_different_ids_export.sol')

# def test_debug(self):
# self._load('using_for_directive.sol')
# self._load('non_utf8.sol')
# units = self.ast2_builder.get_top_level_units()
# self.assertMatchSnapshot(units)
#
# print("x")


class TestSemanticTestCases(LibSolidityTestBase, SnapshotTestCase):
SRC_DIR = 'testcases/libsolidity/semanticTests'
def __init__(self, *args, **kwargs):
Expand Down
Loading

0 comments on commit 62a07a7

Please sign in to comment.