Pending a correct test of request cancellation
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use std::cell::RefCell;
|
||||
use std::cell::{BorrowMutError, RefCell};
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::{Pin, pin};
|
||||
use std::rc::Rc;
|
||||
@@ -16,56 +16,95 @@ use futures::{
|
||||
};
|
||||
use hashbrown::HashMap;
|
||||
use orchid_api_traits::{Decode, Encode, Request, UnderRoot};
|
||||
use orchid_async_utils::LocalSet;
|
||||
use orchid_async_utils::debug::{PanicOnDrop, assert_no_drop};
|
||||
use orchid_async_utils::{cancel_cleanup, local_set, to_task};
|
||||
|
||||
use crate::{clone, finish_or_stash, stash, with_stash};
|
||||
|
||||
// TODO: revise error handling; error recovery is never partial, it always
|
||||
// requires dropping the server, client, and all requests
|
||||
|
||||
/// A token indicating that a reply to a request has been sent. Returned from
|
||||
/// [RepWriter::finish] which is the raw reply channel, or [ReqHandleExt::reply]
|
||||
/// or [ReqReaderExt::reply] which are type checked
|
||||
#[must_use = "Receipts indicate that a required action has been performed within a function. \
|
||||
Most likely this should be returned somewhere."]
|
||||
pub struct Receipt<'a>(PhantomData<&'a mut ()>);
|
||||
impl Receipt<'_> {
|
||||
/// Only call this function from a custom implementation of [RepWriter]
|
||||
pub fn _new() -> Self { Self(PhantomData) }
|
||||
pub struct Receipt;
|
||||
impl Receipt {
|
||||
/// Only ever call this function from a custom implementation of
|
||||
/// [RepWriter::finish]
|
||||
pub fn _new() -> Self { Self }
|
||||
}
|
||||
|
||||
/// Return data while waiting for the response to a request. [Self::future] must
|
||||
/// be awaited in order to ensure that progress is being made
|
||||
pub struct ReqWait {
|
||||
/// Future representing waiting for a request. This must be steadily polled.
|
||||
pub future: LocalBoxFuture<'static, io::Result<Box<dyn RepReader>>>,
|
||||
/// Since the [Self::future] must be awaited which exclusively borrows it,
|
||||
/// this separate handle can be used for cancellation.
|
||||
pub canceller: Box<dyn CancelNotifier>,
|
||||
}
|
||||
|
||||
/// Write guard to outbound for the purpose of serializing a request. Only one
|
||||
/// can exist at a time. Dropping this object should panic.
|
||||
pub trait ReqWriter<'a> {
|
||||
pub trait ReqWriter {
|
||||
/// Access to the underlying channel. This may be buffered.
|
||||
fn writer(&mut self) -> Pin<&mut dyn AsyncWrite>;
|
||||
/// Finalize the request, release the outbound channel, then queue for the
|
||||
/// reply on the inbound channel.
|
||||
fn send(self: Box<Self>) -> LocalBoxFuture<'a, io::Result<Box<dyn RepReader<'a> + 'a>>>;
|
||||
fn send(self: Box<Self>) -> ReqWait;
|
||||
}
|
||||
|
||||
/// Write guard to inbound for the purpose of deserializing a reply. While held,
|
||||
/// no inbound requests or other replies can be processed.
|
||||
///
|
||||
/// Dropping this object should panic even if [RepReader::finish] returns
|
||||
/// synchronously, because the API isn't cancellation safe in general so it is a
|
||||
/// programmer error in all cases to drop an object related to it without proper
|
||||
/// cleanup.
|
||||
pub trait RepReader<'a> {
|
||||
/// # Cancellation
|
||||
///
|
||||
/// If the request has been cancelled and the server has accepted the
|
||||
/// cancellation instead of writing a reply (which is never guaranteed), then
|
||||
/// this object is inert and should be dropped.
|
||||
///
|
||||
/// Dropping this object if [Self::reader] returns [Some] should panic even if
|
||||
/// [RepReader::finish] returns synchronously, because the API isn't
|
||||
/// cancellation safe in general so it is a programmer error to drop an object
|
||||
/// related to it without proper cleanup.
|
||||
pub trait RepReader {
|
||||
/// Access to the underlying channel. The length of the message is inferred
|
||||
/// from the number of bytes read so this must not be buffered.
|
||||
fn reader(&mut self) -> Pin<&mut dyn AsyncRead>;
|
||||
/// from the number of bytes read so this must not be buffered and a full
|
||||
/// reply must always be read from it if available
|
||||
///
|
||||
/// This returns None if the request has successfully been cancelled, in which
|
||||
/// case this object can be dropped without calling [Self::finish]
|
||||
fn reader(&mut self) -> Option<Pin<&mut dyn AsyncRead>>;
|
||||
/// Finish reading the request
|
||||
fn finish(self: Box<Self>) -> LocalBoxFuture<'a, ()>;
|
||||
fn finish(self: Box<Self>) -> LocalBoxFuture<'static, ()>;
|
||||
}
|
||||
|
||||
/// A handle for cancelling in-flight requests without a reference to
|
||||
/// the wait future (which would be mutably borrowed by an await at this point)
|
||||
pub trait CancelNotifier {
|
||||
/// Upon cancellation the future may resolve to a stub version of [RepReader]
|
||||
/// with no reader access, but since the cancellation is not synchronized
|
||||
/// with the server, a full reply may still be received, and if it is, the
|
||||
/// original reply must still be read from it.
|
||||
fn cancel(self: Box<Self>) -> LocalBoxFuture<'static, ()>;
|
||||
}
|
||||
|
||||
/// Write guard to outbound for the purpose of serializing a notification.
|
||||
///
|
||||
/// Dropping this object should panic for the same reason [RepReader] panics
|
||||
pub trait MsgWriter<'a> {
|
||||
pub trait MsgWriter {
|
||||
/// Access to the underlying channel. This may be buffered.
|
||||
fn writer(&mut self) -> Pin<&mut dyn AsyncWrite>;
|
||||
/// Send the notification
|
||||
fn finish(self: Box<Self>) -> LocalBoxFuture<'a, io::Result<()>>;
|
||||
fn finish(self: Box<Self>) -> LocalBoxFuture<'static, io::Result<()>>;
|
||||
}
|
||||
|
||||
/// For initiating outbound requests and notifications
|
||||
pub trait Client {
|
||||
fn start_request(&self) -> LocalBoxFuture<'_, io::Result<Box<dyn ReqWriter<'_> + '_>>>;
|
||||
fn start_notif(&self) -> LocalBoxFuture<'_, io::Result<Box<dyn MsgWriter<'_> + '_>>>;
|
||||
fn start_request(&self) -> LocalBoxFuture<'static, io::Result<Box<dyn ReqWriter>>>;
|
||||
fn start_notif(&self) -> LocalBoxFuture<'static, io::Result<Box<dyn MsgWriter>>>;
|
||||
}
|
||||
|
||||
impl<T: Client + ?Sized> ClientExt for T {}
|
||||
@@ -73,62 +112,146 @@ impl<T: Client + ?Sized> ClientExt for T {}
|
||||
/// notif lifecycle and typing
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait ClientExt: Client {
|
||||
#[allow(
|
||||
clippy::await_holding_refcell_ref,
|
||||
reason = "Must bypass a future return point by sharing the common path"
|
||||
)]
|
||||
async fn request<T: Request + UnderRoot<Root: Encode>>(&self, t: T) -> io::Result<T::Response> {
|
||||
let mut req = self.start_request().await?;
|
||||
t.into_root().encode(req.writer().as_mut()).await?;
|
||||
let mut rep = req.send().await?;
|
||||
let response = T::Response::decode(rep.reader()).await;
|
||||
rep.finish().await;
|
||||
response
|
||||
let start_req = self.start_request();
|
||||
// This section must finish if it has started, and the returned writer's `send`
|
||||
// must be called as well.
|
||||
let common = Rc::new(RefCell::new(Some(Box::pin(async move {
|
||||
let mut writer = start_req.await?;
|
||||
t.into_root().encode(writer.writer().as_mut()).await?;
|
||||
io::Result::Ok(writer)
|
||||
}))));
|
||||
// Initialized in the cancelable section if common returns here. If set, the
|
||||
// future inside must be finished on stash after the notification is sent
|
||||
// to ensure that the server acknowledges the cancellation, or to decode the
|
||||
// result if the cancellation was in fact too late.
|
||||
let req_wait_rc = Rc::new(RefCell::new(None));
|
||||
// If both this and common are None, that means the cancelable section is
|
||||
// already past its last interruptible point, and must be finished on stash
|
||||
cancel_cleanup(
|
||||
clone!(req_wait_rc, common; Box::pin(async move {
|
||||
let req_wait;
|
||||
{
|
||||
let mut common_g = common.try_borrow_mut().expect("cancel will drop us before locking");
|
||||
let common = (common_g.as_mut())
|
||||
.expect("Only unset by us below or by cancel after dropping us");
|
||||
// cancel handler may take over here
|
||||
req_wait = common.await?.send();
|
||||
common_g.take();
|
||||
}
|
||||
let mut rep;
|
||||
{
|
||||
let mut req_wait_g = (req_wait_rc.try_borrow_mut())
|
||||
.expect("We are the first ones to access this");
|
||||
*req_wait_g = Some(req_wait);
|
||||
let req_wait = req_wait_g.as_mut().expect("Initialized right above");
|
||||
// cancel handler may take over here
|
||||
rep = req_wait.future.as_mut().await?;
|
||||
req_wait_g.take();
|
||||
};
|
||||
// cancel handler will not interrupt if we've gotten this far
|
||||
let reader = rep.reader().expect("Not been cancelled thus far");
|
||||
let result = T::Response::decode(reader).await;
|
||||
rep.finish().await;
|
||||
result
|
||||
})),
|
||||
|fut| {
|
||||
stash(async move {
|
||||
// TODO: strategy for IO errors on stash
|
||||
let req_wait = if common.try_borrow_mut().is_ok_and(|r| r.is_none()) {
|
||||
// fut was already past common
|
||||
match req_wait_rc.try_borrow_mut() {
|
||||
Ok(mut opt) => {
|
||||
let Some(req_wait) = opt.take() else {
|
||||
// fut was already reading, finish that read and exit
|
||||
fut.await.expect("IO error on stash");
|
||||
return;
|
||||
};
|
||||
req_wait
|
||||
},
|
||||
Err(BorrowMutError { .. }) => {
|
||||
// fut was in waiting, take over and do our own thing
|
||||
std::mem::drop(fut);
|
||||
req_wait_rc.take().expect("If it was borrowed then it was still set")
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// fut was still in common, take over and finish common
|
||||
std::mem::drop(fut);
|
||||
let common =
|
||||
(common.take()).expect("If it was still borrowed in fut, it was not yet unset");
|
||||
common.await.expect("IO error on stash").send()
|
||||
};
|
||||
req_wait.canceller.cancel().await;
|
||||
let mut rep = req_wait.future.await.expect("IO error on stash");
|
||||
let Some(reader) = rep.reader() else { return };
|
||||
T::Response::decode(reader).await.expect("IO error on stash");
|
||||
rep.finish().await;
|
||||
})
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
async fn notify<T: UnderRoot<Root: Encode>>(&self, t: T) -> io::Result<()> {
|
||||
let mut notif = self.start_notif().await?;
|
||||
t.into_root().encode(notif.writer().as_mut()).await?;
|
||||
notif.finish().await?;
|
||||
Ok(())
|
||||
async fn notify<T: UnderRoot<Root: Encode> + 'static>(&self, t: T) -> io::Result<()> {
|
||||
let start_notif = self.start_notif();
|
||||
finish_or_stash(Box::pin(async {
|
||||
let mut notif = start_notif.await?;
|
||||
t.into_root().encode(notif.writer().as_mut()).await?;
|
||||
notif.finish().await?;
|
||||
Ok(())
|
||||
}))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ReqReader<'a> {
|
||||
pub trait ReqReader {
|
||||
fn reader(&mut self) -> Pin<&mut dyn AsyncRead>;
|
||||
fn finish(self: Box<Self>) -> LocalBoxFuture<'a, Box<dyn ReqHandle<'a> + 'a>>;
|
||||
fn finish(self: Box<Self>) -> LocalBoxFuture<'static, Box<dyn ReqHandle>>;
|
||||
}
|
||||
impl<'a, T: ReqReader<'a> + ?Sized> ReqReaderExt<'a> for T {}
|
||||
impl<T: ReqReader + ?Sized> ReqReaderExt for T {}
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait ReqReaderExt<'a>: ReqReader<'a> {
|
||||
pub trait ReqReaderExt: ReqReader {
|
||||
async fn read_req<R: Decode>(&mut self) -> io::Result<R> { R::decode(self.reader()).await }
|
||||
async fn reply<R: Request>(
|
||||
self: Box<Self>,
|
||||
req: impl Evidence<R>,
|
||||
rep: &R::Response,
|
||||
) -> io::Result<Receipt<'a>> {
|
||||
rep: R::Response,
|
||||
) -> io::Result<Receipt> {
|
||||
self.finish().await.reply(req, rep).await
|
||||
}
|
||||
async fn start_reply(self: Box<Self>) -> io::Result<Box<dyn RepWriter<'a> + 'a>> {
|
||||
async fn start_reply(self: Box<Self>) -> io::Result<Box<dyn RepWriter>> {
|
||||
self.finish().await.start_reply().await
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ReqHandle<'a> {
|
||||
fn start_reply(self: Box<Self>) -> LocalBoxFuture<'a, io::Result<Box<dyn RepWriter<'a> + 'a>>>;
|
||||
pub trait ReqHandle {
|
||||
fn start_reply(self: Box<Self>) -> LocalBoxFuture<'static, io::Result<Box<dyn RepWriter>>>;
|
||||
}
|
||||
impl<'a, T: ReqHandle<'a> + ?Sized> ReqHandleExt<'a> for T {}
|
||||
impl<T: ReqHandle + ?Sized> ReqHandleExt for T {}
|
||||
#[allow(async_fn_in_trait)]
|
||||
pub trait ReqHandleExt<'a>: ReqHandle<'a> {
|
||||
pub trait ReqHandleExt: ReqHandle {
|
||||
async fn reply<Req: Request>(
|
||||
self: Box<Self>,
|
||||
_: impl Evidence<Req>,
|
||||
rep: &Req::Response,
|
||||
) -> io::Result<Receipt<'a>> {
|
||||
let mut reply = self.start_reply().await?;
|
||||
rep.encode(reply.writer()).await?;
|
||||
reply.finish().await
|
||||
rep: Req::Response,
|
||||
) -> io::Result<Receipt> {
|
||||
let start_reply = self.start_reply();
|
||||
finish_or_stash(Box::pin(async move {
|
||||
let mut reply = start_reply.await?;
|
||||
rep.encode(reply.writer()).await?;
|
||||
reply.finish().await
|
||||
}))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pub trait RepWriter<'a> {
|
||||
pub trait RepWriter {
|
||||
fn writer(&mut self) -> Pin<&mut dyn AsyncWrite>;
|
||||
fn finish(self: Box<Self>) -> LocalBoxFuture<'a, io::Result<Receipt<'a>>>;
|
||||
fn finish(self: Box<Self>) -> LocalBoxFuture<'static, io::Result<Receipt>>;
|
||||
}
|
||||
|
||||
pub trait MsgReader<'a> {
|
||||
@@ -166,41 +289,43 @@ type IoLock<T> = Rc<Mutex<Pin<Box<T>>>>;
|
||||
type IoGuard<T> = Bound<MutexGuard<'static, Pin<Box<T>>>, IoLock<T>>;
|
||||
|
||||
/// An incoming request. This holds a lock on the ingress channel.
|
||||
pub struct IoReqReader<'a> {
|
||||
prefix: &'a [u8],
|
||||
pub struct IoReqReader {
|
||||
prefix: u64,
|
||||
read: IoGuard<dyn AsyncRead>,
|
||||
write: &'a Mutex<IoRef<dyn AsyncWrite>>,
|
||||
o: Rc<Mutex<IoRef<dyn AsyncWrite>>>,
|
||||
}
|
||||
impl<'a> ReqReader<'a> for IoReqReader<'a> {
|
||||
impl ReqReader for IoReqReader {
|
||||
fn reader(&mut self) -> Pin<&mut dyn AsyncRead> { self.read.as_mut() }
|
||||
fn finish(self: Box<Self>) -> LocalBoxFuture<'a, Box<dyn ReqHandle<'a> + 'a>> {
|
||||
fn finish(self: Box<Self>) -> LocalBoxFuture<'static, Box<dyn ReqHandle>> {
|
||||
Box::pin(async {
|
||||
Box::new(IoReqHandle { prefix: self.prefix, write: self.write }) as Box<dyn ReqHandle<'a>>
|
||||
Box::new(IoReqHandle { prefix: self.prefix, write: self.o }) as Box<dyn ReqHandle>
|
||||
})
|
||||
}
|
||||
}
|
||||
pub struct IoReqHandle<'a> {
|
||||
prefix: &'a [u8],
|
||||
write: &'a Mutex<IoRef<dyn AsyncWrite>>,
|
||||
|
||||
pub struct IoReqHandle {
|
||||
prefix: u64,
|
||||
write: IoLock<dyn AsyncWrite>,
|
||||
}
|
||||
impl<'a> ReqHandle<'a> for IoReqHandle<'a> {
|
||||
fn start_reply(self: Box<Self>) -> LocalBoxFuture<'a, io::Result<Box<dyn RepWriter<'a> + 'a>>> {
|
||||
impl ReqHandle for IoReqHandle {
|
||||
fn start_reply(self: Box<Self>) -> LocalBoxFuture<'static, io::Result<Box<dyn RepWriter>>> {
|
||||
let write = self.write.clone();
|
||||
Box::pin(async move {
|
||||
let mut write = self.write.lock().await;
|
||||
write.as_mut().write_all(self.prefix).await?;
|
||||
Ok(Box::new(IoRepWriter { write }) as Box<dyn RepWriter<'a>>)
|
||||
let mut write = Bound::async_new(write, |l| l.lock()).await;
|
||||
self.prefix.encode(write.as_mut()).await?;
|
||||
Ok(Box::new(IoRepWriter { write }) as Box<dyn RepWriter>)
|
||||
})
|
||||
}
|
||||
}
|
||||
pub struct IoRepWriter<'a> {
|
||||
write: MutexGuard<'a, IoRef<dyn AsyncWrite>>,
|
||||
pub struct IoRepWriter {
|
||||
write: IoGuard<dyn AsyncWrite>,
|
||||
}
|
||||
impl<'a> RepWriter<'a> for IoRepWriter<'a> {
|
||||
impl RepWriter for IoRepWriter {
|
||||
fn writer(&mut self) -> Pin<&mut dyn AsyncWrite> { self.write.as_mut() }
|
||||
fn finish(mut self: Box<Self>) -> LocalBoxFuture<'a, io::Result<Receipt<'a>>> {
|
||||
fn finish(mut self: Box<Self>) -> LocalBoxFuture<'static, io::Result<Receipt>> {
|
||||
Box::pin(async move {
|
||||
self.writer().flush().await?;
|
||||
Ok(Receipt(PhantomData))
|
||||
Ok(Receipt)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -214,11 +339,16 @@ impl<'a> MsgReader<'a> for IoMsgReader<'a> {
|
||||
fn finish(self: Box<Self>) -> LocalBoxFuture<'static, ()> { Box::pin(async {}) }
|
||||
}
|
||||
|
||||
pub enum ReplyRecord {
|
||||
Cancelled,
|
||||
Ready(IoGuard<dyn AsyncRead>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ReplySub {
|
||||
id: u64,
|
||||
ack: oneshot::Sender<()>,
|
||||
cb: oneshot::Sender<IoGuard<dyn AsyncRead>>,
|
||||
cb: oneshot::Sender<ReplyRecord>,
|
||||
}
|
||||
|
||||
struct IoClient {
|
||||
@@ -231,37 +361,42 @@ impl IoClient {
|
||||
let (req, rep) = mpsc::channel(0);
|
||||
(rep, Self { output, id: Rc::new(RefCell::new(0)), subscribe: Rc::new(req) })
|
||||
}
|
||||
async fn lock_out(&self) -> IoGuard<dyn AsyncWrite> {
|
||||
Bound::async_new(self.output.clone(), async |o| o.lock().await).await
|
||||
}
|
||||
}
|
||||
impl Client for IoClient {
|
||||
fn start_notif(&self) -> LocalBoxFuture<'_, io::Result<Box<dyn MsgWriter<'_> + '_>>> {
|
||||
fn start_notif(&self) -> LocalBoxFuture<'static, io::Result<Box<dyn MsgWriter>>> {
|
||||
let output = self.output.clone();
|
||||
Box::pin(async {
|
||||
let drop_g = assert_no_drop("Notif future dropped");
|
||||
let mut o = self.lock_out().await;
|
||||
let mut o = Bound::async_new(output, |o| o.lock()).await;
|
||||
0u64.encode(o.as_mut()).await?;
|
||||
drop_g.defuse();
|
||||
Ok(Box::new(IoNotifWriter { o, drop_g: assert_no_drop("Notif writer dropped") })
|
||||
as Box<dyn MsgWriter>)
|
||||
})
|
||||
}
|
||||
fn start_request(&self) -> LocalBoxFuture<'_, io::Result<Box<dyn ReqWriter<'_> + '_>>> {
|
||||
Box::pin(async {
|
||||
let id = {
|
||||
let mut id_g = self.id.borrow_mut();
|
||||
*id_g += 1;
|
||||
*id_g
|
||||
};
|
||||
let (cb, reply) = oneshot::channel();
|
||||
let (ack, got_ack) = oneshot::channel();
|
||||
let drop_g = assert_no_drop("Request future dropped");
|
||||
self.subscribe.as_ref().clone().send(ReplySub { id, ack, cb }).await.unwrap();
|
||||
fn start_request(&self) -> LocalBoxFuture<'static, io::Result<Box<dyn ReqWriter>>> {
|
||||
let output = self.output.clone();
|
||||
let id = {
|
||||
let mut id_g = self.id.borrow_mut();
|
||||
*id_g += 1;
|
||||
*id_g
|
||||
};
|
||||
let (cb, reply) = oneshot::channel();
|
||||
let (ack, got_ack) = oneshot::channel();
|
||||
let mut subscribe = self.subscribe.as_ref().clone();
|
||||
let start_req_drop_g = assert_no_drop("Request future dropped");
|
||||
Box::pin(async move {
|
||||
subscribe.send(ReplySub { id, ack, cb }).await.unwrap();
|
||||
got_ack.await.unwrap();
|
||||
let mut w = self.lock_out().await;
|
||||
id.encode(w.as_mut()).await?;
|
||||
drop_g.defuse();
|
||||
let mut xfer_bytes = id.to_be_bytes();
|
||||
xfer_bytes[0] = 0x00;
|
||||
let req_prefix = u64::from_be_bytes(xfer_bytes);
|
||||
let mut w = Bound::async_new(output.clone(), |o| o.lock()).await;
|
||||
req_prefix.encode(w.as_mut()).await?;
|
||||
start_req_drop_g.defuse();
|
||||
Ok(Box::new(IoReqWriter {
|
||||
id,
|
||||
output,
|
||||
reply,
|
||||
w,
|
||||
drop_g: assert_no_drop("Request reader dropped without reply"),
|
||||
@@ -270,34 +405,62 @@ impl Client for IoClient {
|
||||
}
|
||||
}
|
||||
|
||||
struct IoReqWriter {
|
||||
reply: oneshot::Receiver<IoGuard<dyn AsyncRead>>,
|
||||
w: IoGuard<dyn AsyncWrite>,
|
||||
drop_g: PanicOnDrop,
|
||||
struct IoReqCanceller {
|
||||
id: u64,
|
||||
output: IoLock<dyn AsyncWrite>,
|
||||
}
|
||||
impl<'a> ReqWriter<'a> for IoReqWriter {
|
||||
fn writer(&mut self) -> Pin<&mut dyn AsyncWrite> { self.w.as_mut() }
|
||||
fn send(self: Box<Self>) -> LocalBoxFuture<'a, io::Result<Box<dyn RepReader<'a> + 'a>>> {
|
||||
Box::pin(async {
|
||||
let Self { reply, mut w, drop_g } = *self;
|
||||
w.flush().await?;
|
||||
mem::drop(w);
|
||||
let i = reply.await.expect("Client dropped before reply received");
|
||||
drop_g.defuse();
|
||||
Ok(Box::new(IoRepReader {
|
||||
i,
|
||||
drop_g: assert_no_drop("Reply reader dropped without finishing"),
|
||||
}) as Box<dyn RepReader>)
|
||||
impl CancelNotifier for IoReqCanceller {
|
||||
fn cancel(self: Box<Self>) -> LocalBoxFuture<'static, ()> {
|
||||
let mut xfer_bytes = self.id.to_be_bytes();
|
||||
xfer_bytes[0] = 0x02;
|
||||
let cancel_id = u64::from_be_bytes(xfer_bytes);
|
||||
let cancel_signal_drop_g = assert_no_drop("Cannot cancel the sending of a cancellation");
|
||||
let o = self.output.clone();
|
||||
Box::pin(async move {
|
||||
let mut o = o.lock().await;
|
||||
let _ = cancel_id.encode(o.as_mut()).await;
|
||||
cancel_signal_drop_g.defuse();
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct IoRepReader {
|
||||
i: IoGuard<dyn AsyncRead>,
|
||||
struct IoReqWriter {
|
||||
id: u64,
|
||||
reply: oneshot::Receiver<ReplyRecord>,
|
||||
output: IoLock<dyn AsyncWrite>,
|
||||
w: IoGuard<dyn AsyncWrite>,
|
||||
drop_g: PanicOnDrop,
|
||||
}
|
||||
impl<'a> RepReader<'a> for IoRepReader {
|
||||
fn reader(&mut self) -> Pin<&mut dyn AsyncRead> { self.i.as_mut() }
|
||||
impl ReqWriter for IoReqWriter {
|
||||
fn writer(&mut self) -> Pin<&mut dyn AsyncWrite> { self.w.as_mut() }
|
||||
fn send(self: Box<Self>) -> ReqWait {
|
||||
let Self { id, output, reply, mut w, drop_g } = *self;
|
||||
let canceller = IoReqCanceller { id, output };
|
||||
let future = async {
|
||||
w.flush().await?;
|
||||
mem::drop(w);
|
||||
let reply_record = reply.await.expect("Client dropped before reply received");
|
||||
drop_g.defuse();
|
||||
Ok(Box::new(IoRepReader {
|
||||
reply_record,
|
||||
drop_g: assert_no_drop("Reply reader dropped without finishing"),
|
||||
}) as Box<dyn RepReader>)
|
||||
};
|
||||
ReqWait { future: Box::pin(future), canceller: Box::new(canceller) }
|
||||
}
|
||||
}
|
||||
|
||||
struct IoRepReader {
|
||||
reply_record: ReplyRecord,
|
||||
drop_g: PanicOnDrop,
|
||||
}
|
||||
impl RepReader for IoRepReader {
|
||||
fn reader(&mut self) -> Option<Pin<&mut dyn AsyncRead>> {
|
||||
match &mut self.reply_record {
|
||||
ReplyRecord::Cancelled => None,
|
||||
ReplyRecord::Ready(guard) => Some(guard.as_mut()),
|
||||
}
|
||||
}
|
||||
fn finish(self: Box<Self>) -> LocalBoxFuture<'static, ()> {
|
||||
Box::pin(async { self.drop_g.defuse() })
|
||||
}
|
||||
@@ -308,7 +471,7 @@ struct IoNotifWriter {
|
||||
o: IoGuard<dyn AsyncWrite>,
|
||||
drop_g: PanicOnDrop,
|
||||
}
|
||||
impl<'a> MsgWriter<'a> for IoNotifWriter {
|
||||
impl MsgWriter for IoNotifWriter {
|
||||
fn writer(&mut self) -> Pin<&mut dyn AsyncWrite> { self.o.as_mut() }
|
||||
fn finish(mut self: Box<Self>) -> LocalBoxFuture<'static, io::Result<()>> {
|
||||
Box::pin(async move {
|
||||
@@ -333,10 +496,7 @@ impl CommCtx {
|
||||
/// Establish bidirectional request-notification communication over a duplex
|
||||
/// channel. The returned [IoClient] can be used for notifications immediately,
|
||||
/// but requests can only be received while the future is running. The future
|
||||
/// will only resolve when [CommCtx::quit] is called. The generic type
|
||||
/// parameters are associated with the client and serve to ensure with a runtime
|
||||
/// check that the correct message families are sent in the correct directions
|
||||
/// across the channel.
|
||||
/// will only resolve when [CommCtx::exit] is called.
|
||||
pub fn io_comm(
|
||||
o: Pin<Box<dyn AsyncWrite>>,
|
||||
i: Pin<Box<dyn AsyncRead>>,
|
||||
@@ -356,8 +516,8 @@ pub struct IoCommServer {
|
||||
impl IoCommServer {
|
||||
pub async fn listen(
|
||||
self,
|
||||
notif: impl for<'a> AsyncFn(Box<dyn MsgReader<'a> + 'a>) -> io::Result<()>,
|
||||
req: impl for<'a> AsyncFn(Box<dyn ReqReader<'a> + 'a>) -> io::Result<Receipt<'a>>,
|
||||
notif: impl AsyncFn(Box<dyn MsgReader>) -> io::Result<()>,
|
||||
req: impl AsyncFn(Box<dyn ReqReader>) -> io::Result<Receipt>,
|
||||
) -> io::Result<()> {
|
||||
let Self { o, i, onexit, onsub } = self;
|
||||
enum Event {
|
||||
@@ -379,7 +539,9 @@ impl IoCommServer {
|
||||
}
|
||||
}
|
||||
});
|
||||
let (mut add_pending_req, fork_future) = LocalSet::new();
|
||||
|
||||
let running_requests = RefCell::new(HashMap::new());
|
||||
let (mut task_pool, fork_future) = local_set();
|
||||
let mut fork_stream = pin!(fork_future.into_stream());
|
||||
let mut pending_replies = HashMap::new();
|
||||
'body: {
|
||||
@@ -400,32 +562,73 @@ impl IoCommServer {
|
||||
// this is detected and logged on client
|
||||
let _ = ack.send(());
|
||||
},
|
||||
// ID 0 is reserved for single-fire notifications
|
||||
Ok(Event::Input(0, read)) => {
|
||||
let notif = ¬if;
|
||||
let notif_job =
|
||||
async move { notif(Box::new(IoMsgReader { _pd: PhantomData, read })).await };
|
||||
add_pending_req.send(Box::pin(notif_job)).await.unwrap();
|
||||
},
|
||||
// MSB == 0 is a request, !id where MSB == 1 is the corresponding response
|
||||
Ok(Event::Input(id, read)) if (id & (1 << (u64::BITS - 1))) == 0 => {
|
||||
let (o, req) = (o.clone(), &req);
|
||||
let req_job = async move {
|
||||
let mut prefix = Vec::new();
|
||||
(!id).encode_vec(&mut prefix);
|
||||
let _ = req(Box::new(IoReqReader { prefix: &pin!(prefix), read, write: &o })).await;
|
||||
Ok(())
|
||||
};
|
||||
add_pending_req.send(Box::pin(req_job)).await.unwrap();
|
||||
task_pool.spawn(notif(Box::new(IoMsgReader { _pd: PhantomData, read }))).await.unwrap();
|
||||
},
|
||||
// non-zero IDs are associated with requests
|
||||
Ok(Event::Input(id, read)) => {
|
||||
let cb = pending_replies.remove(&!id).expect("Reply to unrecognized request");
|
||||
cb.send(read).unwrap_or_else(|_| panic!("Failed to send reply"));
|
||||
// the MSb decides what kind of message this is
|
||||
let mut id_bytes = id.to_be_bytes();
|
||||
let discr = std::mem::replace(&mut id_bytes[0], 0x00);
|
||||
let id = u64::from_be_bytes(id_bytes);
|
||||
match discr {
|
||||
// request
|
||||
0x00 => {
|
||||
let (o, req, reqs) = (o.clone(), &req, &running_requests);
|
||||
task_pool
|
||||
.spawn(async move {
|
||||
id_bytes[0] = 0x01;
|
||||
let prefix = u64::from_be_bytes(id_bytes);
|
||||
let reader = Box::new(IoReqReader { prefix, read, o });
|
||||
let (fut, handle) = to_task(async { req(reader).await.map(|Receipt| ()) });
|
||||
reqs.borrow_mut().insert(id, handle);
|
||||
with_stash(fut).await;
|
||||
// during this await the read guard is released and thus we may receive a
|
||||
// cancel notification from below
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
},
|
||||
// response
|
||||
0x01 => {
|
||||
let cb = pending_replies.remove(&id).expect("Reply to unrecognized request");
|
||||
cb.send(ReplyRecord::Ready(read))
|
||||
.unwrap_or_else(|_| panic!("Failed to send reply"));
|
||||
},
|
||||
// cancellation
|
||||
0x02 => {
|
||||
match running_requests.borrow().get(&id) {
|
||||
Some(handle) => handle.abort(),
|
||||
// assuming that the client is correct, if there is no record
|
||||
// then the reply was already sent
|
||||
None => continue,
|
||||
};
|
||||
// if the request starts writing back before our abort arrives, we only
|
||||
// get this mutex once it's done
|
||||
let mut write = o.lock().await;
|
||||
// if the request is still in the store, the write didn't begin
|
||||
let Some(_) = running_requests.borrow_mut().remove(&id) else { continue };
|
||||
id_bytes[0] = 0x03;
|
||||
let cancel_code = u64::from_be_bytes(id_bytes);
|
||||
cancel_code.encode(write.as_mut()).await?;
|
||||
},
|
||||
// stub reply for cancelled request
|
||||
0x03 => {
|
||||
let cb = pending_replies.remove(&id).expect("Cancelling unrecognized request");
|
||||
cb.send(ReplyRecord::Cancelled)
|
||||
.unwrap_or_else(|_| panic!("Failed to send reply cancellation"))
|
||||
},
|
||||
n => panic!("Unrecognized message type code {n}"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}?;
|
||||
mem::drop(add_pending_req);
|
||||
mem::drop(task_pool);
|
||||
while let Some(next) = fork_stream.next().await {
|
||||
next?
|
||||
}
|
||||
@@ -441,13 +644,15 @@ mod test {
|
||||
use std::cell::RefCell;
|
||||
|
||||
use futures::channel::mpsc;
|
||||
use futures::{SinkExt, StreamExt, join};
|
||||
use futures::{FutureExt, SinkExt, StreamExt, join, select};
|
||||
use never::Never;
|
||||
use orchid_api_derive::{Coding, Hierarchy};
|
||||
use orchid_api_traits::Request;
|
||||
use orchid_async_utils::debug::spin_on;
|
||||
use unsync_pipe::pipe;
|
||||
|
||||
use crate::comm::{ClientExt, MsgReaderExt, ReqReaderExt, io_comm};
|
||||
use crate::with_stash;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Coding, Hierarchy)]
|
||||
#[extendable]
|
||||
@@ -455,7 +660,7 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn notification() {
|
||||
spin_on(async {
|
||||
spin_on(false, async {
|
||||
let (in1, out2) = pipe(1024);
|
||||
let (in2, out1) = pipe(1024);
|
||||
let (received, mut on_receive) = mpsc::channel(2);
|
||||
@@ -494,7 +699,7 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn request() {
|
||||
spin_on(async {
|
||||
spin_on(false, async {
|
||||
let (in1, out2) = pipe(1024);
|
||||
let (in2, out1) = pipe(1024);
|
||||
let (_, srv_ctx, srv) = io_comm(Box::pin(in2), Box::pin(out2));
|
||||
@@ -506,7 +711,7 @@ mod test {
|
||||
async |_| panic!("No notifs expected"),
|
||||
async |mut req| {
|
||||
let val = req.read_req::<DummyRequest>().await?;
|
||||
req.reply(&val, &(val.0 + 1)).await
|
||||
req.reply(&val, val.0 + 1).await
|
||||
},
|
||||
)
|
||||
.await
|
||||
@@ -533,7 +738,7 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn exit() {
|
||||
spin_on(async {
|
||||
spin_on(false, async {
|
||||
let (input1, output1) = pipe(1024);
|
||||
let (input2, output2) = pipe(1024);
|
||||
let (reply_client, reply_context, reply_server) =
|
||||
@@ -553,7 +758,7 @@ mod test {
|
||||
},
|
||||
async |mut hand| {
|
||||
let req = hand.read_req::<DummyRequest>().await?;
|
||||
hand.reply(&req, &(req.0 + 1)).await
|
||||
hand.reply(&req, req.0 + 1).await
|
||||
},
|
||||
)
|
||||
.await
|
||||
@@ -579,4 +784,49 @@ mod test {
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn timely_cancel() {
|
||||
spin_on(false, async {
|
||||
let (in1, out2) = pipe(1024);
|
||||
let (in2, out1) = pipe(1024);
|
||||
let (wait_in, mut wait_out) = mpsc::channel(0);
|
||||
let (_, srv_ctx, srv) = io_comm(Box::pin(in2), Box::pin(out2));
|
||||
let (client, client_ctx, client_srv) = io_comm(Box::pin(in1), Box::pin(out1));
|
||||
join!(
|
||||
async {
|
||||
srv
|
||||
.listen(
|
||||
async |_| panic!("No notifs expected"),
|
||||
async |mut req| {
|
||||
let _ = req.read_req::<DummyRequest>().await?;
|
||||
wait_in.clone().send(()).await.unwrap();
|
||||
// TODO: verify cancellation
|
||||
futures::future::pending::<Never>().await;
|
||||
unreachable!("request should be cancelled before resume is triggered")
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
},
|
||||
async {
|
||||
client_srv
|
||||
.listen(
|
||||
async |_| panic!("Not expecting ingress notif"),
|
||||
async |_| panic!("Not expecting ingress req"),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
},
|
||||
with_stash(async {
|
||||
select! {
|
||||
_ = client.request(DummyRequest(5)).fuse() => panic!("This one should not run"),
|
||||
rep = wait_out.next() => rep.expect("something?"),
|
||||
};
|
||||
srv_ctx.exit().await.unwrap();
|
||||
client_ctx.exit().await.unwrap();
|
||||
})
|
||||
);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,44 +1,188 @@
|
||||
//! A pattern for running async code from sync destructors and other
|
||||
//! unfortunately sync callbacks
|
||||
//! unfortunately sync callbacks, and for ensuring that these futures finish in
|
||||
//! a timely fashion
|
||||
//!
|
||||
//! We create a task_local vecdeque which is moved into a thread_local whenever
|
||||
//! the task is being polled. A call to [stash] pushes the future onto this
|
||||
//! deque. Before [with_stash] returns, it pops everything from the deque
|
||||
//! individually and awaits each of them, pushing any additionally stashed
|
||||
//! futures onto the back of the same deque.
|
||||
//! deque. Before [with_stash] returns, it awaits everything stashed up to that
|
||||
//! point or inside the stashed futures.
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::collections::VecDeque;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use task_local::task_local;
|
||||
use futures::StreamExt;
|
||||
use futures::future::LocalBoxFuture;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use orchid_async_utils::cancel_cleanup;
|
||||
|
||||
#[derive(Default)]
|
||||
struct StashedFutures {
|
||||
queue: RefCell<VecDeque<Pin<Box<dyn Future<Output = ()>>>>>,
|
||||
}
|
||||
|
||||
task_local! {
|
||||
static STASHED_FUTURES: StashedFutures;
|
||||
thread_local! {
|
||||
/// # Invariant
|
||||
///
|
||||
/// Any function that changes the value of this thread_local must restore it before returning
|
||||
static CURRENT_STASH: RefCell<Option<Vec<LocalBoxFuture<'static, ()>>>> = RefCell::default();
|
||||
}
|
||||
|
||||
/// Complete the argument future, and any futures spawned from it via [stash].
|
||||
/// This is useful mostly to guarantee that messaging destructors have run.
|
||||
pub async fn with_stash<F: Future>(fut: F) -> F::Output {
|
||||
STASHED_FUTURES
|
||||
.scope(StashedFutures::default(), async {
|
||||
let val = fut.await;
|
||||
while let Some(fut) = STASHED_FUTURES.with(|sf| sf.queue.borrow_mut().pop_front()) {
|
||||
fut.await;
|
||||
}
|
||||
val
|
||||
})
|
||||
.await
|
||||
///
|
||||
/// # Cancellation
|
||||
///
|
||||
/// To ensure that stashed futures run, the returned future re-stashes them a
|
||||
/// layer above when dropped. Therefore cancelling `with_stash` is only safe
|
||||
/// within an enclosing `with_stash` outside of a panic.
|
||||
pub fn with_stash<F: Future>(fut: F) -> impl Future<Output = F::Output> {
|
||||
WithStash { stash: FuturesUnordered::new(), state: WithStashState::Main(fut) }
|
||||
}
|
||||
|
||||
/// Schedule a future to be run before the next [with_stash] guard ends. This is
|
||||
/// most useful for sending messages from destructors.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If no enclosing stash is found, this function panics, unless we are already
|
||||
/// panicking. The assumption is that a panic is a vis-major where proper
|
||||
/// cleanup is secondary to avoiding an abort.
|
||||
pub fn stash<F: Future<Output = ()> + 'static>(fut: F) {
|
||||
(STASHED_FUTURES.try_with(|sf| sf.queue.borrow_mut().push_back(Box::pin(fut))))
|
||||
.expect("No stash! Timely completion cannot be guaranteed")
|
||||
CURRENT_STASH.with(|stash| {
|
||||
let mut g = stash.borrow_mut();
|
||||
let Some(stash) = g.as_mut() else {
|
||||
if !std::thread::panicking() {
|
||||
panic!("No stash! Timely completion cannot be guaranteed");
|
||||
}
|
||||
return;
|
||||
};
|
||||
stash.push(Box::pin(fut))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn finish_or_stash<F: Future + Unpin + 'static>(
|
||||
fut: F,
|
||||
) -> impl Future<Output = F::Output> + Unpin + 'static {
|
||||
cancel_cleanup(fut, |fut| {
|
||||
stash(async {
|
||||
fut.await;
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
enum WithStashState<F: Future> {
|
||||
Main(F),
|
||||
Stash {
|
||||
/// Optional to simplify state management, but only ever null on a very
|
||||
/// short stretch
|
||||
output: Option<F::Output>,
|
||||
},
|
||||
}
|
||||
|
||||
struct WithStash<F: Future> {
|
||||
stash: FuturesUnordered<LocalBoxFuture<'static, ()>>,
|
||||
state: WithStashState<F>,
|
||||
}
|
||||
impl<F: Future> Future for WithStash<F> {
|
||||
type Output = F::Output;
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
// SAFETY: the only non-Unpin item is Main#main, and it's pinned right back
|
||||
let Self { state, stash } = unsafe { Pin::get_unchecked_mut(self) };
|
||||
if let WithStashState::Main(main) = state {
|
||||
// SAFETY: this comes from the pin we break on the line above
|
||||
let main = unsafe { Pin::new_unchecked(main) };
|
||||
let prev = CURRENT_STASH.with_borrow_mut(|key| key.replace(Vec::new()));
|
||||
let poll = main.poll(cx);
|
||||
let stash_init = CURRENT_STASH
|
||||
.with_borrow_mut(|key| std::mem::replace(key, prev))
|
||||
.expect("We put a Some() in here and CURRENT_STASH demands restoration");
|
||||
stash.extend(stash_init);
|
||||
if let Poll::Ready(o) = poll {
|
||||
// skip this branch from this point onwards
|
||||
*state = WithStashState::Stash { output: Some(o) };
|
||||
}
|
||||
}
|
||||
match state {
|
||||
WithStashState::Main(_) | WithStashState::Stash { output: None, .. } => Poll::Pending,
|
||||
WithStashState::Stash { output: output @ Some(_) } => loop {
|
||||
// if the queue has new elements, poll_next_unpin has to be called in the next
|
||||
// loop to ensure that wake-ups are triggered for them too, and if
|
||||
// poll_next_unpin is called, the queue may get yet more elements synchronously,
|
||||
// hence the loop
|
||||
let prev = CURRENT_STASH.with_borrow_mut(|key| key.replace(Vec::new()));
|
||||
let poll = stash.poll_next_unpin(cx);
|
||||
let stash_new = CURRENT_STASH
|
||||
.with_borrow_mut(|key| std::mem::replace(key, prev))
|
||||
.expect("We put a Some() in here and CURRENT_STASH demands restoration");
|
||||
stash.extend(stash_new);
|
||||
match poll {
|
||||
Poll::Ready(None) if stash.is_empty() => {
|
||||
let output = output.take().expect("Checked in branching");
|
||||
break Poll::Ready(output);
|
||||
},
|
||||
Poll::Pending => {
|
||||
break Poll::Pending;
|
||||
},
|
||||
Poll::Ready(_) => continue,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<F: Future> Drop for WithStash<F> {
|
||||
fn drop(&mut self) {
|
||||
if std::thread::panicking() {
|
||||
eprintln!("Panicking through with_stash may silently drop stashed cleanup work")
|
||||
}
|
||||
for future in std::mem::take(&mut self.stash) {
|
||||
stash(future);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use futures::SinkExt;
|
||||
use futures::channel::mpsc;
|
||||
use futures::future::join;
|
||||
use orchid_async_utils::debug::spin_on;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn run_stashed_future() {
|
||||
let (mut send, recv) = mpsc::channel(0);
|
||||
spin_on(
|
||||
false,
|
||||
join(
|
||||
with_stash(async {
|
||||
let mut send1 = send.clone();
|
||||
stash(async move {
|
||||
send1.send(1).await.unwrap();
|
||||
});
|
||||
let mut send1 = send.clone();
|
||||
stash(async move {
|
||||
let mut send2 = send1.clone();
|
||||
stash(async move {
|
||||
send2.send(2).await.unwrap();
|
||||
});
|
||||
send1.send(3).await.unwrap();
|
||||
stash(async move {
|
||||
send1.send(4).await.unwrap();
|
||||
})
|
||||
});
|
||||
let mut send1 = send.clone();
|
||||
stash(async move {
|
||||
send1.send(5).await.unwrap();
|
||||
});
|
||||
send.send(6).await.unwrap();
|
||||
}),
|
||||
async {
|
||||
let mut results = recv.take(6).collect::<Vec<_>>().await;
|
||||
results.sort();
|
||||
assert_eq!(
|
||||
&results,
|
||||
&[1, 2, 3, 4, 5, 6],
|
||||
"all variations completed in unspecified order"
|
||||
);
|
||||
},
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user