Pending a correct test of request cancellation

This commit is contained in:
2026-04-22 15:58:34 +00:00
parent 60c96964d9
commit 7f8c247d97
31 changed files with 1059 additions and 499 deletions

View File

@@ -0,0 +1,87 @@
use std::pin::Pin;
use std::task::{Context, Poll};
/// Future returned by [cancel_cleanup]
pub struct CancelCleanup<Fut: Future + Unpin, Fun: FnOnce(Fut)> {
/// Set to None when Ready
fut: Option<Fut>,
/// Only set to None in Drop
on_drop: Option<Fun>,
}
impl<Fut: Future + Unpin, Fun: FnOnce(Fut)> Future for CancelCleanup<Fut, Fun> {
type Output = Fut::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self { fut, .. } = unsafe { self.get_unchecked_mut() };
if let Some(future) = fut {
let future = unsafe { Pin::new_unchecked(future) };
let poll = future.poll(cx);
if poll.is_ready() {
*fut = None;
}
poll
} else {
Poll::Pending
}
}
}
impl<Fut: Future + Unpin, Fun: FnOnce(Fut)> Drop for CancelCleanup<Fut, Fun> {
fn drop(&mut self) {
if let Some(fut) = self.fut.take() {
(self.on_drop.take().unwrap())(fut)
}
}
}
/// Handle a Future's Drop. The callback is only called if the future has not
/// yet returned and would be cancelled, and it receives the future as an
/// argument
pub fn cancel_cleanup<Fut: Future + Unpin, Fun: FnOnce(Fut)>(
fut: Fut,
on_drop: Fun,
) -> CancelCleanup<Fut, Fun> {
CancelCleanup { fut: Some(fut), on_drop: Some(on_drop) }
}
#[cfg(test)]
mod test {
use std::pin::pin;
use futures::channel::mpsc;
use futures::future::join;
use futures::{SinkExt, StreamExt};
use super::*;
use crate::debug::spin_on;
#[test]
fn called_on_drop() {
let mut called = false;
cancel_cleanup(pin!(async {}), |_| called = true);
assert!(called, "cleanup was called when the future was dropped");
}
#[test]
fn not_called_if_finished() {
spin_on(false, async {
let (mut req_in, mut req_out) = mpsc::channel(0);
let (mut rep_in, mut rep_out) = mpsc::channel(0);
join(
async {
req_out.next().await.unwrap();
rep_in.send(()).await.unwrap();
},
async {
cancel_cleanup(
pin!(async {
req_in.send(()).await.unwrap();
rep_out.next().await.unwrap();
}),
|_| panic!("Callback called on drop even though the future was finished"),
)
.await
},
)
.await
});
}
}

View File

