diff --git a/orchid-base/src/comm.rs b/orchid-base/src/comm.rs index 81e5139..f840047 100644 --- a/orchid-base/src/comm.rs +++ b/orchid-base/src/comm.rs @@ -9,7 +9,7 @@ use bound::Bound; use derive_destructure::destructure; use futures::channel::mpsc::{self, Receiver, Sender, channel}; use futures::channel::oneshot; -use futures::future::LocalBoxFuture; +use futures::future::{LocalBoxFuture, join}; use futures::lock::{Mutex, MutexGuard}; use futures::{ AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, Stream, StreamExt, stream_select, @@ -186,8 +186,11 @@ pub trait ClientExt: Client { (common.take()).expect("If it was still borrowed in fut, it was not yet unset"); common.await.expect("IO error on stash").send() }; - req_wait.canceller.cancel().await; - let mut rep = req_wait.future.await.expect("IO error on stash"); + let (mut rep, _) = + join(async { req_wait.future.await.expect("IO error on stash") }, async { + req_wait.canceller.cancel().await + }) + .await; let Some(reader) = rep.reader() else { return }; T::Response::decode(reader).await.expect("IO error on stash"); rep.finish().await; @@ -419,6 +422,7 @@ impl CancelNotifier for IoReqCanceller { Box::pin(async move { let mut o = o.lock().await; let _ = cancel_id.encode(o.as_mut()).await; + let _ = o.flush().await; cancel_signal_drop_g.defuse(); }) } @@ -441,28 +445,33 @@ impl ReqWriter for IoReqWriter { mem::drop(w); let reply_record = reply.await.expect("Client dropped before reply received"); drop_g.defuse(); - Ok(Box::new(IoRepReader { - reply_record, - drop_g: assert_no_drop("Reply reader dropped without finishing"), + Ok(Box::new(match reply_record { + ReplyRecord::Cancelled => IoRepReader::Stub, + ReplyRecord::Ready(read) => + IoRepReader::Active(read, assert_no_drop("Reply reader dropped without finishing")), }) as Box) }; ReqWait { future: Box::pin(future), canceller: Box::new(canceller) } } } -struct IoRepReader { - reply_record: ReplyRecord, - drop_g: PanicOnDrop, +enum IoRepReader { + Active(IoGuard, PanicOnDrop), + Stub, } impl RepReader for IoRepReader { fn reader(&mut self) -> Option> { - match &mut self.reply_record { - ReplyRecord::Cancelled => None, - ReplyRecord::Ready(guard) => Some(guard.as_mut()), + match self { + Self::Stub => None, + Self::Active(guard, _) => Some(guard.as_mut()), } } fn finish(self: Box) -> LocalBoxFuture<'static, ()> { - Box::pin(async { self.drop_g.defuse() }) + Box::pin(async { + if let Self::Active(_, g) = *self { + g.defuse() + } + }) } } @@ -606,14 +615,22 @@ impl IoCommServer { // then the reply was already sent None => continue, }; - // if the request starts writing back before our abort arrives, we only - // get this mutex once it's done - let mut write = o.lock().await; - // if the request is still in the store, the write didn't begin - let Some(_) = running_requests.borrow_mut().remove(&id) else { continue }; - id_bytes[0] = 0x03; - let cancel_code = u64::from_be_bytes(id_bytes); - cancel_code.encode(write.as_mut()).await?; + let (o, running_reqs) = (o.clone(), &running_requests); + task_pool + .spawn(async move { + // if the request starts writing back before our abort arrives, we only + // get this mutex once it's done + let mut write = o.lock().await; + // if the request is still in the store, the write didn't begin + let Some(_) = running_reqs.borrow_mut().remove(&id) else { return Ok(()) }; + id_bytes[0] = 0x03; + let cancel_code = u64::from_be_bytes(id_bytes); + cancel_code.encode(write.as_mut()).await?; + write.flush().await?; + Ok(()) + }) + .await + .unwrap(); }, // stub reply for cancelled request 0x03 => { @@ -645,10 +662,9 @@ mod test { use futures::channel::mpsc; use futures::{FutureExt, SinkExt, StreamExt, join, select}; - use never::Never; use orchid_api_derive::{Coding, Hierarchy}; use orchid_api_traits::Request; - use orchid_async_utils::debug::spin_on; + use orchid_async_utils::debug::{spin_on, with_label}; use unsync_pipe::pipe; use crate::comm::{ClientExt, MsgReaderExt, ReqReaderExt, io_comm}; @@ -747,7 +763,7 @@ mod test { let reply_context = RefCell::new(Some(reply_context)); let (exit, onexit) = futures::channel::oneshot::channel::<()>(); join!( - async move { + with_label("reply", async move { reply_server .listen( async |hand| { @@ -765,8 +781,8 @@ mod test { .unwrap(); exit.send(()).unwrap(); let _client = reply_client; - }, - async move { + }), + with_label("client", async move { req_server .listen( async |_| panic!("Only the other server expected notifs"), @@ -775,7 +791,7 @@ mod test { .await .unwrap(); let _ctx = req_context; - }, + }), async move { req_client.request(DummyRequest(0)).await.unwrap(); req_client.notify(TestNotif(0)).await.unwrap(); @@ -794,35 +810,41 @@ mod test { let (_, srv_ctx, srv) = io_comm(Box::pin(in2), Box::pin(out2)); let (client, client_ctx, client_srv) = io_comm(Box::pin(in1), Box::pin(out1)); join!( - async { + with_label("server", async { srv .listen( async |_| panic!("No notifs expected"), async |mut req| { let _ = req.read_req::().await?; + let _ = req.finish().await; wait_in.clone().send(()).await.unwrap(); - // TODO: verify cancellation - futures::future::pending::().await; - unreachable!("request should be cancelled before resume is triggered") + // This will never return, so if the cancellation does not work, it would block + // the loop + futures::future::pending().await }, ) .await - .unwrap() - }, - async { + .unwrap(); + }), + with_label("client", async { client_srv .listen( async |_| panic!("Not expecting ingress notif"), async |_| panic!("Not expecting ingress req"), ) .await - .unwrap() - }, + .unwrap(); + }), with_stash(async { - select! { - _ = client.request(DummyRequest(5)).fuse() => panic!("This one should not run"), - rep = wait_out.next() => rep.expect("something?"), - }; + with_stash(async { + select! { + _ = client.request(DummyRequest(5)).fuse() => { + panic!("This one should not run") + }, + rep = wait_out.next() => rep.expect("something?"), + } + }) + .await; srv_ctx.exit().await.unwrap(); client_ctx.exit().await.unwrap(); })