From 2b74a643edd5ffea62b7e1d30e21d3eb3a90b5b4 Mon Sep 17 00:00:00 2001 From: jlanson Date: Fri, 27 Sep 2024 16:52:05 -0400 Subject: [PATCH] chore(refactor): introduce visitor pattern for policy expr execution --- hipcheck/src/policy_exprs/env.rs | 35 ++++----- hipcheck/src/policy_exprs/expr.rs | 88 +++++++++++++++++++---- hipcheck/src/policy_exprs/json_pointer.rs | 4 +- hipcheck/src/policy_exprs/mod.rs | 73 +++++++++++-------- 4 files changed, 142 insertions(+), 58 deletions(-) diff --git a/hipcheck/src/policy_exprs/env.rs b/hipcheck/src/policy_exprs/env.rs index 8d36c457..a02af6f7 100644 --- a/hipcheck/src/policy_exprs/env.rs +++ b/hipcheck/src/policy_exprs/env.rs @@ -1,6 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 -use crate::policy_exprs::{eval, Error, Expr, Ident, Primitive, Result, F64}; +use crate::policy_exprs::{ + Array as StructArray, Error, Expr, ExprVisitor, Function as StructFunction, Ident, + Lambda as StructLambda, Primitive, Result, F64, +}; use itertools::Itertools as _; use jiff::{Span, Zoned}; use std::{cmp::Ordering, collections::HashMap, ops::Not as _}; @@ -137,8 +140,8 @@ fn partially_evaluate(fn_name: &'static str, arg: Expr) -> Result { let var_name = "x"; let var = Ident(String::from(var_name)); let func = Ident(String::from(fn_name)); - let op = Function(func, vec![Primitive(Identifier(var.clone())), arg]); - let lambda = Lambda(var, Box::new(op)); + let op = StructFunction::new(func, vec![Primitive(Identifier(var.clone())), arg]).into(); + let lambda = StructLambda::new(var, Box::new(op)).into(); Ok(lambda) } @@ -157,12 +160,12 @@ where check_num_args(name, args, 2)?; - let arg_1 = match eval(env, &args[0])? { + let arg_1 = match env.visit_expr(&args[0])? { Primitive(p) => p, _ => return Err(Error::BadType(name)), }; - let arg_2 = match eval(env, &args[1])? { + let arg_2 = match env.visit_expr(&args[1])? { Primitive(p) => p, _ => return Err(Error::BadType(name)), }; @@ -183,7 +186,7 @@ where { check_num_args(name, args, 1)?; - let primitive = match eval(env, &args[0])? { + let primitive = match env.visit_expr(&args[0])? { Primitive(arg) => arg, _ => return Err(Error::BadType(name)), }; @@ -198,8 +201,8 @@ where { check_num_args(name, args, 1)?; - let arr = match eval(env, &args[0])? { - Array(arg) => array_type(&arg[..])?, + let arr = match env.visit_expr(&args[0])? { + Array(a) => array_type(&a.elts[..])?, _ => return Err(Error::BadType(name)), }; @@ -213,13 +216,13 @@ where { check_num_args(name, args, 2)?; - let (ident, body) = match eval(env, &args[0])? { - Lambda(ident, body) => (ident, body), + let (ident, body) = match env.visit_expr(&args[0])? { + Lambda(l) => (l.arg, l.body), _ => return Err(Error::BadType(name)), }; - let arr = match eval(env, &args[1])? { - Array(arr) => array_type(&arr[..])?, + let arr = match env.visit_expr(&args[1])? { + Array(a) => array_type(&a.elts[..])?, _ => return Err(Error::BadType(name)), }; @@ -322,7 +325,7 @@ fn eval_lambda(env: &Env, ident: &Ident, val: Primitive, body: Expr) -> Result Result { ArrayType::Empty => Vec::new(), }; - Ok(Array(arr)) + Ok(StructArray::new(arr).into()) }; higher_order_array_op(name, env, args, op) @@ -1003,7 +1006,7 @@ fn foreach(env: &Env, args: &[Expr]) -> Result { ArrayType::Empty => Vec::new(), }; - Ok(Array(arr)) + Ok(StructArray::new(arr).into()) }; higher_order_array_op(name, env, args, op) @@ -1013,7 +1016,7 @@ fn dbg(env: &Env, args: &[Expr]) -> Result { let name = "dbg"; check_num_args(name, args, 1)?; let arg = &args[0]; - let result = eval(env, arg)?; + let result = env.visit_expr(arg)?; log::debug!("{arg} = {result}"); Ok(result) } diff --git a/hipcheck/src/policy_exprs/expr.rs b/hipcheck/src/policy_exprs/expr.rs index 41f9adbb..ba824d2b 100644 --- a/hipcheck/src/policy_exprs/expr.rs +++ b/hipcheck/src/policy_exprs/expr.rs @@ -27,18 +27,68 @@ pub enum Expr { Primitive(Primitive), /// An array of primitive data. - Array(Vec), + Array(Array), /// Stores the name of the function, followed by the args. - Function(Ident, Vec), + Function(Function), /// Stores the name of the input variable, followed by the lambda body. - Lambda(Ident, Box), + Lambda(Lambda), /// Stores a late-binding for a JSON value. JsonPointer(JsonPointer), } +/// An array of primitives. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Array { + pub elts: Vec, +} +impl Array { + pub fn new(elts: Vec) -> Self { + Array { elts } + } +} +impl From for Expr { + fn from(value: Array) -> Self { + Expr::Array(value) + } +} + +/// A `deke` function to evaluate. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Function { + pub ident: Ident, + pub args: Vec, +} +impl Function { + pub fn new(ident: Ident, args: Vec) -> Self { + Function { ident, args } + } +} +impl From for Expr { + fn from(value: Function) -> Self { + Expr::Function(value) + } +} + +/// Stores the name of the input variable, followed by the lambda body. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Lambda { + pub arg: Ident, + pub body: Box, +} +impl Lambda { + pub fn new(arg: Ident, body: Box) -> Self { + Lambda { arg, body } + } +} +impl From for Expr { + fn from(value: Lambda) -> Self { + Expr::Lambda(value) + } +} + /// Primitive data. #[derive(Debug, PartialEq, Eq, Clone)] pub enum Primitive { @@ -76,6 +126,11 @@ pub enum Primitive { /// "P1w1dT1h1m1.1s" Span(Span), } +impl From for Expr { + fn from(value: Primitive) -> Self { + Expr::Primitive(value) + } +} /// A variable or function identifier. #[derive(Debug, Clone, PartialEq, Eq)] @@ -87,6 +142,11 @@ pub struct JsonPointer { pointer: String, value: Option, } +impl From for Expr { + fn from(value: JsonPointer) -> Self { + Expr::JsonPointer(value) + } +} /// A non-NaN 64-bit floating point number. pub type F64 = NotNan; @@ -96,13 +156,17 @@ impl Display for Expr { match self { Expr::Primitive(primitive) => write!(f, "{}", primitive), Expr::Array(array) => { - write!(f, "[{}]", array.iter().map(ToString::to_string).join(" ")) + write!( + f, + "[{}]", + array.elts.iter().map(ToString::to_string).join(" ") + ) } - Expr::Function(ident, args) => { - let args = args.iter().map(ToString::to_string).join(" "); - write!(f, "({} {})", ident, args) + Expr::Function(func) => { + let args = func.args.iter().map(ToString::to_string).join(" "); + write!(f, "({} {})", func.ident, args) } - Expr::Lambda(arg, body) => write!(f, "(lambda ({}) {}", arg, body), + Expr::Lambda(l) => write!(f, "(lambda ({}) {}", l.arg, l.body), Expr::JsonPointer(pointer) => write!(f, "${}", pointer.pointer), } } @@ -206,7 +270,7 @@ fn parse_primitive(input: Input<'_>) -> IResult, Primitive> { /// Parse an array. fn parse_array(input: Input<'_>) -> IResult, Expr> { let parser = tuple((Token::OpenBrace, many0(parse_primitive), Token::CloseBrace)); - let mut parser = map(parser, |(_, inner, _)| Expr::Array(inner)); + let mut parser = map(parser, |(_, inner, _)| Array::new(inner).into()); parser(input) } @@ -225,7 +289,7 @@ fn parse_function(input: Input<'_>) -> IResult, Expr> { Token::CloseParen, )); let mut parser = map(parser, |(_, ident, args, _)| { - Expr::Function(Ident(ident), args) + Function::new(Ident(ident), args).into() }); parser(input) } @@ -274,7 +338,7 @@ mod tests { fn func(name: &str, args: Vec) -> Expr { let args = args.into_iter().map(|arg| arg.into_expr()).collect(); - Expr::Function(Ident(String::from(name)), args) + Function::new(Ident(String::from(name)), args).into() } fn int(val: i64) -> Primitive { @@ -298,7 +362,7 @@ mod tests { } fn array(vals: Vec) -> Expr { - Expr::Array(vals) + Array::new(vals).into() } fn json_ptr(name: &str) -> Expr { diff --git a/hipcheck/src/policy_exprs/json_pointer.rs b/hipcheck/src/policy_exprs/json_pointer.rs index b655799f..fb1c79b2 100644 --- a/hipcheck/src/policy_exprs/json_pointer.rs +++ b/hipcheck/src/policy_exprs/json_pointer.rs @@ -3,7 +3,7 @@ use crate::policy_exprs::{ error, error::{Error, Result}, - expr::{Expr, Primitive}, + expr::{Array, Expr, Primitive}, }; use ordered_float::NotNan; use regex::{Captures, Regex, RegexBuilder}; @@ -116,7 +116,7 @@ fn json_to_policy_expr(val: &Value, pointer: &str, context: &Value) -> Result>>()?; // NOTE that no checking is done to confirm that all Primitives are the same type. // That would be a type error in the Policy Expr language. - Ok(Expr::Array(primitives)) + Ok(Array::new(primitives).into()) } // Strings cannot (currently) be represented in the Policy Expr language. Value::String(_) => Err(Error::JSONPointerUnrepresentableType { diff --git a/hipcheck/src/policy_exprs/mod.rs b/hipcheck/src/policy_exprs/mod.rs index d5a08994..03308cf1 100644 --- a/hipcheck/src/policy_exprs/mod.rs +++ b/hipcheck/src/policy_exprs/mod.rs @@ -13,7 +13,7 @@ use crate::policy_exprs::env::Env; pub(crate) use crate::policy_exprs::{bridge::Tokens, expr::F64}; pub use crate::policy_exprs::{ error::{Error, Result}, - expr::{Expr, Ident}, + expr::{Array, Expr, Function, Ident, JsonPointer, Lambda}, token::LexingError, }; use env::Binding; @@ -22,6 +22,23 @@ use json_pointer::process_json_pointers; use serde_json::Value; use std::ops::Deref; +pub trait ExprVisitor { + fn visit_primitive(&self, prim: &Primitive) -> T; + fn visit_array(&self, arr: &Array) -> T; + fn visit_function(&self, func: &Function) -> T; + fn visit_lambda(&self, func: &Lambda) -> T; + fn visit_json_pointer(&self, func: &JsonPointer) -> T; + fn visit_expr(&self, expr: &Expr) -> T { + match expr { + Expr::Primitive(a) => self.visit_primitive(a), + Expr::Array(a) => self.visit_array(a), + Expr::Function(a) => self.visit_function(a), + Expr::Lambda(a) => self.visit_lambda(a), + Expr::JsonPointer(a) => self.visit_json_pointer(a), + } + } +} + /// Evaluates `deke` expressions. pub struct Executor { env: Env<'static>, @@ -45,34 +62,33 @@ impl Executor { pub fn parse_and_eval(&self, raw_program: &str, context: &Value) -> Result { let processed_program = process_json_pointers(raw_program, context)?; let program = parse(&processed_program)?; - let expr = eval(&self.env, &program)?; + let expr = self.env.visit_expr(&program)?; Ok(expr) } } - -/// Evaluate the `Expr`, returning a boolean. -pub(crate) fn eval(env: &Env, program: &Expr) -> Result { - let output = match program { - Expr::Primitive(primitive) => Ok(Expr::Primitive(primitive.resolve(env)?)), - Expr::Array(_) => Ok(program.clone()), - Expr::Function(name, args) => { - let binding = env - .get(name) - .ok_or_else(|| Error::UnknownFunction(name.deref().to_owned()))?; - - if let Binding::Fn(op) = binding { - op(env, args) - } else { - Err(Error::FoundVarExpectedFunc(name.deref().to_owned())) - } +impl ExprVisitor> for Env<'_> { + fn visit_primitive(&self, prim: &Primitive) -> Result { + Ok(prim.resolve(self)?.into()) + } + fn visit_array(&self, array: &Array) -> Result { + Ok(array.clone().into()) + } + fn visit_function(&self, f: &Function) -> Result { + let binding = self + .get(&f.ident) + .ok_or_else(|| Error::UnknownFunction(f.ident.deref().to_owned()))?; + if let Binding::Fn(op) = binding { + op(self, &f.args) + } else { + Err(Error::FoundVarExpectedFunc(f.ident.deref().to_owned())) } - Expr::Lambda(_, body) => Ok((**body).clone()), - Expr::JsonPointer(_) => unreachable!(), - }; - - log::debug!("input: {program:?}, output: {output:?}"); - - output + } + fn visit_lambda(&self, l: &Lambda) -> Result { + Ok((*l.body).clone()) + } + fn visit_json_pointer(&self, json_pointer: &JsonPointer) -> Result { + Ok(json_pointer.clone().into()) + } } #[cfg(test)] @@ -177,7 +193,7 @@ mod tests { let program = "(eq 3 (count (filter (gt 8.0) [1.0 2.0 10.0 20.0 30.0])))"; let context = Value::Null; let result = Executor::std().parse_and_eval(program, &context).unwrap(); - assert_eq!(result, Expr::Primitive(Primitive::Bool(true))); + assert_eq!(result, Primitive::Bool(true).into()); } #[test] @@ -186,7 +202,7 @@ mod tests { "(eq 3 (count (filter (gt 8.0) (foreach (sub 1.0) [1.0 2.0 10.0 20.0 30.0]))))"; let context = Value::Null; let result = Executor::std().parse_and_eval(program, &context).unwrap(); - assert_eq!(result, Expr::Primitive(Primitive::Bool(true))); + assert_eq!(result, Primitive::Bool(true).into()); } #[test] @@ -196,11 +212,12 @@ mod tests { let result = Executor::std().parse_and_eval(program, &context).unwrap(); assert_eq!( result, - Expr::Array(vec![ + Array::new(vec![ Primitive::Int(0), Primitive::Int(0), Primitive::Int(0) ]) + .into() ); }