Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace a few more Into<Purl>'s hither and yon. #86

Merged
merged 3 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions common/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,45 @@ use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use tempfile::TempDir;

#[derive(Copy, Clone)]
pub enum Transactional<'db> {
pub enum Transactional {
None,
Some(&'db DatabaseTransaction),
Some(DatabaseTransaction),
}

impl Transactional {
pub async fn commit(self) -> Result<(), DbErr> {
match self {
Transactional::None => {}
Transactional::Some(inner) => {
inner.commit().await?;
}
}

Ok(())
}
}

impl AsRef<Transactional> for Transactional {
fn as_ref(&self) -> &Transactional {
self
}
}

impl AsRef<Transactional> for () {
fn as_ref(&self) -> &Transactional {
&Transactional::None
}
}

/*
impl<'db> From<&'db DatabaseTransaction> for Transactional<'db> {
fn from(inner: &'db DatabaseTransaction) -> Self {
Self::Some(inner)
}
}

*/

#[derive(Clone)]
pub enum ConnectionOrTransaction<'db> {
Connection(&'db DatabaseConnection),
Expand Down
85 changes: 55 additions & 30 deletions graph/src/graph/advisory/advisory_vulnerability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,17 @@ impl<'g> From<(&AdvisoryContext<'g>, entity::advisory_vulnerability::Model)>
}

