forked from Orchid/orchid
208 lines
5.3 KiB
Rust
208 lines
5.3 KiB
Rust
use std::cell::Cell;
|
|
use std::future::poll_fn;
|
|
use std::marker::PhantomData;
|
|
use std::pin::Pin;
|
|
use std::ptr;
|
|
use std::task::{Context, Poll};
|
|
|
|
use futures::future::LocalBoxFuture;
|
|
use futures::{FutureExt, Stream};
|
|
|
|
type YieldSlot<'a, T> = &'a Cell<Option<T>>;
|
|
|
|
/// Handle that allows you to emit values on a stream. If you drop
|
|
/// this, the stream will end and you will not be polled again.
|
|
pub struct StreamCtx<'a, T>(&'a Cell<Option<T>>, PhantomData<&'a ()>);
|
|
impl<T> StreamCtx<'_, T> {
|
|
pub fn emit(&mut self, value: T) -> impl Future<Output = ()> {
|
|
assert!(self.0.replace(Some(value)).is_none(), "Leftover value in stream");
|
|
let mut state = Poll::Pending;
|
|
poll_fn(move |_| std::mem::replace(&mut state, Poll::Ready(())))
|
|
}
|
|
}
|
|
|
|
enum FnOrFut<'a, T, O> {
|
|
Fn(Option<Box<dyn FnOnce(YieldSlot<'a, T>) -> LocalBoxFuture<'a, O> + 'a>>),
|
|
Fut(LocalBoxFuture<'a, O>),
|
|
}
|
|
|
|
struct AsyncFnStream<'a, T> {
|
|
driver: FnOrFut<'a, T, ()>,
|
|
output: Cell<Option<T>>,
|
|
}
|
|
impl<'a, T> Stream for AsyncFnStream<'a, T> {
|
|
type Item = T;
|
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
unsafe {
|
|
let self_mut = self.get_unchecked_mut();
|
|
let fut = match &mut self_mut.driver {
|
|
FnOrFut::Fut(fut) => fut,
|
|
FnOrFut::Fn(f) => {
|
|
// safety: the cell is held inline in self, which is pinned.
|
|
let cell = ptr::from_ref(&self_mut.output).as_ref().unwrap();
|
|
let fut = f.take().unwrap()(cell);
|
|
self_mut.driver = FnOrFut::Fut(fut);
|
|
return Pin::new_unchecked(self_mut).poll_next(cx);
|
|
},
|
|
};
|
|
match fut.as_mut().poll(cx) {
|
|
Poll::Ready(()) => Poll::Ready(None),
|
|
Poll::Pending => match self_mut.output.replace(None) {
|
|
None => Poll::Pending,
|
|
Some(t) => Poll::Ready(Some(t)),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
struct AsyncFnTryStream<'a, T, E> {
|
|
driver: FnOrFut<'a, T, Result<StreamCtx<'a, T>, E>>,
|
|
output: Cell<Option<T>>,
|
|
}
|
|
impl<'a, T, E> Stream for AsyncFnTryStream<'a, T, E> {
|
|
type Item = Result<T, E>;
|
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
unsafe {
|
|
let self_mut = self.get_unchecked_mut();
|
|
let fut = match &mut self_mut.driver {
|
|
FnOrFut::Fut(fut) => fut,
|
|
FnOrFut::Fn(f) => {
|
|
// safety: the cell is held inline in self, which is pinned.
|
|
let cell = ptr::from_ref(&self_mut.output).as_ref().unwrap();
|
|
let fut = f.take().unwrap()(cell);
|
|
self_mut.driver = FnOrFut::Fut(fut);
|
|
return Pin::new_unchecked(self_mut).poll_next(cx);
|
|
},
|
|
};
|
|
match fut.as_mut().poll(cx) {
|
|
Poll::Ready(Ok(_)) => Poll::Ready(None),
|
|
Poll::Ready(Err(ex)) => Poll::Ready(Some(Err(ex))),
|
|
Poll::Pending => match self_mut.output.replace(None) {
|
|
None => Poll::Pending,
|
|
Some(t) => Poll::Ready(Some(Ok(t))),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Create a stream from an async function acting as a coroutine
|
|
pub fn stream<'a, T: 'a>(
|
|
f: impl for<'b> AsyncFnOnce(StreamCtx<'b, T>) + 'a,
|
|
) -> impl Stream<Item = T> + 'a {
|
|
AsyncFnStream {
|
|
output: Cell::new(None),
|
|
driver: FnOrFut::Fn(Some(Box::new(|t| {
|
|
async { f(StreamCtx(t, PhantomData)).await }.boxed_local()
|
|
}))),
|
|
}
|
|
}
|
|
|
|
/// Create a stream of result from a fallible function.
|
|
pub fn try_stream<'a, T: 'a, E: 'a>(
|
|
f: impl for<'b> AsyncFnOnce(StreamCtx<'b, T>) -> Result<StreamCtx<'b, T>, E> + 'a,
|
|
) -> impl Stream<Item = Result<T, E>> + 'a {
|
|
AsyncFnTryStream {
|
|
output: Cell::new(None),
|
|
driver: FnOrFut::Fn(Some(Box::new(|t| {
|
|
async { f(StreamCtx(t, PhantomData)).await }.boxed_local()
|
|
}))),
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use std::task::Poll;
|
|
use std::{future, pin};
|
|
|
|
use futures::channel::mpsc::channel;
|
|
use futures::{Stream, StreamExt, TryStreamExt};
|
|
use test_executors::spin_on;
|
|
|
|
use crate::{stream, try_stream};
|
|
|
|
#[test]
|
|
fn sync() {
|
|
spin_on(async {
|
|
let v = stream(async |mut cx| {
|
|
for i in 0..5 {
|
|
cx.emit(i).await
|
|
}
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.await;
|
|
assert_eq!(v, [0, 1, 2, 3, 4])
|
|
})
|
|
}
|
|
|
|
#[test]
|
|
/// The exact behaviour of the poll function under blocked use
|
|
fn with_delay() {
|
|
spin_on(async {
|
|
let (mut send, mut recv) = channel(0);
|
|
let mut s = pin::pin!(stream(async |mut cx| {
|
|
for i in 0..2 {
|
|
cx.emit(i).await
|
|
}
|
|
recv.next().await;
|
|
for i in 2..5 {
|
|
cx.emit(i).await
|
|
}
|
|
}));
|
|
let mut log = String::new();
|
|
let log = future::poll_fn(|cx| {
|
|
match s.as_mut().poll_next(cx) {
|
|
Poll::Ready(Some(r)) => log += &format!("Found {r}\n"),
|
|
Poll::Ready(None) => return Poll::Ready(format!("{log}Ended")),
|
|
Poll::Pending => match send.try_send(()) {
|
|
Ok(()) => log += "Unblocked\n",
|
|
Err(err) => return Poll::Ready(format!("{log}Unblock err: {err}")),
|
|
},
|
|
}
|
|
Poll::Pending
|
|
})
|
|
.await;
|
|
const EXPECTED: &str = "\
|
|
Found 0\n\
|
|
Found 1\n\
|
|
Unblocked\n\
|
|
Found 2\n\
|
|
Found 3\n\
|
|
Found 4\n\
|
|
Ended";
|
|
assert_eq!(log, EXPECTED)
|
|
})
|
|
}
|
|
|
|
#[test]
|
|
fn sync_try_all_ok() {
|
|
spin_on(async {
|
|
let v = try_stream::<_, ()>(async |mut cx| {
|
|
for i in 0..5 {
|
|
cx.emit(i).await
|
|
}
|
|
Ok(cx)
|
|
})
|
|
.try_collect::<Vec<_>>()
|
|
.await;
|
|
assert_eq!(v, Ok(vec![0, 1, 2, 3, 4]))
|
|
})
|
|
}
|
|
|
|
#[test]
|
|
fn sync_try_err() {
|
|
spin_on(async {
|
|
let v = try_stream::<_, ()>(async |mut cx| {
|
|
for i in 0..5 {
|
|
cx.emit(i).await
|
|
}
|
|
Err(())
|
|
})
|
|
.try_collect::<Vec<_>>()
|
|
.await;
|
|
assert_eq!(v, Err(()))
|
|
})
|
|
}
|
|
}
|