Skip to content

Commit

Permalink
chore(refactor): introduce visitor pattern for policy expr execution
Browse files Browse the repository at this point in the history
  • Loading branch information
j-lanson committed Oct 11, 2024
1 parent 984ccd5 commit 7ae8621
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 58 deletions.
35 changes: 19 additions & 16 deletions hipcheck/src/policy_exprs/env.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// SPDX-License-Identifier: Apache-2.0

use crate::policy_exprs::{eval, Error, Expr, Ident, Primitive, Result, F64};
use crate::policy_exprs::{
pass::ExprMutator, Array as StructArray, Error, Expr, ExprVisitor, Function as StructFunction,
Ident, Lambda as StructLambda, Primitive, Result, F64,
};
use itertools::Itertools as _;
use jiff::{Span, Zoned};
use std::{cmp::Ordering, collections::HashMap, ops::Not as _};
Expand Down Expand Up @@ -137,8 +140,8 @@ fn partially_evaluate(fn_name: &'static str, arg: Expr) -> Result<Expr> {
let var_name = "x";
let var = Ident(String::from(var_name));
let func = Ident(String::from(fn_name));
let op = Function(func, vec![Primitive(Identifier(var.clone())), arg]);
let lambda = Lambda(var, Box::new(op));
let op = StructFunction::new(func, vec![Primitive(Identifier(var.clone())), arg]).into();
let lambda = StructLambda::new(var, Box::new(op)).into();
Ok(lambda)
}

Expand All @@ -157,12 +160,12 @@ where

check_num_args(name, args, 2)?;

