Skip to content

Commit

Permalink
Replace a few more Into<Purl>'s hither and yon.
Browse files Browse the repository at this point in the history
Add type constraints to make using transactions a bit easier.
A `Transactional::Some(_)` owns its transaction, reducing the things
you need to hang onto.
A `Transactional::Some(_)` is provided ready-to-eat from Graph::transaction().
A `()` can be used for non-transactional as shorthand for `Transactional::None`.
  • Loading branch information
Bob McWhirter committed Mar 15, 2024
1 parent 706b9d8 commit ce12c9e
Show file tree
Hide file tree
Showing 14 changed files with 481 additions and 415 deletions.
33 changes: 30 additions & 3 deletions common/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,45 @@ use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use tempfile::TempDir;

#[derive(Copy, Clone)]
pub enum Transactional<'db> {
pub enum Transactional {
None,
Some(&'db DatabaseTransaction),
Some(DatabaseTransaction),
}

impl Transactional {
pub async fn commit(self) -> Result<(), DbErr> {
match self {
Transactional::None => {}
Transactional::Some(inner) => {
inner.commit().await?;
}
}

Ok(())
}
}

impl AsRef<Transactional> for Transactional {
fn as_ref(&self) -> &Transactional {
self
}
}

impl AsRef<Transactional> for () {
fn as_ref(&self) -> &Transactional {
&Transactional::None
}
}

/*
impl<'db> From<&'db DatabaseTransaction> for Transactional<'db> {
fn from(inner: &'db DatabaseTransaction) -> Self {
Self::Some(inner)
}
}
*/

#[derive(Clone)]
pub enum ConnectionOrTransaction<'db> {
Connection(&'db DatabaseConnection),
Expand Down
85 changes: 55 additions & 30 deletions graph/src/graph/advisory/advisory_vulnerability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,17 @@ impl<'g> From<(&AdvisoryContext<'g>, entity::advisory_vulnerability::Model)>
}

