From 5b9c5be48b47b1e11291b8a45fba41deb2a7a9d7 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 3 Jul 2024 18:45:44 +0200 Subject: [PATCH] Simplify built-in join fn by removing size param --- src/ast.rs | 2 +- src/check.rs | 40 +++++++++------------------------------- src/compile.rs | 2 +- tests/circuit.rs | 19 +++++++++++++++++++ tests/compile.rs | 2 +- 5 files changed, 31 insertions(+), 34 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 0bb7983..447600f 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -279,7 +279,7 @@ pub enum StmtEnum { /// Binds an identifier to each value of an array expr, evaluating the body. ForEachLoop(String, Expr, Vec>), /// Binds an identifier to each joined row of two tables, evaluating the body. - JoinLoop(String, usize, (Expr, Expr), Vec>), + JoinLoop(String, (Expr, Expr), Vec>), /// An expression (all expressions are statements, but not all statements expressions). Expr(Expr), } diff --git a/src/check.rs b/src/check.rs index e626e00..5d17d41 100644 --- a/src/check.rs +++ b/src/check.rs @@ -720,32 +720,16 @@ impl UntypedStmt { ast::StmtEnum::ForEachLoop(var, binding, body) => match &binding.inner { ExprEnum::FnCall(identifier, args) if identifier == "join" => { let mut errors = vec![]; - if args.len() != 3 { + if args.len() != 2 { let e = TypeErrorEnum::WrongNumberOfArgs { - expected: 3, + expected: 2, actual: args.len(), }; return Err(vec![Some(TypeError(e, meta))]); } - let mut size = 0; - let mut arg_exprs: Vec> = vec![]; - if let ExprEnum::NumUnsigned(n, UnsignedNumType::Usize) = args[0].inner { - size = n as usize; - arg_exprs.push(Expr { - inner: ExprEnum::NumUnsigned(n, UnsignedNumType::Usize), - meta: args[0].meta, - ty: Type::Unsigned(UnsignedNumType::Usize), - }); - } else { - let e = TypeErrorEnum::UsizeNotLiteral; - errors.push(Some(TypeError(e, args[0].meta))); - } - let a = args[1].type_check(top_level_defs, env, fns, defs)?; + let a = args[0].type_check(top_level_defs, env, fns, defs)?; let (ty_a, meta_a) = match &a.ty { - Type::Array(_, _) | Type::ArrayConst(_, _) => { - arg_exprs.push(a.clone()); - (a.ty.clone(), a.meta) - } + Type::Array(_, _) | Type::ArrayConst(_, _) => (a.ty.clone(), a.meta), ty => { errors.push(Some(TypeError( TypeErrorEnum::ExpectedArrayType(ty.clone()), @@ -754,12 +738,9 @@ impl UntypedStmt { (ty.clone(), a.meta) } }; - let b = args[2].type_check(top_level_defs, env, fns, defs)?; + let b = args[1].type_check(top_level_defs, env, fns, defs)?; let (ty_b, meta_b) = match &b.ty { - Type::Array(_, _) | Type::ArrayConst(_, _) => { - arg_exprs.push(b.clone()); - (b.ty.clone(), b.meta) - } + Type::Array(_, _) | Type::ArrayConst(_, _) => (b.ty.clone(), b.meta), ty => { errors.push(Some(TypeError( TypeErrorEnum::ExpectedArrayType(ty.clone()), @@ -775,10 +756,7 @@ impl UntypedStmt { let elem_ty_b = expect_array_type(&ty_b, meta_b)?; let tuple_a = expect_tuple_type(&elem_ty_a, meta_a)?; let tuple_b = expect_tuple_type(&elem_ty_b, meta_b)?; - if tuple_a.len() < size - || tuple_b.len() < size - || tuple_a[..size] != tuple_b[..size] - { + if tuple_a.is_empty() || tuple_b.is_empty() || tuple_a[0] != tuple_b[0] { return Err(vec![Some(TypeError( TypeErrorEnum::TypeMismatch(elem_ty_a, elem_ty_b), meta, @@ -793,7 +771,7 @@ impl UntypedStmt { } env.pop(); Ok(Stmt::new( - StmtEnum::JoinLoop(var.clone(), size, (a, b), body_typed), + StmtEnum::JoinLoop(var.clone(), (a, b), body_typed), meta, )) } @@ -813,7 +791,7 @@ impl UntypedStmt { )) } }, - ast::StmtEnum::JoinLoop(_, _, _, _) => { + ast::StmtEnum::JoinLoop(_, _, _) => { unreachable!("Untyped expressions should never be join loops") } } diff --git a/src/compile.rs b/src/compile.rs index 23787d5..d04fb73 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -437,7 +437,7 @@ impl TypedStmt { env.pop(); vec![] } - StmtEnum::JoinLoop(_, _, _, _) => { + StmtEnum::JoinLoop(_, _, _) => { todo!("compile join loop") } } diff --git a/tests/circuit.rs b/tests/circuit.rs index 3960d12..bda189d 100644 --- a/tests/circuit.rs +++ b/tests/circuit.rs @@ -150,3 +150,22 @@ pub fn main(arr1: [u8; 8], arr2: [u8; 8], choice: bool) -> [u8; 8] { ); Ok(()) } + +#[test] +fn optimize_mapped_arrays() -> Result<(), String> { + let prg = " +pub fn main(arr1: [(u16, u16, u32); 8]) -> [((u16, u16), u32); 8] { + let mut arr2 = [((0u16, 0u16), 0u32); 8]; + let mut i = 0usize; + for elem in arr1 { + let (a, b, c) = elem; + arr2[i] = ((a, b), c); + i = i + 1usize; + } + arr2 +}"; + let compiled = compile(prg).map_err(|e| e.prettify(prg))?; + assert_eq!(compiled.circuit.and_gates(), 0); + assert_eq!(compiled.circuit.gates.len(), 2); + Ok(()) +} diff --git a/tests/compile.rs b/tests/compile.rs index 2e41aee..e820a44 100644 --- a/tests/compile.rs +++ b/tests/compile.rs @@ -2057,7 +2057,7 @@ fn compile_join_fn() -> Result<(), Error> { let prg = " pub fn main(rows1: [([u8; 3], u16); 4], rows2: [([u8; 3], u16, u16); 3]) -> u16 { let mut result = 0u16; - for row in join(1usize, rows1, rows2) { + for row in join(rows1, rows2) { let ((_, field1), (_, field2, field3)) = row; result = result + field1 + field2 + field3; }