diff --git a/src/evaluator/mod.rs b/src/evaluator/mod.rs index 24dabfc..2920add 100644 --- a/src/evaluator/mod.rs +++ b/src/evaluator/mod.rs @@ -1,14 +1,19 @@ -use crate::parser::ast::Node; +use { + crate::parser::ast::Node, + std::fmt::{Display, Formatter, Result as FmtResult}, +}; pub mod tree_walker; pub trait Evaluator { fn eval(&self, node: Node) -> Option; } -#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq)] pub enum Object { Integer(i64), Boolean(bool), + ReturnValue(Box), + Error(String), Null, } @@ -21,11 +26,25 @@ impl Object { match self { Object::Integer(v) => v.to_string(), Object::Boolean(v) => v.to_string(), + Object::ReturnValue(ret) => ret.inspect(), + Object::Error(s) => s.to_string(), Object::Null => "NULL".into(), } } } +impl Display for Object { + fn fmt(&self, f: &mut Formatter) -> FmtResult { + f.write_str(match self { + Object::Integer(_) => "INTEGER", + Object::Boolean(_) => "BOOLEAN", + Object::ReturnValue(_) => "RETURN_VALUE", + Object::Error(_) => "ERROR", + Object::Null => "NULL", + }) + } +} + #[cfg(test)] mod tests { use crate::{ @@ -34,13 +53,7 @@ mod tests { parser::{ast::Node, Parser}, }; - #[test] - fn eval_integer_expression() { - let test_cases = [ - ("5", Some(Object::Integer(5))), - ("10", Some(Object::Integer(10))), - ]; - + fn run_test_cases(test_cases: &[(&str, Option)]) { for test in test_cases.iter() { let lexer = Lexer::new(test.0); let mut parser = Parser::new(lexer); @@ -53,6 +66,16 @@ mod tests { } } + #[test] + fn eval_integer_expression() { + let test_cases = [ + ("5", Some(Object::Integer(5))), + ("10", Some(Object::Integer(10))), + ]; + + run_test_cases(&test_cases); + } + #[test] fn eval_boolean_expression() { let test_cases = [ @@ -67,17 +90,7 @@ mod tests { ("(1 > 2 ) == true", Some(FALSE)), ("(1 > 2 ) == false", Some(TRUE)), ]; - - for test in test_cases.iter() { - let lexer = Lexer::new(test.0); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - assert!(program.is_some()); - let program = program.unwrap(); - let evaluator = TreeWalker::new(); - let eval = evaluator.eval(Node::Program(program)); - assert_eq!(eval, test.1); - } + run_test_cases(&test_cases); } #[test] @@ -90,16 +103,7 @@ mod tests { ("!true", Some(FALSE)), ("!false", Some(TRUE)), ]; - for test in test_cases.iter() { - let lexer = Lexer::new(test.0); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - assert!(program.is_some()); - let program = program.unwrap(); - let evaluator = TreeWalker::new(); - let eval = evaluator.eval(Node::Program(program)); - assert_eq!(eval, test.1); - } + run_test_cases(&test_cases); } #[test] @@ -121,16 +125,7 @@ mod tests { ("(5 + 10 * 2 + 15 / 3) * 2 + -10", Some(Object::Integer(50))), ]; - for test in test_cases.iter() { - let lexer = Lexer::new(test.0); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - assert!(program.is_some()); - let program = program.unwrap(); - let evaluator = TreeWalker::new(); - let eval = evaluator.eval(Node::Program(program)); - assert_eq!(eval, test.1); - } + run_test_cases(&test_cases); } #[test] @@ -154,15 +149,68 @@ mod tests { ("if (1 < 2) {10} else {20}", Some(Object::Integer(10))), ]; - for test in test_cases.iter() { - let lexer = Lexer::new(test.0); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - assert!(program.is_some()); - let program = program.unwrap(); - let evaluator = TreeWalker::new(); - let eval = evaluator.eval(Node::Program(program)); - assert_eq!(eval, test.1); - } + run_test_cases(&test_cases); + } + + #[test] + fn eval_return_statements() { + let test_cases = [ + ("return 10; ", Some(Object::Integer(10))), + ("return 10; 9;", Some(Object::Integer(10))), + ("return 2 * 5; 9;", Some(Object::Integer(10))), + ("9; return 2 * 5; 9;", Some(Object::Integer(10))), + ( + "if(10 > 1) { + if(10 > 2) { + return 10; + } + return 1; + }", + Some(Object::Integer(10)), + ), + ]; + + run_test_cases(&test_cases); + } + + #[test] + fn error_handling() { + let test_cases = [ + ( + "-true", + Some(Object::Error("unknown operator: -BOOLEAN".into())), + ), + ( + "true + false;", + Some(Object::Error("unknown operator: BOOLEAN + BOOLEAN".into())), + ), + ( + "5 + true;", + Some(Object::Error("type mismatch: INTEGER + BOOLEAN".into())), + ), + ( + "5 + true; 5", + Some(Object::Error("type mismatch: INTEGER + BOOLEAN".into())), + ), + ( + "5; true + false; 5", + Some(Object::Error("unknown operator: BOOLEAN + BOOLEAN".into())), + ), + ( + "if (10>1){true + false;}", + Some(Object::Error("unknown operator: BOOLEAN + BOOLEAN".into())), + ), + ( + "if(10 > 1) { + if(10 > 2) { + return true + false; + } + return 1; + }", + Some(Object::Error("unknown operator: BOOLEAN + BOOLEAN".into())), + ), + ]; + + run_test_cases(&test_cases); } } diff --git a/src/evaluator/tree_walker/mod.rs b/src/evaluator/tree_walker/mod.rs index 3250cf7..31edfd0 100644 --- a/src/evaluator/tree_walker/mod.rs +++ b/src/evaluator/tree_walker/mod.rs @@ -1,7 +1,10 @@ +// TODO: This is all a mess. Almost certainly because right now, I don't know any better way to do this. +// It's just constantly unwrapping enums from one place and rewrapping it to some other enum(or even the same enum) and returning it +// The error handling story is pretty bad too use crate::{ evaluator::{Evaluator, Object, FALSE, NULL, TRUE}, lexer::TokenType, - parser::ast::{Expression, ExpressionStatement, Node, Statement}, + parser::ast::{BlockStatement, Expression, ExpressionStatement, Node, Program, Statement}, }; pub struct TreeWalker; @@ -15,12 +18,16 @@ impl TreeWalker { impl Evaluator for TreeWalker { fn eval(&self, node: Node) -> Option { match node { - Node::Program(p) => self.eval_statements(p.statements), + Node::Program(p) => self.eval_program(p), Node::Statement(stmt) => match stmt { Statement::ExpressionStatement(ExpressionStatement { expression, .. }) => { self.eval(Node::Expression(expression)) } - Statement::BlockStatement(bs) => self.eval_statements(bs.statements), + Statement::BlockStatement(bs) => self.eval_block_statement(bs), + Statement::Return(ret) => { + let ret_val = self.eval(Node::Expression(ret.value?))?; + Some(Object::ReturnValue(Box::new(ret_val))) + } _ => None, }, Node::Expression(expr) => match expr { @@ -53,10 +60,54 @@ impl Evaluator for TreeWalker { } impl TreeWalker { - fn eval_statements(&self, stmts: Vec) -> Option { + fn eval_program(&self, prg: Program) -> Option { let mut out: Option = None; - for stmt in stmts { - out = self.eval(Node::Statement(stmt)) + for stmt in prg.statements { + out = self.eval(Node::Statement(stmt)); + // No need to evaluate any more statements from a statements vector once we + // get a return keyword. nothing after in the block matters. + if let Some(out) = out.clone() { + match out { + Object::ReturnValue(v) => return Some(*v), + Object::Error(_) => return Some(out), + _ => {} + } + } + } + out + } + + fn eval_block_statement(&self, bs: BlockStatement) -> Option { + let mut out: Option = None; + + for stmt in bs.statements { + out = self.eval(Node::Statement(stmt)); + + // TODO: Find a nicer way to do this. :( + // The objective here is, + // If we encounter a node of type ReturnValue, Don't unwrap it. Return it as is + // So, It can evaluated again by the eval function. + // This is helpful when we have a nested structure with multiple return statments. + // something like, + // if (true) { + // if (true) { + // return 10; + // } + // return 1; + // } + // This will return 1 if we unwrap return right here and return the value within. + // But in reality that shouldn't happen. It should be returning 10 + // So, We don't unwrap the ReturnValue node when we encounter 10. Just return it as is and it is later eval-ed + // and the correct value is returned + if let Some(out) = out.clone() { + match out { + Object::ReturnValue(v) => { + return Some(Object::ReturnValue(v)); + } + Object::Error(_) => return Some(out), + _ => {} + } + } } out } @@ -65,7 +116,10 @@ impl TreeWalker { match operator { TokenType::Bang => Some(self.eval_bang_operator_expression(expr)), TokenType::Minus => Some(self.eval_minus_prefix_operator_expression(expr)), - _ => None, + _ => Some(Object::Error(format!( + "unknown operator: {}{}", + operator, expr + ))), } } @@ -75,7 +129,7 @@ impl TreeWalker { operator: TokenType, right: Object, ) -> Option { - Some(match (left, right) { + Some(match (left.clone(), right.clone()) { (Object::Integer(l), Object::Integer(r)) => match operator { TokenType::Plus => Object::Integer(l + r), TokenType::Minus => Object::Integer(l - r), @@ -85,12 +139,16 @@ impl TreeWalker { TokenType::NotEquals => Object::Boolean(l != r), TokenType::GreaterThan => Object::Boolean(l > r), TokenType::LessThan => Object::Boolean(l < r), - _ => NULL, + _ => Object::Error(format!("unknown operator: {} {} {}", l, operator, r)), }, + (o1 @ _, o2 @ _) if o1.to_string() != o2.to_string() => { + Object::Error(format!("type mismatch: {} {} {}", o1, operator, o2)) + } + _ => match operator { TokenType::Equals => Object::Boolean(left == right), TokenType::NotEquals => Object::Boolean(left != right), - _ => NULL, + _ => Object::Error(format!("unknown operator: {} {} {}", left, operator, right)), }, }) } @@ -107,7 +165,7 @@ impl TreeWalker { fn eval_minus_prefix_operator_expression(&self, expr: Object) -> Object { match expr { Object::Integer(v) => Object::Integer(-v), - _ => NULL, + v @ _ => Object::Error(format!("unknown operator: -{}", v)), } } diff --git a/src/parser/ast/mod.rs b/src/parser/ast/mod.rs index df9128b..1e35288 100644 --- a/src/parser/ast/mod.rs +++ b/src/parser/ast/mod.rs @@ -115,21 +115,19 @@ impl Display for LetStatement { #[derive(Debug, PartialEq)] pub struct ReturnStatement { - return_value: Option, + pub value: Option, } impl ReturnStatement { pub fn new(expr: Expression) -> Self { - ReturnStatement { - return_value: Some(expr), - } + ReturnStatement { value: Some(expr) } } fn parse(parser: &mut Parser) -> Option { let token = parser.lexer.next()?; let expr = Expression::parse(parser, token, ExpressionPriority::Lowest); parser.expect_peek(TokenType::Semicolon)?; - Some(ReturnStatement { return_value: expr }) + Some(ReturnStatement { value: expr }) } } @@ -137,7 +135,7 @@ impl Display for ReturnStatement { fn fmt(&self, f: &mut Formatter) -> FmtResult { let mut out: String = TokenType::Return.to_string(); - if let Some(v) = &self.return_value { + if let Some(v) = &self.value { out.push(' '); let a: String = v.into(); out.push_str(&a); @@ -660,12 +658,9 @@ mod tests { ))), }), Statement::Return(ReturnStatement { - return_value: Some(Expression::Identifier(Identifier::new( - TokenType::Int, - "5", - ))), + value: Some(Expression::Identifier(Identifier::new(TokenType::Int, "5"))), }), - Statement::Return(ReturnStatement { return_value: None }), + Statement::Return(ReturnStatement { value: None }), ], }; assert_eq!( diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 148f3e8..47c626e 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -181,6 +181,18 @@ mod tests { } } + fn check_test_cases(test_cases: &[(&str, Vec)]) { + for test in test_cases.iter() { + let lexer = Lexer::new(test.0); + let mut parser = Parser::new(lexer); + let program = parser.parse_program(); + check_parser_errors(&parser); + assert_eq!(parser.errors.len(), 0); + assert!(program.is_some()); + assert_eq!(program.unwrap().statements, test.1); + } + } + #[test] fn let_statements() { let test_cases = [( @@ -201,15 +213,7 @@ mod tests { ], )]; - for test in test_cases.iter() { - let lexer = Lexer::new(test.0); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - check_parser_errors(&parser); - assert_eq!(parser.errors.len(), 0); - assert!(program.is_some()); - assert_eq!(program.unwrap().statements, test.1); - } + check_test_cases(&test_cases); let fail_case = "let x 5; let 10; let 83838383;"; @@ -241,30 +245,12 @@ mod tests { ], )]; - for test in test_cases.iter() { - let lexer = Lexer::new(test.0); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - check_parser_errors(&parser); - assert_eq!(parser.errors.len(), 0); - assert!(program.is_some()); - assert_eq!(program.unwrap().statements, test.1); - } - - let lexer = Lexer::new("return 5; return 10; return add(10);"); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - - check_parser_errors(&parser); - assert_eq!(parser.errors.len(), 0); - assert!(program.is_some()); - let program = program.unwrap(); - assert_eq!(program.statements.len(), 3); - assert_eq!(parser.errors.len(), 0); + check_test_cases(&test_cases); } #[test] fn identifier_expression() { + // TODO: Add more tests for this let lexer = Lexer::new("foobar;"); let mut parser = Parser::new(lexer); let program = parser.parse_program(); @@ -303,7 +289,7 @@ mod tests { #[test] fn prefix_expressions() { - let prefix_tests = [ + let test_cases = [ ( "!5", vec![Statement::ExpressionStatement(ExpressionStatement::new( @@ -372,20 +358,12 @@ mod tests { ), ]; - for test in prefix_tests.iter() { - let lexer = Lexer::new(test.0); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - check_parser_errors(&parser); - assert_eq!(parser.errors.len(), 0); - assert!(program.is_some()); - assert_eq!(program.unwrap().statements, test.1); - } + check_test_cases(&test_cases); } #[test] fn parsing_infix_expressions() { - let infix_tests = [ + let test_cases = [ ( "5 + 10;", vec![Statement::ExpressionStatement(ExpressionStatement::new( @@ -490,15 +468,7 @@ mod tests { ))], ), ]; - for test in infix_tests.iter() { - let lexer = Lexer::new(test.0); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - check_parser_errors(&parser); - assert_eq!(parser.errors.len(), 0); - assert!(program.is_some()); - assert_eq!(program.unwrap().statements, test.1); - } + check_test_cases(&test_cases); } #[test] @@ -570,16 +540,7 @@ mod tests { ))], ), ]; - - for test in test_cases.iter() { - let lexer = Lexer::new(test.0); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - check_parser_errors(&parser); - assert_eq!(parser.errors.len(), 0); - assert!(program.is_some()); - assert_eq!(program.unwrap().statements, test.1); - } + check_test_cases(&test_cases); } #[test] fn if_expression() { @@ -604,15 +565,7 @@ mod tests { ))], )]; - for test in test_cases.iter() { - let lexer = Lexer::new(test.0); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - check_parser_errors(&parser); - assert_eq!(parser.errors.len(), 0); - assert!(program.is_some()); - assert_eq!(program.unwrap().statements, test.1); - } + check_test_cases(&test_cases); } #[test] fn if_else_expression() { @@ -641,16 +594,7 @@ mod tests { )), ))], )]; - - for test in test_cases.iter() { - let lexer = Lexer::new(test.0); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - check_parser_errors(&parser); - assert_eq!(parser.errors.len(), 0); - assert!(program.is_some()); - assert_eq!(program.unwrap().statements, test.1); - } + check_test_cases(&test_cases); } #[test] @@ -762,15 +706,7 @@ mod tests { ), ]; - for test in test_cases.iter() { - let lexer = Lexer::new(test.0); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - check_parser_errors(&parser); - assert_eq!(parser.errors.len(), 0); - assert!(program.is_some()); - assert_eq!(program.unwrap().statements, test.1); - } + check_test_cases(&test_cases); } #[test] fn call_expression_parsing() { @@ -910,15 +846,7 @@ mod tests { ), ]; - for test in test_cases.iter() { - let lexer = Lexer::new(test.0); - let mut parser = Parser::new(lexer); - let program = parser.parse_program(); - check_parser_errors(&parser); - assert_eq!(parser.errors.len(), 0); - assert!(program.is_some()); - assert_eq!(program.unwrap().statements, test.1); - } + check_test_cases(&test_cases); } #[test] fn call_expression_parsing_string() {