From 8e23177c79b13531d26ceec7feab4ae9883d80e9 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 14 Aug 2024 10:46:37 +0200 Subject: [PATCH] Allow destructuring via patterns in for-each/join loops --- src/ast.rs | 4 +-- src/check.rs | 10 +++---- src/compile.rs | 8 +++--- src/parse.rs | 6 ++-- tests/compile.rs | 75 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 89 insertions(+), 14 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 9b178a9..e7e4195 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -277,9 +277,9 @@ pub enum StmtEnum { /// Assignment of an index in a (mutable) array. ArrayAssign(String, Expr, Expr), /// Binds an identifier to each value of an array expr, evaluating the body. - ForEachLoop(String, Expr, Vec>), + ForEachLoop(Pattern, Expr, Vec>), /// Binds an identifier to each joined row of two tables, evaluating the body. - JoinLoop(String, T, (Expr, Expr), Vec>), + JoinLoop(Pattern, T, (Expr, Expr), Vec>), /// An expression (all expressions are statements, but not all statements expressions). Expr(Expr), } diff --git a/src/check.rs b/src/check.rs index 0bee093..d5a8319 100644 --- a/src/check.rs +++ b/src/check.rs @@ -717,7 +717,7 @@ impl UntypedStmt { ))]), } } - ast::StmtEnum::ForEachLoop(var, binding, body) => match &binding.inner { + ast::StmtEnum::ForEachLoop(pattern, binding, body) => match &binding.inner { ExprEnum::FnCall(identifier, args) if identifier == "join" => { let mut errors = vec![]; if args.len() != 2 { @@ -766,13 +766,13 @@ impl UntypedStmt { let elem_ty = Type::Tuple(vec![elem_ty_a, elem_ty_b]); let mut body_typed = Vec::with_capacity(body.len()); env.push(); - env.let_in_current_scope(var.clone(), (Some(elem_ty), Mutability::Immutable)); + let pattern = pattern.type_check(env, fns, defs, Some(elem_ty))?; for stmt in body { body_typed.push(stmt.type_check(top_level_defs, env, fns, defs)?); } env.pop(); Ok(Stmt::new( - StmtEnum::JoinLoop(var.clone(), join_ty, (a, b), body_typed), + StmtEnum::JoinLoop(pattern.clone(), join_ty, (a, b), body_typed), meta, )) } @@ -781,13 +781,13 @@ impl UntypedStmt { let elem_ty = expect_array_type(&binding.ty, meta)?; let mut body_typed = Vec::with_capacity(body.len()); env.push(); - env.let_in_current_scope(var.clone(), (Some(elem_ty), Mutability::Immutable)); + let pattern = pattern.type_check(env, fns, defs, Some(elem_ty))?; for stmt in body { body_typed.push(stmt.type_check(top_level_defs, env, fns, defs)?); } env.pop(); Ok(Stmt::new( - StmtEnum::ForEachLoop(var.clone(), binding, body_typed), + StmtEnum::ForEachLoop(pattern, binding, body_typed), meta, )) } diff --git a/src/compile.rs b/src/compile.rs index 40534ee..8ab0e3c 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -414,7 +414,7 @@ impl TypedStmt { env.assign_mut(identifier.clone(), array); vec![] } - StmtEnum::ForEachLoop(var, array, body) => { + StmtEnum::ForEachLoop(pattern, array, body) => { let elem_in_bits = match &array.ty { Type::Array(elem_ty, _) | Type::ArrayConst(elem_ty, _) => { elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()) @@ -427,7 +427,7 @@ impl TypedStmt { let mut i = 0; while i < array.len() { let binding = &array[i..i + elem_in_bits]; - env.let_in_current_scope(var.clone(), binding.to_vec()); + pattern.compile(binding, prg, env, circuit); for stmt in body { stmt.compile(prg, env, circuit); @@ -437,7 +437,7 @@ impl TypedStmt { env.pop(); vec![] } - StmtEnum::JoinLoop(var, join_ty, (a, b), body) => { + StmtEnum::JoinLoop(pattern, join_ty, (a, b), body) => { let (elem_bits_a, num_elems_a) = match &a.ty { Type::Array(elem_ty, size) => ( elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()), @@ -524,7 +524,7 @@ impl TypedStmt { let mut env_if_join = env.clone(); env_if_join.push(); - env_if_join.let_in_current_scope(var.clone(), binding.to_vec()); + pattern.compile(&binding, prg, &mut env_if_join, circuit); for stmt in body { stmt.compile(prg, &mut env_if_join, circuit); diff --git a/src/parse.rs b/src/parse.rs index be23571..ff2dd57 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -411,8 +411,8 @@ impl Parser { } } } else if let Some(meta) = self.next_matches(&TokenEnum::KeywordFor) { - // for in { } - let (var, _) = self.expect_identifier()?; + // for in { } + let pattern = self.parse_pattern()?; self.expect(&TokenEnum::KeywordIn)?; self.struct_literals_allowed = false; let binding = self.parse_expr()?; @@ -422,7 +422,7 @@ impl Parser { let meta_end = self.expect(&TokenEnum::RightBrace)?; let meta = join_meta(meta, meta_end); return Ok(Stmt::new( - StmtEnum::ForEachLoop(var, binding, loop_body), + StmtEnum::ForEachLoop(pattern, binding, loop_body), meta, )); } else { diff --git a/tests/compile.rs b/tests/compile.rs index 1fe68b1..5c05f0e 100644 --- a/tests/compile.rs +++ b/tests/compile.rs @@ -2197,3 +2197,78 @@ pub fn main(rows1: [(u8, u16); {a}], rows2: [(u8, u16, u16); {b}]) -> u16 {{ } Ok(()) } + +#[test] +fn compile_join_loop_destructuring() -> Result<(), Error> { + let prg = " +pub fn main(rows1: [(u8, u16); 3], rows2: [(u8, u16); 3]) -> u16 { + let mut result = 0u16; + for ((_, a), (_, b)) in join(rows1, rows2) { + result = result + a + b; + } + result +} +"; + let compiled = compile(prg).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.set_literal(Literal::Array(vec![ + Literal::Tuple(vec![ + Literal::NumUnsigned(1, UnsignedNumType::U8), + Literal::NumUnsigned(2, UnsignedNumType::U16), + ]), + Literal::Tuple(vec![ + Literal::NumUnsigned(2, UnsignedNumType::U8), + Literal::NumUnsigned(4, UnsignedNumType::U16), + ]), + Literal::Tuple(vec![ + Literal::NumUnsigned(3, UnsignedNumType::U8), + Literal::NumUnsigned(6, UnsignedNumType::U16), + ]), + ])) + .unwrap(); + eval.set_literal(Literal::Array(vec![ + Literal::Tuple(vec![ + Literal::NumUnsigned(1, UnsignedNumType::U8), + Literal::NumUnsigned(2, UnsignedNumType::U16), + ]), + Literal::Tuple(vec![ + Literal::NumUnsigned(2, UnsignedNumType::U8), + Literal::NumUnsigned(4, UnsignedNumType::U16), + ]), + Literal::Tuple(vec![ + Literal::NumUnsigned(4, UnsignedNumType::U8), + Literal::NumUnsigned(8, UnsignedNumType::U16), + ]), + ])) + .unwrap(); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!( + u16::try_from(output).map_err(|e| pretty_print(e, prg))?, + 2 + 2 + 4 + 4 + ); + Ok(()) +} + +#[test] +fn compile_for_loop_destructuring() -> Result<(), Error> { + let prg = " +pub fn main(_x: i32) -> i32 { + let mut sum = 0i32; + for (a, b) in [(2i32, 4i32), (6i32, 8i32)] { + sum = sum + a + b; + } + sum +} +"; + let compiled = compile(prg).map_err(|e| pretty_print(e, prg))?; + for x in 0..110 { + let mut eval = compiled.evaluator(); + eval.set_i32(x); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!( + i32::try_from(output).map_err(|e| pretty_print(e, prg))?, + 2 + 4 + 6 + 8, + ); + } + Ok(()) +}