cancellation implemented and first test passing

This commit is contained in:
2026-04-23 12:09:39 +00:00
parent 7f8c247d97
commit 0bc7097c88

View File

@@ -9,7 +9,7 @@ use bound::Bound;
use derive_destructure::destructure; use derive_destructure::destructure;
use futures::channel::mpsc::{self, Receiver, Sender, channel}; use futures::channel::mpsc::{self, Receiver, Sender, channel};
use futures::channel::oneshot; use futures::channel::oneshot;
use futures::future::LocalBoxFuture; use futures::future::{LocalBoxFuture, join};
use futures::lock::{Mutex, MutexGuard}; use futures::lock::{Mutex, MutexGuard};
use futures::{ use futures::{
AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, Stream, StreamExt, stream_select, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, Stream, StreamExt, stream_select,
@@ -186,8 +186,11 @@ pub trait ClientExt: Client {
(common.take()).expect("If it was still borrowed in fut, it was not yet unset"); (common.take()).expect("If it was still borrowed in fut, it was not yet unset");
common.await.expect("IO error on stash").send() common.await.expect("IO error on stash").send()
}; };
req_wait.canceller.cancel().await; let (mut rep, _) =
let mut rep = req_wait.future.await.expect("IO error on stash"); join(async { req_wait.future.await.expect("IO error on stash") }, async {
req_wait.canceller.cancel().await
})
.await;
let Some(reader) = rep.reader() else { return }; let Some(reader) = rep.reader() else { return };
T::Response::decode(reader).await.expect("IO error on stash"); T::Response::decode(reader).await.expect("IO error on stash");
rep.finish().await; rep.finish().await;
@@ -419,6 +422,7 @@ impl CancelNotifier for IoReqCanceller {
Box::pin(async move { Box::pin(async move {
let mut o = o.lock().await; let mut o = o.lock().await;
let _ = cancel_id.encode(o.as_mut()).await; let _ = cancel_id.encode(o.as_mut()).await;
let _ = o.flush().await;
cancel_signal_drop_g.defuse(); cancel_signal_drop_g.defuse();
}) })
} }
@@ -441,28 +445,33 @@ impl ReqWriter for IoReqWriter {
mem::drop(w); mem::drop(w);
let reply_record = reply.await.expect("Client dropped before reply received"); let reply_record = reply.await.expect("Client dropped before reply received");
drop_g.defuse(); drop_g.defuse();
Ok(Box::new(IoRepReader { Ok(Box::new(match reply_record {
reply_record, ReplyRecord::Cancelled => IoRepReader::Stub,
drop_g: assert_no_drop("Reply reader dropped without finishing"), ReplyRecord::Ready(read) =>
IoRepReader::Active(read, assert_no_drop("Reply reader dropped without finishing")),
}) as Box<dyn RepReader>) }) as Box<dyn RepReader>)
}; };
ReqWait { future: Box::pin(future), canceller: Box::new(canceller) } ReqWait { future: Box::pin(future), canceller: Box::new(canceller) }
} }
} }
struct IoRepReader { enum IoRepReader {
reply_record: ReplyRecord, Active(IoGuard<dyn AsyncRead>, PanicOnDrop),
drop_g: PanicOnDrop, Stub,
} }
impl RepReader for IoRepReader { impl RepReader for IoRepReader {
fn reader(&mut self) -> Option<Pin<&mut dyn AsyncRead>> { fn reader(&mut self) -> Option<Pin<&mut dyn AsyncRead>> {
match &mut self.reply_record { match self {
ReplyRecord::Cancelled => None, Self::Stub => None,
ReplyRecord::Ready(guard) => Some(guard.as_mut()), Self::Active(guard, _) => Some(guard.as_mut()),
} }
} }
fn finish(self: Box<Self>) -> LocalBoxFuture<'static, ()> { fn finish(self: Box<Self>) -> LocalBoxFuture<'static, ()> {
Box::pin(async { self.drop_g.defuse() }) Box::pin(async {
if let Self::Active(_, g) = *self {
g.defuse()
}
})
} }
} }
@@ -606,14 +615,22 @@ impl IoCommServer {
// then the reply was already sent // then the reply was already sent
None => continue, None => continue,
}; };
// if the request starts writing back before our abort arrives, we only let (o, running_reqs) = (o.clone(), &running_requests);
// get this mutex once it's done task_pool
let mut write = o.lock().await; .spawn(async move {
// if the request is still in the store, the write didn't begin // if the request starts writing back before our abort arrives, we only
let Some(_) = running_requests.borrow_mut().remove(&id) else { continue }; // get this mutex once it's done
id_bytes[0] = 0x03; let mut write = o.lock().await;
let cancel_code = u64::from_be_bytes(id_bytes); // if the request is still in the store, the write didn't begin
cancel_code.encode(write.as_mut()).await?; let Some(_) = running_reqs.borrow_mut().remove(&id) else { return Ok(()) };
id_bytes[0] = 0x03;
let cancel_code = u64::from_be_bytes(id_bytes);
cancel_code.encode(write.as_mut()).await?;
write.flush().await?;
Ok(())
})
.await
.unwrap();
}, },
// stub reply for cancelled request // stub reply for cancelled request
0x03 => { 0x03 => {
@@ -645,10 +662,9 @@ mod test {
use futures::channel::mpsc; use futures::channel::mpsc;
use futures::{FutureExt, SinkExt, StreamExt, join, select}; use futures::{FutureExt, SinkExt, StreamExt, join, select};
use never::Never;
use orchid_api_derive::{Coding, Hierarchy}; use orchid_api_derive::{Coding, Hierarchy};
use orchid_api_traits::Request; use orchid_api_traits::Request;
use orchid_async_utils::debug::spin_on; use orchid_async_utils::debug::{spin_on, with_label};
use unsync_pipe::pipe; use unsync_pipe::pipe;
use crate::comm::{ClientExt, MsgReaderExt, ReqReaderExt, io_comm}; use crate::comm::{ClientExt, MsgReaderExt, ReqReaderExt, io_comm};
@@ -747,7 +763,7 @@ mod test {
let reply_context = RefCell::new(Some(reply_context)); let reply_context = RefCell::new(Some(reply_context));
let (exit, onexit) = futures::channel::oneshot::channel::<()>(); let (exit, onexit) = futures::channel::oneshot::channel::<()>();
join!( join!(
async move { with_label("reply", async move {
reply_server reply_server
.listen( .listen(
async |hand| { async |hand| {
@@ -765,8 +781,8 @@ mod test {
.unwrap(); .unwrap();
exit.send(()).unwrap(); exit.send(()).unwrap();
let _client = reply_client; let _client = reply_client;
}, }),
async move { with_label("client", async move {
req_server req_server
.listen( .listen(
async |_| panic!("Only the other server expected notifs"), async |_| panic!("Only the other server expected notifs"),
@@ -775,7 +791,7 @@ mod test {
.await .await
.unwrap(); .unwrap();
let _ctx = req_context; let _ctx = req_context;
}, }),
async move { async move {
req_client.request(DummyRequest(0)).await.unwrap(); req_client.request(DummyRequest(0)).await.unwrap();
req_client.notify(TestNotif(0)).await.unwrap(); req_client.notify(TestNotif(0)).await.unwrap();
@@ -794,35 +810,41 @@ mod test {
let (_, srv_ctx, srv) = io_comm(Box::pin(in2), Box::pin(out2)); let (_, srv_ctx, srv) = io_comm(Box::pin(in2), Box::pin(out2));
let (client, client_ctx, client_srv) = io_comm(Box::pin(in1), Box::pin(out1)); let (client, client_ctx, client_srv) = io_comm(Box::pin(in1), Box::pin(out1));
join!( join!(
async { with_label("server", async {
srv srv
.listen( .listen(
async |_| panic!("No notifs expected"), async |_| panic!("No notifs expected"),
async |mut req| { async |mut req| {
let _ = req.read_req::<DummyRequest>().await?; let _ = req.read_req::<DummyRequest>().await?;
let _ = req.finish().await;
wait_in.clone().send(()).await.unwrap(); wait_in.clone().send(()).await.unwrap();
// TODO: verify cancellation // This will never return, so if the cancellation does not work, it would block
futures::future::pending::<Never>().await; // the loop
unreachable!("request should be cancelled before resume is triggered") futures::future::pending().await
}, },
) )
.await .await
.unwrap() .unwrap();
}, }),
async { with_label("client", async {
client_srv client_srv
.listen( .listen(
async |_| panic!("Not expecting ingress notif"), async |_| panic!("Not expecting ingress notif"),
async |_| panic!("Not expecting ingress req"), async |_| panic!("Not expecting ingress req"),
) )
.await .await
.unwrap() .unwrap();
}, }),
with_stash(async { with_stash(async {
select! { with_stash(async {
_ = client.request(DummyRequest(5)).fuse() => panic!("This one should not run"), select! {
rep = wait_out.next() => rep.expect("something?"), _ = client.request(DummyRequest(5)).fuse() => {
}; panic!("This one should not run")
},
rep = wait_out.next() => rep.expect("something?"),
}
})
.await;
srv_ctx.exit().await.unwrap(); srv_ctx.exit().await.unwrap();
client_ctx.exit().await.unwrap(); client_ctx.exit().await.unwrap();
}) })