From d50f59e126cd85f36a7a3dcd4884c34e1ac4efbc Mon Sep 17 00:00:00 2001 From: Mustafa Quraish Date: Sat, 16 Nov 2024 19:10:44 -0500 Subject: [PATCH] Closures: add lsp support + allow null for empty --- compiler/ast/nodes.oc | 4 ++-- compiler/lsp/finder.oc | 12 ++++++++++++ compiler/parser.oc | 3 ++- compiler/passes/code_generator.oc | 23 +++++++++++++++++++++-- compiler/passes/mark_dead_code.oc | 4 ++++ compiler/passes/typechecker.oc | 18 ++++++++++++++---- std/sort.oc | 8 ++++---- tests/bad/question_not_ptr.oc | 2 +- tests/closure.oc | 10 ++++++---- tests/sorting.oc | 3 +-- 10 files changed, 67 insertions(+), 20 deletions(-) diff --git a/compiler/ast/nodes.oc b/compiler/ast/nodes.oc index 24ffc63..c927405 100644 --- a/compiler/ast/nodes.oc +++ b/compiler/ast/nodes.oc @@ -187,7 +187,6 @@ enum FunctionKind { // FIXME: Can we do something about `alloca()` in codegen? // FIXME: Codegen is super hacky, need to clean up and fix: -// FIXME: Have a "ground truth" function for getting names of closure fields/etc. Currently hardcoded in many places // FIXME: Check to see if any existing fields clash with implicit closure fields, and throw an error/rename // FIXME: Checking lifetimes: closures can't outlive their parent function // FIXME: Allow capturing things other than variables @@ -195,7 +194,8 @@ enum FunctionKind { // FIXME: Don't treat global variables/functions as captured variables, they should be looked up normally // FIXME: Almost ALL the details are hardcoded in codegen. Make type-checker aware of extra implicit fields/params etc. // FIXME: - Essentially, "unroll" the closures into structs+fields at type-checker level -// FIXME: We do not correctly account for ordering of used types in closure types +// FIXME: Allow passing a normal function pointer to a closure (just wrap it in Closure struct with empty context) +// FIXME: More comprehensive postive / negative tests for closures. struct Function { kind: FunctionKind diff --git a/compiler/lsp/finder.oc b/compiler/lsp/finder.oc index db5dbfe..68e0328 100644 --- a/compiler/lsp/finder.oc +++ b/compiler/lsp/finder.oc @@ -173,6 +173,7 @@ def Finder::find_in_call_args(&this, node: &AST, args: &Vector<&Argument>): bool } def Finder::find_in_expression(&this, node: &AST): bool { + if not node? return false match node.type { IntLiteral | FloatLiteral | BoolLiteral | StringLiteral | CharLiteral | @@ -200,6 +201,7 @@ def Finder::find_in_expression(&this, node: &AST): bool { return .set_usage(node.resolved_symbol, node) } } + CreateClosure => return .find_in_function(node.u.closure) NSLookup => { // We actually want to point to the type, not the variable if .find_in_expression(node.u.lookup.lhs) return true @@ -385,6 +387,16 @@ def Finder::find_in_type(&this, type: &Type): bool { if .find_in_type(ty) return true } } + FunctionPtr | Closure => { + let func = type.u.func + for param : func.params.iter() { + if .find_in_var(param, node: null) return true + } + if func.return_type? and .find_in_type(func.return_type) return true + if type.span.contains_loc(.loc) { + return .set_usage(type.sym, node: null) // FIXME: What should be the node here? + } + } else => { // FIXME: be more robust if type.span.contains_loc(.loc) { diff --git a/compiler/parser.oc b/compiler/parser.oc index 52edea4..733dfa6 100644 --- a/compiler/parser.oc +++ b/compiler/parser.oc @@ -784,6 +784,7 @@ def Parser::parse_atom(&this, end_type: TokenType): &AST { return AST::new(Error, .token().span) } TokenType::Line => { + let start_loc = .token().span.start let closure_func = .parse_closure() let node = AST::new(CreateClosure, closure_func.span) node.u.closure = closure_func @@ -795,7 +796,7 @@ def Parser::parse_atom(&this, end_type: TokenType): &AST { .curr_func? => .curr_func.sym else => .ns.sym } - let sym = Symbol::new_with_parent(Closure, .ns, parent_sym, closure_name, closure_func.span) + let sym = Symbol::new_with_parent(Closure, .ns, parent_sym, closure_name, Span(start_loc, start_loc)) sym.u.func = closure_func closure_func.sym = sym return node diff --git a/compiler/passes/code_generator.oc b/compiler/passes/code_generator.oc index 071affb..346b540 100644 --- a/compiler/passes/code_generator.oc +++ b/compiler/passes/code_generator.oc @@ -719,7 +719,20 @@ def CodeGenerator::gen_expression(&this, node: &AST, is_top_level: bool = false) IsNotNull => { let expr = node.u.unary.expr .out += "((bool)" - .gen_expression(expr) + + let type = expr.etype + // FIXME: Is there a better place to do this? We generally want to treat + // the closure as the full struct when passing it around, but specifically + // for null-checks we want to only look at the function pointer. + if type? and type.base == Closure { + .gen_expression(expr) + .out += "." + .out += cls::fn_field_name + + } else { + .gen_expression(expr) + } + .out += ")" } PreIncrement | PreDecrement => { @@ -747,7 +760,13 @@ def CodeGenerator::gen_expression(&this, node: &AST, is_top_level: bool = false) .gen_type(node.u.size_of_type) .out += "))" } - Null => .out += "NULL" + Null => match node.etype.base { + // Closures behave like a function pointer in ocen, but are actually a + // struct in the output C. We need special handling here. + // Zero-initialize the struct, including the pointer + Closure => .out <<= f"(({node.etype.sym.out_name()})\{0\})" + else => .out += "NULL" + } BinaryOp => match node.u.binary.op { Index => { let lhs = node.u.binary.lhs diff --git a/compiler/passes/mark_dead_code.oc b/compiler/passes/mark_dead_code.oc index ea1a444..1a8f0d6 100644 --- a/compiler/passes/mark_dead_code.oc +++ b/compiler/passes/mark_dead_code.oc @@ -37,6 +37,7 @@ def MarkDeadCode::free(&this) { def MarkDeadCode::mark_sym(&this, sym: &Symbol) { if not sym? return + sym.is_dead = false match sym.type { Function => .mark_function(sym.u.func) Structure => .mark_struct(sym.u.struc) @@ -63,6 +64,9 @@ def MarkDeadCode::mark_function(&this, f: &Function) { for param : f.params.iter() { .mark_type(param.type) } + + .mark_sym(f.type.sym) + .mark_type(f.return_type) } def MarkDeadCode::mark_type(&this, typ: &Type) { diff --git a/compiler/passes/typechecker.oc b/compiler/passes/typechecker.oc index deab4c1..b25d4ad 100644 --- a/compiler/passes/typechecker.oc +++ b/compiler/passes/typechecker.oc @@ -1367,8 +1367,10 @@ def TypeChecker::check_expression_helper(&this, node: &AST, hint: &Type): &Type } ASTType::Null => { if hint? { - if hint.base == BaseType::Pointer return hint - if hint.base == BaseType::FunctionPtr return hint + match hint.base { + Pointer | FunctionPtr | Closure => return hint + else => {} + } } return .get_type_by_name("untyped_ptr", node.span) @@ -1429,8 +1431,12 @@ def TypeChecker::check_expression_helper(&this, node: &AST, hint: &Type): &Type let typ = .check_expression(node.u.unary.expr) if not typ? return null typ = typ.unaliased() - if typ.base != BaseType::Pointer { - .error(Error::new(node.span, `Can only use ? on pointer types, got {typ.str()}`)) + match typ.base { + Pointer | FunctionPtr | Closure => {} + else => { + .error(Error::new(node.span, `Can only use ? on pointer types, got {typ.str()}`)) + return null + } } return .get_base_type(BaseType::Bool, node.span) } @@ -1663,6 +1669,8 @@ def TypeChecker::check_expression_helper(&this, node: &AST, hint: &Type): &Type } ASTType::CreateClosure => { let clos = node.u.closure + clos.scope = .o.ns().scope + clos.closure_scope = .o.scope() clos.closed_vars = Map::new() @@ -3010,8 +3018,10 @@ def TypeChecker::run(program: &Program) { pass.check_namespace(program.global) + pass.o.push_namespace(program.global) while pass.unchecked_functions.size > 0 { let func = pass.unchecked_functions.pop() as &Function pass.check_function(func) } + pass.o.pop_namespace() } diff --git a/std/sort.oc b/std/sort.oc index b4e3fc2..088a2f0 100644 --- a/std/sort.oc +++ b/std/sort.oc @@ -2,9 +2,9 @@ import std::traits::compare -def sort(data: &T, size: u32) => sort_by(data, size, T::compare) +def sort(data: &T, size: u32) => sort_by(data, size, |a: T, b: T|: i8 => a.compare(b)) -def sort_by(data: &T, size: u32, cmp: fn(T,T): i8) { +def sort_by(data: &T, size: u32, cmp: @fn(T,T): i8) { if size <= 1 { return } @@ -33,9 +33,9 @@ def sort_by(data: &T, size: u32, cmp: fn(T,T): i8) { sort_by(data + i, size - i, cmp) } -def nth_element(data: &T, size: u32, n: u32) => nth_element_by(data, size, n, T::compare) +def nth_element(data: &T, size: u32, n: u32) => nth_element_by(data, size, n, |a: T, b: T|: i8 => a.compare(b)) -def nth_element_by(data: &T, size: u32, n: u32, cmp: fn(T,T): i8): T { +def nth_element_by(data: &T, size: u32, n: u32, cmp: @fn(T,T): i8): T { if size <= 1 return data[0] let pivot = data[size / 2] diff --git a/tests/bad/question_not_ptr.oc b/tests/bad/question_not_ptr.oc index 27ac7af..a3f7644 100644 --- a/tests/bad/question_not_ptr.oc +++ b/tests/bad/question_not_ptr.oc @@ -1,4 +1,4 @@ -/// fail: Can only use ? on pointer types +/// fail: Can only use ? on pointer types, got u32 def main() { let x = 50 diff --git a/tests/closure.oc b/tests/closure.oc index f386624..49744c4 100644 --- a/tests/closure.oc +++ b/tests/closure.oc @@ -1,15 +1,16 @@ /// out: "Total: 145\n" def test(callback: @fn(i32), x: i32) { - callback(x) + if callback? { + callback(x) + } } -// def bar(x: i32): i32 => x +def bar(x: i32): i32 => x def main() { let x: i32 = 100 - // let cb = |a: i32| => x += bar(a) - let cb = |a: i32| => x += a + let cb = |a: i32| => x += bar(a) for let i = 0i32; i < 5; i++ { test(cb, i) @@ -23,5 +24,6 @@ def main() { i ) } + test(null, 10) println(`Total: {x}`) } \ No newline at end of file diff --git a/tests/sorting.oc b/tests/sorting.oc index b994e12..bc2aa29 100644 --- a/tests/sorting.oc +++ b/tests/sorting.oc @@ -22,7 +22,7 @@ def main() { arr[i] = randu32() % 10000 } - nth_element_by(arr, 100, 25, u32::compare) + nth_element_by(arr, 100, 25, |a: u32, b: u32|: i8 => a.compare(b)) let n = arr[25] for let i = 0; i < 100; i = i + 1 { @@ -32,5 +32,4 @@ def main() { assert arr[i] >= n } } - } \ No newline at end of file