cancellation implemented and first test passing
This commit is contained in:
@@ -9,7 +9,7 @@ use bound::Bound;
|
|||||||
use derive_destructure::destructure;
|
use derive_destructure::destructure;
|
||||||
use futures::channel::mpsc::{self, Receiver, Sender, channel};
|
use futures::channel::mpsc::{self, Receiver, Sender, channel};
|
||||||
use futures::channel::oneshot;
|
use futures::channel::oneshot;
|
||||||
use futures::future::LocalBoxFuture;
|
use futures::future::{LocalBoxFuture, join};
|
||||||
use futures::lock::{Mutex, MutexGuard};
|
use futures::lock::{Mutex, MutexGuard};
|
||||||
use futures::{
|
use futures::{
|
||||||
AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, Stream, StreamExt, stream_select,
|
AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, Stream, StreamExt, stream_select,
|
||||||
@@ -186,8 +186,11 @@ pub trait ClientExt: Client {
|
|||||||
(common.take()).expect("If it was still borrowed in fut, it was not yet unset");
|
(common.take()).expect("If it was still borrowed in fut, it was not yet unset");
|
||||||
common.await.expect("IO error on stash").send()
|
common.await.expect("IO error on stash").send()
|
||||||
};
|
};
|
||||||
req_wait.canceller.cancel().await;
|
let (mut rep, _) =
|
||||||
let mut rep = req_wait.future.await.expect("IO error on stash");
|
join(async { req_wait.future.await.expect("IO error on stash") }, async {
|
||||||
|
req_wait.canceller.cancel().await
|
||||||
|
})
|
||||||
|
.await;
|
||||||
let Some(reader) = rep.reader() else { return };
|
let Some(reader) = rep.reader() else { return };
|
||||||
T::Response::decode(reader).await.expect("IO error on stash");
|
T::Response::decode(reader).await.expect("IO error on stash");
|
||||||
rep.finish().await;
|
rep.finish().await;
|
||||||
@@ -419,6 +422,7 @@ impl CancelNotifier for IoReqCanceller {
|
|||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let mut o = o.lock().await;
|
let mut o = o.lock().await;
|
||||||
let _ = cancel_id.encode(o.as_mut()).await;
|
let _ = cancel_id.encode(o.as_mut()).await;
|
||||||
|
let _ = o.flush().await;
|
||||||
cancel_signal_drop_g.defuse();
|
cancel_signal_drop_g.defuse();
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -441,28 +445,33 @@ impl ReqWriter for IoReqWriter {
|
|||||||
mem::drop(w);
|
mem::drop(w);
|
||||||
let reply_record = reply.await.expect("Client dropped before reply received");
|
let reply_record = reply.await.expect("Client dropped before reply received");
|
||||||
drop_g.defuse();
|
drop_g.defuse();
|
||||||
Ok(Box::new(IoRepReader {
|
Ok(Box::new(match reply_record {
|
||||||
reply_record,
|
ReplyRecord::Cancelled => IoRepReader::Stub,
|
||||||
drop_g: assert_no_drop("Reply reader dropped without finishing"),
|
ReplyRecord::Ready(read) =>
|
||||||
|
IoRepReader::Active(read, assert_no_drop("Reply reader dropped without finishing")),
|
||||||
}) as Box<dyn RepReader>)
|
}) as Box<dyn RepReader>)
|
||||||
};
|
};
|
||||||
ReqWait { future: Box::pin(future), canceller: Box::new(canceller) }
|
ReqWait { future: Box::pin(future), canceller: Box::new(canceller) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct IoRepReader {
|
enum IoRepReader {
|
||||||
reply_record: ReplyRecord,
|
Active(IoGuard<dyn AsyncRead>, PanicOnDrop),
|
||||||
drop_g: PanicOnDrop,
|
Stub,
|
||||||
}
|
}
|
||||||
impl RepReader for IoRepReader {
|
impl RepReader for IoRepReader {
|
||||||
fn reader(&mut self) -> Option<Pin<&mut dyn AsyncRead>> {
|
fn reader(&mut self) -> Option<Pin<&mut dyn AsyncRead>> {
|
||||||
match &mut self.reply_record {
|
match self {
|
||||||
ReplyRecord::Cancelled => None,
|
Self::Stub => None,
|
||||||
ReplyRecord::Ready(guard) => Some(guard.as_mut()),
|
Self::Active(guard, _) => Some(guard.as_mut()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn finish(self: Box<Self>) -> LocalBoxFuture<'static, ()> {
|
fn finish(self: Box<Self>) -> LocalBoxFuture<'static, ()> {
|
||||||
Box::pin(async { self.drop_g.defuse() })
|
Box::pin(async {
|
||||||
|
if let Self::Active(_, g) = *self {
|
||||||
|
g.defuse()
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -606,14 +615,22 @@ impl IoCommServer {
|
|||||||
// then the reply was already sent
|
// then the reply was already sent
|
||||||
None => continue,
|
None => continue,
|
||||||
};
|
};
|
||||||
|
let (o, running_reqs) = (o.clone(), &running_requests);
|
||||||
|
task_pool
|
||||||
|
.spawn(async move {
|
||||||
// if the request starts writing back before our abort arrives, we only
|
// if the request starts writing back before our abort arrives, we only
|
||||||
// get this mutex once it's done
|
// get this mutex once it's done
|
||||||
let mut write = o.lock().await;
|
let mut write = o.lock().await;
|
||||||
// if the request is still in the store, the write didn't begin
|
// if the request is still in the store, the write didn't begin
|
||||||
let Some(_) = running_requests.borrow_mut().remove(&id) else { continue };
|
let Some(_) = running_reqs.borrow_mut().remove(&id) else { return Ok(()) };
|
||||||
id_bytes[0] = 0x03;
|
id_bytes[0] = 0x03;
|
||||||
let cancel_code = u64::from_be_bytes(id_bytes);
|
let cancel_code = u64::from_be_bytes(id_bytes);
|
||||||
cancel_code.encode(write.as_mut()).await?;
|
cancel_code.encode(write.as_mut()).await?;
|
||||||
|
write.flush().await?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
},
|
},
|
||||||
// stub reply for cancelled request
|
// stub reply for cancelled request
|
||||||
0x03 => {
|
0x03 => {
|
||||||
@@ -645,10 +662,9 @@ mod test {
|
|||||||
|
|
||||||
use futures::channel::mpsc;
|
use futures::channel::mpsc;
|
||||||
use futures::{FutureExt, SinkExt, StreamExt, join, select};
|
use futures::{FutureExt, SinkExt, StreamExt, join, select};
|
||||||
use never::Never;
|
|
||||||
use orchid_api_derive::{Coding, Hierarchy};
|
use orchid_api_derive::{Coding, Hierarchy};
|
||||||
use orchid_api_traits::Request;
|
use orchid_api_traits::Request;
|
||||||
use orchid_async_utils::debug::spin_on;
|
use orchid_async_utils::debug::{spin_on, with_label};
|
||||||
use unsync_pipe::pipe;
|
use unsync_pipe::pipe;
|
||||||
|
|
||||||
use crate::comm::{ClientExt, MsgReaderExt, ReqReaderExt, io_comm};
|
use crate::comm::{ClientExt, MsgReaderExt, ReqReaderExt, io_comm};
|
||||||
@@ -747,7 +763,7 @@ mod test {
|
|||||||
let reply_context = RefCell::new(Some(reply_context));
|
let reply_context = RefCell::new(Some(reply_context));
|
||||||
let (exit, onexit) = futures::channel::oneshot::channel::<()>();
|
let (exit, onexit) = futures::channel::oneshot::channel::<()>();
|
||||||
join!(
|
join!(
|
||||||
async move {
|
with_label("reply", async move {
|
||||||
reply_server
|
reply_server
|
||||||
.listen(
|
.listen(
|
||||||
async |hand| {
|
async |hand| {
|
||||||
@@ -765,8 +781,8 @@ mod test {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
exit.send(()).unwrap();
|
exit.send(()).unwrap();
|
||||||
let _client = reply_client;
|
let _client = reply_client;
|
||||||
},
|
}),
|
||||||
async move {
|
with_label("client", async move {
|
||||||
req_server
|
req_server
|
||||||
.listen(
|
.listen(
|
||||||
async |_| panic!("Only the other server expected notifs"),
|
async |_| panic!("Only the other server expected notifs"),
|
||||||
@@ -775,7 +791,7 @@ mod test {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let _ctx = req_context;
|
let _ctx = req_context;
|
||||||
},
|
}),
|
||||||
async move {
|
async move {
|
||||||
req_client.request(DummyRequest(0)).await.unwrap();
|
req_client.request(DummyRequest(0)).await.unwrap();
|
||||||
req_client.notify(TestNotif(0)).await.unwrap();
|
req_client.notify(TestNotif(0)).await.unwrap();
|
||||||
@@ -794,35 +810,41 @@ mod test {
|
|||||||
let (_, srv_ctx, srv) = io_comm(Box::pin(in2), Box::pin(out2));
|
let (_, srv_ctx, srv) = io_comm(Box::pin(in2), Box::pin(out2));
|
||||||
let (client, client_ctx, client_srv) = io_comm(Box::pin(in1), Box::pin(out1));
|
let (client, client_ctx, client_srv) = io_comm(Box::pin(in1), Box::pin(out1));
|
||||||
join!(
|
join!(
|
||||||
async {
|
with_label("server", async {
|
||||||
srv
|
srv
|
||||||
.listen(
|
.listen(
|
||||||
async |_| panic!("No notifs expected"),
|
async |_| panic!("No notifs expected"),
|
||||||
async |mut req| {
|
async |mut req| {
|
||||||
let _ = req.read_req::<DummyRequest>().await?;
|
let _ = req.read_req::<DummyRequest>().await?;
|
||||||
|
let _ = req.finish().await;
|
||||||
wait_in.clone().send(()).await.unwrap();
|
wait_in.clone().send(()).await.unwrap();
|
||||||
// TODO: verify cancellation
|
// This will never return, so if the cancellation does not work, it would block
|
||||||
futures::future::pending::<Never>().await;
|
// the loop
|
||||||
unreachable!("request should be cancelled before resume is triggered")
|
futures::future::pending().await
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap();
|
||||||
},
|
}),
|
||||||
async {
|
with_label("client", async {
|
||||||
client_srv
|
client_srv
|
||||||
.listen(
|
.listen(
|
||||||
async |_| panic!("Not expecting ingress notif"),
|
async |_| panic!("Not expecting ingress notif"),
|
||||||
async |_| panic!("Not expecting ingress req"),
|
async |_| panic!("Not expecting ingress req"),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap();
|
||||||
},
|
}),
|
||||||
|
with_stash(async {
|
||||||
with_stash(async {
|
with_stash(async {
|
||||||
select! {
|
select! {
|
||||||
_ = client.request(DummyRequest(5)).fuse() => panic!("This one should not run"),
|
_ = client.request(DummyRequest(5)).fuse() => {
|
||||||
|
panic!("This one should not run")
|
||||||
|
},
|
||||||
rep = wait_out.next() => rep.expect("something?"),
|
rep = wait_out.next() => rep.expect("something?"),
|
||||||
};
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
srv_ctx.exit().await.unwrap();
|
srv_ctx.exit().await.unwrap();
|
||||||
client_ctx.exit().await.unwrap();
|
client_ctx.exit().await.unwrap();
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user