From 04a676e82e52c84f35283fe383bd6583cead7bdb Mon Sep 17 00:00:00 2001 From: "John L. Villalovos" Date: Tue, 12 Sep 2023 09:57:02 -0700 Subject: [PATCH] wip: chore: add type-hints to `imapclient/imapclient.py` --- imapclient/imapclient.py | 326 +++++++++++++++++++++++++--------- imapclient/response_parser.py | 5 +- pyproject.toml | 2 +- 3 files changed, 244 insertions(+), 89 deletions(-) diff --git a/imapclient/imapclient.py b/imapclient/imapclient.py index eea281a8..6d79f7dc 100644 --- a/imapclient/imapclient.py +++ b/imapclient/imapclient.py @@ -2,28 +2,57 @@ # Released subject to the New BSD License # Please see http://en.wikipedia.org/wiki/BSD_licenses +import collections import dataclasses +import datetime import functools import imaplib import itertools +import logging import re import select import socket import ssl as ssl_lib import sys import warnings -from datetime import date, datetime -from logging import getLogger, LoggerAdapter from operator import itemgetter -from typing import List, Optional - -from . import exceptions, imap4, response_lexer, tls +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + Iterator, + List, + MutableMapping, + Optional, + overload, + Sequence, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) + +from . import exceptions, imap4, response_lexer, response_parser, tls from .datetime_util import datetime_to_INTERNALDATE, format_criteria_date from .imap_utf7 import decode as decode_utf7 from .imap_utf7 import encode as encode_utf7 -from .response_parser import parse_fetch_response, parse_message_list, parse_response +from .response_parser import parse_fetch_response, parse_response +from .typing_imapclient import _Atom from .util import assert_imap_protocol, chunk, to_bytes, to_unicode +if TYPE_CHECKING: + from typing import Literal # Only available in Python 3.8 or higher + + # https://github.com/python/typeshed/issues/7855 + _LoggerAdapter = logging.LoggerAdapter[logging.Logger] +else: + _LoggerAdapter = logging.LoggerAdapter + +_ItemsType = Union[Iterable[Union[bytes, str]], bytes, str] +_ItemsTypeWithInt = Union[Iterable[Union[bytes, int, str]], bytes, int, str] + if hasattr(select, "poll"): POLL_SUPPORT = True else: @@ -31,7 +60,7 @@ POLL_SUPPORT = False -logger = getLogger(__name__) +logger = logging.getLogger(__name__) __all__ = [ "IMAPClient", @@ -166,19 +195,23 @@ class Quota: limit: bytes -def require_capability(capability): +# https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators +F = TypeVar("F", bound=Callable[..., Any]) + + +def require_capability(capability: str) -> Callable[[F], F]: """Decorator raising CapabilityError when a capability is not available.""" - def actual_decorator(func): + def actual_decorator(func: F) -> F: @functools.wraps(func) - def wrapper(client, *args, **kwargs): + def wrapper(client: "IMAPClient", *args: Any, **kwargs: Any) -> Any: if not client.has_capability(capability): raise exceptions.CapabilityError( "Server does not support {} capability".format(capability) ) return func(client, *args, **kwargs) - return wrapper + return cast(F, wrapper) return actual_decorator @@ -294,7 +327,9 @@ def __init__( self._set_read_timeout() # Small hack to make imaplib log everything to its own logger - imaplib_logger = IMAPlibLoggerAdapter(getLogger("imapclient.imaplib"), {}) + imaplib_logger = IMAPlibLoggerAdapter( + logging.getLogger("imapclient.imaplib"), {} + ) self._imap.debug = 5 self._imap._mesg = imaplib_logger.debug @@ -315,7 +350,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): except Exception as e: logger.info("Could not close the connection cleanly: %s", e) - def _create_IMAP4(self): + def _create_IMAP4(self) -> imaplib.IMAP4: if self.stream: return imaplib.IMAP4_stream(self.host) @@ -358,7 +393,7 @@ def socket(self): return getattr(self._imap, "sslobj", self._imap.sock) @require_capability("STARTTLS") - def starttls(self, ssl_context=None): + def starttls(self, ssl_context: ssl_lib.SSLContext = None): """Switch to an SSL encrypted connection by sending a STARTTLS command. The *ssl_context* argument is optional and should be a @@ -635,7 +670,7 @@ def _normalise_capabilites(self, raw_response): raw_response = to_bytes(raw_response) return tuple(raw_response.upper().split()) - def has_capability(self, capability): + def has_capability(self, capability: str) -> bool: """Return ``True`` if the IMAP server has the given *capability*.""" # FIXME: this will not detect capabilities that are backwards # compatible with the current level. For instance the SORT @@ -1166,7 +1201,7 @@ def _search(self, criteria, charset): # If the exception is not from a BAD IMAP response, re-raise as-is raise - return parse_message_list(data) + return response_parser.parse_message_list(data) @require_capability("SORT") def sort(self, sort_criteria, criteria="ALL", charset="UTF-8"): @@ -1435,7 +1470,7 @@ def multiappend(self, folder, msgs): Returns the APPEND response from the server. """ - def chunks(): + def chunks() -> Iterator[bytes]: for m in msgs: if isinstance(m, dict): if "flags" in m: @@ -1579,7 +1614,7 @@ def _get_quota(self, quota_root=""): return _parse_quota(self._command_and_check("getquota", _quote(quota_root))) @require_capability("QUOTA") - def get_quota_root(self, mailbox): + def get_quota_root(self, mailbox: str) -> Tuple[MailboxQuotaRoots, List[Quota]]: """Get the quota roots for a mailbox. The IMAP server responds with the quota root and the quotas associated @@ -1593,14 +1628,15 @@ def get_quota_root(self, mailbox): b"GETQUOTAROOT", to_bytes(mailbox), uid=False, response_name="QUOTAROOT" ) quota_rep = self._imap.untagged_responses.pop("QUOTA", []) - quota_root_rep = parse_response(quota_root_rep) + parsed_quota_root_rep = parse_response(quota_root_rep) quota_root = MailboxQuotaRoots( - to_unicode(quota_root_rep[0]), [to_unicode(q) for q in quota_root_rep[1:]] + to_unicode(parsed_quota_root_rep[0]), + [to_unicode(q) for q in parsed_quota_root_rep[1:]], ) return quota_root, _parse_quota(quota_rep) @require_capability("QUOTA") - def set_quota(self, quotas): + def set_quota(self, quotas: List[Quota]): """Set one or more quotas on resources. :param quotas: list of Quota objects @@ -1619,25 +1655,33 @@ def set_quota(self, quotas): set_quota_args.append("{} {}".format(quota.resource, quota.limit)) - set_quota_args = " ".join(set_quota_args) - args = [to_bytes(_quote(quota_root)), to_bytes("({})".format(set_quota_args))] + if TYPE_CHECKING: + assert quota_root is not None + args = [ + to_bytes(_quote(quota_root)), + to_bytes("({})".format(" ".join(set_quota_args))), + ] response = self._raw_command_untagged( b"SETQUOTA", args, uid=False, response_name="QUOTA" ) return _parse_quota(response) - def _check_resp(self, expected, command, typ, data): + def _check_resp( + self, expected: str, command: Union[bytes, str], typ: str, data: List[bytes] + ) -> None: """Check command responses for errors. Raises IMAPClient.Error if the command fails. """ if typ != expected: raise exceptions.IMAPClientError( - "%s failed: %s" % (command, to_unicode(data[0])) + "%s failed: %s" % (to_unicode(command), to_unicode(data[0])) ) - def _consume_until_tagged_response(self, tag, command): + def _consume_until_tagged_response( + self, tag: bytes, command: bytes + ) -> Tuple[bytes, List[Tuple[_Atom, ...]]]: tagged_commands = self._imap.tagged_commands resps = [] while True: @@ -1645,13 +1689,43 @@ def _consume_until_tagged_response(self, tag, command): if tagged_commands[tag]: break resps.append(_parse_untagged_response(line)) - typ, data = tagged_commands.pop(tag) + typ, data = cast(Tuple[str, List[bytes]], tagged_commands.pop(tag)) self._checkok(command, typ, data) return data[0], resps + @overload def _raw_command_untagged( - self, command, args, response_name=None, unpack=False, uid=True - ): + self, + command: bytes, + args: Union[List[bytes], Tuple[bytes, ...], bytes], + *, + unpack: "Literal[False]" = ..., + response_name: Optional[Union[bytes, str]] = ..., + uid: bool = ..., + ) -> List[bytes]: + ... + + @overload + def _raw_command_untagged( + self, + command: bytes, + args: Union[List[bytes], Tuple[bytes, ...], bytes], + *, + unpack: "Literal[True]", + response_name: Optional[Union[bytes, str]] = ..., + uid: bool = ..., + ) -> bytes: + ... + + def _raw_command_untagged( + self, + command: bytes, + args: Union[List[bytes], Tuple[bytes, ...], bytes], + *, + unpack: bool = False, + response_name: Optional[Union[bytes, str]] = None, + uid: bool = True, + ) -> Union[List[bytes], bytes]: # TODO: eventually this should replace _command_and_check (call it _command) typ, data = self._raw_command(command, args, uid=uid) if response_name is None: @@ -1662,7 +1736,12 @@ def _raw_command_untagged( return data[0] return data - def _raw_command(self, command, args, uid=True): + def _raw_command( + self, + command: bytes, + args: Union[List[bytes], Tuple[bytes, ...], bytes], + uid: bool = True, + ) -> Tuple[str, List[bytes]]: """Run the specific command with the arguments given. 8-bit arguments are sent as literals. The return value is (typ, data). @@ -1686,7 +1765,7 @@ def _raw_command(self, command, args, uid=True): prefix.append(b"UID") prefix.append(command) - line = [] + line: List[bytes] = [] for item, is_last in _iter_with_last(prefix + args): if not isinstance(item, bytes): raise ValueError("command args must be passed as bytes") @@ -1715,10 +1794,15 @@ def _raw_command(self, command, args, uid=True): self._imap.send(b"\r\n") - return self._imap._command_complete(to_unicode(command), tag) + return cast( + Tuple[str, List[bytes]], + self._imap._command_complete(to_unicode(command), tag), + ) - def _send_literal(self, tag, item): + def _send_literal(self, tag: str, item: bytes) -> None: """Send a single literal for the command with *tag*.""" + if TYPE_CHECKING: + assert self._cached_capabilities is not None if b"LITERAL+" in self._cached_capabilities: out = b" {" + str(len(item)).encode("ascii") + b"+}\r\n" + item logger.debug("> %s", debug_trunc(out, 64)) @@ -1741,9 +1825,33 @@ def _send_literal(self, tag, item): logger.debug(" (literal) > %s", debug_trunc(item, 256)) self._imap.send(item) + @overload def _command_and_check( - self, command, *args, unpack: bool = False, uid: bool = False - ): + self, + command: Union[bytes, str], + *args: Any, + unpack: "Literal[False]" = ..., + uid: bool = ..., + ) -> List[bytes]: + ... + + @overload + def _command_and_check( + self, + command: Union[bytes, str], + *args: Any, + unpack: "Literal[True]", + uid: bool = ..., + ) -> bytes: + ... + + def _command_and_check( + self, + command: Union[bytes, str], + *args: Any, + unpack: bool = False, + uid: bool = False, + ) -> Union[List[bytes], bytes]: if uid and self.use_uid: command = to_unicode(command) # imaplib must die typ, data = self._imap.uid(command, *args) @@ -1752,23 +1860,38 @@ def _command_and_check( typ, data = meth(*args) self._checkok(command, typ, data) if unpack: - return data[0] + result = data[0] + if TYPE_CHECKING: + assert isinstance(result, bytes) + return result return data - def _checkok(self, command, typ, data): + def _checkok(self, command: Union[bytes, str], typ: str, data: List[bytes]) -> None: self._check_resp("OK", command, typ, data) - def _gm_label_store(self, cmd, messages, labels, silent): + def _gm_label_store( + self, cmd: bytes, messages: _ItemsTypeWithInt, labels: _ItemsType, silent: bool + ) -> Optional[Dict[int, List[str]]]: response = self._store( cmd, messages, self._normalise_labels(labels), b"X-GM-LABELS", silent=silent ) return ( - {msg: utf7_decode_sequence(labels) for msg, labels in response.items()} + { + msg: utf7_decode_sequence(cast(Iterable[Union[bytes, str]], labels)) + for msg, labels in response.items() + } if response else None ) - def _store(self, cmd, messages, flags, fetch_key, silent): + def _store( + self, + cmd: bytes, + messages: _ItemsTypeWithInt, + flags: Iterable[bytes], + fetch_key: bytes, + silent: bool, + ) -> Optional[Dict[int, response_parser._ParseFetchResponseInnerDictValue]]: """Worker function for the various flag manipulation methods. *cmd* is the STORE command to use (eg. '+FLAGS'). @@ -1785,43 +1908,62 @@ def _store(self, cmd, messages, flags, fetch_key, silent): return None return self._filter_fetch_dict(parse_fetch_response(data), fetch_key) - def _filter_fetch_dict(self, fetch_dict, key): + def _filter_fetch_dict( + self, + fetch_dict: "collections.defaultdict[int, response_parser._ParseFetchResponseInnerDict]", + key: bytes, + ) -> Dict[int, response_parser._ParseFetchResponseInnerDictValue]: return dict((msgid, data[key]) for msgid, data in fetch_dict.items()) - def _normalise_folder(self, folder_name): + def _normalise_folder(self, folder_name: Union[bytes, str]) -> Union[bytes, str]: if isinstance(folder_name, bytes): folder_name = folder_name.decode("ascii") if self.folder_encode: folder_name = encode_utf7(folder_name) return _quote(folder_name) - def _normalise_labels(self, labels): + def _normalise_labels(self, labels: _ItemsType) -> List[bytes]: if isinstance(labels, (str, bytes)): labels = (labels,) return [_quote(encode_utf7(label)) for label in labels] @property - def welcome(self): + def welcome(self) -> Optional[bytes]: """access the server greeting message""" try: return self._imap.welcome except AttributeError: pass + return None + +@overload +def _quote(arg: bytes) -> bytes: + ... -def _quote(arg): + +@overload +def _quote(arg: str) -> str: + ... + + +def _quote(arg: Union[bytes, str]) -> Union[bytes, str]: + q: Union[bytes, str] if isinstance(arg, str): arg = arg.replace("\\", "\\\\") arg = arg.replace('"', '\\"') q = '"' - else: - arg = arg.replace(b"\\", b"\\\\") - arg = arg.replace(b'"', b'\\"') - q = b'"' + return q + arg + q + arg = arg.replace(b"\\", b"\\\\") + arg = arg.replace(b'"', b'\\"') + q = b'"' return q + arg + q -def _normalise_search_criteria(criteria, charset=None): +def _normalise_search_criteria( + criteria: _ItemsType, + charset: Optional[str] = None, +) -> List[bytes]: if not criteria: raise exceptions.InvalidCriteriaError("no criteria specified") if not charset: @@ -1834,7 +1976,7 @@ def _normalise_search_criteria(criteria, charset=None): for item in criteria: if isinstance(item, int): out.append(str(item).encode("ascii")) - elif isinstance(item, (datetime, date)): + elif isinstance(item, (datetime.datetime, datetime.date)): out.append(format_criteria_date(item)) elif isinstance(item, (list, tuple)): # Process nested criteria list and wrap in parens. @@ -1847,7 +1989,10 @@ def _normalise_search_criteria(criteria, charset=None): return out -def _normalise_sort_criteria(criteria, charset=None): +def _normalise_sort_criteria( + criteria: _ItemsType, + charset: Optional[str] = None, +) -> bytes: if isinstance(criteria, (str, bytes)): criteria = [criteria] return b"(" + b" ".join(to_bytes(item).upper() for item in criteria) + b")" @@ -1865,8 +2010,10 @@ class _quoted(bytes): They should be created via the *maybe* classmethod. """ + original: bytes + @classmethod - def maybe(cls, original): + def maybe(cls, original: bytes) -> bytes: """Maybe quote a bytes value. If the input requires no quoting it is returned unchanged. @@ -1887,44 +2034,44 @@ def maybe(cls, original): # normalise_text_list, seq_to_parentstr etc have to return unicode # because imaplib handles flags and sort criteria assuming these are # passed as unicode -def normalise_text_list(items): +def normalise_text_list(items: Iterable[str]) -> List[str]: return list(_normalise_text_list(items)) -def seq_to_parenstr(items): +def seq_to_parenstr(items: Iterable[Union[bytes, str]]) -> str: return _join_and_paren(_normalise_text_list(items)) -def seq_to_parenstr_upper(items): +def seq_to_parenstr_upper(items: Iterable[str]) -> str: return _join_and_paren(item.upper() for item in _normalise_text_list(items)) -def _join_and_paren(items): +def _join_and_paren(items: Iterable[str]) -> str: return "(" + " ".join(items) + ")" -def _normalise_text_list(items): +def _normalise_text_list(items: _ItemsType) -> Iterator[str]: if isinstance(items, (str, bytes)): items = (items,) return (to_unicode(c) for c in items) -def join_message_ids(messages): +def join_message_ids(messages: _ItemsTypeWithInt) -> bytes: """Convert a sequence of messages ids or a single integer message id into an id byte string for use with IMAP commands """ if isinstance(messages, (str, bytes, int)): - messages = (to_bytes(messages),) + messages = (to_bytes(cast(Union[bytes, str], messages)),) return b",".join(_maybe_int_to_bytes(m) for m in messages) -def _maybe_int_to_bytes(val): +def _maybe_int_to_bytes(val: Union[bytes, int, str]) -> bytes: if isinstance(val, int): return str(val).encode("us-ascii") return to_bytes(val) -def _parse_untagged_response(text): +def _parse_untagged_response(text: bytes) -> Tuple[_Atom, ...]: assert_imap_protocol(text.startswith(b"* ")) text = text[2:] if text.startswith((b"OK ", b"NO ")): @@ -1932,9 +2079,9 @@ def _parse_untagged_response(text): return parse_response([text]) -def as_pairs(items): +def as_pairs(items: Tuple[_Atom, ...]) -> Iterator[Tuple[_Atom, _Atom]]: i = 0 - last_item = None + last_item: Optional[_Atom] = None for item in items: if i % 2: yield last_item, item @@ -1943,16 +2090,18 @@ def as_pairs(items): i += 1 -def as_triplets(items): +def as_triplets(items: _Atom) -> Iterable[Tuple[_Atom, _Atom, _Atom]]: + if TYPE_CHECKING: + assert isinstance(items, tuple) a = iter(items) return zip(a, a, a) -def _is8bit(data): +def _is8bit(data: bytes) -> bool: return isinstance(data, _literal) or any(b > 127 for b in data) -def _iter_with_last(items): +def _iter_with_last(items: List[bytes]) -> Iterator[Tuple[bytes, bool]]: last_i = len(items) - 1 for i, item in enumerate(items): yield item, i == last_i @@ -1966,43 +2115,43 @@ class _dict_bytes_normaliser: bytes. """ - def __init__(self, d): + def __init__(self, d: Dict[Union[str, bytes], Any]): self._d = d - def iteritems(self): + def iteritems(self) -> Iterator[Tuple[bytes, Any]]: for key, value in self._d.items(): yield to_bytes(key), value # For Python 3 compatibility. items = iteritems - def __contains__(self, ink): - for k in self._gen_keys(ink): + def __contains__(self, key: Union[str, bytes]) -> bool: + for k in self._gen_keys(key): if k in self._d: return True return False - def get(self, ink, default=_not_present): - for k in self._gen_keys(ink): + def get(self, key: Union[str, bytes], default: Any = _not_present) -> Any: + for k in self._gen_keys(key): try: return self._d[k] except KeyError: pass if default == _not_present: - raise KeyError(ink) + raise KeyError(key) return default - def pop(self, ink, default=_not_present): - for k in self._gen_keys(ink): + def pop(self, key: Union[str, bytes], default: Any = _not_present) -> Any: + for k in self._gen_keys(key): try: return self._d.pop(k) except KeyError: pass if default == _not_present: - raise KeyError(ink) + raise KeyError(key) return default - def _gen_keys(self, k): + def _gen_keys(self, k: Union[bytes, str]) -> Iterator[Union[bytes, str]]: yield k if isinstance(k, bytes): yield to_unicode(k) @@ -2010,22 +2159,25 @@ def _gen_keys(self, k): yield to_bytes(k) -def debug_trunc(v, maxlen): +def debug_trunc(v: Sequence[str], maxlen: int) -> str: if len(v) < maxlen: return repr(v) hl = maxlen // 2 return repr(v[:hl]) + "..." + repr(v[-hl:]) -def utf7_decode_sequence(seq): +def utf7_decode_sequence(seq: Iterable[Union[bytes, str]]) -> List[str]: return [decode_utf7(s) for s in seq] -def _parse_quota(quota_rep): - quota_rep = parse_response(quota_rep) +def _parse_quota(quota_rep: List[bytes]) -> List[Quota]: + parsed_quota_rep = parse_response(quota_rep) rv = [] - for quota_root, quota_resource_infos in as_pairs(quota_rep): - for quota_resource_info in as_triplets(quota_resource_infos): + for quota_root, quota_resource_infos in as_pairs(parsed_quota_rep): + if TYPE_CHECKING: + assert isinstance(quota_root, (bytes, str)) + for quota_resource_info_atom in as_triplets(quota_resource_infos): + quota_resource_info = cast(Tuple[bytes, ...], quota_resource_info_atom) rv.append( Quota( quota_root=to_unicode(quota_root), @@ -2037,10 +2189,12 @@ def _parse_quota(quota_rep): return rv -class IMAPlibLoggerAdapter(LoggerAdapter): +class IMAPlibLoggerAdapter(_LoggerAdapter): """Adapter preventing IMAP secrets from going to the logging facility.""" - def process(self, msg, kwargs): + def process( + self, msg: Union[bytes, str], kwargs: MutableMapping[str, Any] + ) -> Tuple[str, MutableMapping[str, Any]]: # msg is usually unicode but see #367. Convert bytes to # unicode if required. if isinstance(msg, bytes): diff --git a/imapclient/response_parser.py b/imapclient/response_parser.py index 9f29e4ca..e8e72dfa 100644 --- a/imapclient/response_parser.py +++ b/imapclient/response_parser.py @@ -100,9 +100,10 @@ def gen_parsed_response(text: List[bytes]) -> Iterator[_Atom]: raise ProtocolError("%s: %r" % (str(err), token)) -_ParseFetchResponseInnerDict = Dict[ - bytes, Optional[Union[datetime.datetime, int, BodyData, Envelope, _Atom]] +_ParseFetchResponseInnerDictValue = Optional[ + Union[datetime.datetime, int, BodyData, Envelope, _Atom] ] +_ParseFetchResponseInnerDict = Dict[bytes, _ParseFetchResponseInnerDictValue] def parse_fetch_response( diff --git a/pyproject.toml b/pyproject.toml index 4478f94f..522bd864 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ warn_unused_ignores = true # Overrides for currently untyped modules [[tool.mypy.overrides]] module = [ - "imapclient.imapclient", +# "imapclient.imapclient", "livetest", ] ignore_errors = true