Skip to content

Commit

Permalink
Allow destructuring via patterns in for-each/join loops
Browse files Browse the repository at this point in the history
  • Loading branch information
fkettelhoit committed Aug 14, 2024
1 parent 1e52a4e commit 8e23177
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,9 @@ pub enum StmtEnum<T> {
/// Assignment of an index in a (mutable) array.
ArrayAssign(String, Expr<T>, Expr<T>),
/// Binds an identifier to each value of an array expr, evaluating the body.
ForEachLoop(String, Expr<T>, Vec<Stmt<T>>),
ForEachLoop(Pattern<T>, Expr<T>, Vec<Stmt<T>>),
/// Binds an identifier to each joined row of two tables, evaluating the body.
JoinLoop(String, T, (Expr<T>, Expr<T>), Vec<Stmt<T>>),
JoinLoop(Pattern<T>, T, (Expr<T>, Expr<T>), Vec<Stmt<T>>),
/// An expression (all expressions are statements, but not all statements expressions).
Expr(Expr<T>),
}
Expand Down
10 changes: 5 additions & 5 deletions src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ impl UntypedStmt {
))]),
}
}
ast::StmtEnum::ForEachLoop(var, binding, body) => match &binding.inner {
ast::StmtEnum::ForEachLoop(pattern, binding, body) => match &binding.inner {
ExprEnum::FnCall(identifier, args) if identifier == "join" => {
let mut errors = vec![];
if args.len() != 2 {
Expand Down Expand Up @@ -766,13 +766,13 @@ impl UntypedStmt {
let elem_ty = Type::Tuple(vec![elem_ty_a, elem_ty_b]);
let mut body_typed = Vec::with_capacity(body.len());
env.push();
env.let_in_current_scope(var.clone(), (Some(elem_ty), Mutability::Immutable));
let pattern = pattern.type_check(env, fns, defs, Some(elem_ty))?;
for stmt in body {
body_typed.push(stmt.type_check(top_level_defs, env, fns, defs)?);
}
env.pop();
Ok(Stmt::new(
StmtEnum::JoinLoop(var.clone(), join_ty, (a, b), body_typed),
StmtEnum::JoinLoop(pattern.clone(), join_ty, (a, b), body_typed),
meta,
))
}
Expand All @@ -781,13 +781,13 @@ impl UntypedStmt {
let elem_ty = expect_array_type(&binding.ty, meta)?;
let mut body_typed = Vec::with_capacity(body.len());
env.push();
env.let_in_current_scope(var.clone(), (Some(elem_ty), Mutability::Immutable));
let pattern = pattern.type_check(env, fns, defs, Some(elem_ty))?;
for stmt in body {
body_typed.push(stmt.type_check(top_level_defs, env, fns, defs)?);
}
env.pop();
Ok(Stmt::new(
StmtEnum::ForEachLoop(var.clone(), binding, body_typed),
StmtEnum::ForEachLoop(pattern, binding, body_typed),
meta,
))
}
Expand Down
8 changes: 4 additions & 4 deletions src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ impl TypedStmt {
env.assign_mut(identifier.clone(), array);
vec![]
}
StmtEnum::ForEachLoop(var, array, body) => {
StmtEnum::ForEachLoop(pattern, array, body) => {
let elem_in_bits = match &array.ty {
Type::Array(elem_ty, _) | Type::ArrayConst(elem_ty, _) => {
elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes())
Expand All @@ -427,7 +427,7 @@ impl TypedStmt {
let mut i = 0;
while i < array.len() {
let binding = &array[i..i + elem_in_bits];
env.let_in_current_scope(var.clone(), binding.to_vec());
pattern.compile(binding, prg, env, circuit);

for stmt in body {
stmt.compile(prg, env, circuit);
Expand All @@ -437,7 +437,7 @@ impl TypedStmt {
env.pop();
vec![]
}
StmtEnum::JoinLoop(var, join_ty, (a, b), body) => {
StmtEnum::JoinLoop(pattern, join_ty, (a, b), body) => {
let (elem_bits_a, num_elems_a) = match &a.ty {
Type::Array(elem_ty, size) => (
elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()),
Expand Down Expand Up @@ -524,7 +524,7 @@ impl TypedStmt {

let mut env_if_join = env.clone();
env_if_join.push();
env_if_join.let_in_current_scope(var.clone(), binding.to_vec());
pattern.compile(&binding, prg, &mut env_if_join, circuit);

for stmt in body {
stmt.compile(prg, &mut env_if_join, circuit);
Expand Down
6 changes: 3 additions & 3 deletions src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ impl Parser {
}
}
} else if let Some(meta) = self.next_matches(&TokenEnum::KeywordFor) {
// for <var> in <binding> { <body> }
let (var, _) = self.expect_identifier()?;
// for <pattern> in <binding> { <body> }
let pattern = self.parse_pattern()?;
self.expect(&TokenEnum::KeywordIn)?;
self.struct_literals_allowed = false;
let binding = self.parse_expr()?;
Expand All @@ -422,7 +422,7 @@ impl Parser {
let meta_end = self.expect(&TokenEnum::RightBrace)?;
let meta = join_meta(meta, meta_end);
return Ok(Stmt::new(
StmtEnum::ForEachLoop(var, binding, loop_body),
StmtEnum::ForEachLoop(pattern, binding, loop_body),
meta,
));
} else {
Expand Down
75 changes: 75 additions & 0 deletions tests/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2197,3 +2197,78 @@ pub fn main(rows1: [(u8, u16); {a}], rows2: [(u8, u16, u16); {b}]) -> u16 {{
}
Ok(())
}

#[test]
fn compile_join_loop_destructuring() -> Result<(), Error> {
let prg = "
pub fn main(rows1: [(u8, u16); 3], rows2: [(u8, u16); 3]) -> u16 {
let mut result = 0u16;
for ((_, a), (_, b)) in join(rows1, rows2) {
result = result + a + b;
}
result
}
";
let compiled = compile(prg).map_err(|e| pretty_print(e, prg))?;
let mut eval = compiled.evaluator();
eval.set_literal(Literal::Array(vec![
Literal::Tuple(vec![
Literal::NumUnsigned(1, UnsignedNumType::U8),
Literal::NumUnsigned(2, UnsignedNumType::U16),
]),
Literal::Tuple(vec![
Literal::NumUnsigned(2, UnsignedNumType::U8),
Literal::NumUnsigned(4, UnsignedNumType::U16),
]),
Literal::Tuple(vec![
Literal::NumUnsigned(3, UnsignedNumType::U8),
Literal::NumUnsigned(6, UnsignedNumType::U16),
]),
]))
.unwrap();
eval.set_literal(Literal::Array(vec![
Literal::Tuple(vec![
Literal::NumUnsigned(1, UnsignedNumType::U8),
Literal::NumUnsigned(2, UnsignedNumType::U16),
]),
Literal::Tuple(vec![
Literal::NumUnsigned(2, UnsignedNumType::U8),
Literal::NumUnsigned(4, UnsignedNumType::U16),
]),
Literal::Tuple(vec![
Literal::NumUnsigned(4, UnsignedNumType::U8),
Literal::NumUnsigned(8, UnsignedNumType::U16),
]),
]))
.unwrap();
let output = eval.run().map_err(|e| pretty_print(e, prg))?;
assert_eq!(
u16::try_from(output).map_err(|e| pretty_print(e, prg))?,
2 + 2 + 4 + 4
);
Ok(())
}

#[test]
fn compile_for_loop_destructuring() -> Result<(), Error> {
let prg = "
pub fn main(_x: i32) -> i32 {
let mut sum = 0i32;
for (a, b) in [(2i32, 4i32), (6i32, 8i32)] {
sum = sum + a + b;
}
sum
}
";
let compiled = compile(prg).map_err(|e| pretty_print(e, prg))?;
for x in 0..110 {
let mut eval = compiled.evaluator();
eval.set_i32(x);
let output = eval.run().map_err(|e| pretty_print(e, prg))?;
assert_eq!(
i32::try_from(output).map_err(|e| pretty_print(e, prg))?,
2 + 4 + 6 + 8,
);
}
Ok(())
}

0 comments on commit 8e23177

Please sign in to comment.