Skip to content

Commit

Permalink
Add mypy to pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhwaniartefact authored Oct 9, 2024
1 parent 55e50db commit 6eb4807
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 22 deletions.
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@ repos:
(?x)^(
(README)\.md
)
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
hooks:
- id: mypy
additional_dependencies:
- pytest
- types-requests
- types-python-dateutil
13 changes: 11 additions & 2 deletions fixity/fixity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from datetime import datetime
from datetime import timezone
from time import sleep
from typing import List
from typing import Optional
from typing import TextIO
from typing import Type
from typing import Union
from uuid import uuid4

from . import reporting
Expand Down Expand Up @@ -91,7 +96,7 @@ def fetch_environment_variables(namespace):
namespace.report_user = namespace.report_pass = None


def scan_message(aip_uuid, status, message):
def scan_message(aip_uuid: str, status: bool, message: str) -> str:
if status is True:
succeeded = "succeeded"
elif status is False:
Expand Down Expand Up @@ -305,7 +310,11 @@ def get_handler(stream, timestamps, log_level=None):
return handler


def main(argv=None, logger=None, stream=None):
def main(
argv: Optional[List[str]] = None,
logger: Union[logging.Logger] = None,
stream: Optional[TextIO] = None,
) -> Union[int, bool, Type[Exception]]:
if logger is None:
logger = get_logger()
if stream is None:
Expand Down
7 changes: 5 additions & 2 deletions fixity/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import create_engine
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import backref
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import sessionmaker

db_path = os.path.join(os.path.dirname(__file__), "fixity.db")
engine = create_engine(f"sqlite:///{db_path}", echo=False)

Session = sessionmaker(bind=engine)
Base = declarative_base()


class Base(DeclarativeBase):
pass


class AIP(Base):
Expand Down
17 changes: 17 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,20 @@ legacy_tox_ini = """
deps = pre-commit
commands = pre-commit run --all-files --show-diff-on-failure
"""

[tool.mypy]
strict = true

[[tool.mypy.overrides]]
module = [
"fixity.*",
"tests.*",
]
ignore_errors = true

[[tool.mypy.overrides]]
module = [
"tests.test_fixity",
]
ignore_errors = false

61 changes: 43 additions & 18 deletions tests/test_fixity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import uuid
from datetime import datetime
from datetime import timezone
from typing import List
from typing import TextIO
from unittest import mock

import pytest
Expand Down Expand Up @@ -34,14 +36,14 @@


@pytest.fixture
def environment(monkeypatch):
def environment(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("STORAGE_SERVICE_URL", STORAGE_SERVICE_URL)
monkeypatch.setenv("STORAGE_SERVICE_USER", STORAGE_SERVICE_USER)
monkeypatch.setenv("STORAGE_SERVICE_KEY", STORAGE_SERVICE_KEY)


@pytest.fixture
def mock_check_fixity():
def mock_check_fixity() -> List[mock.Mock]:
return [
mock.Mock(
**{
Expand All @@ -54,13 +56,15 @@ def mock_check_fixity():
]


def _assert_stream_content_matches(stream, expected):
def _assert_stream_content_matches(stream: TextIO, expected: List[str]) -> None:
stream.seek(0)
assert [line.strip() for line in stream.readlines()] == expected


@mock.patch("requests.get")
def test_scan(_get, environment, mock_check_fixity):
def test_scan(
_get: mock.Mock, environment: None, mock_check_fixity: List[mock.Mock]
) -> None:
_get.side_effect = mock_check_fixity
aip_id = uuid.uuid4()
stream = io.StringIO()
Expand All @@ -86,8 +90,11 @@ def test_scan(_get, environment, mock_check_fixity):
@mock.patch("time.time")
@mock.patch("requests.get")
def test_scan_if_timestamps_argument_is_passed(
_get, time, environment, mock_check_fixity
):
_get: mock.Mock,
time: mock.Mock,
environment: None,
mock_check_fixity: List[mock.Mock],
) -> None:
_get.side_effect = mock_check_fixity
aip_id = uuid.uuid4()
timestamp = 1514775600
Expand Down Expand Up @@ -126,8 +133,14 @@ def test_scan_if_timestamps_argument_is_passed(
],
)
def test_scan_if_report_url_exists(
_post, _get, utcnow, uuid4, mock_check_fixity, environment, monkeypatch
):
_post: mock.Mock,
_get: mock.Mock,
utcnow: mock.Mock,
uuid4: mock.Mock,
environment: None,
mock_check_fixity: List[mock.Mock],
monkeypatch: pytest.MonkeyPatch,
) -> None:
uuid4.return_value = expected_uuid = uuid.uuid4()
_get.side_effect = mock_check_fixity
monkeypatch.setenv("REPORT_URL", REPORT_URL)
Expand Down Expand Up @@ -197,8 +210,12 @@ def test_scan_if_report_url_exists(
],
)
def test_scan_handles_exceptions_if_report_url_exists(
_post, _get, environment, monkeypatch, mock_check_fixity
):
_post: mock.Mock,
_get: mock.Mock,
environment: None,
mock_check_fixity: List[mock.Mock],
monkeypatch: pytest.MonkeyPatch,
) -> None:
_get.side_effect = mock_check_fixity
aip_id = uuid.uuid4()
stream = io.StringIO()
Expand Down Expand Up @@ -237,7 +254,7 @@ def test_scan_handles_exceptions_if_report_url_exists(
),
],
)
def test_scan_handles_exceptions(_get, environment):
def test_scan_handles_exceptions(_get: mock.Mock, environment: None) -> None:
aip_id = uuid.uuid4()
stream = io.StringIO()

