use std::any::Any; use std::cell::RefCell; use std::future::Future; use std::marker::PhantomData; use std::mem; use std::ops::{BitAnd, Deref}; use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use derive_destructure::destructure; use dyn_clone::{DynClone, clone_box}; use futures::channel::mpsc; use futures::future::LocalBoxFuture; use futures::lock::Mutex; use futures::{SinkExt, StreamExt}; use hashbrown::HashMap; use orchid_api_traits::{Channel, Coding, Decode, Encode, MsgSet, Request}; use trait_set::trait_set; use crate::clone; use crate::logging::Logger; pub struct Receipt<'a>(PhantomData<&'a mut ()>); trait_set! { pub trait SendFn = for<'a> FnMut(&'a [u8], ReqNot) -> LocalBoxFuture<'a, ()> + DynClone + 'static; pub trait ReqFn = for<'a> FnMut(RequestHandle<'a, T>, ::Req) -> LocalBoxFuture<'a, Receipt<'a>> + DynClone + 'static; pub trait NotifFn = FnMut(::Notif, ReqNot) -> LocalBoxFuture<'static, ()> + DynClone + 'static; } fn get_id(message: &[u8]) -> (u64, &[u8]) { (u64::from_be_bytes(message[..8].to_vec().try_into().unwrap()), &message[8..]) } pub trait ReqHandlish { fn defer_drop(&self, val: impl Any + 'static) where Self: Sized { self.defer_drop_objsafe(Box::new(val)); } fn defer_drop_objsafe(&self, val: Box); } impl ReqHandlish for &'_ dyn ReqHandlish { fn defer_drop_objsafe(&self, val: Box) { (**self).defer_drop_objsafe(val) } } #[derive(destructure)] pub struct RequestHandle<'a, MS: MsgSet> { defer_drop: RefCell>>, fulfilled: AtomicBool, id: u64, _reqlt: PhantomData<&'a mut ()>, parent: ReqNot, } impl<'a, MS: MsgSet + 'static> RequestHandle<'a, MS> { fn new(parent: ReqNot, id: u64) -> Self { Self { defer_drop: RefCell::default(), fulfilled: false.into(), _reqlt: PhantomData, parent, id, } } pub fn reqnot(&self) -> ReqNot { self.parent.clone() } pub async fn handle(&self, _: &U, rep: &U::Response) -> Receipt<'a> { self.respond(rep).await } pub fn will_handle_as(&self, _: &U) -> ReqTypToken { ReqTypToken(PhantomData) } pub async fn handle_as(&self, _: ReqTypToken, rep: &U::Response) -> Receipt<'a> { self.respond(rep).await } pub async fn respond(&self, response: &impl Encode) -> Receipt<'a> { assert!(!self.fulfilled.swap(true, Ordering::Relaxed), "Already responded to {}", self.id); let mut buf = (!self.id).to_be_bytes().to_vec(); response.encode(Pin::new(&mut buf)).await; let mut send = clone_box(&*self.reqnot().0.lock().await.send); (send)(&buf, self.parent.clone()).await; Receipt(PhantomData) } } impl ReqHandlish for RequestHandle<'_, MS> { fn defer_drop_objsafe(&self, val: Box) { self.defer_drop.borrow_mut().push(val); } } impl Drop for RequestHandle<'_, MS> { fn drop(&mut self) { let done = self.fulfilled.load(Ordering::Relaxed); debug_assert!(done, "Request {} dropped without response", self.id) } } pub struct ReqTypToken(PhantomData); pub struct ReqNotData { id: u64, send: Box>, notif: Box>, req: Box>, responses: HashMap>>, } /// Wraps a raw message buffer to save on copying. /// Dereferences to the tail of the message buffer, cutting off the ID #[derive(Debug, Clone)] pub struct RawReply(Vec); impl Deref for RawReply { type Target = [u8]; fn deref(&self) -> &Self::Target { get_id(&self.0[..]).1 } } pub struct ReqNot(Arc>>, Logger); impl ReqNot { pub fn new( logger: Logger, send: impl SendFn, notif: impl NotifFn, req: impl ReqFn, ) -> Self { Self( Arc::new(Mutex::new(ReqNotData { id: 1, send: Box::new(send), notif: Box::new(notif), req: Box::new(req), responses: HashMap::new(), })), logger, ) } /// Can be called from a polling thread or dispatched in any other way pub async fn receive(&self, message: &[u8]) { let mut g = self.0.lock().await; let (id, payload) = get_id(message); if id == 0 { let mut notif_cb = clone_box(&*g.notif); mem::drop(g); let notif_val = ::Notif::decode(Pin::new(&mut &payload[..])).await; notif_cb(notif_val, self.clone()).await } else if 0 < id.bitand(1 << 63) { let mut sender = g.responses.remove(&!id).expect("Received response for invalid message"); sender.send(message.to_vec()).await.unwrap() } else { let message = ::Req::decode(Pin::new(&mut &payload[..])).await; let mut req_cb = clone_box(&*g.req); mem::drop(g); let rn = self.clone(); req_cb(RequestHandle::new(rn, id), message).await; } } pub async fn notify::Notif>>(&self, notif: N) { let mut send = clone_box(&*self.0.lock().await.send); let mut buf = vec![0; 8]; let msg: ::Notif = notif.into(); msg.encode(Pin::new(&mut buf)).await; send(&buf, self.clone()).await } } pub trait DynRequester { type Transfer; fn logger(&self) -> &Logger; /// Encode and send a request, then receive the response buffer. fn raw_request(&self, data: Self::Transfer) -> LocalBoxFuture<'_, RawReply>; } pub struct MappedRequester<'a, T: 'a>(Box LocalBoxFuture<'a, RawReply> + 'a>, Logger); impl<'a, T> MappedRequester<'a, T> { fn new U::Transfer + 'a>( req: U, cb: F, logger: Logger, ) -> Self { let req_arc = Arc::new(req); let cb_arc = Arc::new(cb); MappedRequester( Box::new(move |t| { Box::pin(clone!(req_arc, cb_arc; async move { req_arc.raw_request(cb_arc(t)).await})) }), logger, ) } } impl DynRequester for MappedRequester<'_, T> { type Transfer = T; fn logger(&self) -> &Logger { &self.1 } fn raw_request(&self, data: Self::Transfer) -> LocalBoxFuture<'_, RawReply> { self.0(data) } } impl DynRequester for ReqNot { type Transfer = ::Req; fn logger(&self) -> &Logger { &self.1 } fn raw_request(&self, req: Self::Transfer) -> LocalBoxFuture<'_, RawReply> { Box::pin(async move { let mut g = self.0.lock().await; let id = g.id; g.id += 1; let mut buf = id.to_be_bytes().to_vec(); req.encode(Pin::new(&mut buf)).await; let (send, mut recv) = mpsc::channel(1); g.responses.insert(id, send); let mut send = clone_box(&*g.send); mem::drop(g); let rn = self.clone(); send(&buf, rn).await; let items = recv.next().await; RawReply(items.unwrap()) }) } } pub trait Requester: DynRequester { #[must_use = "These types are subject to change with protocol versions. \ If you don't want to use the return value, At a minimum, force the type."] fn request>( &self, data: R, ) -> impl Future; fn map<'a, U>(self, cb: impl Fn(U) -> Self::Transfer + 'a) -> MappedRequester<'a, U> where Self: Sized + 'a { let logger = self.logger().clone(); MappedRequester::new(self, cb, logger) } } impl Requester for This { async fn request>(&self, data: R) -> R::Response { let req = format!("{data:?}"); let rep = R::Response::decode(Pin::new(&mut &self.raw_request(data.into()).await[..])).await; writeln!(self.logger(), "Request {req} got response {rep:?}"); rep } } impl Clone for ReqNot { fn clone(&self) -> Self { Self(self.0.clone(), self.1.clone()) } } #[cfg(test)] mod test { use std::rc::Rc; use std::sync::Arc; use futures::FutureExt; use futures::lock::Mutex; use orchid_api_derive::Coding; use orchid_api_traits::{Channel, Request}; use test_executors::spin_on; use super::{MsgSet, ReqNot}; use crate::logging::Logger; use crate::reqnot::Requester as _; use crate::{api, clone}; #[derive(Clone, Debug, Coding, PartialEq)] pub struct TestReq(u8); impl Request for TestReq { type Response = u8; } pub struct TestChan; impl Channel for TestChan { type Notif = u8; type Req = TestReq; } pub struct TestMsgSet; impl MsgSet for TestMsgSet { type In = TestChan; type Out = TestChan; } #[test] fn notification() { spin_on(async { let logger = Logger::new(api::LogStrategy::StdErr); let received = Arc::new(Mutex::new(None)); let receiver = ReqNot::::new( logger.clone(), |_, _| panic!("Should not send anything"), clone!(received; move |notif, _| clone!(received; async move { *received.lock().await = Some(notif); }.boxed_local())), |_, _| panic!("Not receiving a request"), ); let sender = ReqNot::::new( logger, clone!(receiver; move |d, _| clone!(receiver; Box::pin(async move { receiver.receive(d).await }))), |_, _| panic!("Should not receive notif"), |_, _| panic!("Should not receive request"), ); sender.notify(3).await; assert_eq!(*received.lock().await, Some(3)); sender.notify(4).await; assert_eq!(*received.lock().await, Some(4)); }) } #[test] fn request() { spin_on(async { let logger = Logger::new(api::LogStrategy::StdErr); let receiver = Rc::new(Mutex::>>::new(None)); let sender = Rc::new(ReqNot::::new( logger.clone(), clone!(receiver; move |d, _| clone!(receiver; Box::pin(async move { receiver.lock().await.as_ref().unwrap().receive(d).await }))), |_, _| panic!("Should not receive notif"), |_, _| panic!("Should not receive request"), )); *receiver.lock().await = Some(ReqNot::new( logger, clone!(sender; move |d, _| clone!(sender; Box::pin(async move { sender.receive(d).await }))), |_, _| panic!("Not receiving notifs"), |hand, req| { Box::pin(async move { assert_eq!(req, TestReq(5)); hand.respond(&6u8).await }) }, )); let response = sender.request(TestReq(5)).await; assert_eq!(response, 6); }) } }