Skip to content

Commit

Permalink
Change AST representation of match cases
Browse files Browse the repository at this point in the history
Previously the match statement node contained a list of MatchCond
objects. Each of these would have a condition, but an optional body.
All consecutive MatchCond until you saw one with a body were supposed
to be part of the same "branch". For eg:
    match foo {
      Foo | Bar | Baz => print(...)
      Qux => 1
    }
would be
    (Foo, null), (Bar, null), (Baz, print(...)), (Qux, 1)
This was a little awkward to traverse. Now, each MatchCond stores
a list of all the actual conditions, and also the body, which is a more
natural way of storing the AST.
  • Loading branch information
mustafaquraish committed Oct 31, 2024
1 parent 1e347fc commit 824ab62
Show file tree
Hide file tree
Showing 8 changed files with 1,458 additions and 1,564 deletions.
2,514 changes: 1,185 additions & 1,329 deletions bootstrap/stage0.c

Large diffs are not rendered by default.

24 changes: 14 additions & 10 deletions compiler/ast/nodes.oc
Original file line number Diff line number Diff line change
Expand Up @@ -391,25 +391,29 @@ struct FormatString {
exprs: &Vector<&AST>
}

struct MatchCase {
cond: &AST
body: &AST
struct MatchCond {
expr: &AST
cmp_fn: &Function

// For value enums
args: &Vector<&Variable>
}

def MatchCase::new(cond: &AST, body: &AST): &MatchCase {
let _case = mem::alloc<MatchCase>()
_case.cond = cond
_case.body = body
return _case
def MatchCond::new(cond: &AST, args: &Vector<&Variable>, cmp_fn: &Function = null): &MatchCond {
let mcond = mem::alloc<MatchCond>()
mcond.expr = cond
mcond.args = args
mcond.cmp_fn = cmp_fn
return mcond
}

struct MatchCase {
conds: &Vector<&MatchCond>
body: &AST
}

struct Match {
expr: &AST
cases: &Vector<&MatchCase>
cases: &Vector<MatchCase>
defolt: &AST
is_custom_match: bool

Expand Down
15 changes: 8 additions & 7 deletions compiler/lsp/finder.oc
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,16 @@ def Finder::find_in_expression(&this, node: &AST): bool {
Match => {
let stmt = &node.u.match_stmt
if .find_in_expression(stmt.expr) return true
for let i = 0; i < stmt.cases.size; i += 1 {
let case_ = stmt.cases.at(i) as &MatchCase
if .find_in_expression(case_.cond) return true
if case_.args? {
for arg in case_.args.iter() {
if .find_in_var(arg, node) return true
for _case in stmt.cases.iter() {
for cond in _case.conds.iter() {
if .find_in_expression(cond.expr) return true
if cond.args? {
for arg in cond.args.iter() {
if .find_in_var(arg, node) return true
}
}
}
if case_.body? and .find_in_statement(case_.body) return true
if _case.body? and .find_in_statement(_case.body) return true
}
if stmt.defolt? and .find_in_statement(stmt.defolt) return true
}
Expand Down
64 changes: 35 additions & 29 deletions compiler/parser.oc
Original file line number Diff line number Diff line change
Expand Up @@ -467,14 +467,41 @@ def Parser::parse_format_string(&this): &AST {
return node
}

def Parser::parse_match_case_conds(&this, end_type: TokenType): &Vector<&MatchCond> {
let conds = Vector<&MatchCond>::new()
while not .token_is_eof_or(end_type) {
let expr = .parse_atom(TokenType::Line)

let args: &Vector<&Variable> = null
if .consume_if(OpenParen) {
args = Vector<&Variable>::new()
while not .token_is_eof_or(CloseParen) {
let name = .consume(Identifier)
let var = Variable::new(null)
var.sym = Symbol::from_local_variable(name.text, var, name.span)
args.push(var)

if not .consume_if(Comma) {
break
}
}
.consume(CloseParen)
}
conds.push(MatchCond::new(expr, args))
if not .consume_if(Line) then break
}

return conds
}

def Parser::parse_match(&this): &AST {
let op = .consume(TokenType::Match)
let expr = .parse_expression(end_type: TokenType::OpenCurly)
let node = AST::new(Match, op.span.join(expr.span))
node.u.match_stmt.expr = expr
node.u.match_stmt.match_span = op.span

let cases = Vector<&MatchCase>::new()
let cases = Vector<MatchCase>::new()
node.u.match_stmt.cases = cases

if not .token_is(TokenType::OpenCurly) {
Expand All @@ -492,36 +519,15 @@ def Parser::parse_match(&this): &AST {
node.u.match_stmt.defolt = .parse_statement()

} else {
let cond = .parse_atom(TokenType::Line)

let args: &Vector<&Variable> = null
if .consume_if(OpenParen) {
args = Vector<&Variable>::new()
while not .token_is_eof_or(CloseParen) {
let name = .consume(Identifier)
let var = Variable::new(null)
var.sym = Symbol::from_local_variable(name.text, var, name.span)
args.push(var)

if not .consume_if(Comma) {
break
}
}
.consume(CloseParen)
let conds = .parse_match_case_conds(TokenType::FatArrow)
if not .consume_if(TokenType::FatArrow) {
.error(Error::new(.token().span, "Expected => after match case"))
}

let body = null as &AST
if not .consume_if(TokenType::Line) {
if not .consume_if(TokenType::FatArrow) {
.error(Error::new(.token().span, "Expected => after match case"))
}
body = .parse_statement()
if not .token_is(TokenType::CloseCurly) {
.consume_newline_or(TokenType::Comma)
}
let body = .parse_statement()
if not .token_is(TokenType::CloseCurly) {
.consume_newline_or(TokenType::Comma)
}
let _case = MatchCase::new(cond, body)
_case.args = args
let _case = MatchCase(conds, body)
cases.push(_case)
}

Expand Down
145 changes: 75 additions & 70 deletions compiler/passes/code_generator.oc
Original file line number Diff line number Diff line change
Expand Up @@ -732,34 +732,32 @@ def CodeGenerator::gen_custom_match(&this, node: &AST) {

let cases = stmt.cases
.gen_indent()
.out += "if ("
for let i = 0; i < cases.size; i += 1 {
let _case = cases.at(i)

if _case.cmp_fn? {
.out += _case.cmp_fn.sym.out_name()
.out += "("
.out += match_var
.out += ", "
.gen_expression(_case.cond)
.out += ")"
for let i = 0; i < cases.size; i++ {
let _case = cases[i]

} else {
.out += f"({match_var} == "
.gen_expression(_case.cond)
.out += ")"
}
.out += "if ("
let first = true
for cond in _case.conds.iter() {
if not first then .out += " || "
first = false
if cond.cmp_fn? {
.out += cond.cmp_fn.sym.out_name()
.out += "("
.out += match_var
.out += ", "
.gen_expression(cond.expr)
.out += ")"

if _case.body? {
.out += ")"
.gen_match_case_body(node, _case.body)
.out += " else "
if i != cases.size - 1 {
.out += "if ("
} else {
.out += f"({match_var} == "
.gen_expression(cond.expr)
.out += ")"
}
} else {
.out += " || "
}
.out += ")"
.gen_match_case_body(node, _case.body)
.out += " else "
}
if stmt.defolt? {
.gen_match_case_body(node, stmt.defolt)
Expand All @@ -783,18 +781,19 @@ def CodeGenerator::gen_match_venom(&this, node: &AST) {
.gen_expression(expr)
.out += ").tag) {\n"

let is_first_variant = true
let uid = .uid++
let branch_num = 0;

.indent += 1
for cas in match_stmt.cases.iter() {
if is_first_variant and cas.args? {
is_first_variant = false
for let i = 0; i < match_stmt.cases.size; i += 1 {
let _case = match_stmt.cases[i]

let has_args = _case.conds.size > 0 and _case.conds[0].args?

if has_args {
let args = _case.conds[0].args
.gen_indent()
.out += "{\n"
.indent += 1
for arg in cas.args.iter() {
for arg in args.iter() {
if not arg.sym.name.eq("_") {
.gen_indent()
.gen_type_and_name(arg.type, arg.sym.out_name())
Expand All @@ -803,54 +802,56 @@ def CodeGenerator::gen_match_venom(&this, node: &AST) {
}
}

let resolved = cas.cond.resolved_symbol
assert resolved? and resolved.type == ValueEnumVariant
let variant = resolved.u.venom_var
for let j = 0; j < _case.conds.size; j++ {
let cond = _case.conds[j]

.gen_indent()
.out += "case "
.out += variant.sym.out_name()
.out += ":\n"
let resolved = cond.expr.resolved_symbol
assert resolved? and resolved.type == ValueEnumVariant
let variant = resolved.u.venom_var

.indent += 1
if cas.args? {
let args = cas.args
for let i = 0; i < args.size; i += 1 {
let arg = args.at(i)
if not arg.sym.name.eq("_") {
.gen_indent()
.out += arg.sym.out_name()
.out += " = ("
.gen_expression(expr)
.out += ")."
.out += variant.sym.out_name()
.out <<= f".f{i}"
.out += ";\n"
.gen_indent()
.out += "case "
.out += variant.sym.out_name()
.out += ":\n"

.indent += 1
if cond.args? {
let args = cond.args
for let i = 0; i < args.size; i += 1 {
let arg = args.at(i)
if not arg.sym.name.eq("_") {
.gen_indent()
.out += arg.sym.out_name()
.out += " = ("
.gen_expression(expr)
.out += ")."
.out += variant.sym.out_name()
.out <<= f".f{i}"
.out += ";\n"
}
}
.gen_indent()
.out += f"goto m_{uid}_{i};\n"
}
.gen_indent()
.out += f"goto m_{uid}_{branch_num};\n"
.indent -= 1
}
.indent -= 1

if cas.body? {

if _case.body? {
.gen_indent()
.out += f"m_{uid}_{branch_num}:\n"
.out += f"m_{uid}_{i}:\n"
.indent += 1
.gen_indent()
.gen_control_body(node, cas.body)
.gen_control_body(node, _case.body)
.out += " break;\n"
.indent -= 1
is_first_variant = true
branch_num += 1

if cas.args? {
.indent -= 1
.gen_indent()
.out += "}\n"
}
}

if has_args {
.indent -= 1
.gen_indent()
.out += "}\n"
}
}
.indent -= 1
.gen_indent()
Expand All @@ -864,7 +865,9 @@ def CodeGenerator::gen_match_bool(&this, node: &AST) {
let true_case = stmt.cases[0]
let false_case = stmt.cases[1]

if not true_case.cond.u.bool_literal {
let true_expr = true_case.conds[0].expr
assert true_expr.type == BoolLiteral, "Expected a boolean literal in gen_match_bool"
if not true_expr.u.bool_literal {
let tmp = true_case
true_case = false_case
false_case = tmp
Expand Down Expand Up @@ -908,10 +911,12 @@ def CodeGenerator::gen_match(&this, node: &AST) {
let cases = stmt.cases
.indent += 1
for _case : cases.iter() {
.gen_indent()
.out += "case "
.gen_expression(_case.cond)
.out += ":"
for cond : _case.conds.iter() {
.gen_indent()
.out += "case "
.gen_expression(cond.expr)
.out += ":"
}
if _case.body? {
.gen_match_case_body(node, _case.body)
.out += " break;\n"
Expand Down
10 changes: 6 additions & 4 deletions compiler/passes/mark_dead_code.oc
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,13 @@ def MarkDeadCode::mark(&this, node: &AST) {
Match => {
.mark(node.u.match_stmt.expr)
for c : node.u.match_stmt.cases.iter() {
.mark(c.cond)
.mark(c.body)
if c.cmp_fn? {
.mark_function(c.cmp_fn)
for cond in c.conds.iter() {
.mark(cond.expr)
if cond.cmp_fn? {
.mark_function(cond.cmp_fn)
}
}
.mark(c.body)
}
.mark(node.u.match_stmt.defolt)
}
Expand Down
Loading

0 comments on commit 824ab62

Please sign in to comment.