Skip to content

Commit

Permalink
feat: functions (without args)
Browse files Browse the repository at this point in the history
  • Loading branch information
manosriram committed Jul 13, 2024
1 parent f7e5b41 commit 29f3225
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 25 deletions.
57 changes: 35 additions & 22 deletions lark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,32 @@ import (
"github.com/stretchr/testify/assert"
)

func evaluate(t *testing.T, sourceFile string) map[string]interface{} {
content, err := os.ReadFile(sourceFile)
assert.Equal(t, nil, err)

root := types.Compound{Children: []types.Node{}}
symbolTable = make(map[string]interface{})
tokens := token.Tokenize(string(content))
builder := ast.NewAstBuilder(tokens.Tokens)
var tree types.Node
for builder.CurrentTokenPointer < len(tokens.Tokens)-1 {
tree = builder.Parse()
if tree != nil {
root.Children = append(root.Children, tree)
}
}
evaluator := ast.Evaluator{
SymbolTable: symbolTable,
}

for _, node := range root.Children {
evaluator.Evaluate(node)
}

return symbolTable
}

func Test_Tokenize(t *testing.T) {
content, err := os.ReadFile("test_source_files/token.lark")
assert.Equal(t, nil, err)
Expand Down Expand Up @@ -91,9 +117,7 @@ func areValuesEqual(v1, v2 interface{}) bool {
}
}
func Test_Parser(t *testing.T) {
content, err := os.ReadFile("test_source_files/parse.lark")
assert.Equal(t, nil, err)

symbolTable := evaluate(t, "test_source_files/parse.lark")
expectedSymbolTableVars := map[string]interface{}{
"a": false,
"b": true,
Expand All @@ -103,27 +127,16 @@ func Test_Parser(t *testing.T) {
"d": "not_ok",
}

root := types.Compound{Children: []types.Node{}}
symbolTable = make(map[string]interface{})
tokens := token.Tokenize(string(content))
builder := ast.NewAstBuilder(tokens.Tokens)
var tree types.Node
for builder.CurrentTokenPointer < len(tokens.Tokens)-1 {
tree = builder.Parse()
if tree != nil {
root.Children = append(root.Children, tree)
}
}
evaluator := ast.Evaluator{
SymbolTable: symbolTable,
}

for _, node := range root.Children {
evaluator.Evaluate(node)
}

assert.Equal(t, len(expectedSymbolTableVars), len(symbolTable))
for key := range expectedSymbolTableVars {
assert.Equal(t, expectedSymbolTableVars[key], symbolTable[key])
}
}

func Test_Function(t *testing.T) {
symbolTable := evaluate(t, "test_source_files/function.lark")
assert.Equal(t, len(symbolTable), 4)
assert.Equal(t, 1000, symbolTable["fna"])
assert.Equal(t, 500, symbolTable["fnb"])
assert.Equal(t, 1500, symbolTable["fnval"])
}
34 changes: 33 additions & 1 deletion pkg/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,40 @@ func (a *AstBuilder) Expr() types.Node {
case types.ASSIGN:
a.eat(types.ASSIGN)
right := a.Expr()
a.eat(types.SEMICOLON)
switch right.(type) {
case types.FunctionCall:
break
default:
a.eat(types.SEMICOLON)
}
return types.Assign{Id: left, Value: right}
case types.FUNCTION_CALL:
fn := types.FunctionCall{Name: a.getCurrentToken().Value.(types.Literal).Value.(string)}
a.eat(types.FUNCTION_CALL)
a.eat(types.SEMICOLON)
return fn
case types.FUNCTION:
a.eat(types.FUNCTION)
functionName := a.getCurrentToken().Value.(types.Literal).Value
a.eat(types.ID)
function := types.Function{
Name: functionName.(string),
}
a.eat(types.FUNCTION_ARGUMENT_OPEN)
a.eat(types.FUNCTION_ARGUMENT_CLOSE)
a.eat(types.FUNCTION_OPEN)
for a.getCurrentToken().TokenType != types.FUNCTION_RETURN && a.getCurrentToken().TokenType != types.FUNCTION_CLOSE {
node := a.Expr()
function.Children = append(function.Children, node)
}
if a.getCurrentToken().TokenType == types.FUNCTION_RETURN {
a.eat(types.FUNCTION_RETURN)
function.ReturnExpression = a.Expr()
a.eat(types.SEMICOLON)
}
a.eat(types.FUNCTION_CLOSE)
return function

case types.SWAP:
a.eat(types.SWAP)
right := a.Expr()
Expand Down
7 changes: 7 additions & 0 deletions pkg/ast/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ func (e *Evaluator) Visit(node types.Node) interface{} {
right := e.Visit(n.Value)
e.SymbolTable[n.Id.(types.Id).Name] = right
return right
case types.Function:
e.SymbolTable[n.Name] = n
case types.FunctionCall:
for _, v := range e.SymbolTable[n.Name].(types.Function).Children {
e.Visit(v)
}
return e.Visit(e.SymbolTable[n.Name].(types.Function).ReturnExpression)
case types.Swap:
_, ok := e.SymbolTable[n.Left.String()]
if !ok {
Expand Down
33 changes: 31 additions & 2 deletions pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,18 @@ func Tokenize(source string) *Source {
default:
s.Tokens = append(s.Tokens, types.Token{TokenType: types.NOT, Value: types.Literal{Value: "!", Type: types.OPERATOR}, LineNumber: s.CurrentLineNumber})
}
case '[':
s.Tokens = append(s.Tokens, types.Token{TokenType: types.FUNCTION_ARGUMENT_OPEN, Value: types.Literal{Value: "[", Type: types.OPERATOR}, LineNumber: s.CurrentLineNumber})
s.eat()
case ']':
s.Tokens = append(s.Tokens, types.Token{TokenType: types.FUNCTION_ARGUMENT_CLOSE, Value: types.Literal{Value: "]", Type: types.OPERATOR}, LineNumber: s.CurrentLineNumber})
s.eat()
case '<':
s.eat()
switch s.getCurrentToken() {
case '<':
s.Tokens = append(s.Tokens, types.Token{TokenType: types.FUNCTION_OPEN, Value: types.Literal{Value: "<<", Type: types.OPERATOR}, LineNumber: s.CurrentLineNumber})
s.eat()
case '=':
s.Tokens = append(s.Tokens, types.Token{TokenType: types.LESSER_OR_EQUAL, Value: types.Literal{Value: "<=", Type: types.OPERATOR}, LineNumber: s.CurrentLineNumber})
s.eat()
Expand All @@ -190,6 +199,9 @@ func Tokenize(source string) *Source {
case '>':
s.eat()
switch s.getCurrentToken() {
case '>':
s.Tokens = append(s.Tokens, types.Token{TokenType: types.FUNCTION_CLOSE, Value: types.Literal{Value: ">>", Type: types.OPERATOR}, LineNumber: s.CurrentLineNumber})
s.eat()
case '=':
s.Tokens = append(s.Tokens, types.Token{TokenType: types.GREATER_OR_EQUAL, Value: types.Literal{Value: ">=", Type: types.OPERATOR}, LineNumber: s.CurrentLineNumber})
s.eat()
Expand Down Expand Up @@ -234,13 +246,30 @@ func Tokenize(source string) *Source {
s.Tokens = append(s.Tokens, types.Token{TokenType: types.ELSE, Value: types.Literal{Value: "if", Type: types.STATEMENT}, LineNumber: s.CurrentLineNumber})
case "if":
s.Tokens = append(s.Tokens, types.Token{TokenType: types.IF, Value: types.Literal{Value: "if", Type: types.STATEMENT}, LineNumber: s.CurrentLineNumber})
case "fn":
s.Tokens = append(s.Tokens, types.Token{TokenType: types.FUNCTION, Value: types.Literal{Value: "fn", Type: types.STATEMENT}, LineNumber: s.CurrentLineNumber})
case "ret":
s.Tokens = append(s.Tokens, types.Token{TokenType: types.FUNCTION_RETURN, Value: types.Literal{Value: "return", Type: types.STATEMENT}, LineNumber: s.CurrentLineNumber})
default:
s.Tokens = append(s.Tokens, types.Token{TokenType: types.ID, Value: types.Literal{Value: variable, Type: types.STRING}, LineNumber: s.CurrentLineNumber})
switch s.getCurrentToken() {
case '(':
s.eat()
if s.getCurrentToken() == ')' {
s.eat()
s.Tokens = append(s.Tokens, types.Token{TokenType: types.FUNCTION_CALL, Value: types.Literal{Value: variable, Type: types.STATEMENT}, LineNumber: s.CurrentLineNumber})
// s.eat()
} else {
log.Fatalf("expected ')'")
}
break
default:
s.Tokens = append(s.Tokens, types.Token{TokenType: types.ID, Value: types.Literal{Value: variable, Type: types.STRING}, LineNumber: s.CurrentLineNumber})
break
}
}
} else {
log.Fatalf("unsupported type %v at line %d\n", string(s.getCurrentToken()), s.CurrentLineNumber)
}

}
}

Expand Down
39 changes: 39 additions & 0 deletions pkg/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ const (

TRUE = "true"
FALSE = "false"

FUNCTION = "fn"
FUNCTION_ARGUMENT_OPEN = "["
FUNCTION_ARGUMENT_CLOSE = "]"
FUNCTION_RETURN = "return"
FUNCTION_ARGUMENT_SEPARATOR = ","
FUNCTION_OPEN = "<<"
FUNCTION_CLOSE = ">>"
FUNCTION_CALL_OPEN = "("
FUNCTION_CALL_CLOSE = ")"
FUNCTION_CALL = "()"
)

type Token struct {
Expand Down Expand Up @@ -122,6 +133,34 @@ func (e Expression) NodeType() string {
return "expression"
}

type FunctionCall struct {
Name string
Arguments []Id
}

func (f FunctionCall) NodeType() string {
return "functioncall"
}

func (f FunctionCall) String() string {
return "fncall"
}

type Function struct {
Name string
Arguments []Id
Children []Node
ReturnExpression Node
}

func (f Function) NodeType() string {
return "function"
}

func (f Function) String() string {
return "fn"
}

type IfElseStatement struct {
Condition Node
IfChildren []Node
Expand Down
8 changes: 8 additions & 0 deletions source.lark
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,11 @@ if (a==b) {
tt <- !false;
c <-> e;
a <-> c;

fn hi[] <<
a <- 1000;
bbb <- 123;
ret a+bbb;
>>

fnval <- hi();
7 changes: 7 additions & 0 deletions test_source_files/function.lark
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
fn addstatic[] <<
fna <- 1000;
fnb <- 500;
ret fna+fnb;
>>

fnval <- addstatic();
1 change: 1 addition & 0 deletions test_source_files/parse.lark
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ if (a==b) {
d <- "not_ok";
}
tt <- !false;

0 comments on commit 29f3225

Please sign in to comment.