use std::mem; use std::ops::{BitAnd, Deref}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::{sync_channel, SyncSender}; use std::sync::{Arc, Mutex}; use dyn_clone::{clone_box, DynClone}; use hashbrown::HashMap; use orchid_api_traits::{Coding, Decode, Encode, MsgSet, Request}; use trait_set::trait_set; trait_set! { pub trait SendFn = for<'a> FnMut(&'a [u8], ReqNot) + DynClone + Send + 'static; pub trait ReqFn = FnMut(RequestHandle) + Send + 'static; pub trait NotifFn = for<'a> FnMut(T::InNot, ReqNot) + Send + Sync + 'static; } fn get_id(message: &[u8]) -> (u64, &[u8]) { (u64::from_be_bytes(message[..8].to_vec().try_into().unwrap()), &message[8..]) } pub struct RequestHandle { id: u64, message: T::InReq, send: Box>, parent: ReqNot, fulfilled: AtomicBool, } impl RequestHandle { pub fn reqnot(&self) -> ReqNot { self.parent.clone() } pub fn req(&self) -> &MS::InReq { &self.message } fn respond(&self, response: &impl Encode) { assert!(!self.fulfilled.swap(true, Ordering::Relaxed), "Already responded"); let mut buf = (!self.id).to_be_bytes().to_vec(); response.encode(&mut buf); clone_box(&*self.send)(&buf, self.parent.clone()); } pub fn handle(&self, _: &T, rep: &T::Response) { self.respond(rep) } } impl Drop for RequestHandle { fn drop(&mut self) { debug_assert!(self.fulfilled.load(Ordering::Relaxed), "Request dropped without response") } } pub fn respond_with(r: &R, f: impl FnOnce(&R) -> R::Response) -> Vec { r.respond(f(r)) } pub struct ReqNotData { id: u64, send: Box>, notif: Box>, req: Box>, responses: HashMap>>, } pub struct RawReply(Vec); impl Deref for RawReply { type Target = [u8]; fn deref(&self) -> &Self::Target { get_id(&self.0[..]).1 } } pub struct ReqNot(Arc>>); impl ReqNot { pub fn new(send: impl SendFn, notif: impl NotifFn, req: impl ReqFn) -> Self { Self(Arc::new(Mutex::new(ReqNotData { id: 1, send: Box::new(send), notif: Box::new(notif), req: Box::new(req), responses: HashMap::new(), }))) } /// Can be called from a polling thread or dispatched in any other way pub fn receive(&self, message: Vec) { let mut g = self.0.lock().unwrap(); let (id, payload) = get_id(&message[..]); if id == 0 { (g.notif)(T::InNot::decode(&mut &payload[..]), self.clone()) } else if 0 < id.bitand(1 << 63) { let sender = g.responses.remove(&!id).expect("Received response for invalid message"); sender.send(message).unwrap(); } else { let send = clone_box(&*g.send); let message = T::InReq::decode(&mut &payload[..]); (g.req)(RequestHandle { id, message, send, fulfilled: false.into(), parent: self.clone() }) } } pub fn notify>(&self, notif: N) { let mut send = clone_box(&*self.0.lock().unwrap().send); let mut buf = vec![0; 8]; let msg: T::OutNot = notif.into(); msg.encode(&mut buf); send(&buf, self.clone()) } } pub struct MappedRequester<'a, T>(Box RawReply + Send + Sync + 'a>); impl<'a, T> MappedRequester<'a, T> { fn new(req: U) -> Self where T: Into { MappedRequester(Box::new(move |t| req.raw_request(t.into()))) } } impl<'a, T> DynRequester for MappedRequester<'a, T> { type Transfer = T; fn raw_request(&self, data: Self::Transfer) -> RawReply { self.0(data) } } impl DynRequester for ReqNot { type Transfer = T::OutReq; fn raw_request(&self, req: Self::Transfer) -> RawReply { let mut g = self.0.lock().unwrap(); let id = g.id; g.id += 1; let mut buf = id.to_be_bytes().to_vec(); req.encode(&mut buf); let (send, recv) = sync_channel(1); g.responses.insert(id, send); let mut send = clone_box(&*g.send); mem::drop(g); send(&buf, self.clone()); RawReply(recv.recv().unwrap()) } } pub trait DynRequester: Send + Sync { type Transfer; fn raw_request(&self, data: Self::Transfer) -> RawReply; } pub trait Requester: DynRequester { #[must_use = "These types are subject to change with protocol versions. \ If you don't want to use the return value, At a minimum, force the type."] fn request>(&self, data: R) -> R::Response; fn map<'a, U: Into>(self) -> MappedRequester<'a, U> where Self: Sized + 'a { MappedRequester::new(self) } } impl<'a, This: DynRequester + ?Sized + 'a> Requester for This { fn request>(&self, data: R) -> R::Response { R::Response::decode(&mut &self.raw_request(data.into())[..]) } } impl Clone for ReqNot { fn clone(&self) -> Self { Self(self.0.clone()) } } #[cfg(test)] mod test { use std::sync::{Arc, Mutex}; use orchid_api_derive::Coding; use orchid_api_traits::Request; use super::{MsgSet, ReqNot}; use crate::{clone, reqnot::Requester as _}; #[derive(Coding, Debug, PartialEq)] pub struct TestReq(u8); impl Request for TestReq { type Response = u8; } pub struct TestMsgSet; impl MsgSet for TestMsgSet { type InNot = u8; type InReq = TestReq; type OutNot = u8; type OutReq = TestReq; } #[test] fn notification() { let received = Arc::new(Mutex::new(None)); let receiver = ReqNot::::new( |_, _| panic!("Should not send anything"), clone!(received; move |notif, _| *received.lock().unwrap() = Some(notif)), |_| panic!("Not receiving a request"), ); let sender = ReqNot::::new( clone!(receiver; move |d, _| receiver.receive(d.to_vec())), |_, _| panic!("Should not receive notif"), |_| panic!("Should not receive request"), ); sender.notify(3); assert_eq!(*received.lock().unwrap(), Some(3)); sender.notify(4); assert_eq!(*received.lock().unwrap(), Some(4)); } #[test] fn request() { let receiver = Arc::new(Mutex::>>::new(None)); let sender = Arc::new(ReqNot::::new( { let receiver = receiver.clone(); move |d, _| receiver.lock().unwrap().as_ref().unwrap().receive(d.to_vec()) }, |_, _| panic!("Should not receive notif"), |_| panic!("Should not receive request"), )); *receiver.lock().unwrap() = Some(ReqNot::new( { let sender = sender.clone(); move |d, _| sender.receive(d.to_vec()) }, |_, _| panic!("Not receiving notifs"), |req| { assert_eq!(req.req(), &TestReq(5)); req.respond(&6u8); }, )); let response = sender.request(TestReq(5)); assert_eq!(response, 6); } }