implemented function evaluator

This commit is contained in:
Ishan Jain 2024-05-17 20:02:51 +05:30
parent 697ca392b3
commit 4ca2e61654
Signed by: ishan
GPG Key ID: 0506DB2A1CC75C27
4 changed files with 184 additions and 27 deletions

View File

@ -1,8 +1,9 @@
use { use {
crate::parser::ast::Node, crate::parser::ast::{BlockStatement, Identifier, Node},
itertools::Itertools,
std::{ std::{
collections::HashMap, collections::HashMap,
fmt::{Display, Formatter, Result as FmtResult}, fmt::{Display, Formatter, Result as FmtResult, Write},
}, },
}; };
pub mod tree_walker; pub mod tree_walker;
@ -11,33 +12,47 @@ pub trait Evaluator {
fn eval(&self, node: Node, env: &mut Environment) -> Option<Object>; fn eval(&self, node: Node, env: &mut Environment) -> Option<Object>;
} }
#[derive(Debug, Clone, PartialEq)]
pub struct Environment { pub struct Environment {
store: HashMap<String, Object>, store: HashMap<String, Object>,
outer: Option<Box<Environment>>,
} }
impl Environment { impl Environment {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
store: HashMap::new(), store: HashMap::new(),
outer: None,
} }
} }
pub fn get(&self, name: &str) -> Option<Object> { pub fn get(&self, name: &str) -> Option<Object> {
match self.store.get(name) { match self.store.get(name) {
Some(v) => Some(v.clone()), Some(v) => Some(v.clone()),
None => match &self.outer {
Some(outer) => outer.get(name),
None => None, None => None,
},
} }
} }
pub fn set(&mut self, name: String, val: Object) { pub fn set(&mut self, name: String, val: Object) {
self.store.insert(name, val); self.store.insert(name, val);
} }
pub fn new_enclosed(env: Environment) -> Self {
Self {
store: HashMap::new(),
outer: Some(Box::new(env)),
}
}
} }
#[derive(Debug, Clone, Eq, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum Object { pub enum Object {
Integer(i64), Integer(i64),
Boolean(bool), Boolean(bool),
ReturnValue(Box<Object>), ReturnValue(Box<Object>),
Error(String), Error(String),
Function(Function),
Null, Null,
} }
@ -53,6 +68,18 @@ impl Object {
Object::ReturnValue(ret) => ret.inspect(), Object::ReturnValue(ret) => ret.inspect(),
Object::Error(s) => s.to_string(), Object::Error(s) => s.to_string(),
Object::Null => "NULL".into(), Object::Null => "NULL".into(),
Object::Function(s) => {
let mut out = String::new();
out.write_fmt(format_args!(
"fn({}) {{ {} }}",
s.parameters.iter().map(|x| x.to_string()).join(", "),
s.body.to_string()
))
.unwrap();
out
}
} }
} }
} }
@ -64,13 +91,23 @@ impl Display for Object {
Object::Boolean(_) => "BOOLEAN", Object::Boolean(_) => "BOOLEAN",
Object::ReturnValue(_) => "RETURN_VALUE", Object::ReturnValue(_) => "RETURN_VALUE",
Object::Error(_) => "ERROR", Object::Error(_) => "ERROR",
Object::Function(_) => "FUNCTION",
Object::Null => "NULL", Object::Null => "NULL",
}) })
} }
} }
#[derive(Debug, Clone, PartialEq)]
pub struct Function {
parameters: Vec<Identifier>,
body: BlockStatement,
env: Environment,
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::assert_matches::assert_matches;
use crate::{ use crate::{
evaluator::{tree_walker::TreeWalker, Environment, Evaluator, Object, FALSE, NULL, TRUE}, evaluator::{tree_walker::TreeWalker, Environment, Evaluator, Object, FALSE, NULL, TRUE},
lexer::Lexer, lexer::Lexer,
@ -257,4 +294,56 @@ mod tests {
run_test_cases(&test_cases); run_test_cases(&test_cases);
} }
#[test]
fn test_function_object() {
let test_case = "fn(x) { x + 2;};";
let lexer = Lexer::new(&test_case);
let mut parser = Parser::new(lexer);
let program = parser.parse_program();
assert!(program.is_some());
let program = program.unwrap();
let evaluator = TreeWalker::new();
let mut env = Environment::new();
let eval = evaluator.eval(Node::Program(program), &mut env);
let node = eval.unwrap();
assert_matches!(node, Object::Function(_));
if let Object::Function(ref f) = node {
assert_eq!(f.parameters.len(), 1);
assert_eq!(f.parameters.first().unwrap().to_string(), "x");
assert_eq!(f.body.to_string(), "(x + 2)");
}
}
#[test]
fn test_function_application() {
let test_cases = [
(
"let identity = fn(x) { x; }; identity(5);",
Some(Object::Integer(5)),
),
(
"let identity = fn(x) { return x; }; identity(9);",
Some(Object::Integer(9)),
),
(
"let double = fn(x) { return x * 2; }; double(8);",
Some(Object::Integer(16)),
),
(
"let add = fn(x,y) { return x +y; }; add(10, 20);",
Some(Object::Integer(30)),
),
(
"let add = fn(x,y) { return x +y; }; add(20 + 25, add(3+1,2)));",
Some(Object::Integer(51)),
),
("fn(x) { x; }(5)", Some(Object::Integer(5))),
];
run_test_cases(&test_cases);
}
} }

