diff --git a/src/lib.rs b/src/lib.rs index c896bd5..437cf9b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,16 +4,18 @@ pub mod types; -use crate::types::{Message, MessageBody, SeqKvInput}; +use crate::types::{Message, MessageBody}; use futures::future::BoxFuture; use rand::{rngs::SmallRng, Rng, SeedableRng}; use std::{ collections::HashMap, future::Future, io::{stdin, stdout, BufRead, Error as IoError, Stdout, Write}, + ops::{Deref, DerefMut}, sync::Arc, - time::{Duration, Instant, SystemTime, UNIX_EPOCH}, + time::{Duration, SystemTime, UNIX_EPOCH}, }; +use types::NodeType; use tokio::sync::{ oneshot::{self, Receiver, Sender}, @@ -27,6 +29,7 @@ type Handler = Arc< Message, Arc>, Arc>, + Arc>, ) -> BoxFuture<'static, Result<(), String>> + Send + Sync, @@ -35,13 +38,11 @@ type Handler = Arc< pub struct Maelstorm { pub node: Node, pub handlers: HashMap, - callbacks: HashMap, } pub struct Node { pub id: String, pub counter: u64, - pub other_counter: u64, pub nodes: Vec, rng: SmallRng, } @@ -49,6 +50,27 @@ pub struct MaelstormIo { stdout: Stdout, } +pub struct MaelstormCallbacks(HashMap); + +impl Default for MaelstormCallbacks { + fn default() -> Self { + Self(HashMap::new()) + } +} + +impl Deref for MaelstormCallbacks { + type Target = HashMap; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for MaelstormCallbacks { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + impl Default for Maelstorm { fn default() -> Self { let seed = SystemTime::now() @@ -58,12 +80,10 @@ impl Default for Maelstorm { node: Node { id: String::new(), counter: 0, - other_counter: 0, nodes: Vec::new(), rng: SmallRng::seed_from_u64(seed.as_nanos() as u64), }, handlers: HashMap::new(), - callbacks: HashMap::new(), } } } @@ -84,8 +104,9 @@ impl Default for MaelstormIo { } impl Maelstorm { - pub async fn run(self, io: MaelstormIo) { + pub async fn run(self, io: MaelstormIo, callbacks: MaelstormCallbacks) { let program = Arc::new(RwLock::new(self)); + let callbacks = Arc::new(RwLock::new(callbacks)); let io = Arc::new(Mutex::new(io)); let stdin = stdin(); @@ -115,13 +136,9 @@ impl Maelstorm { } }; - eprintln!("INCOMING MESSAGE = {:?}", message); - if let Some(reply_msg_id) = message.body.in_reply_to { - eprintln!("REPLY ID FOR CALLBACK FOUND = {:?}", message); - let mut program = program.write().await; - - let callback = match program.callbacks.remove(&reply_msg_id) { + let mut callbacks = callbacks.write().await; + let callback = match callbacks.remove(&reply_msg_id) { Some(v) => v, None => { eprintln!("no callback for msg with reply id {}", reply_msg_id); @@ -131,8 +148,6 @@ impl Maelstorm { if let Err(e) = callback(message.clone()).await { eprintln!("error in callback: {e}"); - } else { - eprintln!("SUCCESSFULLY CALLED CALLBACK FOR = {:?}", message); } continue; @@ -152,8 +167,9 @@ impl Maelstorm { let pc = program.clone(); let io = io.clone(); + let c = callbacks.clone(); tokio::spawn(async move { - if let Err(e) = handler(message, pc, io).await { + if let Err(e) = handler(message, pc, io, c).await { eprintln!("error in serving request: {}", e); } }); @@ -167,7 +183,12 @@ impl Maelstorm { pub fn register(&mut self, name: &str, func: F) where - F: Fn(Message, Arc>, Arc>) -> Fut + F: Fn( + Message, + Arc>, + Arc>, + Arc>, + ) -> Fut + Send + Sync + 'static, @@ -175,55 +196,48 @@ impl Maelstorm { { self.handlers.insert( name.to_string(), - Arc::new(move |a, b, c| Box::pin(func(a, b, c))), + Arc::new(move |a, b, c, d| Box::pin(func(a, b, c, d))), ); } pub async fn sync_rpc( - &mut self, + &self, io: Arc>, + callbacks: Arc>, msg: Message, ) -> Result { let (tx, rx): (Sender, Receiver) = oneshot::channel(); let m = msg.clone(); - self.rpc(io, msg, async |msg: Message| match tx.send(msg.clone()) { - Ok(v) => { - eprintln!("SENT RESPONSE INTO CHANNEL = {:?}", msg); - - Ok(()) - } - Err(e) => { - format!("error in sending to tx chan: {:?}", e); - Ok(()) - } - }) + self.rpc( + io.clone(), + callbacks.clone(), + msg, + async |msg: Message| match tx.send(msg.clone()) { + Ok(_) => Ok(()), + Err(e) => { + format!("error in sending to tx chan: {:?}", e); + Ok(()) + } + }, + ) .await .map_err(|e| e.to_string())?; - let t = Instant::now(); match tokio::time::timeout(Duration::from_secs(1), rx).await { - Ok(result) => { - eprintln!( - "GOT RESULT BEFORE TIMEOUT = {:?} elapsed = {:?}", - result, - t.elapsed() - ); - - match result { - Ok(v) => Ok(v), - Err(e) => { - eprintln!("sync callback error: {}", e); - Err(e.to_string()) - } + Ok(result) => match result { + Ok(v) => Ok(v), + Err(e) => { + eprintln!("sync callback error: {}", e); + Err(e.to_string()) } - } + }, Err(e) => { eprintln!( "sync callback timeout: {} msg = {:?} callback queue {:?}", e, m, - self.callbacks.keys() + callbacks.read().await.keys() ); Err(e.to_string()) } @@ -231,8 +245,9 @@ impl Maelstorm { } async fn rpc( - &mut self, + &self, io: Arc>, + callbacks: Arc>, mut msg: Message, handler: F, ) -> Result<(), IoError> @@ -242,9 +257,8 @@ impl Maelstorm { { let next_msg_id = msg.body.msg_id.unwrap(); - eprintln!("INSERTING {} {:?}", next_msg_id, msg); - self.callbacks - .insert(next_msg_id, Box::new(move |a| Box::pin(handler(a)))); + let mut callbacks = callbacks.write().await; + callbacks.insert(next_msg_id, Box::new(move |a| Box::pin(handler(a)))); msg.body.msg_id = Some(next_msg_id); @@ -279,7 +293,7 @@ impl Maelstorm { io.write(&out).map(|_| Ok(()))? } - pub fn id(&self) -> String { + pub async fn id(&self) -> String { self.node.id.clone() } } @@ -288,10 +302,14 @@ impl Maelstorm { pub async fn read_counter( &mut self, io: Arc>, + callbacks: Arc>, node_to_read: String, ) -> Result { let msg_id: u64 = self.node.rng.gen(); + let mut value = HashMap::new(); + value.insert("key".to_string(), NodeType::String(node_to_read)); + let msg = Message { id: None, src: self.node.id.clone(), @@ -300,31 +318,31 @@ impl Maelstorm { msg_id: Some(msg_id), in_reply_to: None, message_type: "read".to_string(), - message_body: serde_json::to_value(SeqKvInput { - key: node_to_read, - value: None, - }) - .unwrap(), + message_body: value, }, }; - let mut result = match self.sync_rpc(io, msg).await { + let result = match self.sync_rpc(io, callbacks, msg).await { Ok(v) => v, Err(e) if e.to_string() == "deadline has elapsed" => return Ok(0), Err(e) => return Err(e), }; - let body: SeqKvInput = serde_json::from_value(result.body.message_body.take()) - .map_err(|e| format!("error in parsing response body: {}", e))?; - - Ok(body.value.unwrap()) + match result.body.message_body.get("value") { + None => Ok(0), + Some(v) => Ok(v.u64()), + } } pub async fn write_counter( &mut self, io: Arc>, + callbacks: Arc>, val: u64, ) -> Result<(), String> { let msg_id: u64 = self.node.rng.gen(); + let mut value = HashMap::new(); + value.insert("key".to_string(), NodeType::String(self.node.id.clone())); + value.insert("value".to_string(), NodeType::U64(val)); let msg = Message { id: None, @@ -334,23 +352,15 @@ impl Maelstorm { msg_id: Some(msg_id), in_reply_to: None, message_type: "write".to_string(), - message_body: serde_json::to_value(SeqKvInput { - key: self.node.id.clone(), - value: Some(val), - }) - .unwrap(), + message_body: value, }, }; - let mut result = match self.sync_rpc(io, msg).await { + let _result = match self.sync_rpc(io, callbacks, msg).await { Ok(v) => v, Err(e) if e.to_string() == "deadline has elapsed" => return Ok(()), Err(e) => return Err(e), }; - // TODO: could return the parsed confirmation from here - serde_json::from_value(result.body.message_body.take()) - .map_err(|e| format!("error in parsing response body: {}", e))?; - Ok(()) } } diff --git a/src/main.rs b/src/main.rs index 6dab5a5..151fee8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,10 @@ #![feature(async_closure)] use distributed_systems_flyio::{ - types::{GrowCounterReadMessage, GrowCounterUpdateMessage, InitMessage, Message}, - Maelstorm, MaelstormIo, + types::{Message, NodeType}, + Maelstorm, MaelstormCallbacks, MaelstormIo, }; -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use tokio::sync::{Mutex, RwLock}; #[tokio::main] @@ -16,13 +16,14 @@ async fn main() { // TODO: Replace error string with a type async |mut msg: Message, program: Arc>, - io: Arc>| + io: Arc>, + _: Arc>| -> Result<(), String> { - let message_body: InitMessage = serde_json::from_value(msg.body.message_body.take()) - .map_err(|e| format!("error in parsing response body: {}", e))?; - let mut program = program.write().await; - program.init(message_body.node_id, message_body.nodes); + + let node_id = msg.body.message_body.get("node_id").unwrap().string(); + let nodes = msg.body.message_body.get("node_ids").unwrap().vec_string(); + program.init(node_id, nodes); msg.body.message_type = "init_ok".into(); program @@ -36,31 +37,20 @@ async fn main() { "add", async |mut msg: Message, program: Arc>, - io: Arc>| + io: Arc>, + callbacks: Arc>| -> Result<(), String> { - let body: GrowCounterUpdateMessage = - serde_json::from_value(msg.body.message_body.take()) - .map_err(|e| format!("error in parsing response body: {}", e))?; - msg.body.message_type = "add_ok".into(); let mut program = program.write().await; - program.node.counter += body.delta; + program.node.counter += msg.body.message_body.get("delta").unwrap().u64(); let current = program.node.counter; - program.write_counter(io.clone(), current).await?; - - let mut other_counter = 0; - - for node in program.node.nodes.clone() { - if node == msg.dest { - continue; - } - let resp = program.read_counter(io.clone(), node).await?; - other_counter += resp; - } - program.node.other_counter = other_counter; + program + .write_counter(io.clone(), callbacks.clone(), current) + .await?; + msg.body.message_body = HashMap::new(); program .reply(io, msg) .await @@ -72,30 +62,32 @@ async fn main() { "read", async |mut msg: Message, program: Arc>, - io: Arc>| + io: Arc>, + callbacks: Arc>| -> Result<(), String> { msg.body.message_type = "read_ok".into(); - let mut program = program.write().await; let mut other_counter = 0; + let mut program = program.write().await; + for node in program.node.nodes.clone() { if node == msg.dest { continue; } - let resp = program.read_counter(io.clone(), node).await?; + let resp = program + .read_counter(io.clone(), callbacks.clone(), node) + .await?; other_counter += resp; } - program.node.other_counter = other_counter; - eprintln!( - "READ OP {} {}", - program.node.counter, program.node.other_counter + let mut resp = HashMap::new(); + resp.insert( + "value".to_string(), + NodeType::U64(program.node.counter + other_counter), ); - let resp = GrowCounterReadMessage { - value: program.node.counter + program.node.other_counter, - }; - msg.body.message_body = serde_json::to_value(resp).unwrap(); + + msg.body.message_body = resp; program .reply(io, msg) @@ -104,7 +96,8 @@ async fn main() { }, ); + let callbacks = MaelstormCallbacks::default(); let io = MaelstormIo::default(); - program.run(io).await; + program.run(io, callbacks).await; } diff --git a/src/types.rs b/src/types.rs index d0b002c..984e978 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Clone, Deserialize)] @@ -21,33 +23,39 @@ pub struct MessageBody { pub message_type: String, #[serde(flatten)] - pub message_body: serde_json::Value, + pub message_body: HashMap, } -#[derive(Debug, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct InitMessage { - pub node_id: String, - #[serde(rename = "node_ids")] - pub nodes: Vec, +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(untagged)] +pub enum NodeType { + String(String), + U64(u64), + VecString(Vec), } -#[derive(Debug, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct GrowCounterUpdateMessage { - pub delta: u64, -} +impl NodeType { + pub fn u64(&self) -> u64 { + if let NodeType::U64(v) = self { + return *v; + } else { + unreachable!() + } + } -#[derive(Debug, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct GrowCounterReadMessage { - pub value: u64, -} + pub fn string(&self) -> String { + if let NodeType::String(s) = self { + return s.clone(); + } else { + unreachable!() + } + } -#[derive(Debug, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct SeqKvInput { - pub key: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub value: Option, + pub fn vec_string(&self) -> Vec { + if let NodeType::VecString(s) = self { + return s.clone(); + } else { + unreachable!() + } + } }