Expand Down Expand Up @@ -272,7 +289,9 @@ def test_scan_handles_exceptions(_get, environment):
),
],
)
def test_scan_handles_exceptions_if_no_scan_attempted(_get, environment):
def test_scan_handles_exceptions_if_no_scan_attempted(
_get: mock.Mock, environment: None
) -> None:
aip_id = uuid.uuid4()

response = fixity.main(["scan", str(aip_id)])
Expand All @@ -291,8 +310,8 @@ def test_scan_handles_exceptions_if_no_scan_attempted(_get, environment):
],
ids=["Success", "Fail", "Did not run"],
)
def test_scan_message(status, error_message):
aip_id = uuid.uuid4()
def test_scan_message(status: bool, error_message: str) -> None:
aip_id = str(uuid.uuid4())

response = fixity.scan_message(
aip_uuid=aip_id, status=status, message=error_message
Expand All @@ -306,7 +325,9 @@ def test_scan_message(status, error_message):
@mock.patch(
"requests.get",
)
def test_scanall(_get, environment, mock_check_fixity):
def test_scanall(
_get: mock.Mock, environment: None, mock_check_fixity: List[mock.Mock]
) -> None:
aip1_uuid = str(uuid.uuid4())
aip2_uuid = str(uuid.uuid4())
_get.side_effect = [
Expand Down Expand Up @@ -351,7 +372,7 @@ def test_scanall(_get, environment, mock_check_fixity):


@mock.patch("requests.get")
def test_scanall_handles_exceptions(_get, environment):
def test_scanall_handles_exceptions(_get: mock.Mock, environment: None) -> None:
aip_id1 = str(uuid.uuid4())
aip_id2 = str(uuid.uuid4())
_get.side_effect = [
Expand Down Expand Up @@ -412,7 +433,9 @@ def test_scanall_handles_exceptions(_get, environment):


@mock.patch("requests.get")
def test_main_handles_exceptions_if_scanall_fails(_get, environment):
def test_main_handles_exceptions_if_scanall_fails(
_get: mock.Mock, environment: None
) -> None:
aip_id1 = str(uuid.uuid4())
aip_id2 = str(uuid.uuid4())
_get.side_effect = [
Expand Down Expand Up @@ -473,7 +496,9 @@ def test_main_handles_exceptions_if_scanall_fails(_get, environment):


@mock.patch("requests.get")
def test_scanall_if_sort_argument_is_passed(_get, environment, mock_check_fixity):
def test_scanall_if_sort_argument_is_passed(
_get: mock.Mock, environment: None, mock_check_fixity: List[mock.Mock]
) -> None:
aip1_uuid = str(uuid.uuid4())
aip2_uuid = str(uuid.uuid4())
aip3_uuid = str(uuid.uuid4())
Expand Down

0 comments on commit 6eb4807

Please sign in to comment.