From dd08217ca5450325b73ef630985835c79964a339 Mon Sep 17 00:00:00 2001 From: microproofs Date: Sat, 11 Jan 2025 17:04:26 +0700 Subject: [PATCH] New optimization to split independent lam function applications to enable case constr to optimize further --- crates/aiken-lang/src/gen_uplc.rs | 10 +- ..._project__export__tests__basic_export.snap | 5 +- ...oject__export__tests__recursive_types.snap | 4 +- crates/uplc/src/optimize.rs | 2 +- crates/uplc/src/optimize/shrinker.rs | 173 ++++++++++++++++-- 5 files changed, 167 insertions(+), 27 deletions(-) diff --git a/crates/aiken-lang/src/gen_uplc.rs b/crates/aiken-lang/src/gen_uplc.rs index 2c4c27500..160ea970f 100644 --- a/crates/aiken-lang/src/gen_uplc.rs +++ b/crates/aiken-lang/src/gen_uplc.rs @@ -3712,7 +3712,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.clean_up(false).try_into().unwrap(); + program.clean_up_no_inlines().try_into().unwrap(); Some( eval_program @@ -3822,7 +3822,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.clean_up(false).try_into().unwrap(); + program.clean_up_no_inlines().try_into().unwrap(); let evaluated_term: Term = eval_program .eval(ExBudget::default()) @@ -4364,7 +4364,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.clean_up(false).try_into().unwrap(); + program.clean_up_no_inlines().try_into().unwrap(); let evaluated_term: Term = eval_program .eval(ExBudget::default()) @@ -4389,7 +4389,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.clean_up(false).try_into().unwrap(); + program.clean_up_no_inlines().try_into().unwrap(); let evaluated_term: Term = eval_program .eval(ExBudget::default()) @@ -4802,7 +4802,7 @@ impl<'a> CodeGenerator<'a> { interner.program(&mut program); let eval_program: Program = - program.clean_up(false).try_into().unwrap(); + program.clean_up_no_inlines().try_into().unwrap(); let evaluated_term: Term = eval_program .eval(ExBudget::default()) diff --git a/crates/aiken-project/src/snapshots/aiken_project__export__tests__basic_export.snap b/crates/aiken-project/src/snapshots/aiken_project__export__tests__basic_export.snap index ed1fd9f89..46632a51d 100644 --- a/crates/aiken-project/src/snapshots/aiken_project__export__tests__basic_export.snap +++ b/crates/aiken-project/src/snapshots/aiken_project__export__tests__basic_export.snap @@ -1,7 +1,6 @@ --- source: crates/aiken-project/src/export.rs description: "Code:\n\npub fn add(a: Int, b: Int) -> Int {\n a + b\n}\n" -snapshot_kind: text --- { "name": "test_module.add", @@ -25,8 +24,8 @@ snapshot_kind: text "$ref": "#/definitions/Int" } }, - "compiledCode": "500101002322337000046eb4004dd68009", - "hash": "b8374597a772cef80d891b7f6a03588e10cc19b780251228ba4ce9c6", + "compiledCode": "500101002232337000026eb4008dd68011", + "hash": "e5951afb3263ef11acc0b4c88cd5f5b30b8621ce63fe024b3ea2bec8", "definitions": { "Int": { "dataType": "integer" diff --git a/crates/aiken-project/src/snapshots/aiken_project__export__tests__recursive_types.snap b/crates/aiken-project/src/snapshots/aiken_project__export__tests__recursive_types.snap index 440a96c30..8c4d40af9 100644 --- a/crates/aiken-project/src/snapshots/aiken_project__export__tests__recursive_types.snap +++ b/crates/aiken-project/src/snapshots/aiken_project__export__tests__recursive_types.snap @@ -24,8 +24,8 @@ description: "Code:\n\npub type Foo {\n Empty\n Bar(a, Foo)\n}\n\npub fn "$ref": "#/definitions/Int" } }, - "compiledCode": "5901870101009800aba2aba1aba0aab9eaab9dab9a488888888c8c8c8c8c966002600860126ea8006264b30013005300a375400314800226466e00dd698070009980226103d8798000300e300f001300b37540028048c030c034016264b30013370e900118051baa0018991919b80337006eb4c03c008dd6980780099802980798080011807980800098061baa002300b3754005132337006eb4c038004cc010c038c03c00530103d8798000300b37540048048c030c0340150081805802180080091119192cc004c018c02cdd5000c4c966002600e60186ea80062900044c8cdc01bad30100019800803d300103d879800098081808800a00e300d37540028058c038c03c00a264b30013370e900118061baa0018991919b80337006eb4c044008dd69808800cc00402260226024005301130120014020601c6ea8008c034dd500144c8cdc01bad30100019800803cc040c044006980103d8798000401c601a6ea800900b180718078012014300d0013300b0023300b0014bd701b8748000cc018008cc0180052f5c01", - "hash": "247535960781372d3b2097595ebd748bd61be7c8f2f264e460e095b3", + "compiledCode": "590186010100229800aba2aba1aba0aab9eaab9dab9a9b874800122222223322332259800980298039baa0018992cc004c018c020dd5000c5200089919b80375a60180026600898103d8798000300c300d001300937540028038c028c02c012264b30013370e900118041baa0018999119b80337006eb4c034008dd6980680099802980698070011806980700098049baa00230093754003132337006eb4c030004cc010c030c03400530103d8798000300937540048038c028c02c01100618008009804001198028049980280425eb80888c8c966002600c60106ea8006264b300130073009375400314800226466e00dd69806800cc00401e98103d879800098069807000a00e300a37540028040c02cc03000a264b30013370e900118049baa0018999119b80337006eb4c038008dd69807000cc004022601c601e005300e300f001402060146ea8008c028dd5000c4c8cdc01bad300d0019800803cc034c038006980103d8798000401c60146ea800900818059806001200e300a00133008002330080014bd701", + "hash": "dc9b9c2bbcfb1cb422534ed1c4d04f2e2b9b57a0a498175d055f83e8", "definitions": { "Int": { "dataType": "integer" diff --git a/crates/uplc/src/optimize.rs b/crates/uplc/src/optimize.rs index 0b171217e..bbc0a1c5e 100644 --- a/crates/uplc/src/optimize.rs +++ b/crates/uplc/src/optimize.rs @@ -38,5 +38,5 @@ pub fn aiken_optimize_and_intern(program: Program) -> Program { } } - prog.clean_up(true) + prog.clean_up_no_inlines().afterwards() } diff --git a/crates/uplc/src/optimize/shrinker.rs b/crates/uplc/src/optimize/shrinker.rs index 573f24975..40479531a 100644 --- a/crates/uplc/src/optimize/shrinker.rs +++ b/crates/uplc/src/optimize/shrinker.rs @@ -1409,6 +1409,135 @@ impl Term { } } + // The ultimate function when used in conjunction with case_constr_apply + // This splits [lam fun_name [lam fun_name2 rest ..] ..] into + // [[lam fun_name lam fun_name2 rest ..]..] thus + // allowing for some crazy gains from cast_constr_apply_reducer + fn split_body_lambda(&mut self) { + let mut arg_stack = vec![]; + let mut current_term = &mut std::mem::replace(self, Term::Error.force()); + let mut unsat_lams = vec![]; + + let mut function_groups: Vec, Term)>> = vec![vec![]]; + + loop { + match current_term { + Term::Apply { function, argument } => { + current_term = Rc::make_mut(function); + + let arg = Rc::make_mut(argument); + + arg.split_body_lambda(); + + arg_stack.push(std::mem::replace(arg, Term::Error.force())); + } + Term::Lambda { + parameter_name, + body, + } => { + current_term = Rc::make_mut(body); + + if let Some(arg) = arg_stack.pop() { + let names = arg.get_var_names(); + + let func = (parameter_name.clone(), arg); + + if let Some((position, _)) = + function_groups.iter().enumerate().rfind(|named_functions| { + named_functions + .1 + .iter() + .any(|(name, _)| names.contains(name)) + }) + { + let insert_position = position + 1; + if insert_position == function_groups.len() { + function_groups.push(vec![func]); + } else { + function_groups[insert_position].push(func); + } + } else { + function_groups[0].push(func); + } + } else { + unsat_lams.push(parameter_name.clone()); + } + } + Term::Delay(term) | Term::Force(term) => { + Rc::make_mut(term).split_body_lambda(); + break; + } + Term::Case { .. } => todo!(), + Term::Constr { .. } => todo!(), + _ => break, + } + } + let term_to_build_on = std::mem::replace(current_term, Term::Error.force()); + + // Replace args that weren't consumed + let term = arg_stack + .into_iter() + .rfold(term_to_build_on, |term, arg| term.apply(arg)); + + let term = function_groups.into_iter().rfold(term, |term, group| { + let term = group.iter().rfold(term, |term, (name, _)| Term::Lambda { + parameter_name: name.clone(), + body: term.into(), + }); + + group + .into_iter() + .fold(term, |term, (_, arg)| term.apply(arg)) + }); + + let term = unsat_lams + .into_iter() + .rfold(term, |term, name| Term::Lambda { + parameter_name: name.clone(), + body: term.into(), + }); + + *self = term; + } + + fn get_var_names(&self) -> Vec> { + let mut names = vec![]; + + let mut term = self; + + loop { + match term { + Term::Apply { function, argument } => { + let arg_names = argument.get_var_names(); + + names.extend(arg_names); + + term = function; + } + Term::Var(name) => { + names.push(name.clone()); + break; + } + Term::Delay(t) => { + term = t; + } + Term::Lambda { body, .. } => { + term = body; + } + Term::Constant(_) | Term::Error | Term::Builtin(_) => { + break; + } + Term::Force(t) => { + term = t; + } + Term::Constr { .. } => todo!(), + Term::Case { .. } => todo!(), + } + } + + names + } + // IMPORTANT: RUNS ONE TIME AND ONLY ON THE LAST PASS fn case_constr_apply_reducer( &mut self, @@ -2079,14 +2208,14 @@ impl Program { } // This runs the optimizations that are only done a single time pub fn run_once_pass(self) -> Self { - let program = self + // First pass is necessary to ensure fst_pair and snd_pair are inlined before + // builtin_force_reducer is run + let (program, context) = self .traverse_uplc_with(false, &mut |id, term, _arg_stack, scope, context| { term.inline_constr_ops(id, vec![], scope, context); }) - .0; - - let (program, context) = - program.traverse_uplc_with(false, &mut |id, term, arg_stack, scope, context| { + .0 + .traverse_uplc_with(false, &mut |id, term, arg_stack, scope, context| { term.bls381_compressor(id, vec![], scope, context); term.builtin_force_reducer(id, arg_stack, scope, context); term.remove_inlined_ids(id, vec![], scope, context); @@ -2193,20 +2322,26 @@ impl Program { program } - pub fn clean_up(self, case: bool) -> Self { - let (mut program, context) = self - .traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| { - term.remove_no_inlines(id, vec![], scope, context); - }) - .0 - .traverse_uplc_with(true, &mut |id, term, arg_stack, scope, context| { - term.write_bits_convert_arg(id, arg_stack, scope, context); + pub fn clean_up_no_inlines(self) -> Self { + self.traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| { + term.remove_no_inlines(id, vec![], scope, context); + }) + .0 + } - if case { - term.case_constr_apply_reducer(id, vec![], scope, context); - } + pub fn afterwards(self) -> Self { + let (mut program, context) = + self.traverse_uplc_with(true, &mut |id, term, arg_stack, scope, context| { + term.write_bits_convert_arg(id, arg_stack, scope, context); }); + program = program + .split_body_lambda_reducer() + .traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| { + term.case_constr_apply_reducer(id, vec![], scope, context); + }) + .0; + if context.write_bits_convert { program.term = program.term.data_list_to_integer_list(); } @@ -2451,6 +2586,12 @@ impl Program { step_b } + + pub fn split_body_lambda_reducer(mut self) -> Self { + self.term.split_body_lambda(); + + self + } } fn id_vec_function_to_var(func_name: &str, id_vec: &[usize]) -> String {