Skip to content

Commit

Permalink
Merge pull request #178 from dmtucker/types
Browse files Browse the repository at this point in the history
Add strict type-checking
  • Loading branch information
dmtucker authored Aug 28, 2024
2 parents 749aae2 + fdc1116 commit abda733
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 33 deletions.
94 changes: 67 additions & 27 deletions src/pytest_mypy.py → src/pytest_mypy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,39 @@
"""Mypy static type checker plugin for Pytest"""

from __future__ import annotations

from dataclasses import dataclass
import json
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Dict, List, Optional, TextIO
import typing
import warnings

from filelock import FileLock # type: ignore
from filelock import FileLock
import mypy.api
import pytest

if typing.TYPE_CHECKING: # pragma: no cover
from typing import (
Any,
Dict,
Iterator,
List,
Optional,
TextIO,
Tuple,
Union,
)

# https://github.com/pytest-dev/pytest/issues/7469
from _pytest._code.code import TerminalRepr

# https://github.com/pytest-dev/pytest/pull/12661
from _pytest.terminal import TerminalReporter

# https://github.com/pytest-dev/pytest-xdist/issues/1121
from xdist.workermanage import WorkerController # type: ignore


@dataclass(frozen=True) # compat python < 3.10 (kw_only=True)
class MypyConfigStash:
Expand All @@ -19,30 +42,34 @@ class MypyConfigStash:
mypy_results_path: Path

@classmethod
def from_serialized(cls, serialized):
def from_serialized(cls, serialized: str) -> MypyConfigStash:
return cls(mypy_results_path=Path(serialized))

def serialized(self):
def serialized(self) -> str:
return str(self.mypy_results_path)


mypy_argv = []
mypy_argv: List[str] = []
nodeid_name = "mypy"
stash_key = {
"config": pytest.StashKey[MypyConfigStash](),
}
terminal_summary_title = "mypy"


def default_file_error_formatter(item, results, errors):
def default_file_error_formatter(
item: MypyItem,
results: MypyResults,
errors: List[str],
) -> str:
"""Create a string to be displayed when mypy finds errors in a file."""
return "\n".join(errors)


file_error_formatter = default_file_error_formatter


def pytest_addoption(parser):
def pytest_addoption(parser: pytest.Parser) -> None:
"""Add options for enabling and running mypy."""
group = parser.getgroup("mypy")
group.addoption("--mypy", action="store_true", help="run mypy on .py files")
Expand All @@ -59,31 +86,33 @@ def pytest_addoption(parser):
)


def _xdist_worker(config):
def _xdist_worker(config: pytest.Config) -> Dict[str, Any]:
try:
return {"input": _xdist_workerinput(config)}
except AttributeError:
return {}


def _xdist_workerinput(node):
def _xdist_workerinput(node: Union[WorkerController, pytest.Config]) -> Any:
try:
return node.workerinput
# mypy complains that pytest.Config does not have this attribute,
# but xdist.remote defines it in worker processes.
return node.workerinput # type: ignore[union-attr]
except AttributeError: # compat xdist < 2.0
return node.slaveinput
return node.slaveinput # type: ignore[union-attr]


class MypyXdistControllerPlugin:
"""A plugin that is only registered on xdist controller processes."""

def pytest_configure_node(self, node):
def pytest_configure_node(self, node: WorkerController) -> None:
"""Pass the config stash to workers."""
_xdist_workerinput(node)["mypy_config_stash_serialized"] = node.config.stash[
stash_key["config"]
].serialized()


