Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract report logic #323

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
db: marks tests as database tests. requires a database container and may be slow.
13 changes: 13 additions & 0 deletions src/mainframe/custom_exceptions.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filename exceptions.py should be better i think

Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
from fastapi import HTTPException, status
from dataclasses import dataclass


@dataclass
class PackageNotFound(Exception):
name: str
version: str


@dataclass
class PackageAlreadyReported(Exception):
name: str
reported_version: str


class BadCredentialsException(HTTPException):
Expand Down
89 changes: 86 additions & 3 deletions src/mainframe/database.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Generator
from collections.abc import Sequence
import datetime as dt
from functools import cache
from typing import Generator, Optional

from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy import create_engine, select
from sqlalchemy import orm
from sqlalchemy.orm import Session, joinedload, sessionmaker

from mainframe.constants import mainframe_settings
from mainframe.models.orm import Scan
from typing import Protocol

# pool_size and max_overflow are set to their default values. There is never
# enough load to justify increasing them.
Expand All @@ -21,3 +27,80 @@ def get_db() -> Generator[Session, None, None]:
yield session
finally:
session.close()


class StorageProtocol(Protocol):
def lookup_packages(
self, name: Optional[str] = None, version: Optional[str] = None, since: Optional[dt.datetime] = None
) -> Sequence[Scan]:
"""
Lookup information on scanned packages based on name, version, or time
scanned. If multiple packages are returned, they are ordered with the most
recently queued package first.

Args:
since: A int representing a Unix timestamp representing when to begin the search from.
name: The name of the package.
version: The version of the package.
session: DB session.

Exceptions:
ValueError: Invalid parameter combination was passed. See below.

Returns:
Sequence of `Scan`s, representing the results of the query

Only certain combinations of parameters are allowed. A query is valid if any of the following combinations are used:
- `name` and `version`: Return the package with name `name` and version `version`, if it exists.
- `name` and `since`: Find all packages with name `name` since `since`.
- `since`: Find all packages since `since`.
- `name`: Find all packages with name `name`.
All other combinations are disallowed.

In more formal terms, a query is valid
iff `((name and not since) or (not version and since))`
where a given variable name means that query parameter was passed. Equivalently, a request is invalid
iff `(not (name or since) or (version and since))`
"""
...

def mark_reported(self, *, scan: Scan, subject: str) -> None:
"""Mark the given `Scan` record as reported by `subject`."""
...


class DatabaseStorage(StorageProtocol):
def __init__(self, sessionmaker: orm.sessionmaker[Session]):
self.sessionmaker = sessionmaker

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each instance should have one session that commits changes when the DatabaseStorage gets destroyed. This is the 'Unit of Work" pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally thought of using one DatabaseStorage instance across the whole program as a singleton, though we can also create and destroy them for each endpoint like you're suggesting. What are the advantages to this method, rather than having each individual method of this class manage it's own session and unit of work?

def get_session(self) -> Session:
return self.sessionmaker()

def lookup_packages(
self, name: Optional[str] = None, version: Optional[str] = None, since: Optional[dt.datetime] = None
) -> Sequence[Scan]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One suggestion to consider: expanding the parameters of this function to allow querying by package status or arbitrary filters. This would enable the function to be reused across different endpoints, not just the /report endpoint. More importantly, it would streamline things by reducing redundancy.
You could potentially absorb the functionality of get_reported_version, and possibly even validate_package.
As it stands;

  • get_reported_version serves as a helper function for parsing through a sequence of scans and verifying if they're reported - and raising an error if so. That could be absorbed by just directly querying (or better yet, filtering out) packages whose reported_at columns are null.

  • validate_package serves as a function for a validating if the given sequence of packages are within the given name and version parameters. This may be better placed as the access layer's responsibility (to ensure the right package whose given parameters are returned), or just outright absorbed into report_package.

What do you think?

query = (
select(Scan).order_by(Scan.queued_at.desc()).options(joinedload(Scan.rules), joinedload(Scan.download_urls))
)

if name:
query = query.where(Scan.name == name)
if version:
query = query.where(Scan.version == version)
if since:
query = query.where(Scan.finished_at >= since)

