Skip to content

Commit

Permalink
Closures: add lsp support + allow null for empty
Browse files Browse the repository at this point in the history
  • Loading branch information
mustafaquraish committed Nov 17, 2024
1 parent a13e2c3 commit d50f59e
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 20 deletions.
4 changes: 2 additions & 2 deletions compiler/ast/nodes.oc
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,15 @@ enum FunctionKind {

// FIXME: Can we do something about `alloca()` in codegen?
// FIXME: Codegen is super hacky, need to clean up and fix:
// FIXME: Have a "ground truth" function for getting names of closure fields/etc. Currently hardcoded in many places
// FIXME: Check to see if any existing fields clash with implicit closure fields, and throw an error/rename
// FIXME: Checking lifetimes: closures can't outlive their parent function
// FIXME: Allow capturing things other than variables
// FIXME: Allow nested closures (in resolve_scoped_identifier) we only ever look up one level
// FIXME: Don't treat global variables/functions as captured variables, they should be looked up normally
// FIXME: Almost ALL the details are hardcoded in codegen. Make type-checker aware of extra implicit fields/params etc.
// FIXME: - Essentially, "unroll" the closures into structs+fields at type-checker level
// FIXME: We do not correctly account for ordering of used types in closure types
// FIXME: Allow passing a normal function pointer to a closure (just wrap it in Closure struct with empty context)
// FIXME: More comprehensive postive / negative tests for closures.
struct Function {
kind: FunctionKind

Expand Down
12 changes: 12 additions & 0 deletions compiler/lsp/finder.oc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def Finder::find_in_call_args(&this, node: &AST, args: &Vector<&Argument>): bool
}

def Finder::find_in_expression(&this, node: &AST): bool {
if not node? return false
match node.type {
IntLiteral | FloatLiteral | BoolLiteral |
StringLiteral | CharLiteral |
Expand Down Expand Up @@ -200,6 +201,7 @@ def Finder::find_in_expression(&this, node: &AST): bool {
return .set_usage(node.resolved_symbol, node)
}
}
CreateClosure => return .find_in_function(node.u.closure)
NSLookup => {
// We actually want to point to the type, not the variable
if .find_in_expression(node.u.lookup.lhs) return true
Expand Down Expand Up @@ -385,6 +387,16 @@ def Finder::find_in_type(&this, type: &Type): bool {
if .find_in_type(ty) return true
}
}
FunctionPtr | Closure => {
let func = type.u.func
for param : func.params.iter() {
if .find_in_var(param, node: null) return true
}
if func.return_type? and .find_in_type(func.return_type) return true
if type.span.contains_loc(.loc) {
return .set_usage(type.sym, node: null) // FIXME: What should be the node here?
}
}
else => {
// FIXME: be more robust
if type.span.contains_loc(.loc) {
Expand Down
3 changes: 2 additions & 1 deletion compiler/parser.oc
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,7 @@ def Parser::parse_atom(&this, end_type: TokenType): &AST {
return AST::new(Error, .token().span)
}
TokenType::Line => {
let start_loc = .token().span.start
let closure_func = .parse_closure()
let node = AST::new(CreateClosure, closure_func.span)
node.u.closure = closure_func
Expand All @@ -795,7 +796,7 @@ def Parser::parse_atom(&this, end_type: TokenType): &AST {
.curr_func? => .curr_func.sym
else => .ns.sym
}
let sym = Symbol::new_with_parent(Closure, .ns, parent_sym, closure_name, closure_func.span)
let sym = Symbol::new_with_parent(Closure, .ns, parent_sym, closure_name, Span(start_loc, start_loc))
sym.u.func = closure_func
closure_func.sym = sym
return node
Expand Down
23 changes: 21 additions & 2 deletions compiler/passes/code_generator.oc
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,20 @@ def CodeGenerator::gen_expression(&this, node: &AST, is_top_level: bool = false)
IsNotNull => {
let expr = node.u.unary.expr
.out += "((bool)"
.gen_expression(expr)

let type = expr.etype
// FIXME: Is there a better place to do this? We generally want to treat
// the closure as the full struct when passing it around, but specifically
// for null-checks we want to only look at the function pointer.
if type? and type.base == Closure {
.gen_expression(expr)
.out += "."
.out += cls::fn_field_name

} else {
.gen_expression(expr)
}

.out += ")"
}
PreIncrement | PreDecrement => {
Expand Down Expand Up @@ -747,7 +760,13 @@ def CodeGenerator::gen_expression(&this, node: &AST, is_top_level: bool = false)
.gen_type(node.u.size_of_type)
.out += "))"
}
Null => .out += "NULL"
Null => match node.etype.base {
// Closures behave like a function pointer in ocen, but are actually a
// struct in the output C. We need special handling here.
// Zero-initialize the struct, including the pointer
Closure => .out <<= f"(({node.etype.sym.out_name()})\{0\})"
else => .out += "NULL"
}
BinaryOp => match node.u.binary.op {
Index => {
let lhs = node.u.binary.lhs
Expand Down
4 changes: 4 additions & 0 deletions compiler/passes/mark_dead_code.oc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def MarkDeadCode::free(&this) {

def MarkDeadCode::mark_sym(&this, sym: &Symbol) {
if not sym? return
sym.is_dead = false
match sym.type {
Function => .mark_function(sym.u.func)
Structure => .mark_struct(sym.u.struc)
Expand All @@ -63,6 +64,9 @@ def MarkDeadCode::mark_function(&this, f: &Function) {
for param : f.params.iter() {
.mark_type(param.type)
}

.mark_sym(f.type.sym)
.mark_type(f.return_type)
}

def MarkDeadCode::mark_type(&this, typ: &Type) {
Expand Down
18 changes: 14 additions & 4 deletions compiler/passes/typechecker.oc
Original file line number Diff line number Diff line change
Expand Up @@ -1367,8 +1367,10 @@ def TypeChecker::check_expression_helper(&this, node: &AST, hint: &Type): &Type
}
ASTType::Null => {
if hint? {
if hint.base == BaseType::Pointer return hint
if hint.base == BaseType::FunctionPtr return hint
match hint.base {
Pointer | FunctionPtr | Closure => return hint
else => {}
}
}

return .get_type_by_name("untyped_ptr", node.span)
Expand Down Expand Up @@ -1429,8 +1431,12 @@ def TypeChecker::check_expression_helper(&this, node: &AST, hint: &Type): &Type
let typ = .check_expression(node.u.unary.expr)
if not typ? return null
typ = typ.unaliased()
if typ.base != BaseType::Pointer {
.error(Error::new(node.span, `Can only use ? on pointer types, got {typ.str()}`))
match typ.base {
Pointer | FunctionPtr | Closure => {}
else => {
.error(Error::new(node.span, `Can only use ? on pointer types, got {typ.str()}`))
return null
}
}
return .get_base_type(BaseType::Bool, node.span)
}
Expand Down Expand Up @@ -1663,6 +1669,8 @@ def TypeChecker::check_expression_helper(&this, node: &AST, hint: &Type): &Type
}
ASTType::CreateClosure => {
let clos = node.u.closure
clos.scope = .o.ns().scope

clos.closure_scope = .o.scope()
clos.closed_vars = Map<str, &Symbol>::new()

Expand Down Expand Up @@ -3010,8 +3018,10 @@ def TypeChecker::run(program: &Program) {

pass.check_namespace(program.global)

pass.o.push_namespace(program.global)
while pass.unchecked_functions.size > 0 {
let func = pass.unchecked_functions.pop() as &Function
pass.check_function(func)
}
pass.o.pop_namespace()
}
8 changes: 4 additions & 4 deletions std/sort.oc
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import std::traits::compare

def sort<T>(data: &T, size: u32) => sort_by<T>(data, size, T::compare)
def sort<T>(data: &T, size: u32) => sort_by<T>(data, size, |a: T, b: T|: i8 => a.compare(b))

def sort_by<T>(data: &T, size: u32, cmp: fn(T,T): i8) {
def sort_by<T>(data: &T, size: u32, cmp: @fn(T,T): i8) {
if size <= 1 {
return
}
Expand Down Expand Up @@ -33,9 +33,9 @@ def sort_by<T>(data: &T, size: u32, cmp: fn(T,T): i8) {
sort_by<T>(data + i, size - i, cmp)
}

def nth_element<T>(data: &T, size: u32, n: u32) => nth_element_by<T>(data, size, n, T::compare)
def nth_element<T>(data: &T, size: u32, n: u32) => nth_element_by<T>(data, size, n, |a: T, b: T|: i8 => a.compare(b))

def nth_element_by<T>(data: &T, size: u32, n: u32, cmp: fn(T,T): i8): T {
def nth_element_by<T>(data: &T, size: u32, n: u32, cmp: @fn(T,T): i8): T {
if size <= 1 return data[0]

let pivot = data[size / 2]
Expand Down
2 changes: 1 addition & 1 deletion tests/bad/question_not_ptr.oc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/// fail: Can only use ? on pointer types
/// fail: Can only use ? on pointer types, got u32
def main() {
let x = 50
Expand Down
10 changes: 6 additions & 4 deletions tests/closure.oc
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
/// out: "Total: 145\n"
def test(callback: @fn(i32), x: i32) {
callback(x)
if callback? {
callback(x)
}
}

// def bar(x: i32): i32 => x
def bar(x: i32): i32 => x

def main() {
let x: i32 = 100
// let cb = |a: i32| => x += bar(a)
let cb = |a: i32| => x += a
let cb = |a: i32| => x += bar(a)

for let i = 0i32; i < 5; i++ {
test(cb, i)
Expand All @@ -23,5 +24,6 @@ def main() {
i
)
}
test(null, 10)
println(`Total: {x}`)
}
3 changes: 1 addition & 2 deletions tests/sorting.oc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def main() {
arr[i] = randu32() % 10000
}

nth_element_by<u32>(arr, 100, 25, u32::compare)
nth_element_by<u32>(arr, 100, 25, |a: u32, b: u32|: i8 => a.compare(b))
let n = arr[25]

for let i = 0; i < 100; i = i + 1 {
Expand All @@ -32,5 +32,4 @@ def main() {
assert arr[i] >= n
}
}

}

0 comments on commit d50f59e

Please sign in to comment.