diff --git a/hipcheck/src/policy_exprs/env.rs b/hipcheck/src/policy_exprs/env.rs index 1b86fc52..c727bf13 100644 --- a/hipcheck/src/policy_exprs/env.rs +++ b/hipcheck/src/policy_exprs/env.rs @@ -1,6 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 -use crate::policy_exprs::{eval, Error, Expr, Ident, Primitive, Result, F64}; +use crate::policy_exprs::{ + pass::ExprMutator, Array as StructArray, Error, Expr, ExprVisitor, Function as StructFunction, + Ident, Lambda as StructLambda, Primitive, Result, F64, +}; use itertools::Itertools as _; use jiff::{Span, Zoned}; use std::{cmp::Ordering, collections::HashMap, ops::Not as _}; @@ -137,8 +140,8 @@ fn partially_evaluate(fn_name: &'static str, arg: Expr) -> Result { let var_name = "x"; let var = Ident(String::from(var_name)); let func = Ident(String::from(fn_name)); - let op = Function(func, vec![Primitive(Identifier(var.clone())), arg]); - let lambda = Lambda(var, Box::new(op)); + let op = StructFunction::new(func, vec![Primitive(Identifier(var.clone())), arg]).into(); + let lambda = StructLambda::new(var, Box::new(op)).into(); Ok(lambda) } @@ -157,12 +160,12 @@ where check_num_args(name, args, 2)?; - let arg_1 = match eval(env, &args[0])? { + let arg_1 = match env.visit_expr(args[0].clone())? { Primitive(p) => p, _ => return Err(Error::BadType(name)), }; - let arg_2 = match eval(env, &args[1])? { + let arg_2 = match env.visit_expr(args[1].clone())? { Primitive(p) => p, _ => return Err(Error::BadType(name)), }; @@ -183,7 +186,7 @@ where { check_num_args(name, args, 1)?; - let primitive = match eval(env, &args[0])? { + let primitive = match env.visit_expr(args[0].clone())? { Primitive(arg) => arg, _ => return Err(Error::BadType(name)), }; @@ -198,8 +201,8 @@ where { check_num_args(name, args, 1)?; - let arr = match eval(env, &args[0])? { - Array(arg) => array_type(&arg[..])?, + let arr = match env.visit_expr(args[0].clone())? { + Array(a) => array_type(&a.elts[..])?, _ => return Err(Error::BadType(name)), }; @@ -213,13 +216,13 @@ where { check_num_args(name, args, 2)?; - let (ident, body) = match eval(env, &args[0])? { - Lambda(ident, body) => (ident, body), + let (ident, body) = match env.visit_expr(args[0].clone())? { + Lambda(l) => (l.arg, l.body), _ => return Err(Error::BadType(name)), }; - let arr = match eval(env, &args[1])? { - Array(arr) => array_type(&arr[..])?, + let arr = match env.visit_expr(args[1].clone())? { + Array(a) => array_type(&a.elts[..])?, _ => return Err(Error::BadType(name)), }; @@ -322,7 +325,7 @@ fn eval_lambda(env: &Env, ident: &Ident, val: Primitive, body: Expr) -> Result Result { ArrayType::Empty => Vec::new(), }; - Ok(Array(arr)) + Ok(StructArray::new(arr).into()) }; higher_order_array_op(name, env, args, op) @@ -1003,7 +1006,7 @@ fn foreach(env: &Env, args: &[Expr]) -> Result { ArrayType::Empty => Vec::new(), }; - Ok(Array(arr)) + Ok(StructArray::new(arr).into()) }; higher_order_array_op(name, env, args, op) @@ -1013,7 +1016,7 @@ fn dbg(env: &Env, args: &[Expr]) -> Result { let name = "dbg"; check_num_args(name, args, 1)?; let arg = &args[0]; - let result = eval(env, arg)?; + let result = env.visit_expr(arg.clone())?; log::debug!("{arg} = {result}"); Ok(result) } diff --git a/hipcheck/src/policy_exprs/expr.rs b/hipcheck/src/policy_exprs/expr.rs index ac066252..f7a29e8f 100644 --- a/hipcheck/src/policy_exprs/expr.rs +++ b/hipcheck/src/policy_exprs/expr.rs @@ -27,18 +27,68 @@ pub enum Expr { Primitive(Primitive), /// An array of primitive data. - Array(Vec), + Array(Array), /// Stores the name of the function, followed by the args. - Function(Ident, Vec), + Function(Function), /// Stores the name of the input variable, followed by the lambda body. - Lambda(Ident, Box), + Lambda(Lambda), /// Stores a late-binding for a JSON value. JsonPointer(JsonPointer), } +/// An array of primitives. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Array { + pub elts: Vec, +} +impl Array { + pub fn new(elts: Vec) -> Self { + Array { elts } + } +} +impl From for Expr { + fn from(value: Array) -> Self { + Expr::Array(value) + } +} + +/// A `deke` function to evaluate. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Function { + pub ident: Ident, + pub args: Vec, +} +impl Function { + pub fn new(ident: Ident, args: Vec) -> Self { + Function { ident, args } + } +} +impl From for Expr { + fn from(value: Function) -> Self { + Expr::Function(value) + } +} + +/// Stores the name of the input variable, followed by the lambda body. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Lambda { + pub arg: Ident, + pub body: Box, +} +impl Lambda { + pub fn new(arg: Ident, body: Box) -> Self { + Lambda { arg, body } + } +} +impl From for Expr { + fn from(value: Lambda) -> Self { + Expr::Lambda(value) + } +} + /// Primitive data. #[derive(Debug, PartialEq, Eq, Clone)] pub enum Primitive { @@ -76,6 +126,11 @@ pub enum Primitive { /// "P1w1dT1h1m1.1s" Span(Span), } +impl From for Expr { + fn from(value: Primitive) -> Self { + Expr::Primitive(value) + } +} /// A variable or function identifier. #[derive(Debug, Clone, PartialEq, Eq)] @@ -87,6 +142,11 @@ pub struct JsonPointer { pointer: String, value: Option, } +impl From for Expr { + fn from(value: JsonPointer) -> Self { + Expr::JsonPointer(value) + } +} /// A non-NaN 64-bit floating point number. pub type F64 = NotNan; @@ -96,13 +156,17 @@ impl Display for Expr { match self { Expr::Primitive(primitive) => write!(f, "{}", primitive), Expr::Array(array) => { - write!(f, "[{}]", array.iter().map(ToString::to_string).join(" ")) + write!( + f, + "[{}]", + array.elts.iter().map(ToString::to_string).join(" ") + ) } - Expr::Function(ident, args) => { - let args = args.iter().map(ToString::to_string).join(" "); - write!(f, "({} {})", ident, args) + Expr::Function(func) => { + let args = func.args.iter().map(ToString::to_string).join(" "); + write!(f, "({} {})", func.ident, args) } - Expr::Lambda(arg, body) => write!(f, "(lambda ({}) {}", arg, body), + Expr::Lambda(l) => write!(f, "(lambda ({}) {}", l.arg, l.body), Expr::JsonPointer(pointer) => write!(f, "${}", pointer.pointer), } } @@ -206,7 +270,7 @@ fn parse_primitive(input: Input<'_>) -> IResult, Primitive> { /// Parse an array. fn parse_array(input: Input<'_>) -> IResult, Expr> { let parser = tuple((Token::OpenBrace, many0(parse_primitive), Token::CloseBrace)); - let mut parser = map(parser, |(_, inner, _)| Expr::Array(inner)); + let mut parser = map(parser, |(_, inner, _)| Array::new(inner).into()); parser(input) } @@ -225,7 +289,7 @@ fn parse_function(input: Input<'_>) -> IResult, Expr> { Token::CloseParen, )); let mut parser = map(parser, |(_, ident, args, _)| { - Expr::Function(Ident(ident), args) + Function::new(Ident(ident), args).into() }); parser(input) } @@ -274,7 +338,7 @@ mod tests { fn func(name: &str, args: Vec) -> Expr { let args = args.into_iter().map(|arg| arg.into_expr()).collect(); - Expr::Function(Ident(String::from(name)), args) + Function::new(Ident(String::from(name)), args).into() } fn int(val: i64) -> Primitive { @@ -298,7 +362,7 @@ mod tests { } fn array(vals: Vec) -> Expr { - Expr::Array(vals) + Array::new(vals).into() } fn json_ptr(name: &str) -> Expr { diff --git a/hipcheck/src/policy_exprs/json_pointer.rs b/hipcheck/src/policy_exprs/json_pointer.rs index b655799f..fb1c79b2 100644 --- a/hipcheck/src/policy_exprs/json_pointer.rs +++ b/hipcheck/src/policy_exprs/json_pointer.rs @@ -3,7 +3,7 @@ use crate::policy_exprs::{ error, error::{Error, Result}, - expr::{Expr, Primitive}, + expr::{Array, Expr, Primitive}, }; use ordered_float::NotNan; use regex::{Captures, Regex, RegexBuilder}; @@ -116,7 +116,7 @@ fn json_to_policy_expr(val: &Value, pointer: &str, context: &Value) -> Result>>()?; // NOTE that no checking is done to confirm that all Primitives are the same type. // That would be a type error in the Policy Expr language. - Ok(Expr::Array(primitives)) + Ok(Array::new(primitives).into()) } // Strings cannot (currently) be represented in the Policy Expr language. Value::String(_) => Err(Error::JSONPointerUnrepresentableType { diff --git a/hipcheck/src/policy_exprs/mod.rs b/hipcheck/src/policy_exprs/mod.rs index d5a08994..017be102 100644 --- a/hipcheck/src/policy_exprs/mod.rs +++ b/hipcheck/src/policy_exprs/mod.rs @@ -7,13 +7,15 @@ mod env; mod error; pub mod expr; mod json_pointer; +mod pass; mod token; use crate::policy_exprs::env::Env; pub(crate) use crate::policy_exprs::{bridge::Tokens, expr::F64}; pub use crate::policy_exprs::{ error::{Error, Result}, - expr::{Expr, Ident}, + expr::{Array, Expr, Function, Ident, JsonPointer, Lambda}, + pass::{ExprMutator, ExprVisitor}, token::LexingError, }; use env::Binding; @@ -45,34 +47,27 @@ impl Executor { pub fn parse_and_eval(&self, raw_program: &str, context: &Value) -> Result { let processed_program = process_json_pointers(raw_program, context)?; let program = parse(&processed_program)?; - let expr = eval(&self.env, &program)?; + let expr = self.env.visit_expr(program)?; Ok(expr) } } - -/// Evaluate the `Expr`, returning a boolean. -pub(crate) fn eval(env: &Env, program: &Expr) -> Result { - let output = match program { - Expr::Primitive(primitive) => Ok(Expr::Primitive(primitive.resolve(env)?)), - Expr::Array(_) => Ok(program.clone()), - Expr::Function(name, args) => { - let binding = env - .get(name) - .ok_or_else(|| Error::UnknownFunction(name.deref().to_owned()))?; - - if let Binding::Fn(op) = binding { - op(env, args) - } else { - Err(Error::FoundVarExpectedFunc(name.deref().to_owned())) - } +impl ExprMutator for Env<'_> { + fn visit_primitive(&self, prim: Primitive) -> Result { + Ok(prim.resolve(self)?.into()) + } + fn visit_function(&self, f: Function) -> Result { + let binding = self + .get(&f.ident) + .ok_or_else(|| Error::UnknownFunction(f.ident.deref().to_owned()))?; + if let Binding::Fn(op) = binding { + (op)(self, &f.args) + } else { + Err(Error::FoundVarExpectedFunc(f.ident.deref().to_owned())) } - Expr::Lambda(_, body) => Ok((**body).clone()), - Expr::JsonPointer(_) => unreachable!(), - }; - - log::debug!("input: {program:?}, output: {output:?}"); - - output + } + fn visit_lambda(&self, l: Lambda) -> Result { + Ok((*l.body).clone()) + } } #[cfg(test)] @@ -177,7 +172,7 @@ mod tests { let program = "(eq 3 (count (filter (gt 8.0) [1.0 2.0 10.0 20.0 30.0])))"; let context = Value::Null; let result = Executor::std().parse_and_eval(program, &context).unwrap(); - assert_eq!(result, Expr::Primitive(Primitive::Bool(true))); + assert_eq!(result, Primitive::Bool(true).into()); } #[test] @@ -186,7 +181,7 @@ mod tests { "(eq 3 (count (filter (gt 8.0) (foreach (sub 1.0) [1.0 2.0 10.0 20.0 30.0]))))"; let context = Value::Null; let result = Executor::std().parse_and_eval(program, &context).unwrap(); - assert_eq!(result, Expr::Primitive(Primitive::Bool(true))); + assert_eq!(result, Primitive::Bool(true).into()); } #[test] @@ -196,11 +191,12 @@ mod tests { let result = Executor::std().parse_and_eval(program, &context).unwrap(); assert_eq!( result, - Expr::Array(vec![ + Array::new(vec![ Primitive::Int(0), Primitive::Int(0), Primitive::Int(0) ]) + .into() ); } diff --git a/hipcheck/src/policy_exprs/pass.rs b/hipcheck/src/policy_exprs/pass.rs new file mode 100644 index 00000000..9d44b177 --- /dev/null +++ b/hipcheck/src/policy_exprs/pass.rs @@ -0,0 +1,63 @@ +use crate::policy_exprs::{ + env::Env, + error::{Error, Result}, + expr::*, +}; + +pub trait ExprVisitor { + fn visit_primitive(&self, prim: &Primitive) -> T; + fn visit_array(&self, arr: &Array) -> T; + fn visit_function(&self, func: &Function) -> T; + fn visit_lambda(&self, func: &Lambda) -> T; + fn visit_json_pointer(&self, func: &JsonPointer) -> T; + fn visit_expr(&self, expr: &Expr) -> T { + match expr { + Expr::Primitive(a) => self.visit_primitive(a), + Expr::Array(a) => self.visit_array(a), + Expr::Function(a) => self.visit_function(a), + Expr::Lambda(a) => self.visit_lambda(a), + Expr::JsonPointer(a) => self.visit_json_pointer(a), + } + } + fn run(&self, expr: &Expr) -> T { + self.visit_expr(expr) + } +} + +pub trait ExprMutator { + fn visit_primitive(&self, prim: Primitive) -> Result { + Ok(prim.into()) + } + fn visit_array(&self, arr: Array) -> Result { + Ok(arr.into()) + } + fn visit_function(&self, func: Function) -> Result { + let mut func = func; + func.args = func + .args + .into_iter() + .map(|a| self.visit_expr(a)) + .collect::>>()?; + Ok(func.into()) + } + fn visit_lambda(&self, lamb: Lambda) -> Result { + let mut lamb = lamb; + lamb.body = Box::new(self.visit_expr(*lamb.body.clone())?); + Ok(lamb.into()) + } + fn visit_json_pointer(&self, jp: JsonPointer) -> Result { + Ok(jp.into()) + } + fn visit_expr(&self, expr: Expr) -> Result { + match expr { + Expr::Primitive(a) => self.visit_primitive(a), + Expr::Array(a) => self.visit_array(a), + Expr::Function(a) => self.visit_function(a), + Expr::Lambda(a) => self.visit_lambda(a), + Expr::JsonPointer(a) => self.visit_json_pointer(a), + } + } + fn run(&self, expr: Expr) -> Result { + self.visit_expr(expr) + } +}