diff --git a/.flake8 b/.flake8 index f1520990..65e53ac8 100644 --- a/.flake8 +++ b/.flake8 @@ -2,3 +2,6 @@ exclude = .git,.tox,__pycache__,.eggs,dist,.venv*,docs,build max-line-length = 88 extend-ignore = W503,W504,E203 + +# in pyi stubs, spacing rules are different (black handles this) +per-file-ignores = *.pyi:E302,E305 diff --git a/pyproject.toml b/pyproject.toml index b2bbe0e1..01412980 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,10 @@ where = ["src"] namespaces = false [tool.setuptools.package-data] -globus_sdk = ["py.typed"] +globus_sdk = [ + "py.typed", + "__init__.pyi", +] "globus_sdk.login_flows.local_server_login_flow_manager.html_files" = ["*.html"] [tool.setuptools.dynamic.version] @@ -189,6 +192,9 @@ disable = [ "import-error", # "disallowed" usage of our own classes and objects gets underfoot "protected-access", + # incorrect mis-reporting of lazily loaded attributes makes this lint + # unusable + "no-name-in-module", # objections to log messages doing eager (vs lazy) string formatting # the perf benefit of deferred logging doesn't always outweigh the readability cost "logging-fstring-interpolation", "logging-format-interpolation", diff --git a/scripts/ensure_exports_are_documented.py b/scripts/ensure_exports_are_documented.py index 0ce77713..095d77f3 100755 --- a/scripts/ensure_exports_are_documented.py +++ b/scripts/ensure_exports_are_documented.py @@ -14,7 +14,7 @@ _ALL_NAME_PATTERN = re.compile(r'\s+"(\w+)",?') PACKAGE_LOCS_TO_SCAN = ( - "globus_sdk/", + "globus_sdk/__init__.pyi", "globus_sdk/login_flows/", "globus_sdk/gare/", "globus_sdk/globus_app/", diff --git a/src/globus_sdk/__init__.py b/src/globus_sdk/__init__.py index 40d4cc69..d95f12e9 100644 --- a/src/globus_sdk/__init__.py +++ b/src/globus_sdk/__init__.py @@ -1,8 +1,12 @@ import logging import sys -import typing as t -from .version import __version__ +from ._lazy_import import ( + default_dir_implementation, + default_getattr_implementation, + load_all_tuple, +) +from .version import __version__ # noqa: F401 def _force_eager_imports() -> None: @@ -12,281 +16,19 @@ def _force_eager_imports() -> None: getattr(current_module, attr) -if t.TYPE_CHECKING: - from .authorizers import ( - AccessTokenAuthorizer, - BasicAuthorizer, - ClientCredentialsAuthorizer, - NullAuthorizer, - RefreshTokenAuthorizer, - ) - from .client import BaseClient - from .exc import ( - ErrorSubdocument, - GlobusAPIError, - GlobusConnectionError, - GlobusConnectionTimeoutError, - GlobusError, - GlobusSDKUsageError, - GlobusTimeoutError, - NetworkError, - RemovedInV4Warning, - ValidationError, - ) - from .globus_app import ClientApp, GlobusApp, GlobusAppConfig, UserApp - from .local_endpoint import ( - GlobusConnectPersonalOwnerInfo, - LocalGlobusConnectPersonal, - LocalGlobusConnectServer, - ) - from .response import ArrayResponse, GlobusHTTPResponse, IterableResponse - from .scopes import Scope, ScopeCycleError, ScopeParseError - from .services.auth import ( - AuthAPIError, - AuthClient, - AuthLoginClient, - ConfidentialAppAuthClient, - DependentScopeSpec, - GetConsentsResponse, - GetIdentitiesResponse, - IdentityMap, - NativeAppAuthClient, - OAuthAuthorizationCodeResponse, - OAuthClientCredentialsResponse, - OAuthDependentTokenResponse, - OAuthRefreshTokenResponse, - OAuthTokenResponse, - ) - from .services.compute import ( - ComputeAPIError, - ComputeClient, - ComputeClientV2, - ComputeClientV3, - ComputeFunctionDocument, - ComputeFunctionMetadata, - ) - from .services.flows import ( - FlowsAPIError, - FlowsClient, - IterableFlowsResponse, - SpecificFlowClient, - ) - from .services.gcs import ( - ActiveScaleStoragePolicies, - AzureBlobStoragePolicies, - BlackPearlStoragePolicies, - BoxStoragePolicies, - CephStoragePolicies, - CollectionDocument, - CollectionPolicies, - ConnectorTable, - EndpointDocument, - GCSAPIError, - GCSClient, - GCSRoleDocument, - GlobusConnectServerConnector, - GoogleCloudStorageCollectionPolicies, - GoogleCloudStoragePolicies, - GoogleDriveStoragePolicies, - GuestCollectionDocument, - HPSSStoragePolicies, - IrodsStoragePolicies, - IterableGCSResponse, - MappedCollectionDocument, - OneDriveStoragePolicies, - POSIXCollectionPolicies, - POSIXStagingCollectionPolicies, - POSIXStagingStoragePolicies, - POSIXStoragePolicies, - S3StoragePolicies, - StorageGatewayDocument, - StorageGatewayPolicies, - UnpackingGCSResponse, - UserCredentialDocument, - ) - from .services.groups import ( - BatchMembershipActions, - GroupMemberVisibility, - GroupPolicies, - GroupRequiredSignupFields, - GroupRole, - GroupsAPIError, - GroupsClient, - GroupsManager, - GroupVisibility, - ) - from .services.search import ( - SearchAPIError, - SearchClient, - SearchQuery, - SearchQueryV1, - SearchScrollQuery, - ) - from .services.timer import TimerAPIError, TimerClient - from .services.timers import ( - OnceTimerSchedule, - RecurringTimerSchedule, - TimerJob, - TimersAPIError, - TimersClient, - TransferTimer, - ) - from .services.transfer import ( - ActivationRequirementsResponse, - DeleteData, - IterableTransferResponse, - TransferAPIError, - TransferClient, - TransferData, - ) - from .utils import MISSING, MissingType - -else: - - def __dir__() -> t.List[str]: - # dir(globus_sdk) should include everything exported in __all__ - # as well as some explicitly selected attributes from the default dir() output - # on a module - # - # see also: - # https://discuss.python.org/t/how-to-properly-extend-standard-dir-search-with-module-level-dir/4202 - return list(__all__) + [ - # __all__ itself can be inspected - "__all__", - # useful to figure out where a package is installed - "__file__", - "__path__", - ] - - def __getattr__(name: str) -> t.Any: - from ._lazy_import import load_attr +# +# all lazy SDK attributes are defined in __init__.pyi +# +# to add an attribute, write the relevant import in `__init__.pyi` and update +# the `__all__` tuple there +# +__all__ = load_all_tuple(__name__, "__init__.pyi") +__getattr__ = default_getattr_implementation(__name__, "__init__.pyi") +__dir__ = default_dir_implementation(__name__) - if name in __all__: - value = load_attr(__name__, name) - setattr(sys.modules[__name__], name, value) - return value - - raise AttributeError(f"module {__name__} has no attribute {name}") - - -__all__ = ( - "__version__", - "_force_eager_imports", - "AccessTokenAuthorizer", - "ActivationRequirementsResponse", - "ActiveScaleStoragePolicies", - "ArrayResponse", - "AuthAPIError", - "AuthClient", - "AuthLoginClient", - "AzureBlobStoragePolicies", - "BaseClient", - "BasicAuthorizer", - "BatchMembershipActions", - "BlackPearlStoragePolicies", - "BoxStoragePolicies", - "CephStoragePolicies", - "ClientApp", - "ClientCredentialsAuthorizer", - "CollectionDocument", - "CollectionPolicies", - "ComputeAPIError", - "ComputeClient", - "ComputeClientV2", - "ComputeClientV3", - "ComputeFunctionDocument", - "ComputeFunctionMetadata", - "ConfidentialAppAuthClient", - "ConnectorTable", - "DeleteData", - "DependentScopeSpec", - "EndpointDocument", - "ErrorSubdocument", - "FlowsAPIError", - "FlowsClient", - "GCSAPIError", - "GCSClient", - "GCSRoleDocument", - "GetConsentsResponse", - "GetIdentitiesResponse", - "GlobusAPIError", - "GlobusApp", - "GlobusAppConfig", - "GlobusConnectPersonalOwnerInfo", - "GlobusConnectServerConnector", - "GlobusConnectionError", - "GlobusConnectionTimeoutError", - "GlobusError", - "GlobusHTTPResponse", - "GlobusSDKUsageError", - "GlobusTimeoutError", - "GoogleCloudStorageCollectionPolicies", - "GoogleCloudStoragePolicies", - "GoogleDriveStoragePolicies", - "GroupMemberVisibility", - "GroupPolicies", - "GroupRequiredSignupFields", - "GroupRole", - "GroupVisibility", - "GroupsAPIError", - "GroupsClient", - "GroupsManager", - "GuestCollectionDocument", - "HPSSStoragePolicies", - "IdentityMap", - "IrodsStoragePolicies", - "IterableFlowsResponse", - "IterableGCSResponse", - "IterableResponse", - "IterableTransferResponse", - "LocalGlobusConnectPersonal", - "LocalGlobusConnectServer", - "MISSING", - "MappedCollectionDocument", - "MissingType", - "NativeAppAuthClient", - "NetworkError", - "NullAuthorizer", - "OAuthAuthorizationCodeResponse", - "OAuthClientCredentialsResponse", - "OAuthDependentTokenResponse", - "OAuthRefreshTokenResponse", - "OAuthTokenResponse", - "OnceTimerSchedule", - "OneDriveStoragePolicies", - "POSIXCollectionPolicies", - "POSIXStagingCollectionPolicies", - "POSIXStagingStoragePolicies", - "POSIXStoragePolicies", - "RecurringTimerSchedule", - "RefreshTokenAuthorizer", - "RemovedInV4Warning", - "S3StoragePolicies", - "Scope", - "ScopeCycleError", - "ScopeParseError", - "SearchAPIError", - "SearchClient", - "SearchQuery", - "SearchQueryV1", - "SearchScrollQuery", - "SpecificFlowClient", - "StorageGatewayDocument", - "StorageGatewayPolicies", - "TimerAPIError", - "TimerClient", - "TimerJob", - "TimersAPIError", - "TimersClient", - "TransferAPIError", - "TransferClient", - "TransferData", - "TransferTimer", - "UnpackingGCSResponse", - "UserApp", - "UserCredentialDocument", - "ValidationError", -) +del load_all_tuple +del default_getattr_implementation +del default_dir_implementation # configure logging for a library, per python best practices: diff --git a/src/globus_sdk/__init__.pyi b/src/globus_sdk/__init__.pyi new file mode 100644 index 00000000..6af51037 --- /dev/null +++ b/src/globus_sdk/__init__.pyi @@ -0,0 +1,249 @@ +from .authorizers import ( + AccessTokenAuthorizer, + BasicAuthorizer, + ClientCredentialsAuthorizer, + NullAuthorizer, + RefreshTokenAuthorizer, +) +from .client import BaseClient +from .exc import ( + ErrorSubdocument, + GlobusAPIError, + GlobusConnectionError, + GlobusConnectionTimeoutError, + GlobusError, + GlobusSDKUsageError, + GlobusTimeoutError, + NetworkError, + RemovedInV4Warning, + ValidationError, +) +from .globus_app import ClientApp, GlobusApp, GlobusAppConfig, UserApp +from .local_endpoint import ( + GlobusConnectPersonalOwnerInfo, + LocalGlobusConnectPersonal, + LocalGlobusConnectServer, +) +from .response import ArrayResponse, GlobusHTTPResponse, IterableResponse +from .scopes import Scope, ScopeCycleError, ScopeParseError +from .services.auth import ( + AuthAPIError, + AuthClient, + AuthLoginClient, + ConfidentialAppAuthClient, + DependentScopeSpec, + GetConsentsResponse, + GetIdentitiesResponse, + IdentityMap, + NativeAppAuthClient, + OAuthAuthorizationCodeResponse, + OAuthClientCredentialsResponse, + OAuthDependentTokenResponse, + OAuthRefreshTokenResponse, + OAuthTokenResponse, +) +from .services.compute import ( + ComputeAPIError, + ComputeClient, + ComputeClientV2, + ComputeClientV3, + ComputeFunctionDocument, + ComputeFunctionMetadata, +) +from .services.flows import ( + FlowsAPIError, + FlowsClient, + IterableFlowsResponse, + SpecificFlowClient, +) +from .services.gcs import ( + ActiveScaleStoragePolicies, + AzureBlobStoragePolicies, + BlackPearlStoragePolicies, + BoxStoragePolicies, + CephStoragePolicies, + CollectionDocument, + CollectionPolicies, + ConnectorTable, + EndpointDocument, + GCSAPIError, + GCSClient, + GCSRoleDocument, + GlobusConnectServerConnector, + GoogleCloudStorageCollectionPolicies, + GoogleCloudStoragePolicies, + GoogleDriveStoragePolicies, + GuestCollectionDocument, + HPSSStoragePolicies, + IrodsStoragePolicies, + IterableGCSResponse, + MappedCollectionDocument, + OneDriveStoragePolicies, + POSIXCollectionPolicies, + POSIXStagingCollectionPolicies, + POSIXStagingStoragePolicies, + POSIXStoragePolicies, + S3StoragePolicies, + StorageGatewayDocument, + StorageGatewayPolicies, + UnpackingGCSResponse, + UserCredentialDocument, +) +from .services.groups import ( + BatchMembershipActions, + GroupMemberVisibility, + GroupPolicies, + GroupRequiredSignupFields, + GroupRole, + GroupsAPIError, + GroupsClient, + GroupsManager, + GroupVisibility, +) +from .services.search import ( + SearchAPIError, + SearchClient, + SearchQuery, + SearchQueryV1, + SearchScrollQuery, +) +from .services.timer import TimerAPIError, TimerClient +from .services.timers import ( + OnceTimerSchedule, + RecurringTimerSchedule, + TimerJob, + TimersAPIError, + TimersClient, + TransferTimer, +) +from .services.transfer import ( + ActivationRequirementsResponse, + DeleteData, + IterableTransferResponse, + TransferAPIError, + TransferClient, + TransferData, +) +from .utils import MISSING, MissingType +from .version import __version__ + +def _force_eager_imports() -> None: ... + +__all__ = ( + "AccessTokenAuthorizer", + "BasicAuthorizer", + "ClientCredentialsAuthorizer", + "NullAuthorizer", + "RefreshTokenAuthorizer", + "BaseClient", + "ErrorSubdocument", + "GlobusAPIError", + "GlobusConnectionError", + "GlobusConnectionTimeoutError", + "GlobusError", + "GlobusSDKUsageError", + "GlobusTimeoutError", + "NetworkError", + "RemovedInV4Warning", + "ValidationError", + "ClientApp", + "GlobusApp", + "GlobusAppConfig", + "UserApp", + "GlobusConnectPersonalOwnerInfo", + "LocalGlobusConnectPersonal", + "LocalGlobusConnectServer", + "ArrayResponse", + "GlobusHTTPResponse", + "IterableResponse", + "Scope", + "ScopeCycleError", + "ScopeParseError", + "AuthAPIError", + "AuthClient", + "AuthLoginClient", + "ConfidentialAppAuthClient", + "DependentScopeSpec", + "GetConsentsResponse", + "GetIdentitiesResponse", + "IdentityMap", + "NativeAppAuthClient", + "OAuthAuthorizationCodeResponse", + "OAuthClientCredentialsResponse", + "OAuthDependentTokenResponse", + "OAuthRefreshTokenResponse", + "OAuthTokenResponse", + "ComputeAPIError", + "ComputeClient", + "ComputeClientV2", + "ComputeClientV3", + "ComputeFunctionDocument", + "ComputeFunctionMetadata", + "FlowsAPIError", + "FlowsClient", + "IterableFlowsResponse", + "SpecificFlowClient", + "ActiveScaleStoragePolicies", + "AzureBlobStoragePolicies", + "BlackPearlStoragePolicies", + "BoxStoragePolicies", + "CephStoragePolicies", + "CollectionDocument", + "CollectionPolicies", + "ConnectorTable", + "EndpointDocument", + "GCSAPIError", + "GCSClient", + "GCSRoleDocument", + "GlobusConnectServerConnector", + "GoogleCloudStorageCollectionPolicies", + "GoogleCloudStoragePolicies", + "GoogleDriveStoragePolicies", + "GuestCollectionDocument", + "HPSSStoragePolicies", + "IrodsStoragePolicies", + "IterableGCSResponse", + "MappedCollectionDocument", + "OneDriveStoragePolicies", + "POSIXCollectionPolicies", + "POSIXStagingCollectionPolicies", + "POSIXStagingStoragePolicies", + "POSIXStoragePolicies", + "S3StoragePolicies", + "StorageGatewayDocument", + "StorageGatewayPolicies", + "UnpackingGCSResponse", + "UserCredentialDocument", + "BatchMembershipActions", + "GroupMemberVisibility", + "GroupPolicies", + "GroupRequiredSignupFields", + "GroupRole", + "GroupsAPIError", + "GroupsClient", + "GroupsManager", + "GroupVisibility", + "SearchAPIError", + "SearchClient", + "SearchQuery", + "SearchQueryV1", + "SearchScrollQuery", + "TimerAPIError", + "TimerClient", + "OnceTimerSchedule", + "RecurringTimerSchedule", + "TimerJob", + "TimersAPIError", + "TimersClient", + "TransferTimer", + "ActivationRequirementsResponse", + "DeleteData", + "IterableTransferResponse", + "TransferAPIError", + "TransferClient", + "TransferData", + "MISSING", + "MissingType", + "__version__", + "_force_eager_imports", +) diff --git a/src/globus_sdk/_lazy_import.py b/src/globus_sdk/_lazy_import.py index fedf97a6..20e74cc3 100644 --- a/src/globus_sdk/_lazy_import.py +++ b/src/globus_sdk/_lazy_import.py @@ -1,52 +1,147 @@ """ Tooling for an extremely simple lazy-import system, based on inspection of -FromImports in an `if t.TYPE_CHECKING` branch. +pyi files. + +Given a base module name (used for lookup and error messages) and the name +of a pyi file, we can use the pyi file to lookup import locations. + +i.e. Given + + foo.py + foo.pyi + bar.py + +then if `foo.pyi` has an import `from .bar import BarType`, it is possible to +*read* `foo.pyi` at runtime and use that information to load +`BarType` from `bar`. """ from __future__ import annotations import ast -import importlib -import inspect import sys import typing as t -def load_attr(modname: str, attrname: str) -> t.Any: - mod_ast = _parse_module(modname) - attr_source = find_source_module(modname, attrname, mod_ast=mod_ast) +def load_all_tuple(modname: str, pyi_filename: str) -> tuple[str, ...]: + """ + Load the __all__ tuple from a ``.pyi`` file. + + This should run before the getattr and dir implementations are defined, as those use + the runtime ``__all__`` tuple. + + :param modname: The name of the module doing the load. Usually ``__name__``. + :param pyi_filename: The name of the ``pyi`` file relative to ``modname``. + ``importlib.resources`` will use both of these fields to load the ``pyi`` + data, so the file must be in package metadata. + """ + pyi_ast = _parse_pyi(modname, pyi_filename) + return tuple(_extract_all_tuple_names(modname, pyi_filename, pyi_ast)) + + +def default_getattr_implementation( + modname: str, pyi_filename: str +) -> t.Callable[[str], t.Any]: + """ + Build an implementation of module ``__getattr__`` given the module name and + the pyi file which will drive lazy imports. + + :param modname: The name of the module where ``__getattr__`` is being added. + Usually ``__name__``. + :param pyi_filename: The name of the ``pyi`` file relative to ``modname``. + ``importlib.resources`` will use both of these fields to load the ``pyi`` + data, so the file must be in package metadata. + """ + module_object = sys.modules[modname] + all_tuple = module_object.__all__ + + def getattr_implementation(name: str) -> t.Any: + if name in all_tuple: + value = load_attr(modname, pyi_filename, name) + setattr(module_object, name, value) + return value + + raise AttributeError(f"module {modname} has no attribute {name}") + + return getattr_implementation + + +def default_dir_implementation(modname: str) -> t.Callable[[], list[str]]: + """ + Build an implementation of module ``__dir__`` given the module name. + + :param modname: The name of the module where ``__dir__`` is being added. + Usually ``__name__``. + """ + # dir(globus_sdk) should include everything exported in __all__ + # as well as some explicitly selected attributes from the default dir() output + # on a module + # + # see also: + # https://discuss.python.org/t/how-to-properly-extend-standard-dir-search-with-module-level-dir/4202 + module_object = sys.modules[modname] + all_tuple = module_object.__all__ + + def dir_implementation() -> list[str]: + return list(all_tuple) + [ + # __all__ itself can be inspected + "__all__", + # useful to figure out where a package is installed + "__file__", + "__path__", + ] + + return dir_implementation + + +def load_attr(modname: str, pyi_filename: str, attrname: str) -> t.Any: + """ + Execute an import of a single attribute in the manner that it was declared in a + ``.pyi`` file. + + The import in the pyi data is expected to be a `from x import y` statement. + Only the specific attribute will be imported, even if the pyi declares multiple + imports from the same module. + + :param modname: The name of the module importing the attribute. + Usually ``__name__``. + :param pyi_filename: The name of the ``pyi`` file relative to ``modname``. + ``importlib.resources`` will use both of these fields to load the ``pyi`` + data, so the file must be in package metadata. + :param attrname: The name of the attribute to load. + """ + import importlib + + attr_source = find_source_module(modname, pyi_filename, attrname) attr_source_mod = importlib.import_module(attr_source, modname) return getattr(attr_source_mod, attrname) -def find_source_module( - modname: str, attrname: str, mod_ast: ast.Module | None = None -) -> str: - if mod_ast is None: - mod_ast = _parse_module(modname) - import_from = find_type_checking_import_from(modname, mod_ast, attrname) +def find_source_module(modname: str, pyi_filename: str, attrname: str) -> str: + """ + Find the source module which provides an attribute, based on a declared import in a + ``.pyi`` file. + + The ``.pyi`` data will be parsed as AST and scanned for an appropriate import. + + :param modname: The name of the module importing the attribute. + Usually ``__name__``. + :param pyi_filename: The name of the ``pyi`` file relative to ``modname``. + ``importlib.resources`` will use both of these fields to load the ``pyi`` + data, so the file must be in package metadata. + :param attrname: The name of the attribute to load. + """ + pyi_ast = _parse_pyi(modname, pyi_filename) + import_from = _find_import_from(modname, pyi_ast, attrname) # type ignore the possibility of 'import_from.module == None' # as it's not possible from parsed code return ("." * import_from.level) + import_from.module # type: ignore[operator] -def _parse_module(modname: str) -> ast.Module: - if modname not in _parsed_module_cache: - mod = sys.modules[modname] - source = inspect.getsource(mod) - _parsed_module_cache[modname] = ast.parse(source) - return _parsed_module_cache[modname] - - -_parsed_module_cache: dict[str, ast.Module] = {} - - -def find_type_checking_import_from( - modname: str, mod_ast: ast.Module, attrname: str +def _find_import_from( + modname: str, pyi_ast: ast.Module, attrname: str ) -> ast.ImportFrom: - if_clause = _find_type_checking_if(modname, mod_ast) - if_body = if_clause.body - for statement in if_body: + for statement in pyi_ast.body: if not isinstance(statement, ast.ImportFrom): continue @@ -56,25 +151,49 @@ def find_type_checking_import_from( raise LookupError(f"Could not find import of '{attrname}' in '{modname}'.") -def _find_type_checking_if(modname: str, mod_ast: ast.Module) -> ast.If: - if modname in _type_checking_if_cache: - return _type_checking_if_cache[modname] +def _parse_pyi(anchor_module_name: str, pyi_filename: str) -> ast.Module: + import importlib.resources - for statement in mod_ast.body: - if not isinstance(statement, ast.If): - continue - if not isinstance(statement.test, ast.Attribute): - continue + if (anchor_module_name, pyi_filename) not in _parsed_module_cache: + if sys.version_info >= (3, 9): + source = ( + importlib.resources.files(anchor_module_name) + .joinpath(pyi_filename) + .read_bytes() + ) + else: + source = importlib.resources.read_binary(anchor_module_name, pyi_filename) + _parsed_module_cache[(anchor_module_name, pyi_filename)] = ast.parse(source) + return _parsed_module_cache[(anchor_module_name, pyi_filename)] + + +_parsed_module_cache: dict[tuple[str, str], ast.Module] = {} - attr_node: ast.Attribute = statement.test - if not isinstance(attr_node.value, ast.Name): - continue - name_node: ast.Name = attr_node.value - if name_node.id == "t" and attr_node.attr == "TYPE_CHECKING": - _type_checking_if_cache[modname] = statement - return statement - raise LookupError("Could not find 'TYPE_CHECKING' branch in '{modname}'.") +def _extract_all_tuple_names( + modname: str, pyi_filename: str, pyi_ast: ast.Module +) -> t.Iterator[str]: + all_value = _find_all_value(modname, pyi_filename, pyi_ast) + for element in all_value.elts: + if not isinstance(element, ast.Constant): + continue + yield element.value -_type_checking_if_cache: dict[str, ast.If] = {} +def _find_all_value(modname: str, pyi_filename: str, pyi_ast: ast.Module) -> ast.Tuple: + for statement in pyi_ast.body: + if not isinstance(statement, ast.Assign): + continue + if len(statement.targets) != 1: + continue + target = statement.targets[0] + if not isinstance(target, ast.Name): + continue + if target.id != "__all__": + continue + if not isinstance(statement.value, ast.Tuple): + break + return statement.value + raise LookupError( + f"Could not load '__all__' tuple from '{pyi_filename}' for '{modname}'." + ) diff --git a/tests/non-pytest/lazy-imports/test_for_import_cycles.py b/tests/non-pytest/lazy-imports/test_for_import_cycles.py index db13cae6..f34ff634 100644 --- a/tests/non-pytest/lazy-imports/test_for_import_cycles.py +++ b/tests/non-pytest/lazy-imports/test_for_import_cycles.py @@ -22,7 +22,7 @@ MODULE_NAMES = sorted( { - find_source_module("globus_sdk", attr).lstrip(".") + find_source_module("globus_sdk", "__init__.pyi", attr).lstrip(".") for attr in globus_sdk.__all__ if not attr.startswith("_") }