From 959b1e4344df8713b1df11e0fa1528ffe39d47bd Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 19:25:38 -0500 Subject: [PATCH 01/16] Add custom exceptions --- src/mainframe/custom_exceptions.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/mainframe/custom_exceptions.py b/src/mainframe/custom_exceptions.py index e2ab09fe..bd516eff 100644 --- a/src/mainframe/custom_exceptions.py +++ b/src/mainframe/custom_exceptions.py @@ -1,4 +1,17 @@ from fastapi import HTTPException, status +from dataclasses import dataclass + + +@dataclass +class PackageNotFound(Exception): + name: str + version: str + + +@dataclass +class PackageAlreadyReported(Exception): + name: str + reported_version: str class BadCredentialsException(HTTPException): From 5424588facfcb0d98cb0a7729bb763fb59c232eb Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 19:36:54 -0500 Subject: [PATCH 02/16] Add Storage protocol Add a Storage protocol that abstracts our storage layer from the business logic layer. This provides for a cleaner interface and more ergonomic tests. --- src/mainframe/database.py | 45 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/src/mainframe/database.py b/src/mainframe/database.py index 1198204b..90b63a8b 100644 --- a/src/mainframe/database.py +++ b/src/mainframe/database.py @@ -1,9 +1,14 @@ -from typing import Generator +from collections.abc import Sequence, Generator +import datetime as dt +from typing import Generator, Optional +from typing import Optional from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from mainframe.constants import mainframe_settings +from mainframe.models.orm import Scan +from typing import Protocol # pool_size and max_overflow are set to their default values. There is never # enough load to justify increasing them. @@ -21,3 +26,41 @@ def get_db() -> Generator[Session, None, None]: yield session finally: session.close() +class StorageProtocol(Protocol): + def lookup_packages( + self, name: Optional[str] = None, version: Optional[str] = None, since: Optional[dt.datetime] = None + ) -> Sequence[Scan]: + """ + Lookup information on scanned packages based on name, version, or time + scanned. If multiple packages are returned, they are ordered with the most + recently queued package first. + + Args: + since: A int representing a Unix timestamp representing when to begin the search from. + name: The name of the package. + version: The version of the package. + session: DB session. + + Exceptions: + ValueError: Invalid parameter combination was passed. See below. + + Returns: + Sequence of `Scan`s, representing the results of the query + + Only certain combinations of parameters are allowed. A query is valid if any of the following combinations are used: + - `name` and `version`: Return the package with name `name` and version `version`, if it exists. + - `name` and `since`: Find all packages with name `name` since `since`. + - `since`: Find all packages since `since`. + - `name`: Find all packages with name `name`. + All other combinations are disallowed. + + In more formal terms, a query is valid + iff `((name and not since) or (not version and since))` + where a given variable name means that query parameter was passed. Equivalently, a request is invalid + iff `(not (name or since) or (version and since))` + """ + ... + + def mark_reported(self, *, scan: Scan, subject: str) -> None: + """Mark the given `Scan` record as reported by `subject`.""" + ... From 1854598344920dca148aac96091908f597e72131 Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 19:43:28 -0500 Subject: [PATCH 03/16] Add real database storage protocol implementation --- src/mainframe/database.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/mainframe/database.py b/src/mainframe/database.py index 90b63a8b..3b662e01 100644 --- a/src/mainframe/database.py +++ b/src/mainframe/database.py @@ -5,6 +5,8 @@ from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy import create_engine, select +from sqlalchemy.orm import Session, joinedload, sessionmaker from mainframe.constants import mainframe_settings from mainframe.models.orm import Scan @@ -64,3 +66,35 @@ def lookup_packages( def mark_reported(self, *, scan: Scan, subject: str) -> None: """Mark the given `Scan` record as reported by `subject`.""" ... + + +class DatabaseStorage(StorageProtocol): + def __init__(self, sessionmaker: orm.sessionmaker[Session]): + self.sessionmaker = sessionmaker + + def get_session(self) -> Session: + return self.sessionmaker() + + def lookup_packages( + self, name: Optional[str] = None, version: Optional[str] = None, since: Optional[dt.datetime] = None + ) -> Sequence[Scan]: + query = ( + select(Scan).order_by(Scan.queued_at.desc()).options(joinedload(Scan.rules), joinedload(Scan.download_urls)) + ) + + if name: + query = query.where(Scan.name == name) + if version: + query = query.where(Scan.version == version) + if since: + query = query.where(Scan.finished_at >= since) + + session = self.get_session() + with session, session.begin(): + return session.scalars(query).unique().all() + + def mark_reported(self, *, scan: Scan, subject: str) -> None: + session = self.get_session() + with session, session.begin(): + scan.reported_by = subject + scan.reported_at = dt.datetime.now() From 999bebf18adbb1fa671517eb0cb10a5f87d7ac79 Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 19:44:36 -0500 Subject: [PATCH 04/16] Add dependency to get storage protocol --- src/mainframe/database.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/mainframe/database.py b/src/mainframe/database.py index 3b662e01..4935e736 100644 --- a/src/mainframe/database.py +++ b/src/mainframe/database.py @@ -1,11 +1,10 @@ -from collections.abc import Sequence, Generator +from collections.abc import Sequence import datetime as dt +from functools import cache from typing import Generator, Optional -from typing import Optional -from sqlalchemy import create_engine -from sqlalchemy.orm import Session, sessionmaker from sqlalchemy import create_engine, select +from sqlalchemy import orm from sqlalchemy.orm import Session, joinedload, sessionmaker from mainframe.constants import mainframe_settings @@ -28,6 +27,8 @@ def get_db() -> Generator[Session, None, None]: yield session finally: session.close() + + class StorageProtocol(Protocol): def lookup_packages( self, name: Optional[str] = None, version: Optional[str] = None, since: Optional[dt.datetime] = None @@ -98,3 +99,8 @@ def mark_reported(self, *, scan: Scan, subject: str) -> None: with session, session.begin(): scan.reported_by = subject scan.reported_at = dt.datetime.now() + + +@cache +def get_storage() -> DatabaseStorage: + return DatabaseStorage(sessionmaker) From 61424b26c359fc85481adff77fb0fb7933067e1c Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 19:48:45 -0500 Subject: [PATCH 05/16] Remove autouse for database session fixture Tests are significantly faster when we aren't doing set up and teardown database operations for every test, even ones that aren't database tests. --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index ea14071a..3170ed5f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,7 +55,7 @@ def test_data(request: pytest.FixtureRequest) -> list[Scan]: return request.param -@pytest.fixture(autouse=True) +@pytest.fixture def db_session( superuser_engine: Engine, test_data: list[Scan], sm: sessionmaker[Session] ) -> Generator[Session, None, None]: From a7f3d9b1faf5d5a872b4ebc2606f110b224da42d Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 19:53:53 -0500 Subject: [PATCH 06/16] Add mock database class and dependency Add a mock database class which abides by the Storage protocol. This will allow us to mock out the database in some integration tests. --- tests/conftest.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 3170ed5f..4979897d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,8 @@ +from collections.abc import Sequence, Generator import logging from copy import deepcopy from datetime import datetime, timedelta -from typing import Generator +from typing import Optional from unittest.mock import MagicMock import httpx @@ -12,6 +13,7 @@ from sqlalchemy import Engine, create_engine, text from sqlalchemy.orm import Session, sessionmaker +from mainframe.database import DatabaseStorage, StorageProtocol from mainframe.json_web_token import AuthenticationData from mainframe.models.orm import Base, Scan from mainframe.rules import Rules @@ -21,6 +23,36 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__file__) +class MockDatabase(StorageProtocol): + def __init__(self) -> None: + self.db: list[Scan] = [] + + def add(self, scan: Scan) -> None: + self.db.append(scan) + + def lookup_packages( + self, name: Optional[str] = None, version: Optional[str] = None, since: Optional[datetime] = None + ) -> Sequence[Scan]: + v: list[Scan] = [] + for scan in self.db: + if ( + (scan.name == name) + or (scan.version == version) + or (scan.queued_at and since and scan.queued_at >= since) + ): + v.append(scan) + + return v + + def mark_reported(self, *, scan: Scan, subject: str) -> None: + for s in self.db: + if s.scan_id == scan.scan_id: + scan.reported_by = subject + scan.reported_at = datetime.now() + +@pytest.fixture +def mock_database() -> MockDatabase: + return MockDatabase() @pytest.fixture(scope="session") def sm(engine: Engine) -> sessionmaker[Session]: From 82edcf7dbaeccbffd3b433e9181cbe595095f035 Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 19:55:30 -0500 Subject: [PATCH 07/16] Add storage fixture Add a fixture which yields a DatabaseStorage. This idea is to replace the db_session fixture over time with this fixture. --- tests/conftest.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 4979897d..d2abab88 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__file__) + class MockDatabase(StorageProtocol): def __init__(self) -> None: self.db: list[Scan] = [] @@ -50,10 +51,12 @@ def mark_reported(self, *, scan: Scan, subject: str) -> None: scan.reported_by = subject scan.reported_at = datetime.now() + @pytest.fixture def mock_database() -> MockDatabase: return MockDatabase() + @pytest.fixture(scope="session") def sm(engine: Engine) -> sessionmaker[Session]: return sessionmaker(bind=engine, expire_on_commit=False, autobegin=False) @@ -82,6 +85,20 @@ def engine(superuser_engine: Engine) -> Engine: return create_engine("postgresql+psycopg2://dragonfly:postgres@db:5432/dragonfly", pool_size=5, max_overflow=10) +@pytest.fixture +def storage( + superuser_engine: Engine, test_data: list[Scan], sm: sessionmaker[Session] +) -> Generator[DatabaseStorage, None, None]: + Base.metadata.drop_all(superuser_engine) + Base.metadata.create_all(superuser_engine) + with sm() as s, s.begin(): + s.add_all(deepcopy(test_data)) + + yield DatabaseStorage(sm) + + Base.metadata.drop_all(superuser_engine) + + @pytest.fixture(params=data, scope="session") def test_data(request: pytest.FixtureRequest) -> list[Scan]: return request.param From a0535204e743dd8715d13311c07e73424ce80168 Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 19:58:27 -0500 Subject: [PATCH 08/16] Add `db` pytest marker This marker should be used on all tests that reach for a real database to denote that they may be slow and require a database container to be spun up before this test can be run. This will allow developers who are not making any database changes to run their tests very quickly. All tests (including database ones) should still be run in CI. --- pytest.ini | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 pytest.ini diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..26f13f99 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + db: marks tests as database tests. requires a database container and may be slow. From 6cd2f036b11359dc7de360ba1dba157f5dc4c458 Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 19:56:59 -0500 Subject: [PATCH 09/16] Add database tests --- tests/test_database.py | 241 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 241 insertions(+) create mode 100644 tests/test_database.py diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 00000000..38f5408d --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,241 @@ +from copy import deepcopy +from datetime import datetime, timedelta +from typing import Optional +import pytest +from sqlalchemy import select +from mainframe.database import DatabaseStorage +from mainframe.models.orm import Scan, Status + + +@pytest.mark.db +def test_mark_reported(storage: DatabaseStorage): + scan = Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now(), + ) + + session = storage.get_session() + with session.begin(): + session.add(scan) + + storage.mark_reported(scan=scan, subject="remmy") + + query = select(Scan).where(Scan.name == "package1").where(Scan.version == "1.0.0") + actual = session.scalar(query) + assert actual is not None + assert actual.reported_by == "remmy" + assert actual.reported_at is not None + + +@pytest.mark.db +@pytest.mark.parametrize( + "scans,spec,expected", + [ + ( + [Scan(name="package1", version="1.0.0", status=Status.QUEUED, queued_by="remmy", queued_at=datetime.now())], + ("package1", None, None), + [("package1", "1.0.0")], + ), + ( + [Scan(name="package1", version="1.0.0", status=Status.QUEUED, queued_by="remmy", queued_at=datetime.now())], + ("package1", "1.0.0", None), + [("package1", "1.0.0")], + ), + ( + [ + Scan( + name="package1", version="1.0.0", status=Status.QUEUED, queued_by="remmy", queued_at=datetime.now() + ), + Scan( + name="package1", version="1.0.1", status=Status.QUEUED, queued_by="remmy", queued_at=datetime.now() + ), + ], + ("package1", None, None), + [("package1", "1.0.0"), ("package1", "1.0.1")], + ), + ( + [ + Scan( + name="package1", version="1.0.0", status=Status.QUEUED, queued_by="remmy", queued_at=datetime.now() + ), + Scan( + name="package1", version="1.0.1", status=Status.QUEUED, queued_by="remmy", queued_at=datetime.now() + ), + ], + ("package1", "1.0.1", None), + [("package1", "1.0.1")], + ), + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=2), + ), + ], + ("package1", None, 0), + [("package1", "1.0.0"), ("package1", "1.0.1")], + ), + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=2), + ), + ], + ("package1", None, datetime.now() - timedelta(seconds=4)), + [("package1", "1.0.1")], + ), + # we must use a static time for this test here because it can be flaky otherwise + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime(2024, 10, 4, 2, 4) - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime(2024, 10, 4, 2, 4) - timedelta(seconds=2), + ), + ], + ("package1", None, datetime(2024, 10, 4, 2, 4) - timedelta(seconds=2)), + [("package1", "1.0.1")], + ), + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=2), + ), + ], + ("package1", "1.0.0", datetime.now() - timedelta(seconds=2)), + [], + ), + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=2), + ), + ], + ("package1", None, datetime.now() - timedelta(seconds=1)), + [], + ), + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=2), + ), + ], + ("package2", None, None), + [], + ), + ( + [ + Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=5), + ), + Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now() - timedelta(seconds=2), + ), + ], + ("package1", "1.0.2", None), + [], + ), + ], +) +def test_lookup_packages( + storage: DatabaseStorage, + scans: list[Scan], + spec: tuple[Optional[str], Optional[str], Optional[datetime]], + expected: list[tuple[str, str]], +): + session = storage.get_session() + with session, session.begin(): + session.add_all(deepcopy(scans)) + + name, version, since = spec + results = storage.lookup_packages(name=name, version=version, since=since) + + assert sorted((s.name, s.version) for s in results) == sorted(expected) From f9f831ab1a1adb9e5cb767f91957a9875f589821 Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 19:59:24 -0500 Subject: [PATCH 10/16] Rewrite report endpoint --- src/mainframe/endpoints/report.py | 105 ++++++++++++++++-------------- 1 file changed, 55 insertions(+), 50 deletions(-) diff --git a/src/mainframe/endpoints/report.py b/src/mainframe/endpoints/report.py index 22d5cbe2..e637ac14 100644 --- a/src/mainframe/endpoints/report.py +++ b/src/mainframe/endpoints/report.py @@ -1,15 +1,14 @@ -import datetime as dt +from collections.abc import Sequence from typing import Annotated, Optional import httpx import structlog from fastapi import APIRouter, Depends, HTTPException from fastapi.encoders import jsonable_encoder -from sqlalchemy import select -from sqlalchemy.orm import Session, joinedload from mainframe.constants import mainframe_settings -from mainframe.database import get_db +from mainframe.custom_exceptions import PackageNotFound, PackageAlreadyReported +from mainframe.database import StorageProtocol, get_storage from mainframe.dependencies import get_httpx_client, validate_token from mainframe.json_web_token import AuthenticationData from mainframe.models.orm import Scan @@ -27,57 +26,50 @@ router = APIRouter(tags=["report"]) -def _lookup_package(name: str, version: str, session: Session) -> Scan: +def get_reported_version(scans: Sequence[Scan]) -> Optional[Scan]: """ - Checks if the package is valid according to our database. + Get the version of this scan that was reported. Returns: - True if the package exists in the database. + `Scan`: The scan record that was reported + `None`: No versions of this package were reported + """ - Raises: - HTTPException: 404 Not Found if the name was not found in the database, - or the specified name and version was not found in the database. 409 - Conflict if another version of the same package has already been - reported. + for scan in scans: + if scan.reported_at is not None: + return scan + + return None + + +def validate_package(name: str, version: str, scans: Sequence[Scan]) -> Scan: """ + Checks if the package is valid according to our database. + A package is considered valid if there exists a scan with the given name + and version, and that no other versions have been reported. - log = logger.bind(package={"name": name, "version": version}) + Arguments: + name: The name of the package to validate + version: The version of the package to validate + scans: The sequence of Scan records in the database where name=name + + Returns: + `Scan`: The validated `Scan` object - query = select(Scan).where(Scan.name == name).options(joinedload(Scan.rules)) - with session.begin(): - scans = session.scalars(query).unique().all() + Raises: + PackageNotFound: The given name and version combination + PackageAlreadyReported: The package was already reported + """ if not scans: - error = HTTPException(404, detail=f"No records for package `{name}` were found in the database") - log.error( - f"No records for package {name} found in database", error_message=error.detail, tag="package_not_found_db" - ) - raise error + raise PackageNotFound(name=name, version=version) - for scan in scans: - if scan.reported_at is not None: - error = HTTPException( - 409, - detail=( - f"Only one version of a package may be reported at a time. " - f"(`{scan.name}@{scan.version}` was already reported)" - ), - ) - log.error( - "Only one version of a package allowed to be reported at a time", - error_message=error.detail, - tag="multiple_versions_prohibited", - ) - raise error + if scan := get_reported_version(scans): + raise PackageAlreadyReported(name=scan.name, reported_version=scan.version) - with session.begin(): - scan = session.scalar(query.where(Scan.version == version)) + scan = next((s for s in scans if (s.name, s.version) == (name, version)), None) if scan is None: - error = HTTPException( - 404, detail=f"Package `{name}` has records in the database, but none with version `{version}`" - ) - log.error(f"No version {version} for package {name} in database", tag="invalid_version") - raise error + raise PackageNotFound(name=name, version=version) return scan @@ -156,7 +148,7 @@ def _validate_pypi(name: str, version: str, http_client: httpx.Client): ) def report_package( body: ReportPackageBody, - session: Annotated[Session, Depends(get_db)], + database: Annotated[StorageProtocol, Depends(get_storage)], auth: Annotated[AuthenticationData, Depends(validate_token)], httpx_client: Annotated[httpx.Client, Depends(get_httpx_client)], ): @@ -198,7 +190,24 @@ def report_package( log = logger.bind(package={"name": name, "version": version}) # Check our database first to avoid unnecessarily using PyPI API. - scan = _lookup_package(name, version, session) + try: + scans = database.lookup_packages(name) + scan = validate_package(name, version, scans) + except PackageNotFound as e: + detail = f"No records for package `{e.name} v{e.version}` were found in the database" + error = HTTPException(404, detail=detail) + log.error(detail, error_message=detail, tag="package_not_found_db") + + raise error + except PackageAlreadyReported as e: + detail = ( + f"Only one version of a package may be reported at a time " + f"(`{e.name}@{e.reported_version}` was already reported)" + ) + error = HTTPException(409, detail=detail) + log.error(detail, error_message=error.detail, tag="multiple_versions_prohibited") + + raise error inspector_url = _validate_inspector_url(name, version, body.inspector_url, scan.inspector_url) _validate_additional_information(body, scan) @@ -233,11 +242,7 @@ def report_package( httpx_client.post(f"{mainframe_settings.reporter_url}/report/{name}", json=jsonable_encoder(report)) - with session.begin(): - scan.reported_by = auth.subject - scan.reported_at = dt.datetime.now(dt.timezone.utc) - - session.close() + database.mark_reported(scan=scan, subject=auth.subject) log.info( "Sent report", From a5263c71f4eb59401a68bfaea0275f223675776e Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 19:59:33 -0500 Subject: [PATCH 11/16] Rewrite report endpoint tests --- tests/test_report.py | 243 ++++++++++++++++++++----------------------- 1 file changed, 114 insertions(+), 129 deletions(-) diff --git a/tests/test_report.py b/tests/test_report.py index fde95f05..30338f38 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -1,5 +1,4 @@ from datetime import datetime, timedelta -from copy import deepcopy from typing import Optional from unittest.mock import MagicMock @@ -7,11 +6,12 @@ import pytest from fastapi import HTTPException from fastapi.encoders import jsonable_encoder -from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker from mainframe.endpoints.report import ( - _lookup_package, # pyright: ignore [reportPrivateUsage] + PackageAlreadyReported, + PackageNotFound, + validate_package, + get_reported_version, ) from mainframe.endpoints.report import ( _validate_additional_information, # pyright: ignore [reportPrivateUsage] @@ -26,95 +26,98 @@ from mainframe.json_web_token import AuthenticationData from mainframe.models.orm import DownloadURL, Rule, Scan, Status from mainframe.models.schemas import ( - EmailReport, ObservationKind, ObservationReport, ReportPackageBody, ) +from tests.conftest import MockDatabase -@pytest.mark.parametrize( - "body,url,expected", - [ - ( - ReportPackageBody( - name="c", - version="1.0.0", - recipient=None, - inspector_url=None, - additional_information="this package is bad", - use_email=True, - ), - "/report/email", - EmailReport( - name="c", - version="1.0.0", - rules_matched=["rule 1", "rule 2"], - inspector_url="test inspector url", - additional_information="this package is bad", - ), - ), - ( - ReportPackageBody( - name="c", - version="1.0.0", - recipient=None, - inspector_url=None, - additional_information="this package is bad", - ), - "/report/c", - ObservationReport( - kind=ObservationKind.Malware, - summary="this package is bad", - inspector_url="test inspector url", - extra=dict(yara_rules=["rule 1", "rule 2"]), - ), - ), - ], -) -def test_report( - sm: sessionmaker[Session], - db_session: Session, - auth: AuthenticationData, - body: ReportPackageBody, - url: str, - expected: EmailReport | ObservationReport, -): - scan = Scan( - name="c", +def test_get_reported_version(): + scan1 = Scan( + name="package1", + version="1.0.0", + reported_at=datetime.now(), + ) + + scan2 = Scan( + name="package1", + version="1.0.1", + reported_at=None, + ) + + scans = [scan1, scan2] + + assert get_reported_version(scans) == scan1 + + +def test_get_no_reported_version(): + scan1 = Scan( + name="package1", + version="1.0.0", + reported_at=None, + ) + + scan2 = Scan( + name="package1", + version="1.0.1", + reported_at=None, + ) + + scans = [scan1, scan2] + + assert get_reported_version(scans) is None + + +def test_validate_package(): + scan1 = Scan( + name="package1", version="1.0.0", status=Status.FINISHED, - score=10, - inspector_url="test inspector url", - rules=[Rule(name="rule 1"), Rule(name="rule 2")], - download_urls=[DownloadURL(url="test download url")], - queued_at=datetime.now() - timedelta(seconds=60), queued_by="remmy", - pending_at=datetime.now() - timedelta(seconds=30), - pending_by="remmy", - finished_at=datetime.now(), - finished_by="remmy", + queued_at=datetime.now(), reported_at=None, - reported_by=None, - fail_reason=None, - commit_hash="test commit hash", ) - with db_session.begin(): - db_session.add(scan) + assert validate_package("package1", "1.0.0", [scan1]) == scan1 - mock_httpx_client = MagicMock() - report_package(body, sm(), auth, mock_httpx_client) +def test_validate_package_not_found(): + scan1 = Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now(), + reported_at=None, + ) + + with pytest.raises(PackageNotFound): + validate_package("package2", "1.0.0", [scan1]) - mock_httpx_client.post.assert_called_once_with(url, json=jsonable_encoder(expected)) - with sm() as sess, sess.begin(): - s = sess.scalar(select(Scan).where(Scan.name == "c").where(Scan.version == "1.0.0")) +def test_validate_package_already_reported(): + scan1 = Scan( + name="package1", + version="1.0.0", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now(), + reported_at=None, + ) + scan2 = Scan( + name="package1", + version="1.0.1", + status=Status.FINISHED, + queued_by="remmy", + queued_at=datetime.now(), + reported_at=datetime.now(), + ) - assert s is not None - assert s.reported_by == auth.subject - assert s.reported_at is not None + with pytest.raises(PackageAlreadyReported) as e: + validate_package("package1", "1.0.0", [scan1, scan2]) + + assert (e.value.name, e.value.reported_version) == ("package1", "1.0.1") def test_report_package_not_on_pypi(): @@ -123,16 +126,11 @@ def test_report_package_not_on_pypi(): with pytest.raises(HTTPException) as e: _validate_pypi("c", "1.0.0", mock_httpx_client) - assert e.value.status_code == 404 - -def test_report_unscanned_package(db_session: Session): - with pytest.raises(HTTPException) as e: - _lookup_package("c", "1.0.0", db_session) assert e.value.status_code == 404 -def test_report_invalid_version(db_session: Session): +def test_report(auth: AuthenticationData, mock_database: MockDatabase): scan = Scan( name="c", version="1.0.0", @@ -145,19 +143,39 @@ def test_report_invalid_version(db_session: Session): queued_by="remmy", pending_at=datetime.now() - timedelta(seconds=30), pending_by="remmy", - finished_at=datetime.now() - timedelta(seconds=10), + finished_at=datetime.now(), finished_by="remmy", reported_at=None, - reported_by="remmy", + reported_by=None, fail_reason=None, commit_hash="test commit hash", ) - with db_session.begin(): - db_session.add(scan) - with pytest.raises(HTTPException) as e: - _lookup_package("c", "2.0.0", db_session) - assert e.value.status_code == 404 + mock_database.add(scan) + + body = ReportPackageBody( + name="c", + version="1.0.0", + recipient=None, + inspector_url=None, + additional_information="this package is bad", + ) + + expected = ObservationReport( + kind=ObservationKind.Malware, + summary="this package is bad", + inspector_url="test inspector url", + extra=dict(yara_rules=["rule 1", "rule 2"]), + ) + + mock_httpx_client = MagicMock() + + report_package(body, mock_database, auth, mock_httpx_client) + + mock_httpx_client.post.assert_called_once_with("/report/c", json=jsonable_encoder(expected)) + + assert scan.reported_by is auth.subject + assert scan.reported_at is not None def test_report_missing_inspector_url(): @@ -247,9 +265,9 @@ def test_report_missing_additional_information(body: ReportPackageBody, scan: Sc @pytest.mark.parametrize( - ("scans", "name", "version", "expected_status_code"), + ("scans", "name", "version", "expected_exception"), [ - ([], "a", "1.0.0", 404), + ([], "a", "1.0.0", PackageNotFound), ( [ Scan( @@ -293,7 +311,7 @@ def test_report_missing_additional_information(body: ReportPackageBody, scan: Sc ], "c", "1.0.1", - 409, + PackageAlreadyReported, ), ( [ @@ -338,45 +356,12 @@ def test_report_missing_additional_information(body: ReportPackageBody, scan: Sc ], "c", "2.0.0", - 409, + PackageAlreadyReported, ), ], ) def test_report_lookup_package_validation( - db_session: Session, scans: list[Scan], name: str, version: str, expected_status_code: int + scans: list[Scan], name: str, version: str, expected_exception: type[Exception] ): - with db_session.begin(): - db_session.add_all(deepcopy(scans)) - - with pytest.raises(HTTPException) as e: - _lookup_package(name, version, db_session) - assert e.value.status_code == expected_status_code - - -def test_report_lookup_package(db_session: Session): - scan = Scan( - name="c", - version="1.0.0", - status=Status.FINISHED, - score=0, - inspector_url=None, - rules=[], - download_urls=[], - queued_at=datetime.now() - timedelta(seconds=60), - queued_by="remmy", - pending_at=datetime.now() - timedelta(seconds=30), - pending_by="remmy", - finished_at=datetime.now() - timedelta(seconds=10), - finished_by="remmy", - reported_at=None, - reported_by=None, - fail_reason=None, - commit_hash="test commit hash", - ) - - with db_session.begin(): - db_session.add(scan) - - res = _lookup_package("c", "1.0.0", db_session) - - assert res == scan + with pytest.raises(expected_exception): + validate_package(name, version, scans) From 49c56fcca12268fa582312077bc1c6c5267729cf Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 20:05:27 -0500 Subject: [PATCH 12/16] Fix custom exception import location --- tests/test_report.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_report.py b/tests/test_report.py index 30338f38..fdd83061 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -7,9 +7,8 @@ from fastapi import HTTPException from fastapi.encoders import jsonable_encoder +from mainframe.custom_exceptions import PackageAlreadyReported, PackageNotFound from mainframe.endpoints.report import ( - PackageAlreadyReported, - PackageNotFound, validate_package, get_reported_version, ) From fa7a8604bb5020fc9b0afafb31bbf71597b34ab9 Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 20:24:26 -0500 Subject: [PATCH 13/16] Add test for package not found --- tests/test_report.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_report.py b/tests/test_report.py index fdd83061..c2ed9750 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -128,7 +128,19 @@ def test_report_package_not_on_pypi(): assert e.value.status_code == 404 +def test_report_package_not_found(auth: AuthenticationData, mock_database: MockDatabase): + body = ReportPackageBody( + name="this-package-does-not-exist", + version="1.0.0", + recipient=None, + inspector_url=None, + additional_information="this package is bad", + ) + + with pytest.raises(HTTPException) as e: + report_package(body, mock_database, auth, MagicMock()) + assert e.value.status_code == 404 def test_report(auth: AuthenticationData, mock_database: MockDatabase): scan = Scan( name="c", From f6c607a60937b762d2ca478fe57e661894832225 Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Sat, 5 Oct 2024 20:24:51 -0500 Subject: [PATCH 14/16] Add test for package already reported --- tests/test_report.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_report.py b/tests/test_report.py index c2ed9750..c06939b6 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -128,6 +128,7 @@ def test_report_package_not_on_pypi(): assert e.value.status_code == 404 + def test_report_package_not_found(auth: AuthenticationData, mock_database: MockDatabase): body = ReportPackageBody( name="this-package-does-not-exist", @@ -141,6 +142,46 @@ def test_report_package_not_found(auth: AuthenticationData, mock_database: MockD report_package(body, mock_database, auth, MagicMock()) assert e.value.status_code == 404 + + +@pytest.mark.parametrize("version", ["1.0.0", "1.0.1"]) +def test_report_package_already_reported(auth: AuthenticationData, mock_database: MockDatabase, version: str): + scan = Scan( + name="c", + version="1.0.0", + status=Status.FINISHED, + score=10, + inspector_url="test inspector url", + rules=[Rule(name="rule 1"), Rule(name="rule 2")], + download_urls=[DownloadURL(url="test download url")], + queued_at=datetime.now() - timedelta(seconds=60), + queued_by="remmy", + pending_at=datetime.now() - timedelta(seconds=30), + pending_by="remmy", + finished_at=datetime.now() - timedelta(seconds=15), + finished_by="remmy", + reported_at=datetime.now(), + reported_by="fishy", + fail_reason=None, + commit_hash="test commit hash", + ) + + mock_database.add(scan) + + body = ReportPackageBody( + name="c", + version=version, + recipient=None, + inspector_url=None, + additional_information="this package is bad", + ) + + with pytest.raises(HTTPException) as e: + report_package(body, mock_database, auth, MagicMock()) + + assert e.value.status_code == 409 + + def test_report(auth: AuthenticationData, mock_database: MockDatabase): scan = Scan( name="c", From db1df6c3ded019e20189af2d25e80b8a517142cc Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Wed, 30 Oct 2024 16:30:18 -0500 Subject: [PATCH 15/16] Fix incomplete docstring --- src/mainframe/endpoints/report.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mainframe/endpoints/report.py b/src/mainframe/endpoints/report.py index e637ac14..3f49229d 100644 --- a/src/mainframe/endpoints/report.py +++ b/src/mainframe/endpoints/report.py @@ -57,7 +57,7 @@ def validate_package(name: str, version: str, scans: Sequence[Scan]) -> Scan: `Scan`: The validated `Scan` object Raises: - PackageNotFound: The given name and version combination + PackageNotFound: The given name and version combination was not found PackageAlreadyReported: The package was already reported """ From 0a7283d58ef932ceecb31d7124cd0c7facb127cf Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Wed, 30 Oct 2024 18:18:12 -0500 Subject: [PATCH 16/16] Combine get_reported_version into validate_package --- src/mainframe/endpoints/report.py | 32 ++++++------------------ tests/test_report.py | 41 +------------------------------ 2 files changed, 8 insertions(+), 65 deletions(-) diff --git a/src/mainframe/endpoints/report.py b/src/mainframe/endpoints/report.py index fa3b7a56..40eee632 100644 --- a/src/mainframe/endpoints/report.py +++ b/src/mainframe/endpoints/report.py @@ -27,22 +27,6 @@ router = APIRouter(tags=["report"]) -def get_reported_version(scans: Sequence[Scan]) -> Optional[Scan]: - """ - Get the version of this scan that was reported. - - Returns: - `Scan`: The scan record that was reported - `None`: No versions of this package were reported - """ - - for scan in scans: - if scan.reported_at is not None: - return scan - - return None - - def validate_package(name: str, version: str, scans: Sequence[Scan]) -> Scan: """ Checks if the package is valid according to our database. @@ -62,17 +46,15 @@ def validate_package(name: str, version: str, scans: Sequence[Scan]) -> Scan: PackageAlreadyReported: The package was already reported """ - if not scans: - raise PackageNotFound(name=name, version=version) - - if scan := get_reported_version(scans): - raise PackageAlreadyReported(name=scan.name, reported_version=scan.version) + for scan in scans: + if scan.reported_at is not None: + raise PackageAlreadyReported(name=scan.name, reported_version=scan.version) - scan = next((s for s in scans if (s.name, s.version) == (name, version)), None) - if scan is None: - raise PackageNotFound(name=name, version=version) + for scan in scans: + if (scan.name, scan.version) == (name, version): + return scan - return scan + raise PackageNotFound(name=name, version=version) def _validate_inspector_url(name: str, version: str, body_url: Optional[str], scan_url: Optional[str]) -> str: diff --git a/tests/test_report.py b/tests/test_report.py index 1a4f6ecd..20bc9239 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -8,10 +8,7 @@ from fastapi.encoders import jsonable_encoder from mainframe.custom_exceptions import PackageAlreadyReported, PackageNotFound -from mainframe.endpoints.report import ( - validate_package, - get_reported_version, -) +from mainframe.endpoints.report import validate_package from mainframe.endpoints.report import ( _validate_inspector_url, # pyright: ignore [reportPrivateUsage] ) @@ -29,42 +26,6 @@ from tests.conftest import MockDatabase -def test_get_reported_version(): - scan1 = Scan( - name="package1", - version="1.0.0", - reported_at=datetime.now(), - ) - - scan2 = Scan( - name="package1", - version="1.0.1", - reported_at=None, - ) - - scans = [scan1, scan2] - - assert get_reported_version(scans) == scan1 - - -def test_get_no_reported_version(): - scan1 = Scan( - name="package1", - version="1.0.0", - reported_at=None, - ) - - scan2 = Scan( - name="package1", - version="1.0.1", - reported_at=None, - ) - - scans = [scan1, scan2] - - assert get_reported_version(scans) is None - - def test_validate_package(): scan1 = Scan( name="package1",