session = self.get_session()
with session, session.begin():
return session.scalars(query).unique().all()

def mark_reported(self, *, scan: Scan, subject: str) -> None:
session = self.get_session()
with session, session.begin():
scan.reported_by = subject
scan.reported_at = dt.datetime.now()


@cache
def get_storage() -> DatabaseStorage:
return DatabaseStorage(sessionmaker)
105 changes: 55 additions & 50 deletions src/mainframe/endpoints/report.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import datetime as dt
from collections.abc import Sequence
from typing import Annotated, Optional

import httpx
import structlog
from fastapi import APIRouter, Depends, HTTPException
from fastapi.encoders import jsonable_encoder
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload

from mainframe.constants import mainframe_settings
from mainframe.database import get_db
from mainframe.custom_exceptions import PackageNotFound, PackageAlreadyReported
from mainframe.database import StorageProtocol, get_storage
from mainframe.dependencies import get_httpx_client, validate_token
from mainframe.json_web_token import AuthenticationData
from mainframe.models.orm import Scan
Expand All @@ -27,57 +26,50 @@
router = APIRouter(tags=["report"])


def _lookup_package(name: str, version: str, session: Session) -> Scan:
def get_reported_version(scans: Sequence[Scan]) -> Optional[Scan]:
"""
Checks if the package is valid according to our database.
Get the version of this scan that was reported.

Returns:
True if the package exists in the database.
`Scan`: The scan record that was reported
`None`: No versions of this package were reported
"""

Raises:
HTTPException: 404 Not Found if the name was not found in the database,
or the specified name and version was not found in the database. 409
Conflict if another version of the same package has already been
reported.
for scan in scans:
if scan.reported_at is not None:
return scan

return None


def validate_package(name: str, version: str, scans: Sequence[Scan]) -> Scan:
"""
Checks if the package is valid according to our database.
A package is considered valid if there exists a scan with the given name
and version, and that no other versions have been reported.

log = logger.bind(package={"name": name, "version": version})
Arguments:
name: The name of the package to validate
version: The version of the package to validate
scans: The sequence of Scan records in the database where name=name

Returns:
`Scan`: The validated `Scan` object

query = select(Scan).where(Scan.name == name).options(joinedload(Scan.rules))
with session.begin():
scans = session.scalars(query).unique().all()
Raises:
PackageNotFound: The given name and version combination
PackageAlreadyReported: The package was already reported
"""

if not scans:
error = HTTPException(404, detail=f"No records for package `{name}` were found in the database")
log.error(
f"No records for package {name} found in database", error_message=error.detail, tag="package_not_found_db"
)
raise error
raise PackageNotFound(name=name, version=version)

for scan in scans:
if scan.reported_at is not None:
error = HTTPException(
409,
detail=(
f"Only one version of a package may be reported at a time. "
f"(`{scan.name}@{scan.version}` was already reported)"
),
)
log.error(
"Only one version of a package allowed to be reported at a time",
error_message=error.detail,
tag="multiple_versions_prohibited",
)
raise error
if scan := get_reported_version(scans):
raise PackageAlreadyReported(name=scan.name, reported_version=scan.version)

with session.begin():
scan = session.scalar(query.where(Scan.version == version))
scan = next((s for s in scans if (s.name, s.version) == (name, version)), None)
if scan is None:
error = HTTPException(
404, detail=f"Package `{name}` has records in the database, but none with version `{version}`"
)
log.error(f"No version {version} for package {name} in database", tag="invalid_version")
raise error
raise PackageNotFound(name=name, version=version)

return scan

