Pending a correct test of request cancellation
This commit is contained in:
87
orchid-async-utils/src/cancel_cleanup.rs
Normal file
87
orchid-async-utils/src/cancel_cleanup.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// Future returned by [cancel_cleanup]
|
||||
pub struct CancelCleanup<Fut: Future + Unpin, Fun: FnOnce(Fut)> {
|
||||
/// Set to None when Ready
|
||||
fut: Option<Fut>,
|
||||
/// Only set to None in Drop
|
||||
on_drop: Option<Fun>,
|
||||
}
|
||||
impl<Fut: Future + Unpin, Fun: FnOnce(Fut)> Future for CancelCleanup<Fut, Fun> {
|
||||
type Output = Fut::Output;
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let Self { fut, .. } = unsafe { self.get_unchecked_mut() };
|
||||
if let Some(future) = fut {
|
||||
let future = unsafe { Pin::new_unchecked(future) };
|
||||
let poll = future.poll(cx);
|
||||
if poll.is_ready() {
|
||||
*fut = None;
|
||||
}
|
||||
poll
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<Fut: Future + Unpin, Fun: FnOnce(Fut)> Drop for CancelCleanup<Fut, Fun> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(fut) = self.fut.take() {
|
||||
(self.on_drop.take().unwrap())(fut)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a Future's Drop. The callback is only called if the future has not
|
||||
/// yet returned and would be cancelled, and it receives the future as an
|
||||
/// argument
|
||||
pub fn cancel_cleanup<Fut: Future + Unpin, Fun: FnOnce(Fut)>(
|
||||
fut: Fut,
|
||||
on_drop: Fun,
|
||||
) -> CancelCleanup<Fut, Fun> {
|
||||
CancelCleanup { fut: Some(fut), on_drop: Some(on_drop) }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::pin::pin;
|
||||
|
||||
use futures::channel::mpsc;
|
||||
use futures::future::join;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
|
||||
use super::*;
|
||||
use crate::debug::spin_on;
|
||||
|
||||
#[test]
|
||||
fn called_on_drop() {
|
||||
let mut called = false;
|
||||
cancel_cleanup(pin!(async {}), |_| called = true);
|
||||
assert!(called, "cleanup was called when the future was dropped");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_called_if_finished() {
|
||||
spin_on(false, async {
|
||||
let (mut req_in, mut req_out) = mpsc::channel(0);
|
||||
let (mut rep_in, mut rep_out) = mpsc::channel(0);
|
||||
join(
|
||||
async {
|
||||
req_out.next().await.unwrap();
|
||||
rep_in.send(()).await.unwrap();
|
||||
},
|
||||
async {
|
||||
cancel_cleanup(
|
||||
pin!(async {
|
||||
req_in.send(()).await.unwrap();
|
||||
rep_out.next().await.unwrap();
|
||||
}),
|
||||
|_| panic!("Callback called on drop even though the future was finished"),
|
||||
)
|
||||
.await
|
||||
},
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -143,9 +143,17 @@ pub fn eprint_stream_events<'a, S: Stream + 'a>(
|
||||
)
|
||||
}
|
||||
|
||||
struct SpinWaker(AtomicBool);
|
||||
struct SpinWaker {
|
||||
repeat: AtomicBool,
|
||||
loud: bool,
|
||||
}
|
||||
impl Wake for SpinWaker {
|
||||
fn wake(self: Arc<Self>) { self.0.store(true, Ordering::Relaxed); }
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.repeat.store(true, Ordering::SeqCst);
|
||||
if self.loud {
|
||||
eprintln!("Triggered repeat for spin_on")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A dumb executor that keeps synchronously re-running the future as long as it
|
||||
@@ -155,15 +163,15 @@ impl Wake for SpinWaker {
|
||||
/// # Panics
|
||||
///
|
||||
/// If the future doesn't wake itself and doesn't settle.
|
||||
pub fn spin_on<Fut: Future>(f: Fut) -> Fut::Output {
|
||||
let repeat = Arc::new(SpinWaker(AtomicBool::new(false)));
|
||||
pub fn spin_on<Fut: Future>(loud: bool, f: Fut) -> Fut::Output {
|
||||
let spin_waker = Arc::new(SpinWaker { repeat: AtomicBool::new(false), loud });
|
||||
let mut f = pin!(f);
|
||||
let waker = repeat.clone().into();
|
||||
let waker = spin_waker.clone().into();
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
loop {
|
||||
match f.as_mut().poll(&mut cx) {
|
||||
Poll::Ready(t) => break t,
|
||||
Poll::Pending if repeat.0.swap(false, Ordering::Relaxed) => (),
|
||||
Poll::Pending if spin_waker.repeat.swap(false, Ordering::SeqCst) => (),
|
||||
Poll::Pending => panic!("The future did not exit and did not call its waker."),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
pub mod debug;
|
||||
mod cancel_cleanup;
|
||||
pub use cancel_cleanup::*;
|
||||
mod localset;
|
||||
pub use localset::*;
|
||||
mod task_future;
|
||||
pub use task_future::*;
|
||||
pub use task_future::*;
|
||||
@@ -1,21 +1,35 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::pin::Pin;
|
||||
use std::task::Poll;
|
||||
|
||||
use futures::StreamExt;
|
||||
use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded};
|
||||
use futures::channel::mpsc::{SendError, UnboundedReceiver, UnboundedSender, unbounded};
|
||||
use futures::future::LocalBoxFuture;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
|
||||
pub struct LocalSet<'a, E> {
|
||||
receiver: UnboundedReceiver<LocalBoxFuture<'a, Result<(), E>>>,
|
||||
pending: VecDeque<LocalBoxFuture<'a, Result<(), E>>>,
|
||||
pub struct LocalSetController<'a, E> {
|
||||
sender: UnboundedSender<LocalBoxFuture<'a, Result<(), E>>>,
|
||||
}
|
||||
impl<'a, E> LocalSet<'a, E> {
|
||||
pub fn new() -> (UnboundedSender<LocalBoxFuture<'a, Result<(), E>>>, Self) {
|
||||
let (sender, receiver) = unbounded();
|
||||
(sender, Self { receiver, pending: VecDeque::new() })
|
||||
impl<'a, E> LocalSetController<'a, E> {
|
||||
pub async fn spawn<F: Future<Output = Result<(), E>> + 'a>(
|
||||
&mut self,
|
||||
fut: F,
|
||||
) -> Result<(), SendError> {
|
||||
self.sender.send(Box::pin(fut)).await
|
||||
}
|
||||
}
|
||||
|
||||
pub fn local_set<'a, E: 'a>()
|
||||
-> (LocalSetController<'a, E>, impl Future<Output = Result<(), E>> + 'a) {
|
||||
let (sender, receiver) = unbounded();
|
||||
let controller = LocalSetController { sender };
|
||||
let set = LocalSet { receiver, pending: FuturesUnordered::new() };
|
||||
(controller, set)
|
||||
}
|
||||
|
||||
struct LocalSet<'a, E> {
|
||||
receiver: UnboundedReceiver<LocalBoxFuture<'a, Result<(), E>>>,
|
||||
pending: FuturesUnordered<LocalBoxFuture<'a, Result<(), E>>>,
|
||||
}
|
||||
impl<E> Future for LocalSet<'_, E> {
|
||||
type Output = Result<(), E>;
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
|
||||
@@ -23,7 +37,7 @@ impl<E> Future for LocalSet<'_, E> {
|
||||
let mut any_pending = false;
|
||||
loop {
|
||||
match this.receiver.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(fut)) => this.pending.push_back(fut),
|
||||
Poll::Ready(Some(fut)) => this.pending.push(fut),
|
||||
Poll::Ready(None) => break,
|
||||
Poll::Pending => {
|
||||
any_pending = true;
|
||||
@@ -31,15 +45,14 @@ impl<E> Future for LocalSet<'_, E> {
|
||||
},
|
||||
}
|
||||
}
|
||||
let count = this.pending.len();
|
||||
for _ in 0..count {
|
||||
let mut req = this.pending.pop_front().unwrap();
|
||||
match req.as_mut().poll(cx) {
|
||||
Poll::Ready(Ok(())) => (),
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
loop {
|
||||
match this.pending.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
|
||||
Poll::Ready(Some(Ok(()))) => continue,
|
||||
Poll::Ready(None) => break,
|
||||
Poll::Pending => {
|
||||
any_pending = true;
|
||||
this.pending.push_back(req)
|
||||
break;
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use std::any::Any;
|
||||
use std::cell::RefCell;
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::Pin;
|
||||
use std::pin::{Pin, pin};
|
||||
use std::rc::Rc;
|
||||
use std::task::{Context, Poll, Waker};
|
||||
|
||||
use futures::FutureExt;
|
||||
use futures::channel::oneshot::{self, Canceled};
|
||||
use futures::future::{FusedFuture, LocalBoxFuture};
|
||||
|
||||
struct State {
|
||||
@@ -43,50 +44,56 @@ impl Future for Pollable {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct JoinError;
|
||||
|
||||
/// An object that can be used to inspect the state of the task
|
||||
pub struct Handle<T: 'static>(Rc<RefCell<State>>, PhantomData<T>);
|
||||
pub struct Handle<T: 'static> {
|
||||
send_abort: RefCell<Option<oneshot::Sender<()>>>,
|
||||
ready: Rc<RefCell<bool>>,
|
||||
recv_output: RefCell<oneshot::Receiver<T>>,
|
||||
}
|
||||
impl<T: 'static> Handle<T> {
|
||||
/// Immediately stop working on this task, and return the result if it has
|
||||
/// already finished
|
||||
pub fn abort(&self) -> Option<T> {
|
||||
let mut g = self.0.borrow_mut();
|
||||
g.work.take();
|
||||
match g.result.take() {
|
||||
Some(val) => Some(*val.downcast().expect("Mismatch between type of future and handle")),
|
||||
None => {
|
||||
g.waker.wake_by_ref();
|
||||
None
|
||||
},
|
||||
if let Some(abort) = self.send_abort.take() {
|
||||
let _ = abort.send(());
|
||||
}
|
||||
self.recv_output.borrow_mut().try_recv().ok().flatten()
|
||||
}
|
||||
/// Determine if there's any more work to do on this task
|
||||
pub fn is_finished(&self) -> bool {
|
||||
let g = self.0.borrow();
|
||||
g.result.is_some() || g.work.is_none()
|
||||
}
|
||||
pub fn is_finished(&self) -> bool { *self.ready.borrow() }
|
||||
/// "finish" the freestanding task, and return the future instead
|
||||
pub async fn join(self) -> T {
|
||||
let work = {
|
||||
let mut g = self.0.borrow_mut();
|
||||
if let Some(val) = g.result.take() {
|
||||
return *val.downcast().expect("Mistmatch between type of future and handle");
|
||||
}
|
||||
g.waker.wake_by_ref();
|
||||
g.work.take().expect("Attempted to join task that was already aborted")
|
||||
};
|
||||
*work.await.downcast().expect("Mismatch between type of future and handle")
|
||||
pub async fn join(self) -> Result<T, JoinError> {
|
||||
self.recv_output.into_inner().await.map_err(|Canceled| JoinError)
|
||||
}
|
||||
}
|
||||
|
||||
/// Split a future into an object that can be polled and one that returns
|
||||
/// information on its progress and its result. The first one can be passed to
|
||||
/// an executor or localset, the second can be used to manage it
|
||||
pub fn to_task<F: Future<Output: 'static> + 'static>(f: F) -> (Pollable, Handle<F::Output>) {
|
||||
let dyn_future = Box::pin(async { Box::new(f.await) as Box<dyn Any> });
|
||||
let state = Rc::new(RefCell::new(State {
|
||||
result: None,
|
||||
work: Some(dyn_future),
|
||||
waker: Waker::noop().clone(),
|
||||
}));
|
||||
(Pollable(state.clone()), Handle(state, PhantomData))
|
||||
pub fn to_task<'a, F: Future<Output: 'a> + 'a>(
|
||||
f: F,
|
||||
) -> (impl Future<Output = ()> + 'a, Handle<F::Output>) {
|
||||
let (send_abort, mut on_abort) = oneshot::channel();
|
||||
let (send_output, on_output) = oneshot::channel();
|
||||
let ready = Rc::new(RefCell::new(false));
|
||||
let ready2 = ready.clone();
|
||||
let fut = async move {
|
||||
let mut fut = pin!(f.fuse());
|
||||
let output = futures::select_biased! {
|
||||
res = on_abort => match res {
|
||||
Ok(()) => return,
|
||||
Err(_) => fut.await,
|
||||
},
|
||||
output = fut => output,
|
||||
};
|
||||
ready2.replace(true);
|
||||
let _: Result<_, _> = send_output.send(output);
|
||||
};
|
||||
(fut, Handle {
|
||||
ready,
|
||||
recv_output: RefCell::new(on_output),
|
||||
send_abort: RefCell::new(Some(send_abort)),
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user