impl<'g> AdvisoryVulnerabilityContext<'g> {
pub async fn get_fixed_package_version(
pub async fn get_fixed_package_version<TX: AsRef<Transactional>>(
&self,
purl: Purl,
tx: Transactional<'_>,
tx: TX,
) -> Result<Option<FixedPackageVersionContext>, Error> {
if let Some(package_version) = self.advisory.graph.get_package_version(purl, tx).await? {
if let Some(package_version) = self
.advisory
.graph
.get_package_version(purl, tx.as_ref())
.await?
{
Ok(entity::fixed_package_version::Entity::find()
.filter(
entity::fixed_package_version::Column::AdvisoryId.eq(self.advisory.advisory.id),
Expand All @@ -42,20 +47,25 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {
entity::fixed_package_version::Column::PackageVersionId
.eq(package_version.package_version.id),
)
.one(&self.advisory.graph.connection(tx))
.one(&self.advisory.graph.connection(tx.as_ref()))
.await?
.map(|affected| (self, affected).into()))
} else {
Ok(None)
}
}

pub async fn get_not_affected_package_version(
pub async fn get_not_affected_package_version<TX: AsRef<Transactional>>(
&self,
purl: Purl,
tx: Transactional<'_>,
tx: TX,
) -> Result<Option<NotAffectedPackageVersionContext>, Error> {
if let Some(package_version) = self.advisory.graph.get_package_version(purl, tx).await? {
if let Some(package_version) = self
.advisory
.graph
.get_package_version(purl, tx.as_ref())
.await?
{
Ok(entity::not_affected_package_version::Entity::find()
.filter(
entity::not_affected_package_version::Column::AdvisoryId
Expand All @@ -65,27 +75,25 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {
entity::not_affected_package_version::Column::PackageVersionId
.eq(package_version.package_version.id),
)
.one(&self.advisory.graph.connection(tx))
.one(&self.advisory.graph.connection(tx.as_ref()))
.await?
.map(|not_affected_package_version| (self, not_affected_package_version).into()))
} else {
Ok(None)
}
}

pub async fn get_affected_package_range<P: Into<Purl>>(
pub async fn get_affected_package_range<TX: AsRef<Transactional>>(
&self,
pkg: P,
purl: Purl,
start: &str,
end: &str,
tx: Transactional<'_>,
tx: TX,
) -> Result<Option<AffectedPackageVersionRangeContext>, Error> {
let purl = pkg.into();

if let Some(package_version_range) = self
.advisory
.graph
.get_package_version_range(purl.clone(), start, end, tx)
.get_package_version_range(purl.clone(), start, end, tx.as_ref())
.await?
{
Ok(entity::affected_package_version_range::Entity::find()
Expand All @@ -97,27 +105,31 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {
entity::affected_package_version_range::Column::PackageVersionRangeId
.eq(package_version_range.package_version_range.id),
)
.one(&self.advisory.graph.connection(tx))
.one(&self.advisory.graph.connection(tx.as_ref()))
.await?
.map(|affected| (self, affected).into()))
} else {
Ok(None)
}
}

pub async fn ingest_not_affected_package_version(
pub async fn ingest_not_affected_package_version<TX: AsRef<Transactional>>(
&self,
purl: Purl,
tx: Transactional<'_>,
tx: TX,
) -> Result<NotAffectedPackageVersionContext, Error> {
if let Some(found) = self
.get_not_affected_package_version(purl.clone(), tx)
.get_not_affected_package_version(purl.clone(), tx.as_ref())
.await?
{
return Ok(found);
}

let package_version = self.advisory.graph.ingest_package_version(purl, tx).await?;
let package_version = self
.advisory
.graph
.ingest_package_version(purl, tx.as_ref())
.await?;

let entity = entity::not_affected_package_version::ActiveModel {
id: Default::default(),
Expand All @@ -127,21 +139,30 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {

Ok((
self,
entity.insert(&self.advisory.graph.connection(tx)).await?,
entity
.insert(&self.advisory.graph.connection(tx.as_ref()))
.await?,
)
.into())
}

pub async fn ingest_fixed_package_version(
pub async fn ingest_fixed_package_version<TX: AsRef<Transactional>>(
&self,
purl: Purl,
tx: Transactional<'_>,
tx: TX,
) -> Result<FixedPackageVersionContext, Error> {
if let Some(found) = self.get_fixed_package_version(purl.clone(), tx).await? {
if let Some(found) = self
.get_fixed_package_version(purl.clone(), tx.as_ref())
.await?
{
return Ok(found);
}

let package_version = self.advisory.graph.ingest_package_version(purl, tx).await?;
let package_version = self
.advisory
.graph
.ingest_package_version(purl, tx.as_ref())
.await?;

let entity = entity::fixed_package_version::ActiveModel {
id: Default::default(),
Expand All @@ -151,20 +172,22 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {

Ok((
self,
entity.insert(&self.advisory.graph.connection(tx)).await?,
entity
.insert(&self.advisory.graph.connection(tx.as_ref()))
.await?,
)
.into())
}

pub async fn ingest_affected_package_range(
pub async fn ingest_affected_package_range<TX: AsRef<Transactional>>(
&self,
purl: Purl,
start: &str,
end: &str,
tx: Transactional<'_>,
tx: TX,
) -> Result<AffectedPackageVersionRangeContext, Error> {
if let Some(found) = self
.get_affected_package_range(purl.clone(), start, end, tx)
.get_affected_package_range(purl.clone(), start, end, tx.as_ref())
.await?
{
return Ok(found);
Expand All @@ -173,7 +196,7 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {
let package_version_range = self
.advisory
.graph
.ingest_package_version_range(purl, start, end, tx)
.ingest_package_version_range(purl, start, end, tx.as_ref())
.await?;

let entity = entity::affected_package_version_range::ActiveModel {
Expand All @@ -184,7 +207,9 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {

Ok((
self,
entity.insert(&self.advisory.graph.connection(tx)).await?,
entity
.insert(&self.advisory.graph.connection(tx.as_ref()))
.await?,
)
.into())
}
Expand Down
2 changes: 1 addition & 1 deletion graph/src/graph/advisory/csaf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl<'g> AdvisoryContext<'g> {
let mut entity: entity::advisory::ActiveModel = self.advisory.clone().into();
entity.title = Set(Some(csaf.document.title.clone().to_string()));
entity
.update(&self.graph.connection(Transactional::None))
.update(&self.graph.connection(&Transactional::None))
.await?;

// Ingest vulnerabilities
Expand Down
Loading

0 comments on commit ce12c9e

Please sign in to comment.