let arg_1 = match eval(env, &args[0])? {
let arg_1 = match env.visit_expr(args[0].clone())? {
Primitive(p) => p,
_ => return Err(Error::BadType(name)),
};

let arg_2 = match eval(env, &args[1])? {
let arg_2 = match env.visit_expr(args[1].clone())? {
Primitive(p) => p,
_ => return Err(Error::BadType(name)),
};
Expand All @@ -183,7 +186,7 @@ where
{
check_num_args(name, args, 1)?;

let primitive = match eval(env, &args[0])? {
let primitive = match env.visit_expr(args[0].clone())? {
Primitive(arg) => arg,
_ => return Err(Error::BadType(name)),
};
Expand All @@ -198,8 +201,8 @@ where
{
check_num_args(name, args, 1)?;

let arr = match eval(env, &args[0])? {
Array(arg) => array_type(&arg[..])?,
let arr = match env.visit_expr(args[0].clone())? {
Array(a) => array_type(&a.elts[..])?,
_ => return Err(Error::BadType(name)),
};

Expand All @@ -213,13 +216,13 @@ where
{
check_num_args(name, args, 2)?;

let (ident, body) = match eval(env, &args[0])? {
Lambda(ident, body) => (ident, body),
let (ident, body) = match env.visit_expr(args[0].clone())? {
Lambda(l) => (l.arg, l.body),
_ => return Err(Error::BadType(name)),
};

let arr = match eval(env, &args[1])? {
Array(arr) => array_type(&arr[..])?,
let arr = match env.visit_expr(args[1].clone())? {
Array(a) => array_type(&a.elts[..])?,
_ => return Err(Error::BadType(name)),
};

Expand Down Expand Up @@ -322,7 +325,7 @@ fn eval_lambda(env: &Env, ident: &Ident, val: Primitive, body: Expr) -> Result<E
return Err(Error::AlreadyBound);
}

eval(&child, &body)
child.visit_expr(body)
}

#[allow(clippy::bool_comparison)]
Expand Down Expand Up @@ -944,7 +947,7 @@ fn filter(env: &Env, args: &[Expr]) -> Result<Expr> {
ArrayType::Empty => Vec::new(),
};

Ok(Array(arr))
Ok(StructArray::new(arr).into())
};

higher_order_array_op(name, env, args, op)
Expand Down Expand Up @@ -1003,7 +1006,7 @@ fn foreach(env: &Env, args: &[Expr]) -> Result<Expr> {
ArrayType::Empty => Vec::new(),
};

Ok(Array(arr))
Ok(StructArray::new(arr).into())
};

higher_order_array_op(name, env, args, op)
Expand All @@ -1013,7 +1016,7 @@ fn dbg(env: &Env, args: &[Expr]) -> Result<Expr> {
let name = "dbg";
check_num_args(name, args, 1)?;
let arg = &args[0];
let result = eval(env, arg)?;
let result = env.visit_expr(arg.clone())?;
log::debug!("{arg} = {result}");
Ok(result)
}
88 changes: 76 additions & 12 deletions hipcheck/src/policy_exprs/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,68 @@ pub enum Expr {
Primitive(Primitive),

/// An array of primitive data.
Array(Vec<Primitive>),
Array(Array),

/// Stores the name of the function, followed by the args.
Function(Ident, Vec<Expr>),
Function(Function),

/// Stores the name of the input variable, followed by the lambda body.
Lambda(Ident, Box<Expr>),
Lambda(Lambda),

/// Stores a late-binding for a JSON value.
JsonPointer(JsonPointer),
}

/// An array of primitives.
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Array {
pub elts: Vec<Primitive>,
}
impl Array {
pub fn new(elts: Vec<Primitive>) -> Self {
Array { elts }
}
}
impl From<Array> for Expr {
fn from(value: Array) -> Self {
Expr::Array(value)
}
}

/// A `deke` function to evaluate.
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Function {
pub ident: Ident,
pub args: Vec<Expr>,
}
impl Function {
pub fn new(ident: Ident, args: Vec<Expr>) -> Self {
Function { ident, args }
}
}
impl From<Function> for Expr {
fn from(value: Function) -> Self {
Expr::Function(value)
}
}

/// Stores the name of the input variable, followed by the lambda body.
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Lambda {
pub arg: Ident,
pub body: Box<Expr>,
}
impl Lambda {
pub fn new(arg: Ident, body: Box<Expr>) -> Self {
Lambda { arg, body }
}
}
impl From<Lambda> for Expr {
fn from(value: Lambda) -> Self {
Expr::Lambda(value)
}
}

/// Primitive data.
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Primitive {
Expand Down Expand Up @@ -76,6 +126,11 @@ pub enum Primitive {
/// "P1w1dT1h1m1.1s"
Span(Span),
}
impl From<Primitive> for Expr {
fn from(value: Primitive) -> Self {
Expr::Primitive(value)
}
}

/// A variable or function identifier.
#[derive(Debug, Clone, PartialEq, Eq)]
Expand All @@ -87,6 +142,11 @@ pub struct JsonPointer {
pointer: String,
value: Option<serde_json::Value>,
}
impl From<JsonPointer> for Expr {
fn from(value: JsonPointer) -> Self {
Expr::JsonPointer(value)
}
}

/// A non-NaN 64-bit floating point number.
pub type F64 = NotNan<f64>;
Expand All @@ -96,13 +156,17 @@ impl Display for Expr {
match self {
Expr::Primitive(primitive) => write!(f, "{}", primitive),
Expr::Array(array) => {
write!(f, "[{}]", array.iter().map(ToString::to_string).join(" "))
write!(
f,
"[{}]",
array.elts.iter().map(ToString::to_string).join(" ")
)
}
Expr::Function(ident, args) => {
let args = args.iter().map(ToString::to_string).join(" ");
write!(f, "({} {})", ident, args)
Expr::Function(func) => {
let args = func.args.iter().map(ToString::to_string).join(" ");
write!(f, "({} {})", func.ident, args)
}
Expr::Lambda(arg, body) => write!(f, "(lambda ({}) {}", arg, body),
Expr::Lambda(l) => write!(f, "(lambda ({}) {}", l.arg, l.body),
Expr::JsonPointer(pointer) => write!(f, "${}", pointer.pointer),
}
}
Expand Down Expand Up @@ -206,7 +270,7 @@ fn parse_primitive(input: Input<'_>) -> IResult<Input<'_>, Primitive> {
/// Parse an array.
fn parse_array(input: Input<'_>) -> IResult<Input<'_>, Expr> {
let parser = tuple((Token::OpenBrace, many0(parse_primitive), Token::CloseBrace));
let mut parser = map(parser, |(_, inner, _)| Expr::Array(inner));
let mut parser = map(parser, |(_, inner, _)| Array::new(inner).into());
parser(input)
}

Expand All @@ -225,7 +289,7 @@ fn parse_function(input: Input<'_>) -> IResult<Input<'_>, Expr> {
Token::CloseParen,
));
let mut parser = map(parser, |(_, ident, args, _)| {
Expr::Function(Ident(ident), args)
Function::new(Ident(ident), args).into()
});
parser(input)
}
Expand Down Expand Up @@ -274,7 +338,7 @@ mod tests {

fn func(name: &str, args: Vec<impl IntoExpr>) -> Expr {
let args = args.into_iter().map(|arg| arg.into_expr()).collect();
Expr::Function(Ident(String::from(name)), args)
Function::new(Ident(String::from(name)), args).into()
}

fn int(val: i64) -> Primitive {
Expand All @@ -298,7 +362,7 @@ mod tests {
}

fn array(vals: Vec<Primitive>) -> Expr {
Expr::Array(vals)
Array::new(vals).into()
}

fn json_ptr(name: &str) -> Expr {
Expand Down
4 changes: 2 additions & 2 deletions hipcheck/src/policy_exprs/json_pointer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::policy_exprs::{
error,
error::{Error, Result},
expr::{Expr, Primitive},
expr::{Array, Expr, Primitive},
};
use ordered_float::NotNan;
use regex::{Captures, Regex, RegexBuilder};
Expand Down Expand Up @@ -116,7 +116,7 @@ fn json_to_policy_expr(val: &Value, pointer: &str, context: &Value) -> Result<Ex
.collect::<Result<Vec<Primitive>>>()?;
// NOTE that no checking is done to confirm that all Primitives are the same type.
// That would be a type error in the Policy Expr language.
Ok(Expr::Array(primitives))
Ok(Array::new(primitives).into())
}
// Strings cannot (currently) be represented in the Policy Expr language.
Value::String(_) => Err(Error::JSONPointerUnrepresentableType {
Expand Down
52 changes: 24 additions & 28 deletions hipcheck/src/policy_exprs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ mod env;
mod error;
pub mod expr;
mod json_pointer;
mod pass;
mod token;

use crate::policy_exprs::env::Env;
pub(crate) use crate::policy_exprs::{bridge::Tokens, expr::F64};
pub use crate::policy_exprs::{
error::{Error, Result},
expr::{Expr, Ident},
expr::{Array, Expr, Function, Ident, JsonPointer, Lambda},
pass::{ExprMutator, ExprVisitor},
token::LexingError,
};
use env::Binding;
Expand Down Expand Up @@ -45,34 +47,27 @@ impl Executor {
pub fn parse_and_eval(&self, raw_program: &str, context: &Value) -> Result<Expr> {
let processed_program = process_json_pointers(raw_program, context)?;
let program = parse(&processed_program)?;
let expr = eval(&self.env, &program)?;
let expr = self.env.visit_expr(program)?;
Ok(expr)
}
}

/// Evaluate the `Expr`, returning a boolean.
pub(crate) fn eval(env: &Env, program: &Expr) -> Result<Expr> {
let output = match program {
Expr::Primitive(primitive) => Ok(Expr::Primitive(primitive.resolve(env)?)),
Expr::Array(_) => Ok(program.clone()),
Expr::Function(name, args) => {
let binding = env
.get(name)
.ok_or_else(|| Error::UnknownFunction(name.deref().to_owned()))?;

if let Binding::Fn(op) = binding {
op(env, args)
} else {
Err(Error::FoundVarExpectedFunc(name.deref().to_owned()))
}
impl ExprMutator for Env<'_> {
fn visit_primitive(&self, prim: Primitive) -> Result<Expr> {
Ok(prim.resolve(self)?.into())
}
fn visit_function(&self, f: Function) -> Result<Expr> {
let binding = self
.get(&f.ident)
.ok_or_else(|| Error::UnknownFunction(f.ident.deref().to_owned()))?;
if let Binding::Fn(op) = binding {
(op)(self, &f.args)
} else {
Err(Error::FoundVarExpectedFunc(f.ident.deref().to_owned()))
}
Expr::Lambda(_, body) => Ok((**body).clone()),
Expr::JsonPointer(_) => unreachable!(),
};

log::debug!("input: {program:?}, output: {output:?}");

output
}
fn visit_lambda(&self, l: Lambda) -> Result<Expr> {
Ok((*l.body).clone())
}
}

#[cfg(test)]
Expand Down Expand Up @@ -177,7 +172,7 @@ mod tests {
let program = "(eq 3 (count (filter (gt 8.0) [1.0 2.0 10.0 20.0 30.0])))";
let context = Value::Null;
let result = Executor::std().parse_and_eval(program, &context).unwrap();
assert_eq!(result, Expr::Primitive(Primitive::Bool(true)));
assert_eq!(result, Primitive::Bool(true).into());
}

#[test]
Expand All @@ -186,7 +181,7 @@ mod tests {
"(eq 3 (count (filter (gt 8.0) (foreach (sub 1.0) [1.0 2.0 10.0 20.0 30.0]))))";
let context = Value::Null;
let result = Executor::std().parse_and_eval(program, &context).unwrap();
assert_eq!(result, Expr::Primitive(Primitive::Bool(true)));
assert_eq!(result, Primitive::Bool(true).into());
}

#[test]
Expand All @@ -196,11 +191,12 @@ mod tests {
let result = Executor::std().parse_and_eval(program, &context).unwrap();
assert_eq!(
result,
Expr::Array(vec![
Array::new(vec![
Primitive::Int(0),
Primitive::Int(0),
Primitive::Int(0)
])
.into()
);
}

Expand Down
Loading

0 comments on commit 7ae8621

Please sign in to comment.