@@ -143,9 +143,17 @@ pub fn eprint_stream_events<'a, S: Stream + 'a>(
)
}
struct SpinWaker(AtomicBool);
struct SpinWaker {
repeat: AtomicBool,
loud: bool,
}
impl Wake for SpinWaker {
fn wake(self: Arc<Self>) { self.0.store(true, Ordering::Relaxed); }
fn wake(self: Arc<Self>) {
self.repeat.store(true, Ordering::SeqCst);
if self.loud {
eprintln!("Triggered repeat for spin_on")
}
}
}
/// A dumb executor that keeps synchronously re-running the future as long as it
@@ -155,15 +163,15 @@ impl Wake for SpinWaker {
/// # Panics
///
/// If the future doesn't wake itself and doesn't settle.
pub fn spin_on<Fut: Future>(f: Fut) -> Fut::Output {
let repeat = Arc::new(SpinWaker(AtomicBool::new(false)));
pub fn spin_on<Fut: Future>(loud: bool, f: Fut) -> Fut::Output {
let spin_waker = Arc::new(SpinWaker { repeat: AtomicBool::new(false), loud });
let mut f = pin!(f);
let waker = repeat.clone().into();
let waker = spin_waker.clone().into();
let mut cx = Context::from_waker(&waker);
loop {
match f.as_mut().poll(&mut cx) {
Poll::Ready(t) => break t,
Poll::Pending if repeat.0.swap(false, Ordering::Relaxed) => (),
Poll::Pending if spin_waker.repeat.swap(false, Ordering::SeqCst) => (),
Poll::Pending => panic!("The future did not exit and did not call its waker."),
}
}

View File

@@ -1,5 +1,7 @@
pub mod debug;
mod cancel_cleanup;
pub use cancel_cleanup::*;
mod localset;
pub use localset::*;
mod task_future;
pub use task_future::*;
pub use task_future::*;

View File

@@ -1,21 +1,35 @@
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::Poll;
use futures::StreamExt;
use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded};
use futures::channel::mpsc::{SendError, UnboundedReceiver, UnboundedSender, unbounded};
use futures::future::LocalBoxFuture;
use futures::stream::FuturesUnordered;
use futures::{SinkExt, StreamExt};
pub struct LocalSet<'a, E> {
receiver: UnboundedReceiver<LocalBoxFuture<'a, Result<(), E>>>,
pending: VecDeque<LocalBoxFuture<'a, Result<(), E>>>,
pub struct LocalSetController<'a, E> {
sender: UnboundedSender<LocalBoxFuture<'a, Result<(), E>>>,
}
impl<'a, E> LocalSet<'a, E> {
pub fn new() -> (UnboundedSender<LocalBoxFuture<'a, Result<(), E>>>, Self) {
let (sender, receiver) = unbounded();
(sender, Self { receiver, pending: VecDeque::new() })
impl<'a, E> LocalSetController<'a, E> {
pub async fn spawn<F: Future<Output = Result<(), E>> + 'a>(
&mut self,
fut: F,
) -> Result<(), SendError> {
self.sender.send(Box::pin(fut)).await
}
}
pub fn local_set<'a, E: 'a>()
-> (LocalSetController<'a, E>, impl Future<Output = Result<(), E>> + 'a) {
let (sender, receiver) = unbounded();
let controller = LocalSetController { sender };
let set = LocalSet { receiver, pending: FuturesUnordered::new() };
(controller, set)
}
struct LocalSet<'a, E> {
receiver: UnboundedReceiver<LocalBoxFuture<'a, Result<(), E>>>,
pending: FuturesUnordered<LocalBoxFuture<'a, Result<(), E>>>,
}
impl<E> Future for LocalSet<'_, E> {
type Output = Result<(), E>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
@@ -23,7 +37,7 @@ impl<E> Future for LocalSet<'_, E> {
let mut any_pending = false;
loop {
match this.receiver.poll_next_unpin(cx) {
Poll::Ready(Some(fut)) => this.pending.push_back(fut),
Poll::Ready(Some(fut)) => this.pending.push(fut),
Poll::Ready(None) => break,
Poll::Pending => {
any_pending = true;
@@ -31,15 +45,14 @@ impl<E> Future for LocalSet<'_, E> {
},
}
}
let count = this.pending.len();
for _ in 0..count {
let mut req = this.pending.pop_front().unwrap();
match req.as_mut().poll(cx) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
loop {
match this.pending.poll_next_unpin(cx) {
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
Poll::Ready(Some(Ok(()))) => continue,
Poll::Ready(None) => break,
Poll::Pending => {
any_pending = true;
this.pending.push_back(req)
break;
},
}
}

View File

@@ -1,10 +1,11 @@
use std::any::Any;
use std::cell::RefCell;
use std::marker::PhantomData;
use std::pin::Pin;
use std::pin::{Pin, pin};
use std::rc::Rc;
use std::task::{Context, Poll, Waker};
use futures::FutureExt;
use futures::channel::oneshot::{self, Canceled};
use futures::future::{FusedFuture, LocalBoxFuture};
struct State {
@@ -43,50 +44,56 @@ impl Future for Pollable {
}
}
pub struct JoinError;
/// An object that can be used to inspect the state of the task
pub struct Handle<T: 'static>(Rc<RefCell<State>>, PhantomData<T>);
pub struct Handle<T: 'static> {
send_abort: RefCell<Option<oneshot::Sender<()>>>,
ready: Rc<RefCell<bool>>,
recv_output: RefCell<oneshot::Receiver<T>>,
}
impl<T: 'static> Handle<T> {
/// Immediately stop working on this task, and return the result if it has
/// already finished
pub fn abort(&self) -> Option<T> {
let mut g = self.0.borrow_mut();
g.work.take();
match g.result.take() {
Some(val) => Some(*val.downcast().expect("Mismatch between type of future and handle")),
None => {
g.waker.wake_by_ref();
None
},
if let Some(abort) = self.send_abort.take() {
let _ = abort.send(());
}
self.recv_output.borrow_mut().try_recv().ok().flatten()
}
/// Determine if there's any more work to do on this task
pub fn is_finished(&self) -> bool {
let g = self.0.borrow();
g.result.is_some() || g.work.is_none()
}
pub fn is_finished(&self) -> bool { *self.ready.borrow() }
/// "finish" the freestanding task, and return the future instead
pub async fn join(self) -> T {
let work = {
let mut g = self.0.borrow_mut();
if let Some(val) = g.result.take() {
return *val.downcast().expect("Mistmatch between type of future and handle");
}
g.waker.wake_by_ref();
g.work.take().expect("Attempted to join task that was already aborted")
};
*work.await.downcast().expect("Mismatch between type of future and handle")
pub async fn join(self) -> Result<T, JoinError> {
self.recv_output.into_inner().await.map_err(|Canceled| JoinError)
}
}
/// Split a future into an object that can be polled and one that returns
/// information on its progress and its result. The first one can be passed to
/// an executor or localset, the second can be used to manage it
pub fn to_task<F: Future<Output: 'static> + 'static>(f: F) -> (Pollable, Handle<F::Output>) {
let dyn_future = Box::pin(async { Box::new(f.await) as Box<dyn Any> });
let state = Rc::new(RefCell::new(State {
result: None,
work: Some(dyn_future),
waker: Waker::noop().clone(),
}));
(Pollable(state.clone()), Handle(state, PhantomData))
pub fn to_task<'a, F: Future<Output: 'a> + 'a>(
f: F,
) -> (impl Future<Output = ()> + 'a, Handle<F::Output>) {
let (send_abort, mut on_abort) = oneshot::channel();
let (send_output, on_output) = oneshot::channel();
let ready = Rc::new(RefCell::new(false));
let ready2 = ready.clone();
let fut = async move {
let mut fut = pin!(f.fuse());
let output = futures::select_biased! {
res = on_abort => match res {
Ok(()) => return,
Err(_) => fut.await,
},
output = fut => output,
};
ready2.replace(true);
let _: Result<_, _> = send_output.send(output);
};
(fut, Handle {
ready,
recv_output: RefCell::new(on_output),
send_abort: RefCell::new(Some(send_abort)),
})
}