def pytest_configure(config):
def pytest_configure(config: pytest.Config) -> None:
"""
Initialize the path used to cache mypy results,
register a custom marker for MypyItems,
Expand Down Expand Up @@ -125,7 +154,10 @@ def pytest_configure(config):
mypy_argv.append(f"--config-file={mypy_config_file}")


def pytest_collect_file(file_path, parent):
def pytest_collect_file(
file_path: Path,
parent: pytest.Collector,
) -> Optional[MypyFile]:
"""Create a MypyFileItem for every file mypy should run on."""
if file_path.suffix in {".py", ".pyi"} and any(
[
Expand All @@ -145,7 +177,7 @@ def pytest_collect_file(file_path, parent):
class MypyFile(pytest.File):
"""A File that Mypy will run on."""

def collect(self):
def collect(self) -> Iterator[MypyItem]:
"""Create a MypyFileItem for the File."""
yield MypyFileItem.from_parent(parent=self, name=nodeid_name)
# Since mypy might check files that were not collected,
Expand All @@ -163,24 +195,28 @@ class MypyItem(pytest.Item):

MARKER = "mypy"

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.add_marker(self.MARKER)

def repr_failure(self, excinfo):
def repr_failure(
self,
excinfo: pytest.ExceptionInfo[BaseException],
style: Optional[str] = None,
) -> Union[str, TerminalRepr]:
"""
Unwrap mypy errors so we get a clean error message without the
full exception repr.
"""
if excinfo.errisinstance(MypyError):
return excinfo.value.args[0]
return str(excinfo.value.args[0])
return super().repr_failure(excinfo)


class MypyFileItem(MypyItem):
"""A check for Mypy errors in a File."""

def runtest(self):
def runtest(self) -> None:
"""Raise an exception if mypy found errors for this item."""
results = MypyResults.from_session(self.session)
abspath = str(self.path.absolute())
Expand All @@ -193,10 +229,10 @@ def runtest(self):
raise MypyError(file_error_formatter(self, results, errors))
warnings.warn("\n" + "\n".join(errors), MypyWarning)

def reportinfo(self):
def reportinfo(self) -> Tuple[str, None, str]:
"""Produce a heading for the test report."""
return (
self.path,
str(self.path),
None,
str(self.path.relative_to(self.config.invocation_params.dir)),
)
Expand All @@ -205,7 +241,7 @@ def reportinfo(self):
class MypyStatusItem(MypyItem):
"""A check for a non-zero mypy exit status."""

def runtest(self):
def runtest(self) -> None:
"""Raise a MypyError if mypy exited with a non-zero status."""
results = MypyResults.from_session(self.session)
if results.status:
Expand All @@ -216,7 +252,7 @@ def runtest(self):
class MypyResults:
"""Parsed results from Mypy."""

_abspath_errors_type = Dict[str, List[str]]
_abspath_errors_type = typing.Dict[str, typing.List[str]]

opts: List[str]
stdout: str
Expand All @@ -230,7 +266,7 @@ def dump(self, results_f: TextIO) -> None:
return json.dump(vars(self), results_f)

@classmethod
def load(cls, results_f: TextIO) -> "MypyResults":
def load(cls, results_f: TextIO) -> MypyResults:
"""Get results cached by dump()."""
return cls(**json.load(results_f))

Expand All @@ -240,7 +276,7 @@ def from_mypy(
paths: List[Path],
*,
opts: Optional[List[str]] = None,
) -> "MypyResults":
) -> MypyResults:
"""Generate results from mypy."""

if opts is None:
Expand Down Expand Up @@ -275,7 +311,7 @@ def from_mypy(
)

@classmethod
def from_session(cls, session) -> "MypyResults":
def from_session(cls, session: pytest.Session) -> MypyResults:
"""Load (or generate) cached mypy results for a pytest session."""
mypy_results_path = session.config.stash[stash_key["config"]].mypy_results_path
with FileLock(str(mypy_results_path) + ".lock"):
Expand Down Expand Up @@ -309,7 +345,11 @@ class MypyWarning(pytest.PytestWarning):
class MypyReportingPlugin:
"""A Pytest plugin that reports mypy results."""

def pytest_terminal_summary(self, terminalreporter, config):
def pytest_terminal_summary(
self,
terminalreporter: TerminalReporter,
config: pytest.Config,
) -> None:
"""Report stderr and unrecognized lines from stdout."""
mypy_results_path = config.stash[stash_key["config"]].mypy_results_path
try:
Expand Down
Empty file added src/pytest_mypy/py.typed
Empty file.
8 changes: 8 additions & 0 deletions tests/test_pytest_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,11 @@ def pytest_configure(config):
result.assert_outcomes(passed=mypy_checks)
assert result.ret == pytest.ExitCode.OK
assert f"= {pytest_mypy.terminal_summary_title} =" not in str(result.stdout)


def test_py_typed(testdir):
"""Mypy recognizes that pytest_mypy is typed."""
name = "typed"
testdir.makepyfile(**{name: "import pytest_mypy"})
result = testdir.run("mypy", f"{name}.py")
assert result.ret == 0
15 changes: 9 additions & 6 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ envlist =
py310-pytest{7.0, 7.x, 8.0, 8.x}-mypy{1.0, 1.x}-xdist{1.x, 2.0, 2.x, 3.0, 3.x}
py311-pytest{7.0, 7.x, 8.0, 8.x}-mypy{1.0, 1.x}-xdist{1.x, 2.0, 2.x, 3.0, 3.x}
py312-pytest{7.0, 7.x, 8.0, 8.x}-mypy{1.0, 1.x}-xdist{1.x, 2.0, 2.x, 3.0, 3.x}
publish
static
publish

[gh-actions]
python =
3.7: py37-pytest{7.0, 7.x}-mypy{1.0, 1.x}-xdist{1.x, 2.0, 2.x, 3.0, 3.x}
3.8: py38-pytest{7.0, 7.x, 8.0, 8.x}-mypy{1.0, 1.x}-xdist{1.x, 2.0, 2.x, 3.0, 3.x}, publish, static
3.8: py38-pytest{7.0, 7.x, 8.0, 8.x}-mypy{1.0, 1.x}-xdist{1.x, 2.0, 2.x, 3.0, 3.x}
3.9: py39-pytest{7.0, 7.x, 8.0, 8.x}-mypy{1.0, 1.x}-xdist{1.x, 2.0, 2.x, 3.0, 3.x}
3.10: py310-pytest{7.0, 7.x, 8.0, 8.x}-mypy{1.0, 1.x}-xdist{1.x, 2.0, 2.x, 3.0, 3.x}
3.11: py311-pytest{7.0, 7.x, 8.0, 8.x}-mypy{1.0, 1.x}-xdist{1.x, 2.0, 2.x, 3.0, 3.x}
3.12: py312-pytest{7.0, 7.x, 8.0, 8.x}-mypy{1.0, 1.x}-xdist{1.x, 2.0, 2.x, 3.0, 3.x}
3.12: py312-pytest{7.0, 7.x, 8.0, 8.x}-mypy{1.0, 1.x}-xdist{1.x, 2.0, 2.x, 3.0, 3.x}, static, publish

[testenv]
constrain_package_deps = true
Expand All @@ -39,7 +39,8 @@ deps =
packaging ~= 21.3
pytest-cov ~= 4.1.0
pytest-randomly ~= 3.4

setenv =
COVERAGE_FILE = .coverage.{envname}
commands = pytest -p no:mypy {posargs:--cov pytest_mypy --cov-branch --cov-fail-under 100 --cov-report term-missing -n auto}

[pytest]
Expand All @@ -56,15 +57,17 @@ commands =
twine {posargs:check} {envtmpdir}/*

[testenv:static]
basepython = py312 # pytest.Node.from_parent uses typing.Self
deps =
bandit ~= 1.7.0
black ~= 24.2.0
flake8 ~= 7.0.0
mypy ~= 1.8.0
mypy ~= 1.11.0
pytest-xdist >= 3.6.0 # needed for type-checking
commands =
black --check src tests
flake8 src tests
mypy src
mypy --strict src
bandit --recursive src

[flake8]
Expand Down

0 comments on commit abda733

Please sign in to comment.