From 4ca2e61654357141ab186ac213d3b29033e24ea5 Mon Sep 17 00:00:00 2001 From: Ishan Jain Date: Fri, 17 May 2024 20:02:51 +0530 Subject: [PATCH] implemented function evaluator --- src/evaluator/mod.rs | 97 ++++++++++++++++++++++++++++++-- src/evaluator/tree_walker/mod.rs | 74 ++++++++++++++++++++++-- src/main.rs | 2 + src/parser/ast/mod.rs | 38 ++++++------- 4 files changed, 184 insertions(+), 27 deletions(-) diff --git a/src/evaluator/mod.rs b/src/evaluator/mod.rs index 6755d6e..7774a33 100644 --- a/src/evaluator/mod.rs +++ b/src/evaluator/mod.rs @@ -1,8 +1,9 @@ use { - crate::parser::ast::Node, + crate::parser::ast::{BlockStatement, Identifier, Node}, + itertools::Itertools, std::{ collections::HashMap, - fmt::{Display, Formatter, Result as FmtResult}, + fmt::{Display, Formatter, Result as FmtResult, Write}, }, }; pub mod tree_walker; @@ -11,33 +12,47 @@ pub trait Evaluator { fn eval(&self, node: Node, env: &mut Environment) -> Option; } +#[derive(Debug, Clone, PartialEq)] pub struct Environment { store: HashMap, + outer: Option>, } impl Environment { pub fn new() -> Self { Self { store: HashMap::new(), + outer: None, } } pub fn get(&self, name: &str) -> Option { match self.store.get(name) { Some(v) => Some(v.clone()), - None => None, + None => match &self.outer { + Some(outer) => outer.get(name), + None => None, + }, } } pub fn set(&mut self, name: String, val: Object) { self.store.insert(name, val); } + + pub fn new_enclosed(env: Environment) -> Self { + Self { + store: HashMap::new(), + outer: Some(Box::new(env)), + } + } } -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum Object { Integer(i64), Boolean(bool), ReturnValue(Box), Error(String), + Function(Function), Null, } @@ -53,6 +68,18 @@ impl Object { Object::ReturnValue(ret) => ret.inspect(), Object::Error(s) => s.to_string(), Object::Null => "NULL".into(), + Object::Function(s) => { + let mut out = String::new(); + + out.write_fmt(format_args!( + "fn({}) {{ {} }}", + s.parameters.iter().map(|x| x.to_string()).join(", "), + s.body.to_string() + )) + .unwrap(); + + out + } } } } @@ -64,13 +91,23 @@ impl Display for Object { Object::Boolean(_) => "BOOLEAN", Object::ReturnValue(_) => "RETURN_VALUE", Object::Error(_) => "ERROR", + Object::Function(_) => "FUNCTION", Object::Null => "NULL", }) } } +#[derive(Debug, Clone, PartialEq)] +pub struct Function { + parameters: Vec, + body: BlockStatement, + env: Environment, +} + #[cfg(test)] mod tests { + use std::assert_matches::assert_matches; + use crate::{ evaluator::{tree_walker::TreeWalker, Environment, Evaluator, Object, FALSE, NULL, TRUE}, lexer::Lexer, @@ -257,4 +294,56 @@ mod tests { run_test_cases(&test_cases); } + + #[test] + fn test_function_object() { + let test_case = "fn(x) { x + 2;};"; + + let lexer = Lexer::new(&test_case); + let mut parser = Parser::new(lexer); + let program = parser.parse_program(); + assert!(program.is_some()); + let program = program.unwrap(); + let evaluator = TreeWalker::new(); + let mut env = Environment::new(); + let eval = evaluator.eval(Node::Program(program), &mut env); + let node = eval.unwrap(); + + assert_matches!(node, Object::Function(_)); + + if let Object::Function(ref f) = node { + assert_eq!(f.parameters.len(), 1); + assert_eq!(f.parameters.first().unwrap().to_string(), "x"); + assert_eq!(f.body.to_string(), "(x + 2)"); + } + } + + #[test] + fn test_function_application() { + let test_cases = [ + ( + "let identity = fn(x) { x; }; identity(5);", + Some(Object::Integer(5)), + ), + ( + "let identity = fn(x) { return x; }; identity(9);", + Some(Object::Integer(9)), + ), + ( + "let double = fn(x) { return x * 2; }; double(8);", + Some(Object::Integer(16)), + ), + ( + "let add = fn(x,y) { return x +y; }; add(10, 20);", + Some(Object::Integer(30)), + ), + ( + "let add = fn(x,y) { return x +y; }; add(20 + 25, add(3+1,2)));", + Some(Object::Integer(51)), + ), + ("fn(x) { x; }(5)", Some(Object::Integer(5))), + ]; + + run_test_cases(&test_cases); + } } diff --git a/src/evaluator/tree_walker/mod.rs b/src/evaluator/tree_walker/mod.rs index 788d594..1039621 100644 --- a/src/evaluator/tree_walker/mod.rs +++ b/src/evaluator/tree_walker/mod.rs @@ -10,7 +10,7 @@ use crate::{ }, }; -use super::Environment; +use super::{Environment, Function}; pub struct TreeWalker; @@ -34,12 +34,10 @@ impl Evaluator for TreeWalker { Some(Object::ReturnValue(Box::new(ret_val))) } Statement::Let(LetStatement { name, value }) => { - let value = self.eval(Node::Expression(value.unwrap()), env)?; + let value = self.eval(Node::Expression(value?), env)?; env.set(name.to_string(), value.clone()); Some(value) } - - _ => None, }, Node::Expression(expr) => match expr { Expression::Identifier(v) => self.eval_identifier(v, env), @@ -54,6 +52,24 @@ impl Evaluator for TreeWalker { let right = self.eval(Node::Expression(*ie.right), env)?; self.eval_infix_expression(left, ie.operator, right) } + Expression::FunctionExpression(fnl) => { + return Some(Object::Function(Function { + body: fnl.body, + parameters: fnl.parameters, + env: env.clone(), + })) + } + Expression::CallExpression(v) => { + let function = self.eval(Node::Expression(*v.function), env)?; + // Resolve function arguments and update the environment + // before executing function body + let args = match self.eval_expression(v.arguments, env) { + Ok(v) => v, + Err(e) => return Some(e), + }; + + self.apply_function(function, args) + } Expression::IfExpression(ie) => { let condition = self.eval(Node::Expression(*ie.condition), env)?; @@ -199,4 +215,54 @@ impl TreeWalker { node.to_string() )))) } + + fn eval_expression( + &self, + exprs: Vec, + env: &mut Environment, + ) -> Result, Object> { + let mut out = vec![]; + + for expr in exprs { + match self.eval(Node::Expression(expr), env) { + Some(v @ Object::Error(_)) => return Err(v), + Some(v) => out.push(v), + None => { + break; + } + } + } + + Ok(out) + } + + fn apply_function(&self, function: Object, args: Vec) -> Option { + let function = match function { + Object::Function(f) => f, + v => return Some(Object::Error(format!("not a function: {}", v.to_string()))), + }; + + let mut enclosed_env = Environment::new_enclosed(function.env); + for (i, parameter) in function.parameters.iter().enumerate() { + if args.len() <= i { + println!("{:?} {}", args, i); + return Some(Object::Error(format!("incorrect number of arguments"))); + } + + enclosed_env.set(parameter.value.clone(), args[i].clone()); + } + + let resp = self.eval( + Node::Statement(Statement::BlockStatement(function.body)), + &mut enclosed_env, + ); + + // Unwrap return here to prevent it from bubbling up the stack + // and stopping execution elsewhere. + if let Some(Object::ReturnValue(v)) = resp.as_ref() { + return Some(*v.clone()); + } + + resp + } } diff --git a/src/main.rs b/src/main.rs index 0708c1e..3c4d77f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +#![feature(assert_matches)] + #[macro_use] extern crate lazy_static; diff --git a/src/parser/ast/mod.rs b/src/parser/ast/mod.rs index 331e7ea..be6cf9c 100644 --- a/src/parser/ast/mod.rs +++ b/src/parser/ast/mod.rs @@ -34,7 +34,7 @@ impl Display for Program { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub enum Statement { Let(LetStatement), Return(ReturnStatement), @@ -65,7 +65,7 @@ impl Display for Statement { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct LetStatement { // name field is to store the identifier of the binding pub name: Identifier, @@ -113,7 +113,7 @@ impl Display for LetStatement { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct ReturnStatement { pub value: Option, } @@ -145,7 +145,7 @@ impl Display for ReturnStatement { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct ExpressionStatement { token: Token, pub expression: Expression, @@ -184,7 +184,7 @@ pub enum ExpressionPriority { Call = 6, } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub enum Expression { Identifier(Identifier), IntegerLiteral(IntegerLiteral), @@ -263,10 +263,10 @@ impl From<&Expression> for String { // Identifier will be an expression // Identifier in a let statement like, let x = 5; where `x` is an identifier doesn't produce a value // but an identifier *can* produce value when used on rhs, e.g. let x = y; Here `y` is producing a value -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct Identifier { token: TokenType, - value: String, + pub value: String, } impl Identifier { @@ -290,7 +290,7 @@ impl Display for Identifier { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct IntegerLiteral { pub value: i64, } @@ -313,7 +313,7 @@ impl IntegerLiteral { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct PrefixExpression { pub operator: TokenType, pub right: Box, @@ -347,7 +347,7 @@ impl Display for PrefixExpression { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct InfixExpression { pub left: Box, pub operator: TokenType, @@ -383,7 +383,7 @@ impl Display for InfixExpression { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct BooleanExpression { token: TokenType, pub value: bool, @@ -409,7 +409,7 @@ impl Display for BooleanExpression { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct IfExpression { pub condition: Box, pub consequence: BlockStatement, @@ -461,7 +461,7 @@ impl Display for IfExpression { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct BlockStatement { pub statements: Vec, } @@ -501,11 +501,11 @@ impl Display for BlockStatement { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct FunctionLiteral { token: Token, - parameters: Vec, - body: BlockStatement, + pub parameters: Vec, + pub body: BlockStatement, } impl FunctionLiteral { @@ -567,10 +567,10 @@ impl Display for FunctionLiteral { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct CallExpression { - function: Box, - arguments: Vec, + pub function: Box, + pub arguments: Vec, } impl CallExpression {