impl<'g> AdvisoryVulnerabilityContext<'g> {
pub async fn get_fixed_package_version(
pub async fn get_fixed_package_version<TX: AsRef<Transactional>>(
&self,
purl: Purl,
tx: Transactional<'_>,
tx: TX,
) -> Result<Option<FixedPackageVersionContext>, Error> {
if let Some(package_version) = self.advisory.graph.get_package_version(purl, tx).await? {
if let Some(package_version) = self
.advisory
.graph
.get_package_version(purl, tx.as_ref())
.await?
{
Ok(entity::fixed_package_version::Entity::find()
.filter(
entity::fixed_package_version::Column::AdvisoryId.eq(self.advisory.advisory.id),
Expand All @@ -42,20 +47,25 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {
entity::fixed_package_version::Column::PackageVersionId
.eq(package_version.package_version.id),
)
.one(&self.advisory.graph.connection(tx))
.one(&self.advisory.graph.connection(tx.as_ref()))
.await?
.map(|affected| (self, affected).into()))
} else {
Ok(None)
}
}

pub async fn get_not_affected_package_version(
pub async fn get_not_affected_package_version<TX: AsRef<Transactional>>(
&self,
purl: Purl,
tx: Transactional<'_>,
tx: TX,
) -> Result<Option<NotAffectedPackageVersionContext>, Error> {
if let Some(package_version) = self.advisory.graph.get_package_version(purl, tx).await? {
if let Some(package_version) = self
.advisory
.graph
.get_package_version(purl, tx.as_ref())
.await?
{
Ok(entity::not_affected_package_version::Entity::find()
.filter(
entity::not_affected_package_version::Column::AdvisoryId
Expand All @@ -65,27 +75,25 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {
entity::not_affected_package_version::Column::PackageVersionId
.eq(package_version.package_version.id),
)
.one(&self.advisory.graph.connection(tx))
.one(&self.advisory.graph.connection(tx.as_ref()))
.await?
.map(|not_affected_package_version| (self, not_affected_package_version).into()))
} else {
Ok(None)
}
}

pub async fn get_affected_package_range<P: Into<Purl>>(
pub async fn get_affected_package_range<TX: AsRef<Transactional>>(
&self,
pkg: P,
purl: Purl,
start: &str,
end: &str,
tx: Transactional<'_>,
tx: TX,
) -> Result<Option<AffectedPackageVersionRangeContext>, Error> {
let purl = pkg.into();

if let Some(package_version_range) = self
.advisory
.graph
.get_package_version_range(purl.clone(), start, end, tx)
.get_package_version_range(purl.clone(), start, end, tx.as_ref())
.await?
{
Ok(entity::affected_package_version_range::Entity::find()
Expand All @@ -97,27 +105,31 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {
entity::affected_package_version_range::Column::PackageVersionRangeId
.eq(package_version_range.package_version_range.id),
)
.one(&self.advisory.graph.connection(tx))
.one(&self.advisory.graph.connection(tx.as_ref()))
.await?
.map(|affected| (self, affected).into()))
} else {
Ok(None)
}
}

pub async fn ingest_not_affected_package_version(
pub async fn ingest_not_affected_package_version<TX: AsRef<Transactional>>(
&self,
purl: Purl,
tx: Transactional<'_>,
tx: TX,
) -> Result<NotAffectedPackageVersionContext, Error> {
if let Some(found) = self
.get_not_affected_package_version(purl.clone(), tx)
.get_not_affected_package_version(purl.clone(), tx.as_ref())
.await?
{
return Ok(found);
}

let package_version = self.advisory.graph.ingest_package_version(purl, tx).await?;
let package_version = self
.advisory
.graph
.ingest_package_version(purl, tx.as_ref())
.await?;

let entity = entity::not_affected_package_version::ActiveModel {
id: Default::default(),
Expand All @@ -127,21 +139,30 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {

Ok((
self,
entity.insert(&self.advisory.graph.connection(tx)).await?,
entity
.insert(&self.advisory.graph.connection(tx.as_ref()))
.await?,
)
.into())
}

pub async fn ingest_fixed_package_version(
pub async fn ingest_fixed_package_version<TX: AsRef<Transactional>>(
&self,
purl: Purl,
tx: Transactional<'_>,
tx: TX,
) -> Result<FixedPackageVersionContext, Error> {
if let Some(found) = self.get_fixed_package_version(purl.clone(), tx).await? {
if let Some(found) = self
.get_fixed_package_version(purl.clone(), tx.as_ref())
.await?
{
return Ok(found);
}

let package_version = self.advisory.graph.ingest_package_version(purl, tx).await?;
let package_version = self
.advisory
.graph
.ingest_package_version(purl, tx.as_ref())
.await?;

let entity = entity::fixed_package_version::ActiveModel {
id: Default::default(),
Expand All @@ -151,20 +172,22 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {

Ok((
self,
entity.insert(&self.advisory.graph.connection(tx)).await?,
entity
.insert(&self.advisory.graph.connection(tx.as_ref()))
.await?,
)
.into())
}

pub async fn ingest_affected_package_range(
pub async fn ingest_affected_package_range<TX: AsRef<Transactional>>(
&self,
purl: Purl,
start: &str,
end: &str,
tx: Transactional<'_>,
tx: TX,
) -> Result<AffectedPackageVersionRangeContext, Error> {
if let Some(found) = self
.get_affected_package_range(purl.clone(), start, end, tx)
.get_affected_package_range(purl.clone(), start, end, tx.as_ref())
.await?
{
return Ok(found);
Expand All @@ -173,7 +196,7 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {
let package_version_range = self
.advisory
.graph
.ingest_package_version_range(purl, start, end, tx)
.ingest_package_version_range(purl, start, end, tx.as_ref())
.await?;

let entity = entity::affected_package_version_range::ActiveModel {
Expand All @@ -184,7 +207,9 @@ impl<'g> AdvisoryVulnerabilityContext<'g> {

Ok((
self,
entity.insert(&self.advisory.graph.connection(tx)).await?,
entity
.insert(&self.advisory.graph.connection(tx.as_ref()))
.await?,
)
.into())
}
Expand Down
2 changes: 1 addition & 1 deletion graph/src/graph/advisory/csaf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl<'g> AdvisoryContext<'g> {
let mut entity: entity::advisory::ActiveModel = self.advisory.clone().into();
entity.title = Set(Some(csaf.document.title.clone().to_string()));
entity
.update(&self.graph.connection(Transactional::None))
.update(&self.graph.connection(&Transactional::None))
.await?;

// Ingest vulnerabilities
Expand Down
Loading