Skip to content

Commit

Permalink
refactor the filtering logic into a single filtering fn
Browse files Browse the repository at this point in the history
Fixes trustification#168

Signed-off-by: Jim Crossley <[email protected]>
  • Loading branch information
jcrossley3 committed Apr 17, 2024
1 parent e02689b commit c3e8eaa
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 100 deletions.
23 changes: 6 additions & 17 deletions modules/graph/src/graph/sbom/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use trustify_entity as entity;
use trustify_entity::relationship::Relationship;
use trustify_entity::{sbom, vulnerability};
use trustify_module_search::model::SearchOptions;
use trustify_module_search::query::{Filter, Sort};
use trustify_module_search::query::Query as TrustifyQuery;

pub mod spdx;
mod tests;
Expand Down Expand Up @@ -60,22 +60,11 @@ impl Graph {

let SearchOptions { sort, q } = search;

let mut select =
sbom::Entity::find().filter(Filter::<sbom::Entity>::from_str(&q)?.into_condition());

if !sort.is_empty() {
for s in sort
.split(',')
.map(Sort::<sbom::Entity>::from_str)
.collect::<Result<Vec<_>, _>>()?
.iter()
{
select = select.order_by(s.field, s.order.clone());
}
}
select = select.order_by_desc(sbom::Column::Id);

let limiter = select.limiting(&connection, paginated.offset, paginated.limit);
let limiter = sbom::Entity::find().filtering(&q, &sort)?.limiting(
&connection,
paginated.offset,
paginated.limit,
);

Ok(PaginatedResults {
total: limiter.total().await?,
Expand Down
22 changes: 4 additions & 18 deletions modules/graph/src/graph/vulnerability/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use trustify_entity as entity;
use trustify_entity::vulnerability::Model;
use trustify_entity::{advisory, advisory_vulnerability, vulnerability, vulnerability_description};
use trustify_module_search::model::SearchOptions;
use trustify_module_search::query::{Filter, Sort};
use trustify_module_search::query::Query;

impl Graph {
pub async fn vulnerabilities<TX: AsRef<Transactional>>(
Expand All @@ -31,23 +31,9 @@ impl Graph {

let SearchOptions { sort, q } = search;

let mut select = vulnerability::Entity::find()
.filter(Filter::<vulnerability::Entity>::from_str(&q)?.into_condition());

// comma-delimited sort param, e.g. 'field1:asc,field2:desc'
if !sort.is_empty() {
for s in sort
.split(',')
.map(Sort::<vulnerability::Entity>::from_str)
.collect::<Result<Vec<_>, _>>()?
.iter()
{
select = select.order_by(s.field, s.order.clone());
}
}
select = select.order_by_desc(vulnerability::Column::Id);

let limiter = select.limiting(&connection, paginated.offset, paginated.limit);
let limiter = vulnerability::Entity::find()
.filtering(&q, &sort)?
.limiting(&connection, paginated.offset, paginated.limit);

Ok(PaginatedResults {
total: limiter.total().await?,
Expand Down
85 changes: 62 additions & 23 deletions modules/search/src/query.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use crate::service::Error;
use human_date_parser::{from_human_time, ParseResult};
use regex::Regex;
use sea_orm::sea_query::IntoCondition;
use sea_orm::{ColumnTrait, ColumnType, Condition, EntityTrait, Iterable, Order, Value};
use sea_orm::sea_query::{ConditionExpression, IntoCondition};
use sea_orm::{
ColumnTrait, ColumnType, Condition, EntityTrait, Iterable, Order, QueryFilter, QueryOrder,
Select, Value,
};
use std::fmt::Display;
use std::str::FromStr;
use std::sync::OnceLock;
Expand All @@ -14,22 +17,56 @@ use time::{Date, OffsetDateTime};
// Public interface
/////////////////////////////////////////////////////////////////////////

pub struct Filter<T: EntityTrait> {
pub trait Query<T: EntityTrait> {
fn filtering(self, filters: &str, sorts: &str) -> Result<Select<T>, Error>;
}

impl<T: EntityTrait> Query<T> for Select<T> {
fn filtering(self, filters: &str, sorts: &str) -> Result<Self, Error> {
let id = T::Column::from_str("id")
.map_err(|_| Error::SearchSyntax("Entity missing Id field".into()))?;
let result = if sorts.is_empty() {
self.filter(Filter::<T>::from_str(filters)?)
.order_by_desc(id)
} else {
sorts
.split(',')
.map(Sort::<T>::from_str)
.collect::<Result<Vec<_>, _>>()?
.iter()
.fold(self.filter(Filter::<T>::from_str(filters)?), |select, s| {
select.order_by(s.field, s.order.clone())
})
.order_by_desc(id)
};
Ok(result)
}
}

/////////////////////////////////////////////////////////////////////////
// Internal types
/////////////////////////////////////////////////////////////////////////

struct Filter<T: EntityTrait> {
operands: Operand<T>,
operator: Operator,
}

pub struct Sort<T: EntityTrait> {
pub field: T::Column,
pub order: Order,
struct Sort<T: EntityTrait> {
field: T::Column,
order: Order,
}

impl<T: EntityTrait> Filter<T> {
pub fn into_condition(self) -> Condition {
/////////////////////////////////////////////////////////////////////////
// SeaORM impls
/////////////////////////////////////////////////////////////////////////

impl<T: EntityTrait> IntoCondition for Filter<T> {
fn into_condition(self) -> Condition {
match self.operands {
Operand::Simple(col, v) => match self.operator {
Operator::Equal => col.eq(v).into_condition(),
Operator::NotEqual => col.ne(v).into_condition(),
Operator::Equal => col.eq(v),
Operator::NotEqual => col.ne(v),
op @ (Operator::Like | Operator::NotLike) => {
let v = format!(
"%{}%",
Expand All @@ -40,27 +77,29 @@ impl<T: EntityTrait> Filter<T> {
} else {
col.not_like(v)
}
.into_condition()
}
Operator::GreaterThan => col.gt(v).into_condition(),
Operator::GreaterThanOrEqual => col.gte(v).into_condition(),
Operator::LessThan => col.lt(v).into_condition(),
Operator::LessThanOrEqual => col.lte(v).into_condition(),
Operator::GreaterThan => col.gt(v),
Operator::GreaterThanOrEqual => col.gte(v),
Operator::LessThan => col.lt(v),
Operator::LessThanOrEqual => col.lte(v),
_ => unreachable!(),
},
}
.into_condition(),
Operand::Composite(v) => match self.operator {
Operator::And => v
.into_iter()
.fold(Condition::all(), |and, f| and.add(f.into_condition())),
Operator::Or => v
.into_iter()
.fold(Condition::any(), |or, f| or.add(f.into_condition())),
Operator::And => v.into_iter().fold(Condition::all(), |and, f| and.add(f)),
Operator::Or => v.into_iter().fold(Condition::any(), |or, f| or.add(f)),
_ => unreachable!(),
},
}
}
}

impl<T: EntityTrait> From<Filter<T>> for ConditionExpression {
fn from(f: Filter<T>) -> Self {
ConditionExpression::Condition(f.into_condition())
}
}

/////////////////////////////////////////////////////////////////////////
// FromStr impls
/////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -197,7 +236,7 @@ impl FromStr for Operator {
}

/////////////////////////////////////////////////////////////////////////
// Non-public helpers
// Internal helpers
/////////////////////////////////////////////////////////////////////////

enum Operand<T: EntityTrait> {
Expand Down
52 changes: 10 additions & 42 deletions modules/search/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use std::str::FromStr;

use crate::{
model::{FoundAdvisory, FoundSbom},
query::{Filter, Sort},
query::Query,
};
use actix_web::{body::BoxBody, HttpResponse, ResponseError};
use sea_orm::{EntityTrait, QueryFilter, QueryOrder};
use sea_orm::EntityTrait;
use trustify_common::{
db::{limiter::LimiterTrait, Database},
error::ErrorInformation,
Expand Down Expand Up @@ -55,25 +53,9 @@ impl SearchService {
sort: String,
paginated: Paginated,
) -> Result<PaginatedResults<FoundAdvisory>, Error> {
let mut select = advisory::Entity::find()
.filter(Filter::<advisory::Entity>::from_str(&filters)?.into_condition());

// comma-delimited sort param, e.g. 'field1:asc,field2:desc'
if !sort.is_empty() {
for s in sort
.split(',')
.map(Sort::<advisory::Entity>::from_str)
.collect::<Result<Vec<_>, _>>()?
.iter()
{
select = select.order_by(s.field, s.order.clone());
}
}
// we always sort by ID last, so that we have a stable order for pagination
select = select.order_by_desc(advisory::Column::Id);

let limiting = select.limiting(&self.db, paginated.offset, paginated.limit);

let limiting = advisory::Entity::find()
.filtering(&filters, &sort)?
.limiting(&self.db, paginated.offset, paginated.limit);
Ok(PaginatedResults {
total: limiting.total().await?,
items: limiting
Expand All @@ -93,25 +75,11 @@ impl SearchService {
sort: String,
paginated: Paginated,
) -> Result<PaginatedResults<FoundSbom>, Error> {
let mut select = sbom::Entity::find()
.filter(Filter::<sbom::Entity>::from_str(&filters)?.into_condition());

// comma-delimited sort param, e.g. 'field1:asc,field2:desc'
if !sort.is_empty() {
for s in sort
.split(',')
.map(Sort::<advisory::Entity>::from_str)
.collect::<Result<Vec<_>, _>>()?
.iter()
{
select = select.order_by(s.field, s.order.clone());
}
}
// we always sort by ID last, so that we have a stable order for pagination
select = select.order_by_desc(sbom::Column::Id);

let limiting = select.limiting(&self.db, paginated.offset, paginated.limit);

let limiting = sbom::Entity::find().filtering(&filters, &sort)?.limiting(
&self.db,
paginated.offset,
paginated.limit,
);
Ok(PaginatedResults {
total: limiting.total().await?,
items: limiting
Expand Down

0 comments on commit c3e8eaa

Please sign in to comment.