View File

@ -10,7 +10,7 @@ use crate::{
}, },
}; };
use super::Environment; use super::{Environment, Function};
pub struct TreeWalker; pub struct TreeWalker;
@ -34,12 +34,10 @@ impl Evaluator for TreeWalker {
Some(Object::ReturnValue(Box::new(ret_val))) Some(Object::ReturnValue(Box::new(ret_val)))
} }
Statement::Let(LetStatement { name, value }) => { Statement::Let(LetStatement { name, value }) => {
let value = self.eval(Node::Expression(value.unwrap()), env)?; let value = self.eval(Node::Expression(value?), env)?;
env.set(name.to_string(), value.clone()); env.set(name.to_string(), value.clone());
Some(value) Some(value)
} }
_ => None,
}, },
Node::Expression(expr) => match expr { Node::Expression(expr) => match expr {
Expression::Identifier(v) => self.eval_identifier(v, env), Expression::Identifier(v) => self.eval_identifier(v, env),
@ -54,6 +52,24 @@ impl Evaluator for TreeWalker {
let right = self.eval(Node::Expression(*ie.right), env)?; let right = self.eval(Node::Expression(*ie.right), env)?;
self.eval_infix_expression(left, ie.operator, right) self.eval_infix_expression(left, ie.operator, right)
} }
Expression::FunctionExpression(fnl) => {
return Some(Object::Function(Function {
body: fnl.body,
parameters: fnl.parameters,
env: env.clone(),
}))
}
Expression::CallExpression(v) => {
let function = self.eval(Node::Expression(*v.function), env)?;
// Resolve function arguments and update the environment
// before executing function body
let args = match self.eval_expression(v.arguments, env) {
Ok(v) => v,
Err(e) => return Some(e),
};
self.apply_function(function, args)
}
Expression::IfExpression(ie) => { Expression::IfExpression(ie) => {
let condition = self.eval(Node::Expression(*ie.condition), env)?; let condition = self.eval(Node::Expression(*ie.condition), env)?;
@ -199,4 +215,54 @@ impl TreeWalker {
node.to_string() node.to_string()
)))) ))))
} }
fn eval_expression(
&self,
exprs: Vec<Expression>,
env: &mut Environment,
) -> Result<Vec<Object>, Object> {
let mut out = vec![];
for expr in exprs {
match self.eval(Node::Expression(expr), env) {
Some(v @ Object::Error(_)) => return Err(v),
Some(v) => out.push(v),
None => {
break;
}
}
}
Ok(out)
}
fn apply_function(&self, function: Object, args: Vec<Object>) -> Option<Object> {
let function = match function {
Object::Function(f) => f,
v => return Some(Object::Error(format!("not a function: {}", v.to_string()))),
};
let mut enclosed_env = Environment::new_enclosed(function.env);
for (i, parameter) in function.parameters.iter().enumerate() {
if args.len() <= i {
println!("{:?} {}", args, i);
return Some(Object::Error(format!("incorrect number of arguments")));
}
enclosed_env.set(parameter.value.clone(), args[i].clone());
}
let resp = self.eval(
Node::Statement(Statement::BlockStatement(function.body)),
&mut enclosed_env,
);
// Unwrap return here to prevent it from bubbling up the stack
// and stopping execution elsewhere.
if let Some(Object::ReturnValue(v)) = resp.as_ref() {
return Some(*v.clone());
}
resp
}
} }

