cancellation implemented and first test passing

This commit is contained in:
2026-04-23 12:09:39 +00:00
parent 7f8c247d97
commit 0bc7097c88

View File

@@ -9,7 +9,7 @@ use bound::Bound;
use derive_destructure::destructure;
use futures::channel::mpsc::{self, Receiver, Sender, channel};
use futures::channel::oneshot;
use futures::future::LocalBoxFuture;
use futures::future::{LocalBoxFuture, join};
use futures::lock::{Mutex, MutexGuard};
use futures::{
AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, Stream, StreamExt, stream_select,
@@ -186,8 +186,11 @@ pub trait ClientExt: Client {
(common.take()).expect("If it was still borrowed in fut, it was not yet unset");
common.await.expect("IO error on stash").send()
};
req_wait.canceller.cancel().await;
let mut rep = req_wait.future.await.expect("IO error on stash");
let (mut rep, _) =
join(async { req_wait.future.await.expect("IO error on stash") }, async {
req_wait.canceller.cancel().await
})
.await;
let Some(reader) = rep.reader() else { return };
T::Response::decode(reader).await.expect("IO error on stash");
rep.finish().await;
@@ -419,6 +422,7 @@ impl CancelNotifier for IoReqCanceller {
Box::pin(async move {
let mut o = o.lock().await;
let _ = cancel_id.encode(o.as_mut()).await;
let _ = o.flush().await;
cancel_signal_drop_g.defuse();
})
}
@@ -441,28 +445,33 @@ impl ReqWriter for IoReqWriter {
mem::drop(w);
let reply_record = reply.await.expect("Client dropped before reply received");
drop_g.defuse();
Ok(Box::new(IoRepReader {
reply_record,
drop_g: assert_no_drop("Reply reader dropped without finishing"),
Ok(Box::new(match reply_record {
ReplyRecord::Cancelled => IoRepReader::Stub,
ReplyRecord::Ready(read) =>
IoRepReader::Active(read, assert_no_drop("Reply reader dropped without finishing")),
}) as Box<dyn RepReader>)
};
ReqWait { future: Box::pin(future), canceller: Box::new(canceller) }
}
}
struct IoRepReader {
reply_record: ReplyRecord,
drop_g: PanicOnDrop,
enum IoRepReader {
Active(IoGuard<dyn AsyncRead>, PanicOnDrop),
Stub,
}
impl RepReader for IoRepReader {
fn reader(&mut self) -> Option<Pin<&mut dyn AsyncRead>> {
match &mut self.reply_record {
ReplyRecord::Cancelled => None,
ReplyRecord::Ready(guard) => Some(guard.as_mut()),
match self {
Self::Stub => None,
Self::Active(guard, _) => Some(guard.as_mut()),
}
}
fn finish(self: Box<Self>) -> LocalBoxFuture<'static, ()> {
Box::pin(async { self.drop_g.defuse() })
Box::pin(async {
if let Self::Active(_, g) = *self {
g.defuse()
}
})
}
}
@@ -606,14 +615,22 @@ impl IoCommServer {
// then the reply was already sent
None => continue,
};
let (o, running_reqs) = (o.clone(), &running_requests);
task_pool
.spawn(async move {
// if the request starts writing back before our abort arrives, we only
// get this mutex once it's done
let mut write = o.lock().await;
// if the request is still in the store, the write didn't begin
let Some(_) = running_requests.borrow_mut().remove(&id) else { continue };
let Some(_) = running_reqs.borrow_mut().remove(&id) else { return Ok(()) };
id_bytes[0] = 0x03;
let cancel_code = u64::from_be_bytes(id_bytes);
cancel_code.encode(write.as_mut()).await?;
write.flush().await?;
Ok(())
})
.await
.unwrap();
},
// stub reply for cancelled request
0x03 => {
@@ -645,10 +662,9 @@ mod test {
use futures::channel::mpsc;
use futures::{FutureExt, SinkExt, StreamExt, join, select};
use never::Never;
use orchid_api_derive::{Coding, Hierarchy};
use orchid_api_traits::Request;
use orchid_async_utils::debug::spin_on;
use orchid_async_utils::debug::{spin_on, with_label};
use unsync_pipe::pipe;
use crate::comm::{ClientExt, MsgReaderExt, ReqReaderExt, io_comm};
@@ -747,7 +763,7 @@ mod test {
let reply_context = RefCell::new(Some(reply_context));
let (exit, onexit) = futures::channel::oneshot::channel::<()>();
join!(
async move {
with_label("reply", async move {
reply_server
.listen(
async |hand| {
@@ -765,8 +781,8 @@ mod test {
.unwrap();
exit.send(()).unwrap();
let _client = reply_client;
},
async move {
}),
with_label("client", async move {
req_server
.listen(
async |_| panic!("Only the other server expected notifs"),
@@ -775,7 +791,7 @@ mod test {
.await
.unwrap();
let _ctx = req_context;
},
}),
async move {
req_client.request(DummyRequest(0)).await.unwrap();
req_client.notify(TestNotif(0)).await.unwrap();
@@ -794,35 +810,41 @@ mod test {
let (_, srv_ctx, srv) = io_comm(Box::pin(in2), Box::pin(out2));
let (client, client_ctx, client_srv) = io_comm(Box::pin(in1), Box::pin(out1));
join!(
async {
with_label("server", async {
srv
.listen(
async |_| panic!("No notifs expected"),
async |mut req| {
let _ = req.read_req::<DummyRequest>().await?;
let _ = req.finish().await;
wait_in.clone().send(()).await.unwrap();
// TODO: verify cancellation
futures::future::pending::<Never>().await;
unreachable!("request should be cancelled before resume is triggered")
// This will never return, so if the cancellation does not work, it would block
// the loop
futures::future::pending().await
},
)
.await
.unwrap()
},
async {
.unwrap();
}),
with_label("client", async {
client_srv
.listen(
async |_| panic!("Not expecting ingress notif"),
async |_| panic!("Not expecting ingress req"),
)
.await
.unwrap()
},
.unwrap();
}),
with_stash(async {
with_stash(async {
select! {
_ = client.request(DummyRequest(5)).fuse() => panic!("This one should not run"),
_ = client.request(DummyRequest(5)).fuse() => {
panic!("This one should not run")
},
rep = wait_out.next() => rep.expect("something?"),
};
}
})
.await;
srv_ctx.exit().await.unwrap();
client_ctx.exit().await.unwrap();
})