From 5effc2654322bfae961b14726750a925df48490c Mon Sep 17 00:00:00 2001 From: "for { kys() }" Date: Wed, 20 Nov 2024 14:56:41 +0530 Subject: [PATCH] ch3 --- ast/ast.go | 6 + evaluator/evaluator.go | 243 ++++++++++++++++++++++++++++++++++-- evaluator/evaluator_test.go | 220 +++++++++++++++++++++++++++++++- object/environment.go | 30 +++++ object/object.go | 55 +++++++- parser/parser.go | 10 ++ repl/repl.go | 4 +- 7 files changed, 551 insertions(+), 17 deletions(-) create mode 100644 object/environment.go diff --git a/ast/ast.go b/ast/ast.go index 0ff1d8e..3ade9ee 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -93,6 +93,12 @@ type CallExpression struct { Arguments []Expression } +type FunctionalLiteral struct { + Token token.Token + Parameters []*Identifier + Body *BlockStatement +} + func (ls *LetStatement) statementNode() {} func (ls *LetStatement) TokenLiteral() string { return ls.Token.Literal } diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index aee98a9..f73714c 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -3,6 +3,7 @@ package evaluator import ( "flare/ast" "flare/object" + "fmt" ) var ( @@ -11,24 +12,244 @@ var ( FALSE = &object.Boolean{Value: false} ) -func Eval(node ast.Node) object.Object { +func Eval(node ast.Node, env *object.Environment) object.Object { switch node := node.(type) { case *ast.Program: - return evalStatements(node.Statements) + return evalProgram(node, env) case *ast.ExpressionStatement: - return Eval(node.Expression) + return Eval(node.Expression, env) case *ast.Boolean: return nativeBoolToBooleanObject(node.Value) case *ast.IntegerLiteral: return &object.Integer{Value: node.Value} case *ast.PrefixExpression: - right := Eval(node.Right) + right := Eval(node.Right, env) + if isError(right) { + return right + } return evalPrefixExpression(node.Operator, right) + case *ast.InfixExpression: + left := Eval(node.Left, env) + if isError(left) { + return left + } + right := Eval(node.Right, env) + if isError(right) { + return right + } + return evalInfixExpression(node.Operator, left, right) + case *ast.BlockStatement: + return evalBlockStatements(node, env) + case *ast.IfExpression: + return evalIfExpression(node, env) + case *ast.ReturnStatement: + val := Eval(node.ReturnValue, env) + if isError(val) { + return val + } + return &object.ReturnValue{Value: val} + case *ast.LetStatement: + val := Eval(node.Value, env) + if isError(val) { + return val + } + env.Set(node.Name.Value, val) + case *ast.Identifier: + return evalIdentifier(node, env) + case *ast.FunctionLiteral: + params := node.Parameters + body := node.Body + return &object.Function{Parameters: params, Env: env, Body: body} + case *ast.CallExpression: + function := Eval(node.Function, env) + if isError(function) { + return function + } + args := evalExpressions(node.Arguments, env) + if len(args) == 1 && isError(args[0]) { + return args[0] + } + return applyFunction(function, args) } return nil } +func applyFunction(fn object.Object, args []object.Object) object.Object { + function, ok := fn.(*object.Function) + if !ok { + return newError("not a function: %s", fn.Type()) + } + + extendedEnv := extendedFunctionEnv(function, args) + evaluated := Eval(function.Body, extendedEnv) + + return unwrapReturnValue(evaluated) +} + +func extendedFunctionEnv(fn *object.Function, args []object.Object) *object.Environment { + env := object.NewEnclosedEnvironment(fn.Env) + + for paramIndex, param := range fn.Parameters { + env.Set(param.Value, args[paramIndex]) + } + + return env +} + +func unwrapReturnValue(obj object.Object) object.Object { + if returnValue, ok := obj.(*object.ReturnValue); ok { + return returnValue.Value + } + + return obj +} + +func evalExpressions(exps []ast.Expression, env *object.Environment) []object.Object { + var result []object.Object + + for _, e := range exps { + evaluated := Eval(e, env) + if isError(evaluated) { + return []object.Object{evaluated} + } + result = append(result, evaluated) + } + + return result +} + +func isError(obj object.Object) bool { + if obj != nil { + return obj.Type() == object.ERROR_OBJ + } + return false +} + +func evalIdentifier(node *ast.Identifier, env *object.Environment) object.Object { + val, ok := env.Get(node.Value) + if !ok { + return newError("identifier not found: " + node.Value) + } + + return val +} + +func newError(format string, a ...interface{}) *object.Error { + return &object.Error{Message: fmt.Sprintf(format, a...)} +} + +func evalBlockStatements(block *ast.BlockStatement, env *object.Environment) object.Object { + var result object.Object + + for _, statement := range block.Statements { + result = Eval(statement, env) + + if result != nil { + rt := result.Type() + if rt == object.RETURN_VAL_OBJ || rt == object.ERROR_OBJ { + return result + } + } + } + + return result +} + +func evalProgram(program *ast.Program, env *object.Environment) object.Object { + var result object.Object + + for _, statement := range program.Statements { + result = Eval(statement, env) + + switch result := result.(type) { + case *object.ReturnValue: + return result.Value + case *object.Error: + return result + } + } + + return result +} + +func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Object { + condition := Eval(ie.Condition, env) + + if isError(condition) { + return condition + } + + if isTruthy(condition) { + return Eval(ie.Consequence, env) + } else if ie.Alternative != nil { + return Eval(ie.Alternative, env) + } else { + return NULL + } +} + +func isTruthy(obj object.Object) bool { + switch obj { + case NULL: + return false + case TRUE: + return true + case FALSE: + return false + default: + return true + } +} + +func evalInfixExpression(operator string, left, right object.Object) object.Object { + if left.Type() != right.Type() { + return newError("type mismatch: %s %s %s", left.Type(), operator, right.Type()) + } + + switch left.Type() { + case object.INTEGER_OBJ: + return evalIntegerInfixExpression(operator, left, right) + case object.BOOLEAN_OBJ: + switch operator { + case "==": + return nativeBoolToBooleanObject(left == right) + case "!=": + return nativeBoolToBooleanObject(left != right) + default: + return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type()) + } + default: + return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type()) + } +} + +func evalIntegerInfixExpression(operator string, left, right object.Object) object.Object { + leftVal := left.(*object.Integer).Value + rightVal := right.(*object.Integer).Value + + switch operator { + case "+": + return &object.Integer{Value: leftVal + rightVal} + case "-": + return &object.Integer{Value: leftVal - rightVal} + case "*": + return &object.Integer{Value: leftVal * rightVal} + case "/": + return &object.Integer{Value: leftVal / rightVal} + case "<": + return nativeBoolToBooleanObject(leftVal < rightVal) + case ">": + return nativeBoolToBooleanObject(leftVal > rightVal) + case "==": + return nativeBoolToBooleanObject(leftVal == rightVal) + case "!=": + return nativeBoolToBooleanObject(leftVal != rightVal) + default: + return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type()) + } +} + func evalPrefixExpression(operator string, right object.Object) object.Object { switch operator { case "!": @@ -36,13 +257,13 @@ func evalPrefixExpression(operator string, right object.Object) object.Object { case "-": return evalMinusPrefixOperatorExpression(right) default: - return NULL + return newError("unknown operator: %s%s", operator, right.Type()) } } func evalMinusPrefixOperatorExpression(right object.Object) object.Object { if right.Type() != object.INTEGER_OBJ { - return NULL + return newError("unknown operator: -%s", right.Type()) } value := right.(*object.Integer).Value @@ -70,11 +291,15 @@ func nativeBoolToBooleanObject(input bool) *object.Boolean { return FALSE } -func evalStatements(stmts []ast.Statement) object.Object { - var result object.Object +func evalStatements(stmts []ast.Statement, env *object.Environment) object.Object { + var result object.Object = NULL for _, statement := range stmts { - result = Eval(statement) + result = Eval(statement, env) + + if returnValue, ok := result.(*object.ReturnValue); ok { + return returnValue.Value + } } return result diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index 7f39d42..0dcef59 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -7,6 +7,201 @@ import ( "testing" ) +func TestFunctionApplication(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"let identity = fn(x) { x; }; identity(5);", 5}, + {"let identity = fn(x) { return x; }; identity(5);", 5}, + {"let double = fn(x) { x * 2; }; double(5);", 10}, + {"let add = fn(x, y) { x + y; }; add(5, 5);", 10}, + {"let add = fn(x, y) { x + y; }; add(5 + 5, add(5, 5));", 20}, + {"fn(x) { x; }(5)", 5}, + } + + for _, tt := range tests { + testIntegerObject(t, testEval(tt.input), tt.expected) + } +} + +func TestClosures(t *testing.T) { + input := ` +let newAdder = fn(x) { +fn(y) { x + y }; +}; +let addTwo = newAdder(2); +addTwo(2);` + + testIntegerObject(t, testEval(input), 4) +} + +func TestFunctionObject(t *testing.T) { + input := "fn(x) { x + 2; };" + evaluated := testEval(input) + + fn, ok := evaluated.(*object.Function) + if !ok { + t.Fatalf("object is not Function. got=%T (%+v)", evaluated, evaluated) + } + + if len(fn.Parameters) != 1 { + t.Fatalf("function has wrong parameters. Parameters=%+v", + fn.Parameters) + } + + if fn.Parameters[0].String() != "x" { + t.Fatalf("parameter is not 'x'. got=%q", fn.Parameters[0]) + } + + expectedBody := "(x + 2)" + if fn.Body.String() != expectedBody { + t.Fatalf("body is not %q. got=%q", expectedBody, fn.Body.String()) + } +} + +func TestLetStatements(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"let a = 5; a;", 5}, + {"let a = 5 * 5; a;", 25}, + {"let a = 5; let b = a; b;", 5}, + {"let a = 5; let b = a; let c = a + b + 5; c;", 15}, + } + + for _, tt := range tests { + testIntegerObject(t, testEval(tt.input), tt.expected) + } +} + +func TestErrorHandling(t *testing.T) { + tests := []struct { + input string + expectedMessage string + }{ + { + "foobar", + "identifier not found: foobar", + }, + { + "5 + true;", + "type mismatch: INTEGER + BOOLEAN", + }, + { + "5 + true; 5;", + "type mismatch: INTEGER + BOOLEAN", + }, + { + "-true", + "unknown operator: -BOOLEAN", + }, + { + "true + false;", + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + "5; true + false; 5", + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + "if (10 > 1) { true + false; }", + "unknown operator: BOOLEAN + BOOLEAN", + }, + { + ` +if (10 > 1) { +if (10 > 1) { +return true + false; +} +return 1; +} +`, + "unknown operator: BOOLEAN + BOOLEAN", + }, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + errObj, ok := evaluated.(*object.Error) + if !ok { + t.Errorf("no error object returned. got=%T(%+v)", + evaluated, evaluated) + continue + } + + if errObj.Message != tt.expectedMessage { + t.Errorf("wrong error message. expected=%q, got=%q", + tt.expectedMessage, errObj.Message) + } + } +} + +func TestReturnStatements(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"return 10;", 10}, + {"return 10; 9;", 10}, + {"return 2 * 5; 9;", 10}, + {"9; return 2 * 5; 9;", 10}, + { + ` +if (10 > 1) { + if (10 > 1) { + return 10; + } + + return 1; +} + `, + 10, + }, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + testIntegerObject(t, evaluated, tt.expected) + } +} + +func TestIfElseExpression(t *testing.T) { + tests := []struct { + input string + expected interface{} + }{ + {"if (true) { 10 }", 10}, + {"if (false) { 10 }", nil}, + {"if (1) { 10 }", 10}, + {"if (1 < 2) { 10 }", 10}, + {"if (1 > 2) { 10 }", nil}, + {"if (1 > 2) { 10 } else { 20 }", 20}, + {"if (1 < 2) { 10 } else { 20 }", 10}, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + + integer, ok := tt.expected.(int) + if ok { + testIntegerObject(t, evaluated, int64(integer)) + } else { + testNullObject(t, evaluated) + } + } +} + +func testNullObject(t *testing.T, obj object.Object) bool { + if obj != NULL { + t.Errorf("object is not NULL. got=%T (%+v)", obj, obj) + return false + } + + return true +} + func TestBangOperator(t *testing.T) { tests := []struct { input string @@ -33,6 +228,14 @@ func TestEvalBooleanExpression(t *testing.T) { }{ {"true", true}, {"false", false}, + {"1 < 2", true}, + {"1 > 2", false}, + {"1 < 1", false}, + {"1 > 1", false}, + {"1 == 1", true}, + {"1 != 1", false}, + {"1 == 2", false}, + {"1 != 2", true}, } for _, tt := range tests { @@ -49,8 +252,7 @@ func testBooleanObject(t *testing.T, obj object.Object, expected bool) bool { } if result.Value != expected { - t.Errorf("object has wrong value. got=%t, want=%t", - result.Value, expected) + t.Errorf("object has wrong value. got=%t, want=%t", result.Value, expected) return false } @@ -66,6 +268,17 @@ func TestEvalIntegerExpression(t *testing.T) { {"10", 10}, {"-5", -5}, {"-10", -10}, + {"5 + 5 + 5 + 5 - 10", 10}, + {"2 * 2 * 2 * 2 * 2", 32}, + {"-50 + 100 + -50", 0}, + {"5 * 2 + 10", 20}, + {"5 + 2 * 10", 25}, + {"20 + 2 * -10", 0}, + {"50 / 2 * 2 + 10", 60}, + {"2 * (5 + 10)", 30}, + {"3 * 3 * 3 + 10", 37}, + {"3 * (3 * 3) + 10", 37}, + {"(5 + 10 * 2 + 15 / 3) * 2 + -10", 50}, } for _, tt := range tests { @@ -78,8 +291,9 @@ func testEval(input string) object.Object { l := lexer.New(input) p := parser.New(l) program := p.ParseProgram() + env := object.NewEnvironment() - return Eval(program) + return Eval(program, env) } func testIntegerObject(t *testing.T, obj object.Object, expected int64) bool { diff --git a/object/environment.go b/object/environment.go new file mode 100644 index 0000000..f302865 --- /dev/null +++ b/object/environment.go @@ -0,0 +1,30 @@ +package object + +type Environment struct { + store map[string]Object + outer *Environment +} + +func NewEnclosedEnvironment(outer *Environment) *Environment { + env := NewEnvironment() + env.outer = outer + return env +} + +func NewEnvironment() *Environment { + s := make(map[string]Object) + return &Environment{store: s, outer: nil} +} + +func (e *Environment) Get(name string) (Object, bool) { + obj, ok := e.store[name] + if !ok && e.outer != nil { + obj, ok = e.outer.Get(name) + } + return obj, ok +} + +func (e *Environment) Set(name string, val Object) Object { + e.store[name] = val + return val +} diff --git a/object/object.go b/object/object.go index 5c9a336..c46db78 100644 --- a/object/object.go +++ b/object/object.go @@ -1,13 +1,21 @@ package object -import "fmt" +import ( + "bytes" + "flare/ast" + "fmt" + "strings" +) type ObjectType string const ( - INTEGER_OBJ = "INTEGER" - BOOLEAN_OBJ = "BOOLEAN" - NULL_OBJ = "NULL" + INTEGER_OBJ = "INTEGER" + BOOLEAN_OBJ = "BOOLEAN" + NULL_OBJ = "NULL" + RETURN_VAL_OBJ = "RETURN_VAL" + ERROR_OBJ = "ERROR" + FUNCTION_OBJ = "FUNCTION" ) type Object interface { @@ -25,6 +33,20 @@ type Boolean struct { type Null struct{} +type ReturnValue struct { + Value Object +} + +type Error struct { + Message string +} + +type Function struct { + Parameters []*ast.Identifier + Body *ast.BlockStatement + Env *Environment +} + func (i *Integer) Inspect() string { return fmt.Sprintf("%d", i.Value) } func (i *Integer) Type() ObjectType { return INTEGER_OBJ } @@ -33,3 +55,28 @@ func (b *Boolean) Inspect() string { return fmt.Sprintf("%t", b.Value) } func (n *Null) Type() ObjectType { return NULL_OBJ } func (n *Null) Inspect() string { return "null" } + +func (rv *ReturnValue) Type() ObjectType { return RETURN_VAL_OBJ } +func (rv *ReturnValue) Inspect() string { return rv.Value.Inspect() } + +func (e *Error) Type() ObjectType { return ERROR_OBJ } +func (e *Error) Inspect() string { return "ERROR: " + e.Message } + +func (f *Function) Type() ObjectType { return FUNCTION_OBJ } +func (f *Function) Inspect() string { + var out bytes.Buffer + + params := []string{} + for _, p := range f.Parameters { + params = append(params, p.String()) + } + + out.WriteString("fn") + out.WriteString("(") + out.WriteString(strings.Join(params, ", ")) + out.WriteString(") {\n") + out.WriteString(f.Body.String()) + out.WriteString("\n}") + + return out.String() +} diff --git a/parser/parser.go b/parser/parser.go index 87e4252..2ead88c 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -176,6 +176,16 @@ func (p *Parser) parseIfExpression() ast.Expression { expression.Consequence = p.parseBlockStatement() + if p.peekTokenIs(token.ELSE) { + p.nextToken() + + if !p.expectPeek(token.LBRACE) { + return nil + } + + expression.Alternative = p.parseBlockStatement() + } + return expression } diff --git a/repl/repl.go b/repl/repl.go index 6570bf0..5df30f3 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -4,6 +4,7 @@ import ( "bufio" "flare/evaluator" "flare/lexer" + "flare/object" "flare/parser" "fmt" "io" @@ -13,6 +14,7 @@ const PROOMPT = ">> " func Start(in io.Reader, out io.Writer) { scanner := bufio.NewScanner(in) + env := object.NewEnvironment() for { fmt.Print(PROOMPT) @@ -37,7 +39,7 @@ func Start(in io.Reader, out io.Writer) { continue } - evaluated := evaluator.Eval(program) + evaluated := evaluator.Eval(program, env) if evaluated != nil { io.WriteString(out, evaluated.Inspect()) io.WriteString(out, "\n")