View File

@ -1,3 +1,5 @@
#![feature(assert_matches)]
#[macro_use] #[macro_use]
extern crate lazy_static; extern crate lazy_static;

View File

@ -34,7 +34,7 @@ impl Display for Program {
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub enum Statement { pub enum Statement {
Let(LetStatement), Let(LetStatement),
Return(ReturnStatement), Return(ReturnStatement),
@ -65,7 +65,7 @@ impl Display for Statement {
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct LetStatement { pub struct LetStatement {
// name field is to store the identifier of the binding // name field is to store the identifier of the binding
pub name: Identifier, pub name: Identifier,
@ -113,7 +113,7 @@ impl Display for LetStatement {
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct ReturnStatement { pub struct ReturnStatement {
pub value: Option<Expression>, pub value: Option<Expression>,
} }
@ -145,7 +145,7 @@ impl Display for ReturnStatement {
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct ExpressionStatement { pub struct ExpressionStatement {
token: Token, token: Token,
pub expression: Expression, pub expression: Expression,
@ -184,7 +184,7 @@ pub enum ExpressionPriority {
Call = 6, Call = 6,
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub enum Expression { pub enum Expression {
Identifier(Identifier), Identifier(Identifier),
IntegerLiteral(IntegerLiteral), IntegerLiteral(IntegerLiteral),
@ -263,10 +263,10 @@ impl From<&Expression> for String {
// Identifier will be an expression // Identifier will be an expression
// Identifier in a let statement like, let x = 5; where `x` is an identifier doesn't produce a value // Identifier in a let statement like, let x = 5; where `x` is an identifier doesn't produce a value
// but an identifier *can* produce value when used on rhs, e.g. let x = y; Here `y` is producing a value // but an identifier *can* produce value when used on rhs, e.g. let x = y; Here `y` is producing a value
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct Identifier { pub struct Identifier {
token: TokenType, token: TokenType,
value: String, pub value: String,
} }
impl Identifier { impl Identifier {
@ -290,7 +290,7 @@ impl Display for Identifier {
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct IntegerLiteral { pub struct IntegerLiteral {
pub value: i64, pub value: i64,
} }
@ -313,7 +313,7 @@ impl IntegerLiteral {
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct PrefixExpression { pub struct PrefixExpression {
pub operator: TokenType, pub operator: TokenType,
pub right: Box<Expression>, pub right: Box<Expression>,
@ -347,7 +347,7 @@ impl Display for PrefixExpression {
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct InfixExpression { pub struct InfixExpression {
pub left: Box<Expression>, pub left: Box<Expression>,
pub operator: TokenType, pub operator: TokenType,
@ -383,7 +383,7 @@ impl Display for InfixExpression {
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct BooleanExpression { pub struct BooleanExpression {
token: TokenType, token: TokenType,
pub value: bool, pub value: bool,
@ -409,7 +409,7 @@ impl Display for BooleanExpression {
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct IfExpression { pub struct IfExpression {
pub condition: Box<Expression>, pub condition: Box<Expression>,
pub consequence: BlockStatement, pub consequence: BlockStatement,
@ -461,7 +461,7 @@ impl Display for IfExpression {
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct BlockStatement { pub struct BlockStatement {
pub statements: Vec<Statement>, pub statements: Vec<Statement>,
} }
@ -501,11 +501,11 @@ impl Display for BlockStatement {
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct FunctionLiteral { pub struct FunctionLiteral {
token: Token, token: Token,
parameters: Vec<Identifier>, pub parameters: Vec<Identifier>,
body: BlockStatement, pub body: BlockStatement,
} }
impl FunctionLiteral { impl FunctionLiteral {
@ -567,10 +567,10 @@ impl Display for FunctionLiteral {
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Clone)]
pub struct CallExpression { pub struct CallExpression {
function: Box<Expression>, pub function: Box<Expression>,
arguments: Vec<Expression>, pub arguments: Vec<Expression>,
} }
impl CallExpression { impl CallExpression {