Expand Down Expand Up @@ -156,7 +148,7 @@ def _validate_pypi(name: str, version: str, http_client: httpx.Client):
)
def report_package(
body: ReportPackageBody,
session: Annotated[Session, Depends(get_db)],
database: Annotated[StorageProtocol, Depends(get_storage)],
auth: Annotated[AuthenticationData, Depends(validate_token)],
httpx_client: Annotated[httpx.Client, Depends(get_httpx_client)],
):
Expand Down Expand Up @@ -198,7 +190,24 @@ def report_package(
log = logger.bind(package={"name": name, "version": version})

# Check our database first to avoid unnecessarily using PyPI API.
scan = _lookup_package(name, version, session)
try:
scans = database.lookup_packages(name)
scan = validate_package(name, version, scans)
except PackageNotFound as e:
detail = f"No records for package `{e.name} v{e.version}` were found in the database"
error = HTTPException(404, detail=detail)
log.error(detail, error_message=detail, tag="package_not_found_db")

raise error
except PackageAlreadyReported as e:
detail = (
f"Only one version of a package may be reported at a time "
f"(`{e.name}@{e.reported_version}` was already reported)"
)
error = HTTPException(409, detail=detail)
log.error(detail, error_message=error.detail, tag="multiple_versions_prohibited")

raise error
inspector_url = _validate_inspector_url(name, version, body.inspector_url, scan.inspector_url)
_validate_additional_information(body, scan)

Expand Down Expand Up @@ -233,11 +242,7 @@ def report_package(

httpx_client.post(f"{mainframe_settings.reporter_url}/report/{name}", json=jsonable_encoder(report))

with session.begin():
scan.reported_by = auth.subject
scan.reported_at = dt.datetime.now(dt.timezone.utc)

session.close()
database.mark_reported(scan=scan, subject=auth.subject)

log.info(
"Sent report",
Expand Down
53 changes: 51 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections.abc import Sequence, Generator
import logging
from copy import deepcopy
from datetime import datetime, timedelta
from typing import Generator
from typing import Optional
from unittest.mock import MagicMock

import httpx
Expand All @@ -12,6 +13,7 @@
from sqlalchemy import Engine, create_engine, text
from sqlalchemy.orm import Session, sessionmaker

from mainframe.database import DatabaseStorage, StorageProtocol
from mainframe.json_web_token import AuthenticationData
from mainframe.models.orm import Base, Scan
from mainframe.rules import Rules
Expand All @@ -22,6 +24,39 @@
logger = logging.getLogger(__file__)


class MockDatabase(StorageProtocol):
def __init__(self) -> None:
self.db: list[Scan] = []

def add(self, scan: Scan) -> None:
self.db.append(scan)

def lookup_packages(
self, name: Optional[str] = None, version: Optional[str] = None, since: Optional[datetime] = None
) -> Sequence[Scan]:
v: list[Scan] = []
for scan in self.db:
if (
(scan.name == name)
or (scan.version == version)
or (scan.queued_at and since and scan.queued_at >= since)
):
v.append(scan)

return v

def mark_reported(self, *, scan: Scan, subject: str) -> None:
for s in self.db:
if s.scan_id == scan.scan_id:
scan.reported_by = subject
scan.reported_at = datetime.now()


@pytest.fixture
def mock_database() -> MockDatabase:
return MockDatabase()


@pytest.fixture(scope="session")
def sm(engine: Engine) -> sessionmaker[Session]:
return sessionmaker(bind=engine, expire_on_commit=False, autobegin=False)
Expand Down Expand Up @@ -50,12 +85,26 @@ def engine(superuser_engine: Engine) -> Engine:
return create_engine("postgresql+psycopg2://dragonfly:postgres@db:5432/dragonfly", pool_size=5, max_overflow=10)


@pytest.fixture
def storage(
superuser_engine: Engine, test_data: list[Scan], sm: sessionmaker[Session]
) -> Generator[DatabaseStorage, None, None]:
Base.metadata.drop_all(superuser_engine)
Base.metadata.create_all(superuser_engine)
with sm() as s, s.begin():
s.add_all(deepcopy(test_data))

yield DatabaseStorage(sm)

Base.metadata.drop_all(superuser_engine)


@pytest.fixture(params=data, scope="session")
def test_data(request: pytest.FixtureRequest) -> list[Scan]:
return request.param


@pytest.fixture(autouse=True)
@pytest.fixture
def db_session(
superuser_engine: Engine, test_data: list[Scan], sm: sessionmaker[Session]
) -> Generator[Session, None, None]:
Expand Down
Loading
Loading