diff --git a/common/src/db.rs b/common/src/db.rs index 991ea17c4..421becd0a 100644 --- a/common/src/db.rs +++ b/common/src/db.rs @@ -9,18 +9,45 @@ use std::ops::{Deref, DerefMut}; use std::sync::Arc; use tempfile::TempDir; -#[derive(Copy, Clone)] -pub enum Transactional<'db> { +pub enum Transactional { None, - Some(&'db DatabaseTransaction), + Some(DatabaseTransaction), } +impl Transactional { + pub async fn commit(self) -> Result<(), DbErr> { + match self { + Transactional::None => {} + Transactional::Some(inner) => { + inner.commit().await?; + } + } + + Ok(()) + } +} + +impl AsRef for Transactional { + fn as_ref(&self) -> &Transactional { + self + } +} + +impl AsRef for () { + fn as_ref(&self) -> &Transactional { + &Transactional::None + } +} + +/* impl<'db> From<&'db DatabaseTransaction> for Transactional<'db> { fn from(inner: &'db DatabaseTransaction) -> Self { Self::Some(inner) } } + */ + #[derive(Clone)] pub enum ConnectionOrTransaction<'db> { Connection(&'db DatabaseConnection), diff --git a/graph/src/graph/advisory/advisory_vulnerability.rs b/graph/src/graph/advisory/advisory_vulnerability.rs index 6152d941b..476018984 100644 --- a/graph/src/graph/advisory/advisory_vulnerability.rs +++ b/graph/src/graph/advisory/advisory_vulnerability.rs @@ -28,12 +28,17 @@ impl<'g> From<(&AdvisoryContext<'g>, entity::advisory_vulnerability::Model)> } impl<'g> AdvisoryVulnerabilityContext<'g> { - pub async fn get_fixed_package_version( + pub async fn get_fixed_package_version>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - if let Some(package_version) = self.advisory.graph.get_package_version(purl, tx).await? { + if let Some(package_version) = self + .advisory + .graph + .get_package_version(purl, tx.as_ref()) + .await? + { Ok(entity::fixed_package_version::Entity::find() .filter( entity::fixed_package_version::Column::AdvisoryId.eq(self.advisory.advisory.id), @@ -42,7 +47,7 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { entity::fixed_package_version::Column::PackageVersionId .eq(package_version.package_version.id), ) - .one(&self.advisory.graph.connection(tx)) + .one(&self.advisory.graph.connection(tx.as_ref())) .await? .map(|affected| (self, affected).into())) } else { @@ -50,12 +55,17 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { } } - pub async fn get_not_affected_package_version( + pub async fn get_not_affected_package_version>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - if let Some(package_version) = self.advisory.graph.get_package_version(purl, tx).await? { + if let Some(package_version) = self + .advisory + .graph + .get_package_version(purl, tx.as_ref()) + .await? + { Ok(entity::not_affected_package_version::Entity::find() .filter( entity::not_affected_package_version::Column::AdvisoryId @@ -65,7 +75,7 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { entity::not_affected_package_version::Column::PackageVersionId .eq(package_version.package_version.id), ) - .one(&self.advisory.graph.connection(tx)) + .one(&self.advisory.graph.connection(tx.as_ref())) .await? .map(|not_affected_package_version| (self, not_affected_package_version).into())) } else { @@ -73,19 +83,17 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { } } - pub async fn get_affected_package_range>( + pub async fn get_affected_package_range>( &self, - pkg: P, + purl: Purl, start: &str, end: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - let purl = pkg.into(); - if let Some(package_version_range) = self .advisory .graph - .get_package_version_range(purl.clone(), start, end, tx) + .get_package_version_range(purl.clone(), start, end, tx.as_ref()) .await? { Ok(entity::affected_package_version_range::Entity::find() @@ -97,7 +105,7 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { entity::affected_package_version_range::Column::PackageVersionRangeId .eq(package_version_range.package_version_range.id), ) - .one(&self.advisory.graph.connection(tx)) + .one(&self.advisory.graph.connection(tx.as_ref())) .await? .map(|affected| (self, affected).into())) } else { @@ -105,19 +113,23 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { } } - pub async fn ingest_not_affected_package_version( + pub async fn ingest_not_affected_package_version>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result { if let Some(found) = self - .get_not_affected_package_version(purl.clone(), tx) + .get_not_affected_package_version(purl.clone(), tx.as_ref()) .await? { return Ok(found); } - let package_version = self.advisory.graph.ingest_package_version(purl, tx).await?; + let package_version = self + .advisory + .graph + .ingest_package_version(purl, tx.as_ref()) + .await?; let entity = entity::not_affected_package_version::ActiveModel { id: Default::default(), @@ -127,21 +139,30 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { Ok(( self, - entity.insert(&self.advisory.graph.connection(tx)).await?, + entity + .insert(&self.advisory.graph.connection(tx.as_ref())) + .await?, ) .into()) } - pub async fn ingest_fixed_package_version( + pub async fn ingest_fixed_package_version>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result { - if let Some(found) = self.get_fixed_package_version(purl.clone(), tx).await? { + if let Some(found) = self + .get_fixed_package_version(purl.clone(), tx.as_ref()) + .await? + { return Ok(found); } - let package_version = self.advisory.graph.ingest_package_version(purl, tx).await?; + let package_version = self + .advisory + .graph + .ingest_package_version(purl, tx.as_ref()) + .await?; let entity = entity::fixed_package_version::ActiveModel { id: Default::default(), @@ -151,20 +172,22 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { Ok(( self, - entity.insert(&self.advisory.graph.connection(tx)).await?, + entity + .insert(&self.advisory.graph.connection(tx.as_ref())) + .await?, ) .into()) } - pub async fn ingest_affected_package_range( + pub async fn ingest_affected_package_range>( &self, purl: Purl, start: &str, end: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result { if let Some(found) = self - .get_affected_package_range(purl.clone(), start, end, tx) + .get_affected_package_range(purl.clone(), start, end, tx.as_ref()) .await? { return Ok(found); @@ -173,7 +196,7 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { let package_version_range = self .advisory .graph - .ingest_package_version_range(purl, start, end, tx) + .ingest_package_version_range(purl, start, end, tx.as_ref()) .await?; let entity = entity::affected_package_version_range::ActiveModel { @@ -184,7 +207,9 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { Ok(( self, - entity.insert(&self.advisory.graph.connection(tx)).await?, + entity + .insert(&self.advisory.graph.connection(tx.as_ref())) + .await?, ) .into()) } diff --git a/graph/src/graph/advisory/csaf/mod.rs b/graph/src/graph/advisory/csaf/mod.rs index f7067ecd2..ba0a947f6 100644 --- a/graph/src/graph/advisory/csaf/mod.rs +++ b/graph/src/graph/advisory/csaf/mod.rs @@ -18,7 +18,7 @@ impl<'g> AdvisoryContext<'g> { let mut entity: entity::advisory::ActiveModel = self.advisory.clone().into(); entity.title = Set(Some(csaf.document.title.clone().to_string())); entity - .update(&self.graph.connection(Transactional::None)) + .update(&self.graph.connection(&Transactional::None)) .await?; // Ingest vulnerabilities diff --git a/graph/src/graph/advisory/mod.rs b/graph/src/graph/advisory/mod.rs index bc7c7f8bb..e9878c74e 100644 --- a/graph/src/graph/advisory/mod.rs +++ b/graph/src/graph/advisory/mod.rs @@ -23,13 +23,13 @@ pub mod not_affected_package_version; pub mod csaf; impl Graph { - pub(crate) async fn get_advisory_by_id( + pub(crate) async fn get_advisory_by_id>( &self, id: i32, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { Ok(entity::advisory::Entity::find_by_id(id) - .one(&self.connection(tx)) + .one(&self.connection(tx.as_ref())) .await? .map(|advisory| (self, advisory).into())) } @@ -49,12 +49,12 @@ impl Graph { .map(|sbom| (self, sbom).into())) } - pub async fn ingest_advisory( + pub async fn ingest_advisory>( &self, identifier: impl Into, location: impl Into, sha256: impl Into, - tx: Transactional<'_>, + tx: TX, ) -> Result { let identifier = identifier.into(); let location = location.into(); @@ -100,10 +100,10 @@ impl<'g> From<(&'g Graph, entity::advisory::Model)> for AdvisoryContext<'g> { } impl<'g> AdvisoryContext<'g> { - pub async fn get_vulnerability( + pub async fn get_vulnerability>( &self, identifier: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result>, Error> { Ok(entity::advisory_vulnerability::Entity::find() .join( @@ -112,33 +112,40 @@ impl<'g> AdvisoryContext<'g> { ) .filter(entity::advisory_vulnerability::Column::AdvisoryId.eq(self.advisory.id)) .filter(entity::vulnerability::Column::Identifier.eq(identifier)) - .one(&self.graph.connection(tx)) + .one(&self.graph.connection(tx.as_ref())) .await? .map(|vuln| (self, vuln).into())) } - pub async fn ingest_vulnerability( + pub async fn ingest_vulnerability>( &self, identifier: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result { - if let Some(found) = self.get_vulnerability(identifier, tx).await? { + if let Some(found) = self.get_vulnerability(identifier, tx.as_ref()).await? { return Ok(found); } - let cve = self.graph.ingest_vulnerability(identifier, tx).await?; + let cve = self + .graph + .ingest_vulnerability(identifier, tx.as_ref()) + .await?; let entity = entity::advisory_vulnerability::ActiveModel { advisory_id: Set(self.advisory.id), vulnerability_id: Set(cve.vulnerability.id), }; - Ok((self, entity.insert(&self.graph.connection(tx)).await?).into()) + Ok(( + self, + entity.insert(&self.graph.connection(tx.as_ref())).await?, + ) + .into()) } - pub async fn vulnerabilities( + pub async fn vulnerabilities>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { Ok(entity::advisory_vulnerability::Entity::find() .join( @@ -148,20 +155,20 @@ impl<'g> AdvisoryContext<'g> { .rev(), ) .filter(entity::advisory_vulnerability::Column::AdvisoryId.eq(self.advisory.id)) - .all(&self.graph.connection(tx)) + .all(&self.graph.connection(tx.as_ref())) .await? .drain(0..) .map(|e| (self, e).into()) .collect()) } - pub async fn vulnerability_assertions( + pub async fn vulnerability_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result { - let affected = self.affected_assertions(tx).await?; - let not_affected = self.not_affected_assertions(tx).await?; - let fixed = self.fixed_assertions(tx).await?; + let affected = self.affected_assertions(tx.as_ref()).await?; + let not_affected = self.not_affected_assertions(tx.as_ref()).await?; + let fixed = self.fixed_assertions(tx.as_ref()).await?; let mut merged = affected.assertions.clone(); @@ -182,9 +189,9 @@ impl<'g> AdvisoryContext<'g> { Ok(AdvisoryVulnerabilityAssertions { assertions: merged }) } - pub async fn affected_assertions( + pub async fn affected_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result { #[derive(FromQueryResult, Debug)] struct AffectedVersion { @@ -233,7 +240,7 @@ impl<'g> AdvisoryContext<'g> { ) .filter(entity::affected_package_version_range::Column::AdvisoryId.eq(self.advisory.id)) .into_model::() - .all(&self.graph.connection(tx)) + .all(&self.graph.connection(tx.as_ref())) .await?; let mut assertions = HashMap::new(); @@ -260,9 +267,9 @@ impl<'g> AdvisoryContext<'g> { Ok(AdvisoryVulnerabilityAssertions { assertions }) } - pub async fn not_affected_assertions( + pub async fn not_affected_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result { #[derive(FromQueryResult, Debug)] struct NotAffectedVersion { @@ -309,7 +316,7 @@ impl<'g> AdvisoryContext<'g> { ) .filter(entity::not_affected_package_version::Column::AdvisoryId.eq(self.advisory.id)) .into_model::() - .all(&self.graph.connection(tx)) + .all(&self.graph.connection(tx.as_ref())) .await?; let mut assertions = HashMap::new(); @@ -335,9 +342,9 @@ impl<'g> AdvisoryContext<'g> { Ok(AdvisoryVulnerabilityAssertions { assertions }) } - pub async fn fixed_assertions( + pub async fn fixed_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result { #[derive(FromQueryResult, Debug)] struct FixedVersion { @@ -384,7 +391,7 @@ impl<'g> AdvisoryContext<'g> { ) .filter(entity::fixed_package_version::Column::AdvisoryId.eq(self.advisory.id)) .into_model::() - .all(&self.graph.connection(tx)) + .all(&self.graph.connection(tx.as_ref())) .await?; let mut assertions = HashMap::new(); diff --git a/graph/src/graph/cpe22.rs b/graph/src/graph/cpe22.rs index d2c379af7..495f72402 100644 --- a/graph/src/graph/cpe22.rs +++ b/graph/src/graph/cpe22.rs @@ -9,10 +9,10 @@ use trustify_common::db::Transactional; use trustify_entity as entity; impl Graph { - pub async fn get_cpe22>( + pub async fn get_cpe22, TX: AsRef>( &self, cpe: C, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { let cpe = cpe.into(); @@ -55,35 +55,35 @@ impl Graph { Value(inner) => query.filter(entity::cpe22::Column::Edition.eq(inner)), }; - if let Some(found) = query.one(&self.connection(tx)).await? { + if let Some(found) = query.one(&self.connection(tx.as_ref())).await? { Ok(Some((self, found).into())) } else { Ok(None) } } - pub(crate) async fn get_cpe22_by_query( + pub(crate) async fn get_cpe22_by_query>( &self, query: SelectStatement, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { Ok(entity::cpe22::Entity::find() .filter(entity::cpe22::Column::Id.in_subquery(query)) - .all(&self.connection(tx)) + .all(&self.connection(tx.as_ref())) .await? .drain(0..) .map(|cpe22| (self, cpe22).into()) .collect()) } - pub async fn ingest_cpe22>( + pub async fn ingest_cpe22, TX: AsRef>( &self, cpe: C, - tx: Transactional<'_>, + tx: TX, ) -> Result { let cpe = cpe.into(); - if let Some(found) = self.get_cpe22(cpe.clone(), tx).await? { + if let Some(found) = self.get_cpe22(cpe.clone(), tx.as_ref()).await? { return Ok(found); } @@ -123,7 +123,7 @@ impl Graph { language: Default::default(), }; - Ok((self, entity.insert(&self.connection(tx)).await?).into()) + Ok((self, entity.insert(&self.connection(tx.as_ref())).await?).into()) } } diff --git a/graph/src/graph/mod.rs b/graph/src/graph/mod.rs index 756b85493..35031685d 100644 --- a/graph/src/graph/mod.rs +++ b/graph/src/graph/mod.rs @@ -62,14 +62,14 @@ impl Graph { Self { db } } - pub async fn transaction(&self) -> Result { - Ok(self.db.begin().await?) + //pub async fn transaction(&self) -> Result { + //Ok(self.db.begin().await?) + //} + pub async fn transaction(&self) -> Result { + Ok(Transactional::Some(self.db.begin().await?)) } - pub(crate) fn connection<'db>( - &'db self, - tx: Transactional<'db>, - ) -> ConnectionOrTransaction<'db> { + pub(crate) fn connection<'db>(&'db self, tx: &'db Transactional) -> ConnectionOrTransaction { match tx { Transactional::None => ConnectionOrTransaction::Connection(&self.db), Transactional::Some(tx) => ConnectionOrTransaction::Transaction(tx), diff --git a/graph/src/graph/package/mod.rs b/graph/src/graph/package/mod.rs index 43e8d47b2..4a11c939e 100644 --- a/graph/src/graph/package/mod.rs +++ b/graph/src/graph/package/mod.rs @@ -13,6 +13,7 @@ use sea_orm::{ QuerySelect, QueryTrait, Set, }; use sea_query::{JoinType, SelectStatement, UnionType}; +use std::borrow::Borrow; use std::fmt::{Debug, Formatter}; use trustify_common::db::Transactional; use trustify_common::package::{Assertion, Claimant, PackageVulnerabilityAssertions}; @@ -30,16 +31,21 @@ impl Graph { /// /// The `pkg` parameter does not necessarily require the presence of qualifiers, but /// is assumed to be *complete*. - pub async fn ingest_qualified_package( + pub async fn ingest_qualified_package>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result { - if let Some(found) = self.get_qualified_package(purl.clone(), tx).await? { + if let Some(found) = self + .get_qualified_package(purl.clone(), tx.as_ref()) + .await? + { return Ok(found); } - let package_version = self.ingest_package_version(purl.clone(), tx).await?; + let package_version = self + .ingest_package_version(purl.clone(), tx.as_ref()) + .await?; package_version.ingest_qualified_package(purl, tx).await } @@ -47,45 +53,47 @@ impl Graph { /// Ensure the graph knows about and contains a record for a *versioned* package. /// /// This method will ensure the package being referenced is also ingested. - pub async fn ingest_package_version( + pub async fn ingest_package_version>( &self, pkg: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result { - if let Some(found) = self.get_package_version(pkg.clone(), tx).await? { + if let Some(found) = self.get_package_version(pkg.clone(), tx.as_ref()).await? { return Ok(found); } - let package = self.ingest_package(pkg.clone(), tx).await?; + let package = self.ingest_package(pkg.clone(), tx.as_ref()).await?; - package.ingest_package_version(pkg.clone(), tx).await + package + .ingest_package_version(pkg.clone(), tx.as_ref()) + .await } /// Ensure the graph knows about and contains a record for a *versioned range* of a package. /// /// This method will ensure the package being referenced is also ingested. - pub async fn ingest_package_version_range( + pub async fn ingest_package_version_range>( &self, pkg: Purl, start: &str, end: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result { - let package = self.ingest_package(pkg.clone(), tx).await?; + let package = self.ingest_package(pkg.clone(), tx.as_ref()).await?; package - .ingest_package_version_range(pkg.clone(), start, end, tx) + .ingest_package_version_range(pkg.clone(), start, end, tx.as_ref()) .await } /// Ensure the graph knows about and contains a record for a *versionless* package. /// /// This method will ensure the package being referenced is also ingested. - pub async fn ingest_package( + pub async fn ingest_package>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result { - if let Some(found) = self.get_package(purl.clone(), tx).await? { + if let Some(found) = self.get_package(purl.clone(), tx.as_ref()).await? { Ok(found) } else { let model = entity::package::ActiveModel { @@ -95,33 +103,35 @@ impl Graph { name: Set(purl.name.clone()), }; - Ok((self, model.insert(&self.connection(tx)).await?).into()) + Ok((self, model.insert(&self.connection(tx.as_ref())).await?).into()) } } /// Retrieve a *fully-qualified* package entry, if it exists. /// /// Non-mutating to the graph. - pub async fn get_qualified_package( + pub async fn get_qualified_package>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - if let Some(package_version) = self.get_package_version(purl.clone(), tx).await? { - package_version.get_qualified_package(purl, tx).await + if let Some(package_version) = self.get_package_version(purl.clone(), tx.as_ref()).await? { + package_version + .get_qualified_package(purl, tx.as_ref()) + .await } else { Ok(None) } } - pub(crate) async fn get_qualified_package_by_id( + pub(crate) async fn get_qualified_package_by_id>( &self, id: i32, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { let mut found = entity::qualified_package::Entity::find_by_id(id) .find_with_related(entity::package_qualifier::Entity) - .all(&self.connection(tx)) + .all(&self.connection(tx.as_ref())) .await?; if !found.is_empty() { @@ -147,22 +157,22 @@ impl Graph { } } - pub(crate) async fn get_qualified_packages_by_query( + pub(crate) async fn get_qualified_packages_by_query>( &self, query: SelectStatement, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { let mut found = entity::qualified_package::Entity::find() .filter(entity::qualified_package::Column::Id.in_subquery(query)) .find_with_related(entity::package_qualifier::Entity) - .all(&self.connection(tx)) + .all(&self.connection(tx.as_ref())) .await?; let mut package_versions = Vec::new(); for (base, qualifiers) in &found { if let Some(package_version) = self - .get_package_version_by_id(base.package_version_id, tx) + .get_package_version_by_id(base.package_version_id, tx.as_ref()) .await? { let qualifiers = qualifiers @@ -181,29 +191,29 @@ impl Graph { /// Retrieve a *versioned* package entry, if it exists. /// /// Non-mutating to the graph. - pub async fn get_package_version( + pub async fn get_package_version>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result>, Error> { - if let Some(pkg) = self.get_package(purl.clone(), tx).await? { - pkg.get_package_version(purl, tx).await + if let Some(pkg) = self.get_package(purl.clone(), tx.as_ref()).await? { + pkg.get_package_version(purl, tx.as_ref()).await } else { Ok(None) } } - pub(crate) async fn get_package_version_by_id( + pub(crate) async fn get_package_version_by_id>( &self, id: i32, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { if let Some(package_version) = entity::package_version::Entity::find_by_id(id) - .one(&self.connection(tx)) + .one(&self.connection(tx.as_ref())) .await? { if let Some(package) = self - .get_package_by_id(package_version.package_id, tx) + .get_package_by_id(package_version.package_id, tx.as_ref()) .await? { Ok(Some((&package, package_version).into())) @@ -218,16 +228,16 @@ impl Graph { /// Retrieve a *version range* of a package entry, if it exists. /// /// Non-mutating to the graph. - pub async fn get_package_version_range>( + pub async fn get_package_version_range>( &self, - pkg: P, + purl: Purl, start: &str, end: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - let purl = pkg.into(); - if let Some(pkg) = self.get_package(purl.clone(), tx).await? { - pkg.get_package_version_range(purl, start, end, tx).await + if let Some(pkg) = self.get_package(purl.clone(), tx.as_ref()).await? { + pkg.get_package_version_range(purl, start, end, tx.as_ref()) + .await } else { Ok(None) } @@ -236,10 +246,10 @@ impl Graph { /// Retrieve a *versionless* package entry, if it exists. /// /// Non-mutating to the graph. - pub async fn get_package( + pub async fn get_package>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { Ok(entity::package::Entity::find() .filter(entity::package::Column::Type.eq(purl.ty.clone())) @@ -249,18 +259,18 @@ impl Graph { entity::package::Column::Namespace.is_null() }) .filter(entity::package::Column::Name.eq(purl.name.clone())) - .one(&self.connection(tx)) + .one(&self.connection(tx.as_ref())) .await? .map(|package| (self, package).into())) } - pub(crate) async fn get_package_by_id( + pub(crate) async fn get_package_by_id>( &self, id: i32, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { if let Some(found) = entity::package::Entity::find_by_id(id) - .one(&self.connection(tx)) + .one(&self.connection(tx.as_ref())) .await? { Ok(Some((self, found).into())) @@ -291,14 +301,17 @@ impl<'g> From<(&'g Graph, entity::package::Model)> for PackageContext<'g> { impl<'g> PackageContext<'g> { /// Ensure the graph knows about and contains a record for a *version range* of this package. - pub async fn ingest_package_version_range( + pub async fn ingest_package_version_range>( &self, purl: Purl, start: &str, end: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - if let Some(found) = self.get_package_version_range(purl, start, end, tx).await? { + if let Some(found) = self + .get_package_version_range(purl, start, end, tx.as_ref()) + .await? + { Ok(found) } else { let entity = entity::package_version_range::ActiveModel { @@ -308,37 +321,41 @@ impl<'g> PackageContext<'g> { end: Set(end.to_string()), }; - Ok((self, entity.insert(&self.graph.connection(tx)).await?).into()) + Ok(( + self, + entity.insert(&self.graph.connection(tx.as_ref())).await?, + ) + .into()) } } /// Retrieve a *version range* package entry for this package, if it exists. /// /// Non-mutating to the graph. - pub async fn get_package_version_range( + pub async fn get_package_version_range>( &self, purl: Purl, start: &str, end: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result>, Error> { Ok(entity::package_version_range::Entity::find() .filter(entity::package_version_range::Column::PackageId.eq(self.package.id)) .filter(entity::package_version_range::Column::Start.eq(start.to_string())) .filter(entity::package_version_range::Column::End.eq(end.to_string())) - .one(&self.graph.connection(tx)) + .one(&self.graph.connection(tx.as_ref())) .await? .map(|package_version_range| (self, package_version_range).into())) } /// Ensure the graph knows about and contains a record for a *version* of this package. - pub async fn ingest_package_version( + pub async fn ingest_package_version>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { if let Some(version) = &purl.version { - if let Some(found) = self.get_package_version(purl.clone(), tx).await? { + if let Some(found) = self.get_package_version(purl.clone(), tx.as_ref()).await? { Ok(found) } else { let model = entity::package_version::ActiveModel { @@ -347,7 +364,11 @@ impl<'g> PackageContext<'g> { version: Set(version.clone()), }; - Ok((self, model.insert(&self.graph.connection(tx)).await?).into()) + Ok(( + self, + model.insert(&self.graph.connection(tx.as_ref())).await?, + ) + .into()) } } else { Err(Error::Purl(PurlErr::MissingVersion(purl.to_string()))) @@ -357,10 +378,10 @@ impl<'g> PackageContext<'g> { /// Retrieve a *version* package entry for this package, if it exists. /// /// Non-mutating to the graph. - pub async fn get_package_version( + pub async fn get_package_version>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result>, Error> { if let Some(package_version) = entity::package_version::Entity::find() .join( @@ -369,7 +390,7 @@ impl<'g> PackageContext<'g> { ) .filter(entity::package::Column::Id.eq(self.package.id)) .filter(entity::package_version::Column::Version.eq(purl.version.clone())) - .one(&self.graph.connection(tx)) + .one(&self.graph.connection(tx.as_ref())) .await? { Ok(Some((self, package_version).into())) @@ -381,25 +402,25 @@ impl<'g> PackageContext<'g> { /// Retrieve known versions of this package. /// /// Non-mutating to the graph. - pub async fn get_versions( + pub async fn get_versions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { Ok(entity::package_version::Entity::find() .filter(entity::package_version::Column::PackageId.eq(self.package.id)) - .all(&self.graph.connection(tx)) + .all(&self.graph.connection(tx.as_ref())) .await? .drain(0..) .map(|each| (self, each).into()) .collect()) } - pub async fn get_versions_paginated( + pub async fn get_versions_paginated>( &self, paginated: Paginated, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - let connection = self.graph.connection(tx); + let connection = self.graph.connection(tx.as_ref()); let pagination = entity::package_version::Entity::find() .filter(entity::package_version::Column::PackageId.eq(self.package.id)) @@ -441,13 +462,13 @@ impl<'g> PackageContext<'g> { /// /// Assertions are a mixture of "affected" and "not affected", for any version /// of this package, from any relevant advisory making statements. - pub async fn vulnerability_assertions( + pub async fn vulnerability_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result { - let affected = self.affected_assertions(tx).await?; + let affected = self.affected_assertions(tx.as_ref()).await?; - let not_affected = self.not_affected_assertions(tx).await?; + let not_affected = self.not_affected_assertions(tx.as_ref()).await?; let mut merged = PackageVulnerabilityAssertions::default(); @@ -464,9 +485,9 @@ impl<'g> PackageContext<'g> { /// /// Assertions are "affected" for any version of this package, /// from any relevant advisory making statements. - pub async fn affected_assertions( + pub async fn affected_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result { #[derive(FromQueryResult, Debug)] struct AffectedVersion { @@ -499,7 +520,7 @@ impl<'g> PackageContext<'g> { ) .filter(entity::package::Column::Id.eq(self.package.id)) .into_model::() - .all(&self.graph.connection(tx)) + .all(&self.graph.connection(tx.as_ref())) .await?; let mut assertions = PackageVulnerabilityAssertions::default(); @@ -525,9 +546,9 @@ impl<'g> PackageContext<'g> { /// /// Assertions are "not affected" for any version of this package, /// from any relevant advisory making statements. - pub async fn not_affected_assertions( + pub async fn not_affected_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result { #[derive(FromQueryResult, Debug)] struct NotAffectedVersion { @@ -554,7 +575,7 @@ impl<'g> PackageContext<'g> { ) .filter(entity::package_version::Column::PackageId.eq(self.package.id)) .into_model::() - .all(&self.graph.connection(tx)) + .all(&self.graph.connection(tx.as_ref())) .await?; let mut assertions = PackageVulnerabilityAssertions::default(); @@ -576,9 +597,9 @@ impl<'g> PackageContext<'g> { } /// Retrieve all advisories mentioning this base package. - pub async fn advisories_mentioning( + pub async fn advisories_mentioning>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result>, Error> { let mut not_affected_subquery = entity::not_affected_package_version::Entity::find() .select_only() @@ -608,7 +629,7 @@ impl<'g> PackageContext<'g> { .to_owned(), ), ) - .all(&self.graph.connection(tx)) + .all(&self.graph.connection(tx.as_ref())) .await?; Ok(advisories diff --git a/graph/src/graph/package/package_version.rs b/graph/src/graph/package/package_version.rs index d72773900..b3948d210 100644 --- a/graph/src/graph/package/package_version.rs +++ b/graph/src/graph/package/package_version.rs @@ -40,14 +40,15 @@ impl<'g> From<(&PackageContext<'g>, entity::package_version::Model)> for Package } impl<'g> PackageVersionContext<'g> { - pub async fn ingest_qualified_package>( + pub async fn ingest_qualified_package>( &self, - pkg: P, - mut tx: Transactional<'_>, + purl: Purl, + tx: TX, ) -> Result, Error> { - let purl = pkg.into(); - - if let Some(found) = self.get_qualified_package(purl.clone(), tx).await? { + if let Some(found) = self + .get_qualified_package(purl.clone(), tx.as_ref()) + .await? + { return Ok(found); } @@ -58,7 +59,7 @@ impl<'g> PackageVersionContext<'g> { }; let qualified_package = qualified_package - .insert(&self.package.graph.connection(tx)) + .insert(&self.package.graph.connection(tx.as_ref())) .await?; for (k, v) in &purl.qualifiers { @@ -69,22 +70,23 @@ impl<'g> PackageVersionContext<'g> { value: Set(v.clone()), }; - qualifier.insert(&self.package.graph.connection(tx)).await?; + qualifier + .insert(&self.package.graph.connection(tx.as_ref())) + .await?; } Ok((self, qualified_package, purl.qualifiers.clone()).into()) } - pub async fn get_qualified_package<'p, P: Into>( - &'p self, - pkg: P, - tx: Transactional<'_>, + pub async fn get_qualified_package>( + &self, + purl: Purl, + tx: TX, ) -> Result>, Error> { - let purl = pkg.into(); let found = entity::qualified_package::Entity::find() .filter(entity::qualified_package::Column::PackageVersionId.eq(self.package_version.id)) .find_with_related(entity::package_qualifier::Entity) - .all(&self.package.graph.connection(tx)) + .all(&self.package.graph.connection(tx.as_ref())) .await?; for (qualified_package, qualifiers) in found { @@ -101,13 +103,13 @@ impl<'g> PackageVersionContext<'g> { Ok(None) } - pub async fn vulnerability_assertions( + pub async fn vulnerability_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result { - let affected = self.affected_assertions(tx).await?; + let affected = self.affected_assertions(tx.as_ref()).await?; - let not_affected = self.not_affected_assertions(tx).await?; + let not_affected = self.not_affected_assertions(tx.as_ref()).await?; let mut merged = PackageVulnerabilityAssertions::default(); @@ -120,9 +122,9 @@ impl<'g> PackageVersionContext<'g> { Ok(merged) } - pub async fn affected_assertions( + pub async fn affected_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result { let possibly_affected = self.package.affected_assertions(tx).await?; @@ -131,9 +133,9 @@ impl<'g> PackageVersionContext<'g> { Ok(filtered) } - pub async fn not_affected_assertions( + pub async fn not_affected_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result { #[derive(FromQueryResult, Debug)] struct NotAffectedVersion { @@ -160,7 +162,7 @@ impl<'g> PackageVersionContext<'g> { ) .filter(entity::package_version::Column::Id.eq(self.package_version.id)) .into_model::() - .all(&self.package.graph.connection(tx)) + .all(&self.package.graph.connection(tx.as_ref())) .await?; let mut assertions = PackageVulnerabilityAssertions::default(); @@ -185,15 +187,15 @@ impl<'g> PackageVersionContext<'g> { /// Retrieve known variants of this package version. /// /// Non-mutating to the graph. - pub async fn get_variants>( + pub async fn get_variants>( &self, - pkg: P, - tx: Transactional<'_>, + pkg: Purl, + tx: TX, ) -> Result, Error> { Ok(entity::qualified_package::Entity::find() .filter(entity::qualified_package::Column::PackageVersionId.eq(self.package_version.id)) .find_with_related(entity::package_qualifier::Entity) - .all(&self.package.graph.connection(tx)) + .all(&self.package.graph.connection(tx.as_ref())) .await? .drain(0..) .map(|(base, qualifiers)| { diff --git a/graph/src/graph/package/qualified_package.rs b/graph/src/graph/package/qualified_package.rs index c812d77ad..2f42074eb 100644 --- a/graph/src/graph/package/qualified_package.rs +++ b/graph/src/graph/package/qualified_package.rs @@ -74,7 +74,10 @@ impl<'g> From> for Purl { } impl<'g> QualifiedPackageContext<'g> { - pub async fn sboms_containing(&self, tx: Transactional<'_>) -> Result, Error> { + pub async fn sboms_containing>( + &self, + tx: TX, + ) -> Result, Error> { /* Ok(entity::sbom::Entity::find() .join( @@ -95,12 +98,12 @@ impl<'g> QualifiedPackageContext<'g> { todo!() } - pub async fn vulnerability_assertions( + pub async fn vulnerability_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result { - let affected = self.affected_assertions(tx).await?; - let not_affected = self.not_affected_assertions(tx).await?; + let affected = self.affected_assertions(tx.as_ref()).await?; + let not_affected = self.not_affected_assertions(tx.as_ref()).await?; let mut merged = PackageVulnerabilityAssertions::default(); @@ -113,16 +116,16 @@ impl<'g> QualifiedPackageContext<'g> { Ok(merged) } - pub async fn affected_assertions( + pub async fn affected_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result { self.package_version.affected_assertions(tx).await } - pub async fn not_affected_assertions( + pub async fn not_affected_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result { self.package_version.not_affected_assertions(tx).await } diff --git a/graph/src/graph/sbom/mod.rs b/graph/src/graph/sbom/mod.rs index cb3ce2a0c..70f8b1389 100644 --- a/graph/src/graph/sbom/mod.rs +++ b/graph/src/graph/sbom/mod.rs @@ -40,11 +40,11 @@ impl Graph { .map(|sbom| (self, sbom).into())) } - pub async fn ingest_sbom( + pub async fn ingest_sbom>( &self, location: &str, sha256: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result { if let Some(found) = self.get_sbom(location, sha256).await? { return Ok(found); @@ -67,10 +67,10 @@ impl Graph { /// /// If the requested SBOM does not exist in the graph, it will not exist /// after this query either. This function is *non-mutating*. - pub async fn locate_sbom( + pub async fn locate_sbom>( &self, sbom_locator: SbomLocator, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { match sbom_locator { SbomLocator::Id(id) => self.locate_sbom_by_id(id, tx).await, @@ -81,10 +81,10 @@ impl Graph { } } - pub async fn locate_sboms( + pub async fn locate_sboms>( &self, sbom_locator: SbomLocator, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { match sbom_locator { SbomLocator::Id(id) => { @@ -102,46 +102,46 @@ impl Graph { } } - async fn locate_one_sbom( + async fn locate_one_sbom>( &self, query: SelectEntity, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { Ok(query - .one(&self.connection(tx)) + .one(&self.connection(tx.as_ref())) .await? .map(|sbom| (self, sbom).into())) } - async fn locate_many_sboms( + async fn locate_many_sboms>( &self, query: SelectEntity, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { Ok(query - .all(&self.connection(tx)) + .all(&self.connection(tx.as_ref())) .await? .drain(0..) .map(|sbom| (self, sbom).into()) .collect()) } - async fn locate_sbom_by_id( + async fn locate_sbom_by_id>( &self, id: i32, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { let query = entity::sbom::Entity::find_by_id(id); Ok(entity::sbom::Entity::find_by_id(id) - .one(&self.connection(tx)) + .one(&self.connection(tx.as_ref())) .await? .map(|sbom| (self, sbom).into())) } - async fn locate_sbom_by_location( + async fn locate_sbom_by_location>( &self, location: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { self.locate_one_sbom( entity::sbom::Entity::find() @@ -151,10 +151,10 @@ impl Graph { .await } - async fn locate_sboms_by_location( + async fn locate_sboms_by_location>( &self, location: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { self.locate_many_sboms( entity::sbom::Entity::find() @@ -164,10 +164,10 @@ impl Graph { .await } - async fn locate_sbom_by_sha256( + async fn locate_sbom_by_sha256>( &self, sha256: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { self.locate_one_sbom( entity::sbom::Entity::find() @@ -177,10 +177,10 @@ impl Graph { .await } - async fn locate_sboms_by_sha256( + async fn locate_sboms_by_sha256>( &self, sha256: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { self.locate_many_sboms( entity::sbom::Entity::find() @@ -190,12 +190,12 @@ impl Graph { .await } - async fn locate_sbom_by_purl( + async fn locate_sbom_by_purl>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - let package = self.get_qualified_package(purl, tx).await?; + let package = self.get_qualified_package(purl, tx.as_ref()).await?; if let Some(package) = package { self.locate_one_sbom( @@ -208,7 +208,7 @@ impl Graph { entity::sbom_describes_package::Column::QualifiedPackageId .eq(package.qualified_package.id), ), - tx, + tx.as_ref(), ) .await } else { @@ -216,12 +216,12 @@ impl Graph { } } - async fn locate_sboms_by_purl( + async fn locate_sboms_by_purl>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - let package = self.get_qualified_package(purl, tx).await?; + let package = self.get_qualified_package(purl, tx.as_ref()).await?; if let Some(package) = package { self.locate_many_sboms( @@ -234,7 +234,7 @@ impl Graph { entity::sbom_describes_package::Column::QualifiedPackageId .eq(package.qualified_package.id), ), - tx, + tx.as_ref(), ) .await } else { @@ -242,12 +242,12 @@ impl Graph { } } - async fn locate_sbom_by_cpe22( + async fn locate_sbom_by_cpe22>( &self, cpe: &Cpe22, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - if let Some(cpe) = self.get_cpe22(cpe.clone(), tx).await? { + if let Some(cpe) = self.get_cpe22(cpe.clone(), tx.as_ref()).await? { self.locate_one_sbom( entity::sbom::Entity::find() .join( @@ -255,7 +255,7 @@ impl Graph { entity::sbom_describes_cpe22::Relation::Sbom.def().rev(), ) .filter(entity::sbom_describes_cpe22::Column::Cpe22Id.eq(cpe.cpe22.id)), - tx, + tx.as_ref(), ) .await } else { @@ -263,12 +263,12 @@ impl Graph { } } - async fn locate_sboms_by_cpe22>( + async fn locate_sboms_by_cpe22, TX: AsRef>( &self, cpe: C, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - if let Some(found) = self.get_cpe22(cpe, tx).await? { + if let Some(found) = self.get_cpe22(cpe, tx.as_ref()).await? { self.locate_many_sboms( entity::sbom::Entity::find() .join( @@ -276,7 +276,7 @@ impl Graph { entity::sbom_describes_cpe22::Relation::Sbom.def().rev(), ) .filter(entity::sbom_describes_cpe22::Column::Cpe22Id.eq(found.cpe22.id)), - tx, + tx.as_ref(), ) .await } else { @@ -287,7 +287,7 @@ impl Graph { #[derive(Clone)] pub struct SbomContext { - pub(crate) system: Graph, + pub(crate) graph: Graph, pub(crate) sbom: entity::sbom::Model, } @@ -306,24 +306,24 @@ impl Debug for SbomContext { impl From<(&Graph, entity::sbom::Model)> for SbomContext { fn from((system, sbom): (&Graph, entity::sbom::Model)) -> Self { Self { - system: system.clone(), + graph: system.clone(), sbom, } } } impl SbomContext { - pub async fn ingest_describes_cpe22>( + pub async fn ingest_describes_cpe22, TX: AsRef>( &self, cpe: C, - tx: Transactional<'_>, + tx: TX, ) -> Result<(), Error> { - let cpe = self.system.ingest_cpe22(cpe, tx).await?; + let cpe = self.graph.ingest_cpe22(cpe, tx.as_ref()).await?; let fetch = entity::sbom_describes_cpe22::Entity::find() .filter(entity::sbom_describes_cpe22::Column::SbomId.eq(self.sbom.id)) .filter(entity::sbom_describes_cpe22::Column::Cpe22Id.eq(cpe.cpe22.id)) - .one(&self.system.connection(tx)) + .one(&self.graph.connection(tx.as_ref())) .await?; if fetch.is_none() { @@ -332,42 +332,45 @@ impl SbomContext { cpe22_id: Set(cpe.cpe22.id), }; - model.insert(&self.system.connection(tx)).await?; + model.insert(&self.graph.connection(tx.as_ref())).await?; } Ok(()) } - pub async fn ingest_describes_package( + pub async fn ingest_describes_package>( &self, purl: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result<(), Error> { let fetch = entity::sbom_describes_package::Entity::find() .filter( Condition::all() .add(entity::sbom_describes_package::Column::SbomId.eq(self.sbom.id)), ) - .one(&self.system.connection(tx)) + .one(&self.graph.connection(tx.as_ref())) .await?; if fetch.is_none() { - let package = self.system.ingest_qualified_package(purl, tx).await?; + let package = self + .graph + .ingest_qualified_package(purl, tx.as_ref()) + .await?; let model = entity::sbom_describes_package::ActiveModel { sbom_id: Set(self.sbom.id), qualified_package_id: Set(package.qualified_package.id), }; - model.insert(&self.system.connection(tx)).await?; + model.insert(&self.graph.connection(tx.as_ref())).await?; } Ok(()) } - pub async fn describes_packages( + pub async fn describes_packages>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - self.system + self.graph .get_qualified_packages_by_query( entity::sbom_describes_package::Entity::find() .select_only() @@ -379,11 +382,11 @@ impl SbomContext { .await } - pub async fn describes_cpe22s( + pub async fn describes_cpe22s>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - self.system + self.graph .get_cpe22_by_query( entity::sbom_describes_cpe22::Entity::find() .select_only() @@ -397,21 +400,21 @@ impl SbomContext { /// Within the context of *this* SBOM, ingest a relationship between /// two packages. - async fn ingest_package_relates_to_package( + async fn ingest_package_relates_to_package>( &self, left_package_input: Purl, relationship: Relationship, right_package_input: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result<(), Error> { let left_package = self - .system - .ingest_qualified_package(left_package_input.clone(), tx) + .graph + .ingest_qualified_package(left_package_input.clone(), tx.as_ref()) .await; let right_package = self - .system - .ingest_qualified_package(right_package_input.clone(), tx) + .graph + .ingest_qualified_package(right_package_input.clone(), tx.as_ref()) .await; match (&left_package, &right_package) { @@ -429,7 +432,7 @@ impl SbomContext { entity::package_relates_to_package::Column::RightPackageId .eq(right_package.qualified_package.id), ) - .one(&self.system.connection(tx)) + .one(&self.graph.connection(tx.as_ref())) .await? .is_none() { @@ -440,7 +443,7 @@ impl SbomContext { sbom_id: Set(self.sbom.id), }; - entity.insert(&self.system.connection(tx)).await?; + entity.insert(&self.graph.connection(tx.as_ref())).await?; } } (Err(_), Err(_)) => { @@ -467,13 +470,13 @@ impl SbomContext { Ok(()) } - pub async fn related_packages_transitively_x( + pub async fn related_packages_transitively_x>( &self, relationship: Relationship, pkg: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - let pkg = self.system.get_qualified_package(pkg, tx).await?; + let pkg = self.graph.get_qualified_package(pkg, tx.as_ref()).await?; if let Some(pkg) = pkg { #[derive(Debug, FromQueryResult)] @@ -483,7 +486,7 @@ impl SbomContext { } Ok(self - .system + .graph .get_qualified_packages_by_query( Query::select() .column(LeftPackageId) @@ -496,7 +499,7 @@ impl SbomContext { QualifiedPackageTransitive, ) .to_owned(), - tx, + tx.as_ref(), ) .await?) } else { @@ -504,13 +507,13 @@ impl SbomContext { } } - pub async fn related_packages_transitively( + pub async fn related_packages_transitively>( &self, relationships: &[Relationship], pkg: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - let pkg = self.system.get_qualified_package(pkg, tx).await?; + let pkg = self.graph.get_qualified_package(pkg, tx.as_ref()).await?; if let Some(pkg) = pkg { #[derive(Debug, FromQueryResult)] @@ -532,7 +535,7 @@ impl SbomContext { let qualified_package_id: SimpleExpr = pkg.qualified_package.id.into(); Ok(self - .system + .graph .get_qualified_packages_by_query( Query::select() .column(LeftPackageId) @@ -545,7 +548,7 @@ impl SbomContext { QualifiedPackageTransitive, ) .to_owned(), - tx, + tx.as_ref(), ) .await?) } else { @@ -553,13 +556,13 @@ impl SbomContext { } } - pub async fn related_packages( + pub async fn related_packages>( &self, relationship: Relationship, pkg: Purl, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - let pkg = self.system.get_qualified_package(pkg, tx).await?; + let pkg = self.graph.get_qualified_package(pkg, tx.as_ref()).await?; if let Some(pkg) = pkg { let related_query = entity::package_relates_to_package::Entity::find() @@ -576,7 +579,7 @@ impl SbomContext { let mut found = entity::qualified_package::Entity::find() .filter(entity::qualified_package::Column::Id.in_subquery(related_query)) .find_with_related(entity::package_qualifier::Entity) - .all(&self.system.connection(tx)) + .all(&self.graph.connection(tx.as_ref())) .await?; let mut related = Vec::new(); @@ -584,15 +587,15 @@ impl SbomContext { for (base, qualifiers) in found.drain(0..) { if let Some(package_version) = entity::package_version::Entity::find_by_id(base.package_version_id) - .one(&self.system.connection(tx)) + .one(&self.graph.connection(tx.as_ref())) .await? { if let Some(package) = entity::package::Entity::find_by_id(package_version.package_id) - .one(&self.system.connection(tx)) + .one(&self.graph.connection(tx.as_ref())) .await? { - let package = (&self.system, package).into(); + let package = (&self.graph, package).into(); let package_version = (&package, package_version).into(); let qualifiers_map = qualifiers @@ -612,11 +615,11 @@ impl SbomContext { } } - pub async fn vulnerability_assertions( + pub async fn vulnerability_assertions>( &self, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { - let described_packages = self.describes_packages(tx).await?; + let described_packages = self.describes_packages(tx.as_ref()).await?; let mut applicable = HashSet::new(); for pkg in described_packages { @@ -633,9 +636,12 @@ impl SbomContext { let mut assertions = HashMap::new(); for pkg in applicable { - let package_assertions = pkg.vulnerability_assertions(tx).await?; + let package_assertions = pkg.vulnerability_assertions(tx.as_ref()).await?; if !package_assertions.assertions.is_empty() { - assertions.insert(pkg.clone(), pkg.vulnerability_assertions(tx).await?); + assertions.insert( + pkg.clone(), + pkg.vulnerability_assertions(tx.as_ref()).await?, + ); } } diff --git a/graph/src/graph/sbom/spdx.rs b/graph/src/graph/sbom/spdx.rs index e235312b1..e0610c1ff 100644 --- a/graph/src/graph/sbom/spdx.rs +++ b/graph/src/graph/sbom/spdx.rs @@ -23,79 +23,69 @@ impl SbomContext { pub async fn ingest_spdx(&self, sbom_data: SPDX) -> Result<(), anyhow::Error> { // FIXME: not sure this is correct. It may be that we need to use `DatabaseTransaction` instead of the `db` field let sbom = self.clone(); - //let graph = self.graph.clone(); - self.system - .db - .transaction(|tx| { - Box::pin(async move { - let tx: Transactional = tx.into(); - // For each thing described in the SBOM data, link it up to an sbom_cpe or sbom_package. - for described in &sbom_data.document_creation_information.document_describes { - for described_package in sbom_data - .package_information - .iter() - .filter(|each| each.package_spdx_identifier.eq(described)) - { - for reference in &described_package.external_reference { - if reference.reference_type == "purl" { - //log::debug!("describes pkg {}", reference.reference_locator); - sbom.ingest_describes_package( - reference.reference_locator.as_str().try_into()?, - tx, - ) - .await?; - } else if reference.reference_type == "cpe22Type" { - //log::debug!("describes cpe22 {}", reference.reference_locator); - if let Ok(cpe) = cpe::uri::Uri::parse(&reference.reference_locator) { - sbom.ingest_describes_cpe22( - cpe, - tx, - ) - .await?; - } - - } - } + let tx = self.graph.transaction().await?; + + // For each thing described in the SBOM data, link it up to an sbom_cpe or sbom_package. + for described in &sbom_data.document_creation_information.document_describes { + for described_package in sbom_data + .package_information + .iter() + .filter(|each| each.package_spdx_identifier.eq(described)) + { + for reference in &described_package.external_reference { + if reference.reference_type == "purl" { + //log::debug!("describes pkg {}", reference.reference_locator); + sbom.ingest_describes_package( + reference.reference_locator.as_str().try_into()?, + &tx, + ) + .await?; + } else if reference.reference_type == "cpe22Type" { + //log::debug!("describes cpe22 {}", reference.reference_locator); + if let Ok(cpe) = cpe::uri::Uri::parse(&reference.reference_locator) { + sbom.ingest_describes_cpe22(cpe, &tx).await?; + } + } + } - // connect all other tree-ish package trees in the context of this sbom. - for package_info in &sbom_data.package_information { - let package_identifier = &package_info.package_spdx_identifier; - for package_ref in &package_info.external_reference { - if package_ref.reference_type == "purl" { - let package_a = package_ref.reference_locator.clone(); - //log::debug!("pkg_a: {}", package_a); - - for relationship in sbom_data - .relationships_for_spdx_id(package_identifier) - { - if let Some(package) = sbom_data - .package_information - .iter() - .find(|each| { - each.package_spdx_identifier - == relationship.related_spdx_element - }) - { - for reference in &package.external_reference { - if reference.reference_type == "purl" { - let package_b = reference.reference_locator.clone(); - - // Check for the degenerate case that seems to appear where an SBOM inceptions itself. - if package_a != package_b { - if let Ok((left, rel, right)) = SpdxRelationship( - &package_a, - &relationship.relationship_type, - &package_b).try_into() { - sbom.ingest_package_relates_to_package( - left.try_into()?, - rel, - right.try_into()?, - tx, - ).await? - } - } - } + // connect all other tree-ish package trees in the context of this sbom. + for package_info in &sbom_data.package_information { + let package_identifier = &package_info.package_spdx_identifier; + for package_ref in &package_info.external_reference { + if package_ref.reference_type == "purl" { + let package_a = package_ref.reference_locator.clone(); + //log::debug!("pkg_a: {}", package_a); + + for relationship in + sbom_data.relationships_for_spdx_id(package_identifier) + { + if let Some(package) = + sbom_data.package_information.iter().find(|each| { + each.package_spdx_identifier + == relationship.related_spdx_element + }) + { + for reference in &package.external_reference { + if reference.reference_type == "purl" { + let package_b = reference.reference_locator.clone(); + + // Check for the degenerate case that seems to appear where an SBOM inceptions itself. + if package_a != package_b { + if let Ok((left, rel, right)) = SpdxRelationship( + &package_a, + &relationship.relationship_type, + &package_b, + ) + .try_into() + { + sbom.ingest_package_relates_to_package( + left.try_into()?, + rel, + right.try_into()?, + &tx, + ) + .await? } } } @@ -104,11 +94,10 @@ impl SbomContext { } } } - - Ok::<(), Error>(()) - }) - }) - .await?; + } + } + } + tx.commit().await?; Ok(()) } diff --git a/graph/src/graph/vulnerability.rs b/graph/src/graph/vulnerability.rs index 72cf61ffc..0f49e49fc 100644 --- a/graph/src/graph/vulnerability.rs +++ b/graph/src/graph/vulnerability.rs @@ -16,12 +16,12 @@ use trustify_entity::vulnerability::Model; use trustify_entity::{advisory, advisory_vulnerability, vulnerability, vulnerability_description}; impl Graph { - pub async fn ingest_vulnerability( + pub async fn ingest_vulnerability>( &self, identifier: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result { - if let Some(found) = self.get_vulnerability(identifier, tx).await? { + if let Some(found) = self.get_vulnerability(identifier, tx.as_ref()).await? { Ok(found) } else { let entity = vulnerability::ActiveModel { @@ -30,18 +30,18 @@ impl Graph { title: NotSet, }; - Ok((self, entity.insert(&self.connection(tx)).await?).into()) + Ok((self, entity.insert(&self.connection(tx.as_ref())).await?).into()) } } - pub async fn get_vulnerability( + pub async fn get_vulnerability>( &self, identifier: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { Ok(vulnerability::Entity::find() .filter(vulnerability::Column::Identifier.eq(identifier)) - .one(&self.connection(tx)) + .one(&self.connection(tx.as_ref())) .await? .map(|cve| (self, cve).into())) } @@ -69,38 +69,41 @@ impl From<(&Graph, vulnerability::Model)> for VulnerabilityContext { } impl VulnerabilityContext { - pub async fn advisories(&self, tx: Transactional<'_>) -> Result, Error> { + pub async fn advisories>( + &self, + tx: TX, + ) -> Result, Error> { Ok(advisory::Entity::find() .join( JoinType::Join, advisory_vulnerability::Relation::Advisory.def().rev(), ) .filter(advisory_vulnerability::Column::VulnerabilityId.eq(self.vulnerability.id)) - .all(&self.graph.connection(tx)) + .all(&self.graph.connection(tx.as_ref())) .await? .drain(0..) .map(|advisory| (&self.graph, advisory).into()) .collect()) } - pub async fn set_title( + pub async fn set_title>( &self, title: Option, - tx: Transactional<'_>, + tx: TX, ) -> Result<(), Error> { let mut entity: vulnerability::ActiveModel = self.vulnerability.clone().into(); entity.title = Set(title); - entity.save(&self.graph.connection(tx)).await?; + entity.save(&self.graph.connection(tx.as_ref())).await?; Ok(()) } - pub async fn add_description( + pub async fn add_description>( &self, lang: &str, description: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result<(), Error> { let model = vulnerability_description::ActiveModel { id: Default::default(), @@ -109,21 +112,21 @@ impl VulnerabilityContext { description: Set(description.to_string()), }; - model.save(&self.graph.connection(tx)).await?; + model.save(&self.graph.connection(tx.as_ref())).await?; Ok(()) } - pub async fn descriptions( + pub async fn descriptions>( &self, lang: &str, - tx: Transactional<'_>, + tx: TX, ) -> Result, Error> { Ok(self .vulnerability .find_related(entity::vulnerability_description::Entity) .filter(vulnerability_description::Column::Lang.eq(lang)) - .all(&self.graph.connection(tx)) + .all(&self.graph.connection(tx.as_ref())) .await? .drain(..) .map(|e| e.description) diff --git a/ingestors/src/advisory/osv/loader.rs b/ingestors/src/advisory/osv/loader.rs index 1b339ace8..e2cf91324 100644 --- a/ingestors/src/advisory/osv/loader.rs +++ b/ingestors/src/advisory/osv/loader.rs @@ -1,6 +1,5 @@ use std::io::Read; -use trustify_common::db::Transactional; use trustify_graph::graph::Graph; use crate::advisory::osv::schema::Vulnerability; @@ -35,16 +34,11 @@ impl<'g> OsvLoader<'g> { }) { for cve_id in cve_ids { println!("INGEST VULN {}", cve_id); - let vuln = self - .graph - .ingest_vulnerability(cve_id, Transactional::Some(&tx)) - .await?; + let vuln = self.graph.ingest_vulnerability(cve_id, &tx).await?; - vuln.set_title(osv.summary.clone(), Transactional::Some(&tx)) - .await?; + vuln.set_title(osv.summary.clone(), &tx).await?; if let Some(details) = &osv.details { - vuln.add_description("en", details, Transactional::Some(&tx)) - .await? + vuln.add_description("en", details, &tx).await? } } @@ -52,7 +46,7 @@ impl<'g> OsvLoader<'g> { let sha256 = hex::encode(hashes.sha256.as_ref()); self.graph - .ingest_advisory(osv.id, location, sha256, Transactional::Some(&tx)) + .ingest_advisory(osv.id, location, sha256, &tx) .await?; } diff --git a/ingestors/src/cve/loader.rs b/ingestors/src/cve/loader.rs index 35116785b..32634b40c 100644 --- a/ingestors/src/cve/loader.rs +++ b/ingestors/src/cve/loader.rs @@ -1,9 +1,10 @@ +use std::io::Read; + +use trustify_graph::graph::Graph; + use crate::cve::cve_record::v5::CveRecord; use crate::hashing::HashingRead; use crate::Error; -use std::io::Read; -use trustify_common::db::Transactional; -use trustify_graph::graph::Graph; /// Loader capable of parsing a CVE Record JSON file /// and manipulating the Graph to integrate it into @@ -34,20 +35,16 @@ impl<'g> CveLoader<'g> { let vulnerability = self .graph - .ingest_vulnerability(cve.cve_metadata.cve_id(), Transactional::None) + .ingest_vulnerability(cve.cve_metadata.cve_id(), &tx) .await?; vulnerability - .set_title(cve.containers.cna.title.clone(), Transactional::Some(&tx)) + .set_title(cve.containers.cna.title.clone(), &tx) .await?; for description in cve.containers.cna.descriptions { vulnerability - .add_description( - &description.lang, - &description.value, - Transactional::Some(&tx), - ) + .add_description(&description.lang, &description.value, &tx) .await?; } @@ -55,12 +52,7 @@ impl<'g> CveLoader<'g> { let sha256 = hex::encode(hashes.sha256.as_ref()); self.graph - .ingest_advisory( - cve.cve_metadata.cve_id(), - location, - sha256, - Transactional::Some(&tx), - ) + .ingest_advisory(cve.cve_metadata.cve_id(), location, sha256, &tx) .await?; tx.commit().await?; @@ -71,14 +63,17 @@ impl<'g> CveLoader<'g> { #[cfg(test)] mod test { - use crate::cve::loader::CveLoader; use std::fs::File; use std::path::PathBuf; use std::str::FromStr; + use test_log::test; - use trustify_common::db::{Database, Transactional}; + + use trustify_common::db::Database; use trustify_graph::graph::Graph; + use crate::cve::loader::CveLoader; + #[test(tokio::test)] async fn cve_loader() -> Result<(), anyhow::Error> { let db = Database::for_test("ingestors_cve_loader").await?; @@ -90,9 +85,7 @@ mod test { let cve_json = test_data.join("CVE-2024-28111.json"); let cve_file = File::open(cve_json)?; - let loaded_vulnerability = graph - .get_vulnerability("CVE-2024-28111", Transactional::None) - .await?; + let loaded_vulnerability = graph.get_vulnerability("CVE-2024-28111", ()).await?; assert!(loaded_vulnerability.is_none()); @@ -110,9 +103,7 @@ mod test { loader.load("CVE-2024-28111.json", cve_file).await?; - let loaded_vulnerability = graph - .get_vulnerability("CVE-2024-28111", Transactional::None) - .await?; + let loaded_vulnerability = graph.get_vulnerability("CVE-2024-28111", ()).await?; assert!(loaded_vulnerability.is_some()); @@ -128,9 +119,7 @@ mod test { let loaded_vulnerability = loaded_vulnerability.unwrap(); - let descriptions = loaded_vulnerability - .descriptions("en", Transactional::None) - .await?; + let descriptions = loaded_vulnerability.descriptions("en", ()).await?; assert_eq!(1, descriptions.len());