use std::cell::RefCell; use std::marker::PhantomData; use std::pin::{Pin, pin}; use std::rc::Rc; use std::{io, mem}; use async_fn_stream::try_stream; use bound::Bound; use derive_destructure::destructure; use futures::channel::mpsc::{self, Receiver, Sender, channel}; use futures::channel::oneshot; use futures::future::LocalBoxFuture; use futures::lock::{Mutex, MutexGuard}; use futures::{ AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, Stream, StreamExt, stream_select, }; use hashbrown::HashMap; use orchid_api_traits::{Decode, Encode, Request, UnderRoot}; use crate::localset::LocalSet; #[must_use = "Receipts indicate that a required action has been performed within a function. \ Most likely this should be returned somewhere."] pub struct Receipt<'a>(PhantomData<&'a mut ()>); impl Receipt<'_> { /// Only call this function from a custom implementation of [RepWriter] pub fn _new() -> Self { Self(PhantomData) } } /// Write guard to outbound for the purpose of serializing a request. Only one /// can exist at a time. Dropping this object should panic. pub trait ReqWriter<'a> { /// Access to the underlying channel. This may be buffered. fn writer(&mut self) -> Pin<&mut dyn AsyncWrite>; /// Finalize the request, release the outbound channel, then queue for the /// reply on the inbound channel. fn send(self: Box) -> LocalBoxFuture<'a, io::Result + 'a>>>; } /// Write guard to inbound for the purpose of deserializing a reply. While held, /// no inbound requests or other replies can be processed. /// /// Dropping this object should panic even if [RepReader::finish] returns /// synchronously, because the API isn't cancellation safe in general so it is a /// programmer error in all cases to drop an object related to it without proper /// cleanup. pub trait RepReader<'a> { /// Access to the underlying channel. The length of the message is inferred /// from the number of bytes read so this must not be buffered. fn reader(&mut self) -> Pin<&mut dyn AsyncRead>; /// Finish reading the request fn finish(self: Box) -> LocalBoxFuture<'a, ()>; } /// Write guard to outbound for the purpose of serializing a notification. /// /// Dropping this object should panic for the same reason [RepReader] panics pub trait MsgWriter<'a> { /// Access to the underlying channel. This may be buffered. fn writer(&mut self) -> Pin<&mut dyn AsyncWrite>; /// Send the notification fn finish(self: Box) -> LocalBoxFuture<'a, io::Result<()>>; } /// For initiating outbound requests and notifications pub trait Client { fn start_request(&self) -> LocalBoxFuture<'_, io::Result + '_>>>; fn start_notif(&self) -> LocalBoxFuture<'_, io::Result + '_>>>; } impl ClientExt for T {} /// Extension trait with convenience methods that handle outbound request and /// notif lifecycle and typing #[allow(async_fn_in_trait)] pub trait ClientExt: Client { async fn request>(&self, t: T) -> io::Result { let mut req = self.start_request().await?; t.into_root().encode(req.writer().as_mut()).await?; let mut rep = req.send().await?; let response = T::Response::decode(rep.reader()).await; rep.finish().await; response } async fn notify>(&self, t: T) -> io::Result<()> { let mut notif = self.start_notif().await?; t.into_root().encode(notif.writer().as_mut()).await?; notif.finish().await?; Ok(()) } } pub trait ReqReader<'a> { fn reader(&mut self) -> Pin<&mut dyn AsyncRead>; fn finish(self: Box) -> LocalBoxFuture<'a, Box + 'a>>; } impl<'a, T: ReqReader<'a> + ?Sized> ReqReaderExt<'a> for T {} #[allow(async_fn_in_trait)] pub trait ReqReaderExt<'a>: ReqReader<'a> { async fn read_req(&mut self) -> io::Result { R::decode(self.reader()).await } async fn reply( self: Box, req: impl Evidence, rep: &R::Response, ) -> io::Result> { self.finish().await.reply(req, rep).await } async fn start_reply(self: Box) -> io::Result + 'a>> { self.finish().await.start_reply().await } } pub trait ReqHandle<'a> { fn start_reply(self: Box) -> LocalBoxFuture<'a, io::Result + 'a>>>; } impl<'a, T: ReqHandle<'a> + ?Sized> ReqHandleExt<'a> for T {} #[allow(async_fn_in_trait)] pub trait ReqHandleExt<'a>: ReqHandle<'a> { async fn reply( self: Box, _: impl Evidence, rep: &Req::Response, ) -> io::Result> { let mut reply = self.start_reply().await?; rep.encode(reply.writer()).await?; reply.finish().await } } pub trait RepWriter<'a> { fn writer(&mut self) -> Pin<&mut dyn AsyncWrite>; fn finish(self: Box) -> LocalBoxFuture<'a, io::Result>>; } pub trait MsgReader<'a> { fn reader(&mut self) -> Pin<&mut dyn AsyncRead>; fn finish(self: Box) -> LocalBoxFuture<'a, ()>; } impl<'a, T: ?Sized + MsgReader<'a>> MsgReaderExt<'a> for T {} #[allow(async_fn_in_trait)] pub trait MsgReaderExt<'a>: MsgReader<'a> { async fn read(mut self: Box) -> io::Result { let n = N::decode(self.reader()).await; self.finish().await; n } } /// A form of [Evidence] that doesn't require the value to be kept around pub struct Witness(PhantomData); impl Witness { pub fn of(_: &T) -> Self { Self(PhantomData) } } impl Copy for Witness {} impl Clone for Witness { fn clone(&self) -> Self { *self } } /// A proxy for the type of a value either previously saved into a [Witness] or /// still available. pub trait Evidence {} impl Evidence for &'_ T {} impl Evidence for Witness {} type IoRef = Pin>; type IoLock = Rc>>>; type IoGuard = Bound>>, IoLock>; /// An incoming request. This holds a lock on the ingress channel. pub struct IoReqReader<'a> { prefix: &'a [u8], read: IoGuard, write: &'a Mutex>, } impl<'a> ReqReader<'a> for IoReqReader<'a> { fn reader(&mut self) -> Pin<&mut dyn AsyncRead> { self.read.as_mut() } fn finish(self: Box) -> LocalBoxFuture<'a, Box + 'a>> { Box::pin(async { Box::new(IoReqHandle { prefix: self.prefix, write: self.write }) as Box> }) } } pub struct IoReqHandle<'a> { prefix: &'a [u8], write: &'a Mutex>, } impl<'a> ReqHandle<'a> for IoReqHandle<'a> { fn start_reply(self: Box) -> LocalBoxFuture<'a, io::Result + 'a>>> { Box::pin(async move { let mut write = self.write.lock().await; write.as_mut().write_all(self.prefix).await?; Ok(Box::new(IoRepWriter { write }) as Box>) }) } } pub struct IoRepWriter<'a> { write: MutexGuard<'a, IoRef>, } impl<'a> RepWriter<'a> for IoRepWriter<'a> { fn writer(&mut self) -> Pin<&mut dyn AsyncWrite> { self.write.as_mut() } fn finish(mut self: Box) -> LocalBoxFuture<'a, io::Result>> { Box::pin(async move { self.writer().flush().await?; Ok(Receipt(PhantomData)) }) } } pub struct IoMsgReader<'a> { _pd: PhantomData<&'a mut ()>, read: IoGuard, } impl<'a> MsgReader<'a> for IoMsgReader<'a> { fn reader(&mut self) -> Pin<&mut dyn AsyncRead> { self.read.as_mut() } fn finish(self: Box) -> LocalBoxFuture<'static, ()> { Box::pin(async {}) } } #[derive(Debug)] struct ReplySub { id: u64, ack: oneshot::Sender<()>, cb: oneshot::Sender>, } struct IoClient { output: IoLock, id: Rc>, subscribe: Rc>, } impl IoClient { fn new(output: IoLock) -> (Receiver, Self) { let (req, rep) = mpsc::channel(0); (rep, Self { output, id: Rc::new(RefCell::new(0)), subscribe: Rc::new(req) }) } async fn lock_out(&self) -> IoGuard { Bound::async_new(self.output.clone(), async |o| o.lock().await).await } } impl Client for IoClient { fn start_notif(&self) -> LocalBoxFuture<'_, io::Result + '_>>> { Box::pin(async { let mut o = self.lock_out().await; 0u64.encode(o.as_mut()).await?; Ok(Box::new(IoNotifWriter { o }) as Box) }) } fn start_request(&self) -> LocalBoxFuture<'_, io::Result + '_>>> { Box::pin(async { let id = { let mut id_g = self.id.borrow_mut(); *id_g += 1; *id_g }; let (cb, reply) = oneshot::channel(); let (ack, got_ack) = oneshot::channel(); self.subscribe.as_ref().clone().send(ReplySub { id, ack, cb }).await.unwrap(); got_ack.await.unwrap(); let mut w = self.lock_out().await; id.encode(w.as_mut()).await?; Ok(Box::new(IoReqWriter { reply, w }) as Box) }) } } struct IoReqWriter { reply: oneshot::Receiver>, w: IoGuard, } impl<'a> ReqWriter<'a> for IoReqWriter { fn writer(&mut self) -> Pin<&mut dyn AsyncWrite> { self.w.as_mut() } fn send(self: Box) -> LocalBoxFuture<'a, io::Result + 'a>>> { Box::pin(async { let Self { reply, mut w } = *self; w.flush().await?; mem::drop(w); let i = reply.await.expect("Client dropped before reply received"); Ok(Box::new(IoRepReader { i }) as Box) }) } } struct IoRepReader { i: IoGuard, } impl<'a> RepReader<'a> for IoRepReader { fn reader(&mut self) -> Pin<&mut dyn AsyncRead> { self.i.as_mut() } fn finish(self: Box) -> LocalBoxFuture<'static, ()> { Box::pin(async {}) } } #[derive(destructure)] struct IoNotifWriter { o: IoGuard, } impl<'a> MsgWriter<'a> for IoNotifWriter { fn writer(&mut self) -> Pin<&mut dyn AsyncWrite> { self.o.as_mut() } fn finish(mut self: Box) -> LocalBoxFuture<'static, io::Result<()>> { Box::pin(async move { self.o.flush().await }) } } pub struct CommCtx { exit: Sender<()>, o: Rc>>>, } impl CommCtx { pub async fn exit(self) -> io::Result<()> { self.o.lock().await.as_mut().close().await?; self.exit.clone().send(()).await.expect("quit channel dropped"); Ok(()) } } /// Establish bidirectional request-notification communication over a duplex /// channel. The returned [IoClient] can be used for notifications immediately, /// but requests can only be received while the future is running. The future /// will only resolve when [CommCtx::quit] is called. The generic type /// parameters are associated with the client and serve to ensure with a runtime /// check that the correct message families are sent in the correct directions /// across the channel. pub fn io_comm( o: Pin>, i: Pin>, ) -> (impl Client + 'static, CommCtx, IoCommServer) { let i = Rc::new(Mutex::new(i)); let o = Rc::new(Mutex::new(o)); let (onsub, client) = IoClient::new(o.clone()); let (exit, onexit) = channel(1); (client, CommCtx { exit, o: o.clone() }, IoCommServer { o, i, onsub, onexit }) } pub struct IoCommServer { o: Rc>>>, i: Rc>>>, onsub: Receiver, onexit: Receiver<()>, } impl IoCommServer { pub async fn listen( self, notif: impl for<'a> AsyncFn(Box + 'a>) -> io::Result<()>, req: impl for<'a> AsyncFn(Box + 'a>) -> io::Result>, ) -> io::Result<()> { let Self { o, i, onexit, onsub } = self; enum Event { Input(u64, IoGuard), Sub(ReplySub), Exit, } let exiting = RefCell::new(false); let input_stream = try_stream(async |mut h| { loop { let mut g = Bound::async_new(i.clone(), async |i| i.lock().await).await; match u64::decode(g.as_mut()).await { Ok(id) => h.emit(Event::Input(id, g)).await, Err(e) => match e.kind() { io::ErrorKind::BrokenPipe | io::ErrorKind::ConnectionAborted | io::ErrorKind::UnexpectedEof => h.emit(Event::Exit).await, _ => return Err(e), }, } } }); let (mut add_pending_req, fork_future) = LocalSet::new(); let mut fork_stream = pin!(fork_future.fuse().into_stream()); let mut pending_replies = HashMap::new(); 'body: { let mut shared = pin!(stream_select!( pin!(input_stream) as Pin<&mut dyn Stream>>, onsub.map(|sub| Ok(Event::Sub(sub))), fork_stream.as_mut().map(|res| { res.map(|()| panic!("this substream cannot exit while the loop is running")) }), onexit.map(|()| Ok(Event::Exit)), )); while let Some(next) = shared.next().await { match next { Err(e) => break 'body Err(e), Ok(Event::Exit) => { *exiting.borrow_mut() = true; let mut out = o.lock().await; out.as_mut().flush().await?; out.as_mut().close().await?; break; }, Ok(Event::Sub(ReplySub { id, ack, cb })) => { pending_replies.insert(id, cb); ack.send(()).unwrap(); }, Ok(Event::Input(0, read)) => { let notif = ¬if; let notif_job = async move { notif(Box::new(IoMsgReader { _pd: PhantomData, read })).await }; add_pending_req.send(Box::pin(notif_job)).await.unwrap(); }, // MSB == 0 is a request, !id where MSB == 1 is the corresponding response Ok(Event::Input(id, read)) if (id & (1 << (u64::BITS - 1))) == 0 => { let (o, req) = (o.clone(), &req); let req_job = async move { let mut prefix = Vec::new(); (!id).encode_vec(&mut prefix); let _ = req(Box::new(IoReqReader { prefix: &pin!(prefix), read, write: &o })).await; Ok(()) }; add_pending_req.send(Box::pin(req_job)).await.unwrap(); }, Ok(Event::Input(id, read)) => { let cb = pending_replies.remove(&!id).expect("Reply to unrecognized request"); cb.send(read).unwrap_or_else(|_| panic!("Failed to send reply")); }, } } Ok(()) }?; mem::drop(add_pending_req); while let Some(next) = fork_stream.next().await { next? } Ok(()) } } #[cfg(test)] mod test { use std::cell::RefCell; use futures::channel::mpsc; use futures::{SinkExt, StreamExt, join}; use orchid_api_derive::{Coding, Hierarchy}; use orchid_api_traits::Request; use test_executors::spin_on; use unsync_pipe::pipe; use crate::logging::test::TestLogger; use crate::logging::with_logger; use crate::reqnot::{ClientExt, MsgReaderExt, ReqReaderExt, io_comm}; #[derive(Clone, Debug, PartialEq, Coding, Hierarchy)] #[extendable] struct TestNotif(u64); #[test] fn notification() { let logger = TestLogger::new(async |s| eprint!("{s}")); spin_on(with_logger(logger, async { let (in1, out2) = pipe(1024); let (in2, out1) = pipe(1024); let (received, mut on_receive) = mpsc::channel(2); let (_, recv_ctx, recv_srv) = io_comm(Box::pin(in2), Box::pin(out2)); let (sender, ..) = io_comm(Box::pin(in1), Box::pin(out1)); join!( async { recv_srv .listen( async |notif| { received.clone().send(notif.read::().await?).await.unwrap(); Ok(()) }, async |_| panic!("Should receive notif, not request"), ) .await .unwrap() }, async { sender.notify(TestNotif(3)).await.unwrap(); assert_eq!(on_receive.next().await, Some(TestNotif(3))); sender.notify(TestNotif(4)).await.unwrap(); assert_eq!(on_receive.next().await, Some(TestNotif(4))); recv_ctx.exit().await.unwrap(); } ); })) } #[derive(Clone, Debug, Coding, Hierarchy)] #[extendable] struct DummyRequest(u64); impl Request for DummyRequest { type Response = u64; } #[test] fn request() { let logger = TestLogger::new(async |s| eprint!("{s}")); spin_on(with_logger(logger, async { let (in1, out2) = pipe(1024); let (in2, out1) = pipe(1024); let (_, srv_ctx, srv) = io_comm(Box::pin(in2), Box::pin(out2)); let (client, client_ctx, client_srv) = io_comm(Box::pin(in1), Box::pin(out1)); join!( async { srv .listen( async |_| panic!("No notifs expected"), async |mut req| { let val = req.read_req::().await?; req.reply(&val, &(val.0 + 1)).await }, ) .await .unwrap() }, async { client_srv .listen( async |_| panic!("Not expecting ingress notif"), async |_| panic!("Not expecting ingress req"), ) .await .unwrap() }, async { let response = client.request(DummyRequest(5)).await.unwrap(); assert_eq!(response, 6); srv_ctx.exit().await.unwrap(); client_ctx.exit().await.unwrap(); } ); })) } #[test] fn exit() { let logger = TestLogger::new(async |s| eprint!("{s}")); spin_on(with_logger(logger, async { let (input1, output1) = pipe(1024); let (input2, output2) = pipe(1024); let (reply_client, reply_context, reply_server) = io_comm(Box::pin(input1), Box::pin(output2)); let (req_client, req_context, req_server) = io_comm(Box::pin(input2), Box::pin(output1)); let reply_context = RefCell::new(Some(reply_context)); let (exit, onexit) = futures::channel::oneshot::channel::<()>(); join!( async move { reply_server .listen( async |hand| { let _notif = hand.read::().await.unwrap(); let context = reply_context.borrow_mut().take().unwrap(); context.exit().await?; Ok(()) }, async |mut hand| { let req = hand.read_req::().await?; hand.reply(&req, &(req.0 + 1)).await }, ) .await .unwrap(); exit.send(()).unwrap(); let _client = reply_client; }, async move { req_server .listen( async |_| panic!("Only the other server expected notifs"), async |_| panic!("Only the other server expected requests"), ) .await .unwrap(); let _ctx = req_context; }, async move { req_client.request(DummyRequest(0)).await.unwrap(); req_client.notify(TestNotif(0)).await.unwrap(); onexit.await.unwrap(); } ) })); } }