diff --git a/Cargo.lock b/Cargo.lock index 3641f73..3b62c5c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -80,12 +80,102 @@ name = "distributed-systems-flyio" version = "0.1.0" dependencies = [ "dashmap", + "futures", "rand", "serde", "serde_json", "tokio", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "getrandom" version = "0.2.11" @@ -217,6 +307,12 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -338,6 +434,15 @@ dependencies = [ "libc", ] +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + [[package]] name = "smallvec" version = "1.13.2" diff --git a/Cargo.toml b/Cargo.toml index a6d8872..06fbc42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] dashmap = "5.5.3" +futures = "0.3.30" rand = { version = "0.8.5", features = ["small_rng"] } serde = { version = "1.0.193", features = ["serde_derive"] } serde_json = "1.0.109" diff --git a/src/lib.rs b/src/lib.rs index 2ba28aa..1e3bbc5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,59 +2,73 @@ #![feature(hash_set_entry)] #![feature(trait_alias)] -pub mod seq_kv; pub mod types; -use crate::{seq_kv::MonotonicCounter, types::Message}; -use rand::{rngs::SmallRng, Rng, SeedableRng}; +use crate::types::{Message, MessageBody, SeqKvInput}; +use futures::future::BoxFuture; +use rand::{rngs::SmallRng, SeedableRng}; use serde_json::Value; use std::{ collections::HashMap, - io::{stdin, stdout, BufRead, Error as IoError, Stdin, Stdout, Write}, + future::Future, + io::{stdin, stdout, BufRead, Error as IoError, Stdout, Write}, + pin::Pin, sync::{Arc, Mutex, RwLock}, - time::{SystemTime, UNIX_EPOCH}, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; +use tokio::{ + runtime::Handle, + sync::oneshot::{self, Receiver, Sender}, }; -use tokio::runtime::Runtime; -pub trait Handler = Fn(Message, Arc>, Arc>) -> Result<(), String> - + Send - + Sync - + 'static; +trait Callback = FnOnce(Message) -> Result<(), String> + Send + Sync + 'static; -pub struct Malestorm { +type Handler = Arc< + dyn Fn( + Message, + Arc>, + Arc>, + ) -> BoxFuture<'static, Result<(), String>> + + Send + + Sync, +>; + +pub struct Maelstorm { + mutex: Mutex<()>, pub node: Node, - pub handlers: HashMap>, + pub handlers: HashMap, + callbacks: HashMap>, } pub struct Node { - id: String, - nodes: Vec, - counter: MonotonicCounter, + pub id: String, + pub nodes: Vec, rng: SmallRng, } -pub struct MalestormIo { +pub struct MaelstormIo { stdout: Stdout, } -impl Default for Malestorm { +impl Default for Maelstorm { fn default() -> Self { let seed = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("time went backwards??"); Self { + mutex: Mutex::new(()), node: Node { id: String::new(), nodes: Vec::new(), - counter: MonotonicCounter::new(), rng: SmallRng::seed_from_u64(seed.as_secs()), }, handlers: HashMap::new(), + callbacks: HashMap::new(), } } } -impl MalestormIo { +impl MaelstormIo { fn write(&self, buf: &[u8]) -> Result<(), IoError> { let mut writer = self.stdout.lock(); writer.write_all(buf)?; @@ -63,14 +77,14 @@ impl MalestormIo { } } -impl Default for MalestormIo { +impl Default for MaelstormIo { fn default() -> Self { Self { stdout: stdout() } } } -impl Malestorm { - pub fn run(self, runtime: Runtime, io: MalestormIo) { +impl Maelstorm { + pub async fn run(self, io: MaelstormIo) { let program = Arc::new(RwLock::new(self)); let io = Arc::new(Mutex::new(io)); let stdin = stdin(); @@ -101,12 +115,29 @@ impl Malestorm { } }; - let mtype = message.body.message_type.clone(); + if let Some(reply_msg_id) = message.body.in_reply_to { + let mut program = program.write().unwrap(); - let handler: Arc = { + let callback = match program.callbacks.remove(&reply_msg_id) { + Some(v) => v, + None => { + eprintln!("no callback for msg with reply id {}", reply_msg_id); + continue; + } + }; + + if let Err(e) = callback(message) + /* add await*/ + { + eprintln!("error in callback: {e}"); + } + continue; + } + + let handler = { let program = program.read().unwrap(); - match program.handlers.get(&mtype) { + match program.handlers.get(&message.body.message_type) { Some(v) => v.clone(), None => { //eprintln!("no handler found for {}", message.body.message_type); @@ -117,42 +148,24 @@ impl Malestorm { let pc = program.clone(); let io = io.clone(); - runtime.spawn(async move { - if let Err(e) = handler(message, pc, io) { + tokio::spawn(async move { + if let Err(e) = handler(message, pc, io).await { eprintln!("error in serving request: {}", e); } }); } } - pub fn set_node_id(&mut self, node_id: String) { - self.node.id = node_id; - } - - pub fn set_nodes(&mut self, nodes: Vec) { + pub fn init(&mut self, node_id: String, nodes: Vec) { self.node.nodes = nodes; - } - - pub fn get_nodes(&self) -> Vec { - self.node.nodes.clone() - } - - pub fn read_counter(&self, src: &str) -> u64 { - self.node.counter.read(src) - } - - pub fn write_counter(&mut self, src: &str, v: u64) { - self.node.counter.write(src, v); - } - pub fn add_counter(&mut self, src: &str, v: u64) { - self.node.counter.add(src, v); + self.node.id = node_id; } pub fn sync( &self, dest: &str, prev_msg: Message, - io: Arc>, + io: Arc>, ) -> Result<(), IoError> { let msg = Message { id: Some(prev_msg.id.unwrap_or(0) + 1), @@ -167,32 +180,91 @@ impl Malestorm { }; let out = serde_json::to_vec(&msg)?; - eprintln!("wrote = {:?}", String::from_utf8_lossy(&out)); let io = io.lock().unwrap(); io.write(&out)?; Ok(()) } - pub fn generate_client_id(&mut self) -> String { - let s2: u64 = self.node.rng.gen(); - let s1 = &self.node.id; - format!("{}_{}", s1, s2) + pub fn register(&mut self, name: &str, func: F) + where + F: Fn(Message, Arc>, Arc>) -> Fut + + Send + + Sync + + 'static, + Fut: Future> + Send + 'static, + { + self.handlers.insert( + name.to_string(), + Arc::new(move |a, b, c| Box::pin(func(a, b, c))), + ); } - pub fn register(&mut self, name: &str, func: impl Handler) { - self.handlers.insert(name.to_string(), Arc::new(func)); + pub async fn sync_rpc( + &mut self, + io: Arc>, + msg: Message, + ) -> Result { + let (tx, rx): (Sender, Receiver) = oneshot::channel(); + + self.rpc( + io, + msg, + Box::new(|msg| { + tx.send(msg) + .map_err(|e| format!("error in sending to tx chan: {:?}", e))?; + Ok(()) + }), + ) + .map_err(|e| e.to_string())?; + + match tokio::time::timeout(Duration::from_secs(2), rx).await { + Ok(result) => match result { + Ok(v) => Ok(v), + Err(e) => { + eprintln!("sync callback error: {}", e); + Err(e.to_string()) + } + }, + Err(e) => { + eprintln!("sync callback timeout: {}", e); + Err(e.to_string()) + } + } } - pub fn send(&self, io: Arc>, mut msg: Message) -> Result<(), IoError> { + fn rpc( + &mut self, + io: Arc>, + mut msg: Message, + handler: impl Callback, + ) -> Result<(), IoError> { + let next_msg_id = msg.body.msg_id.unwrap() + 1; + self.callbacks.insert(next_msg_id, Box::new(handler)); + + msg.body.msg_id = Some(next_msg_id); + + self.send(io, msg.clone(), &msg.src) + } + + pub fn reply(&mut self, io: Arc>, mut msg: Message) -> Result<(), IoError> { + msg.body.in_reply_to = msg.body.msg_id; + msg.body.msg_id = None; + self.send(io, msg.clone(), &msg.src) + } + + fn send( + &self, + io: Arc>, + mut msg: Message, + dst: &str, + ) -> Result<(), IoError> { // Before replying, Swap src / dst in original message // Add the correct value for in_reply_to - std::mem::swap(&mut msg.src, &mut msg.dest); + msg.dest = dst.to_string(); msg.src = self.node.id.clone(); - msg.body.in_reply_to = msg.body.msg_id; - let out = serde_json::to_vec(&msg)?; let io = io.lock().unwrap(); @@ -204,9 +276,62 @@ impl Malestorm { } } -async fn parse_request_and_handle( - program: Arc>, - io: Arc>, - buf: &[u8], -) { -} +//impl Maelstorm { +// pub fn read_counter( +// &mut self, +// io: Arc>, +// store: &str, +// node_to_read: String, +// ) -> Result { +// let msg = Message { +// id: None, +// src: self.node.id.clone(), +// dest: store.to_string(), +// body: MessageBody { +// msg_id: Some(0), +// in_reply_to: None, +// message_type: "read".to_string(), +// message_body: serde_json::to_value(SeqKvInput { +// key: node_to_read, +// value: None, +// }) +// .unwrap(), +// }, +// }; +// let mut result = self.sync_rpc(io, msg)?; +// +// let body: SeqKvInput = serde_json::from_value(result.body.message_body.take()) +// .map_err(|e| format!("error in parsing response body: {}", e))?; +// +// Ok(body.value.unwrap()) +// } +// +// pub fn write_counter( +// &mut self, +// io: Arc>, +// store: &str, +// val: u64, +// ) -> Result<(), String> { +// let msg = Message { +// id: None, +// src: self.node.id.clone(), +// dest: store.to_string(), +// body: MessageBody { +// msg_id: Some(0), +// in_reply_to: None, +// message_type: "write".to_string(), +// message_body: serde_json::to_value(SeqKvInput { +// key: self.node.id.clone(), +// value: Some(val), +// }) +// .unwrap(), +// }, +// }; +// let mut result = self.sync_rpc(io, msg)?; +// +// let body: SeqKvInput = serde_json::from_value(result.body.message_body.take()) +// .map_err(|e| format!("error in parsing response body: {}", e))?; +// +// Ok(()) +// } +//} diff --git a/src/main.rs b/src/main.rs index ed7ed00..e26d49a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,108 +2,100 @@ use distributed_systems_flyio::{ types::{GrowCounterReadMessage, GrowCounterUpdateMessage, InitMessage, Message}, - Malestorm, MalestormIo, + Maelstorm, MaelstormIo, }; use std::sync::{Arc, Mutex, RwLock}; -use tokio::runtime::Runtime; -fn main() { - let mut program = Malestorm::default(); +async fn test( + msg: Message, + program: Arc>, + io: Arc>, +) -> Result<(), String> { + Ok(()) +} + +#[tokio::main] +async fn main() { + let mut program = Maelstorm::default(); program.register( "init", // TODO: Replace error string with a type - |mut msg: Message, - program: Arc>, - io: Arc>| - -> Result<(), String> { + async |mut msg: Message, + program: Arc>, + io: Arc>| + -> Result<(), String> { let message_body: InitMessage = serde_json::from_value(msg.body.message_body.take()) .map_err(|e| format!("error in parsing response body: {}", e))?; let mut program = program.write().unwrap(); - program.set_node_id(message_body.node_id); - program.set_nodes(message_body.nodes); + program.init(message_body.node_id, message_body.nodes); msg.body.message_type = "init_ok".into(); program - .send(io, msg) + .reply(io, msg) .map_err(|e| format!("init: error in writing response: {}", e)) }, ); - program.register( - "add", - |mut msg: Message, - program: Arc>, - io: Arc>| - -> Result<(), String> { - let body: GrowCounterUpdateMessage = - serde_json::from_value(msg.body.message_body.take()) - .map_err(|e| format!("error in parsing response body: {}", e))?; - msg.body.message_type = "add_ok".into(); + // program.register( + // "add", + // async |mut msg: Message, + // program: Arc>, + // io: Arc>| + // -> Result<(), String> { + // let body: GrowCounterUpdateMessage = + // serde_json::from_value(msg.body.message_body.take()) + // .map_err(|e| format!("error in parsing response body: {}", e))?; - let mut program = program.write().unwrap(); - program.add_counter(&msg.src, body.delta); + // msg.body.message_type = "add_ok".into(); - program - .send(io, msg.clone()) - .map_err(|e| format!("add: error in writing response: {}", e)) - }, - ); + // let mut program = program.write().unwrap(); - program.register( - "read", - |mut msg: Message, - program: Arc>, - io: Arc>| - -> Result<(), String> { - msg.body.message_type = "read_ok".into(); + // let id = program.node.id.clone(); - let program = program.read().unwrap(); - let mut sum = 0; - for node in program.get_nodes() { - if node == msg.src { - sum += program.read_counter(&node); - continue; - } - // Sync first then add + // // let current = program + // // .read_counter(io.clone(), "seq-kv", id) + // // .expect("error in reading value"); - program.sync(&node, msg.clone(), io.clone()); + // // program.write_counter(io.clone(), "seq-kv", current + body.delta)?; - sum += program.read_counter(&msg.src); - } + // program + // .reply(io, msg) + // .map_err(|e| format!("add: error in writing response: {}", e)) + // }, + // ); - msg.body.message_body = - serde_json::to_value(GrowCounterReadMessage { value: sum }).unwrap(); + // program.register( + // "read", + // async |mut msg: Message, + // program: Arc>, + // io: Arc>| + // -> Result<(), String> { + // msg.body.message_type = "read_ok".into(); - program - .send(io, msg.clone()) - .map_err(|e| format!("read: error in writing response: {}", e)) - }, - ); + // let mut program = program.write().unwrap(); + // let mut sum = 0; + // for node in program.node.nodes.clone() { + // if *node == msg.src { + // sum += program.read_counter(io.clone(), "seq-kv", node)?; + // continue; + // } + // // Sync first then add - program.register( - "counter_sync", - |mut msg: Message, - program: Arc>, - io: Arc>| - -> Result<(), String> { - msg.body.message_type = "counter_sync_ok".into(); + // sum += program.read_counter(io.clone(), "seq-kv", node)?; + // } - let program = program.read().unwrap(); - msg.body.message_body = serde_json::to_value(GrowCounterReadMessage { - value: program.read_counter(&msg.src), - }) - .unwrap(); + // msg.body.message_body = + // serde_json::to_value(GrowCounterReadMessage { value: sum }).unwrap(); - program - .send(io, msg.clone()) - .map_err(|e| format!("read: error in writing response: {}", e)) - }, - ); + // program + // .reply(io, msg.clone()) + // .map_err(|e| format!("read: error in writing response: {}", e)) + // }, + // ); - let io = MalestormIo::default(); - let runtime = Runtime::new().unwrap(); + let io = MaelstormIo::default(); - program.run(runtime, io); + program.run(io).await; } diff --git a/src/seq_kv.rs b/src/seq_kv.rs deleted file mode 100644 index 17a7182..0000000 --- a/src/seq_kv.rs +++ /dev/null @@ -1,28 +0,0 @@ -use std::collections::HashMap; - -#[derive(Debug)] -pub struct MonotonicCounter { - counter: HashMap, -} - -impl MonotonicCounter { - pub fn new() -> Self { - Self { - counter: HashMap::new(), - } - } - - pub fn read(&self, node: &str) -> u64 { - *self.counter.get(node).unwrap_or(&0) - } - - pub fn write(&mut self, node: &str, v: u64) { - *self.counter.entry(node.to_string()).or_insert(0) = v; - } - - pub fn add(&mut self, node: &str, v: u64) { - eprintln!("{:?}", self.counter); - - *self.counter.entry(node.to_string()).or_insert(0) += v; - } -} diff --git a/src/types.rs b/src/types.rs index f4d4d45..415f51b 100644 --- a/src/types.rs +++ b/src/types.rs @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Clone, Deserialize)] #[serde(deny_unknown_fields)] pub struct Message { + #[serde(skip_serializing_if = "Option::is_none")] pub id: Option, pub src: String, pub dest: String, @@ -11,7 +12,9 @@ pub struct Message { #[derive(Debug, Serialize, Clone, Deserialize)] pub struct MessageBody { + #[serde(skip_serializing_if = "Option::is_none")] pub msg_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub in_reply_to: Option, #[serde(rename = "type")] @@ -41,50 +44,10 @@ pub struct GrowCounterReadMessage { pub value: u64, } -#[cfg(test)] -mod test { - - use super::*; - - pub fn extraneous_fields_fail() { - let body = "{ - \"src\": \"c1\", - \"dest\": \"n1\", - \"body\": { - \"type\": \"echo\", - \"msg_id\": 1, - \"echo\": \"Please echo 35\", - \"extraneous\": \"hi\" - } -}"; - - let resp: Result = serde_json::from_str(body); - - if resp.is_ok() { - eprintln!("successfully parsed into Message {:#?}", resp); - } - - assert!(resp.is_err()); - } - - #[test] - pub fn parse_echo() { - let body = "{ - \"src\": \"c1\", - \"dest\": \"n1\", - \"body\": { - \"type\": \"echo\", - \"msg_id\": 1, - \"echo\": \"Please echo 35\" - } -}"; - - let resp: Result = serde_json::from_str(body); - - if resp.is_err() { - eprintln!("failed to parsed into Echo Message {:#?}", resp); - } - - assert!(resp.is_ok()); - } +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct SeqKvInput { + pub key: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, }