Skip to content

Commit

Permalink
Simplify built-in join fn by removing size param
Browse files Browse the repository at this point in the history
  • Loading branch information
fkettelhoit committed Jul 3, 2024
1 parent 99e0416 commit 5b9c5be
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ pub enum StmtEnum<T> {
/// Binds an identifier to each value of an array expr, evaluating the body.
ForEachLoop(String, Expr<T>, Vec<Stmt<T>>),
/// Binds an identifier to each joined row of two tables, evaluating the body.
JoinLoop(String, usize, (Expr<T>, Expr<T>), Vec<Stmt<T>>),
JoinLoop(String, (Expr<T>, Expr<T>), Vec<Stmt<T>>),
/// An expression (all expressions are statements, but not all statements expressions).
Expr(Expr<T>),
}
Expand Down
40 changes: 9 additions & 31 deletions src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -720,32 +720,16 @@ impl UntypedStmt {
ast::StmtEnum::ForEachLoop(var, binding, body) => match &binding.inner {
ExprEnum::FnCall(identifier, args) if identifier == "join" => {
let mut errors = vec![];
if args.len() != 3 {
if args.len() != 2 {
let e = TypeErrorEnum::WrongNumberOfArgs {
expected: 3,
expected: 2,
actual: args.len(),
};
return Err(vec![Some(TypeError(e, meta))]);
}
let mut size = 0;
let mut arg_exprs: Vec<Expr<Type>> = vec![];
if let ExprEnum::NumUnsigned(n, UnsignedNumType::Usize) = args[0].inner {
size = n as usize;
arg_exprs.push(Expr {
inner: ExprEnum::NumUnsigned(n, UnsignedNumType::Usize),
meta: args[0].meta,
ty: Type::Unsigned(UnsignedNumType::Usize),
});
} else {
let e = TypeErrorEnum::UsizeNotLiteral;
errors.push(Some(TypeError(e, args[0].meta)));
}
let a = args[1].type_check(top_level_defs, env, fns, defs)?;
let a = args[0].type_check(top_level_defs, env, fns, defs)?;
let (ty_a, meta_a) = match &a.ty {
Type::Array(_, _) | Type::ArrayConst(_, _) => {
arg_exprs.push(a.clone());
(a.ty.clone(), a.meta)
}
Type::Array(_, _) | Type::ArrayConst(_, _) => (a.ty.clone(), a.meta),
ty => {
errors.push(Some(TypeError(
TypeErrorEnum::ExpectedArrayType(ty.clone()),
Expand All @@ -754,12 +738,9 @@ impl UntypedStmt {
(ty.clone(), a.meta)
}
};
let b = args[2].type_check(top_level_defs, env, fns, defs)?;
let b = args[1].type_check(top_level_defs, env, fns, defs)?;
let (ty_b, meta_b) = match &b.ty {
Type::Array(_, _) | Type::ArrayConst(_, _) => {
arg_exprs.push(b.clone());
(b.ty.clone(), b.meta)
}
Type::Array(_, _) | Type::ArrayConst(_, _) => (b.ty.clone(), b.meta),
ty => {
errors.push(Some(TypeError(
TypeErrorEnum::ExpectedArrayType(ty.clone()),
Expand All @@ -775,10 +756,7 @@ impl UntypedStmt {
let elem_ty_b = expect_array_type(&ty_b, meta_b)?;
let tuple_a = expect_tuple_type(&elem_ty_a, meta_a)?;
let tuple_b = expect_tuple_type(&elem_ty_b, meta_b)?;
if tuple_a.len() < size
|| tuple_b.len() < size
|| tuple_a[..size] != tuple_b[..size]
{
if tuple_a.is_empty() || tuple_b.is_empty() || tuple_a[0] != tuple_b[0] {
return Err(vec![Some(TypeError(
TypeErrorEnum::TypeMismatch(elem_ty_a, elem_ty_b),
meta,
Expand All @@ -793,7 +771,7 @@ impl UntypedStmt {
}
env.pop();
Ok(Stmt::new(
StmtEnum::JoinLoop(var.clone(), size, (a, b), body_typed),
StmtEnum::JoinLoop(var.clone(), (a, b), body_typed),
meta,
))
}
Expand All @@ -813,7 +791,7 @@ impl UntypedStmt {
))
}
},
ast::StmtEnum::JoinLoop(_, _, _, _) => {
ast::StmtEnum::JoinLoop(_, _, _) => {
unreachable!("Untyped expressions should never be join loops")
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ impl TypedStmt {
env.pop();
vec![]
}
StmtEnum::JoinLoop(_, _, _, _) => {
StmtEnum::JoinLoop(_, _, _) => {
todo!("compile join loop")
}
}
Expand Down
19 changes: 19 additions & 0 deletions tests/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,22 @@ pub fn main(arr1: [u8; 8], arr2: [u8; 8], choice: bool) -> [u8; 8] {
);
Ok(())
}

#[test]
fn optimize_mapped_arrays() -> Result<(), String> {
let prg = "
pub fn main(arr1: [(u16, u16, u32); 8]) -> [((u16, u16), u32); 8] {
let mut arr2 = [((0u16, 0u16), 0u32); 8];
let mut i = 0usize;
for elem in arr1 {
let (a, b, c) = elem;
arr2[i] = ((a, b), c);
i = i + 1usize;
}
arr2
}";
let compiled = compile(prg).map_err(|e| e.prettify(prg))?;
assert_eq!(compiled.circuit.and_gates(), 0);
assert_eq!(compiled.circuit.gates.len(), 2);
Ok(())
}
2 changes: 1 addition & 1 deletion tests/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2057,7 +2057,7 @@ fn compile_join_fn() -> Result<(), Error> {
let prg = "
pub fn main(rows1: [([u8; 3], u16); 4], rows2: [([u8; 3], u16, u16); 3]) -> u16 {
let mut result = 0u16;
for row in join(1usize, rows1, rows2) {
for row in join(rows1, rows2) {
let ((_, field1), (_, field2, field3)) = row;
result = result + field1 + field2 + field3;
}
Expand Down

0 comments on commit 5b9c5be

Please sign in to comment.