use std::cell::Cell; use std::future::poll_fn; use std::marker::PhantomData; use std::pin::Pin; use std::ptr; use std::task::{Context, Poll}; use futures::future::LocalBoxFuture; use futures::{FutureExt, Stream}; type YieldSlot<'a, T> = &'a Cell>; /// Handle that allows you to emit values on a stream. If you drop /// this, the stream will end and you will not be polled again. pub struct StreamCtx<'a, T>(&'a Cell>, PhantomData<&'a ()>); impl StreamCtx<'_, T> { pub fn emit(&mut self, value: T) -> impl Future { assert!(self.0.replace(Some(value)).is_none(), "Leftover value in stream"); let mut state = Poll::Pending; poll_fn(move |_| std::mem::replace(&mut state, Poll::Ready(()))) } } enum FnOrFut<'a, T, O> { Fn(Option) -> LocalBoxFuture<'a, O> + 'a>>), Fut(LocalBoxFuture<'a, O>), } struct AsyncFnStream<'a, T> { driver: FnOrFut<'a, T, ()>, output: Cell>, } impl<'a, T> Stream for AsyncFnStream<'a, T> { type Item = T; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { unsafe { let self_mut = self.get_unchecked_mut(); let fut = match &mut self_mut.driver { FnOrFut::Fut(fut) => fut, FnOrFut::Fn(f) => { // safety: the cell is held inline in self, which is pinned. let cell = ptr::from_ref(&self_mut.output).as_ref().unwrap(); let fut = f.take().unwrap()(cell); self_mut.driver = FnOrFut::Fut(fut); return Pin::new_unchecked(self_mut).poll_next(cx); }, }; match fut.as_mut().poll(cx) { Poll::Ready(()) => Poll::Ready(None), Poll::Pending => match self_mut.output.replace(None) { None => Poll::Pending, Some(t) => Poll::Ready(Some(t)), }, } } } } struct AsyncFnTryStream<'a, T, E> { driver: FnOrFut<'a, T, Result, E>>, output: Cell>, } impl<'a, T, E> Stream for AsyncFnTryStream<'a, T, E> { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { unsafe { let self_mut = self.get_unchecked_mut(); let fut = match &mut self_mut.driver { FnOrFut::Fut(fut) => fut, FnOrFut::Fn(f) => { // safety: the cell is held inline in self, which is pinned. let cell = ptr::from_ref(&self_mut.output).as_ref().unwrap(); let fut = f.take().unwrap()(cell); self_mut.driver = FnOrFut::Fut(fut); return Pin::new_unchecked(self_mut).poll_next(cx); }, }; match fut.as_mut().poll(cx) { Poll::Ready(Ok(_)) => Poll::Ready(None), Poll::Ready(Err(ex)) => Poll::Ready(Some(Err(ex))), Poll::Pending => match self_mut.output.replace(None) { None => Poll::Pending, Some(t) => Poll::Ready(Some(Ok(t))), }, } } } } /// Create a stream from an async function acting as a coroutine pub fn stream<'a, T: 'a>( f: impl for<'b> AsyncFnOnce(StreamCtx<'b, T>) + 'a, ) -> impl Stream + 'a { AsyncFnStream { output: Cell::new(None), driver: FnOrFut::Fn(Some(Box::new(|t| { async { f(StreamCtx(t, PhantomData)).await }.boxed_local() }))), } } /// Create a stream of result from a fallible function. pub fn try_stream<'a, T: 'a, E: 'a>( f: impl for<'b> AsyncFnOnce(StreamCtx<'b, T>) -> Result, E> + 'a, ) -> impl Stream> + 'a { AsyncFnTryStream { output: Cell::new(None), driver: FnOrFut::Fn(Some(Box::new(|t| { async { f(StreamCtx(t, PhantomData)).await }.boxed_local() }))), } } #[cfg(test)] mod test { use std::task::Poll; use std::{future, pin}; use futures::channel::mpsc::channel; use futures::{Stream, StreamExt, TryStreamExt}; use test_executors::spin_on; use crate::{stream, try_stream}; #[test] fn sync() { spin_on(async { let v = stream(async |mut cx| { for i in 0..5 { cx.emit(i).await } }) .collect::>() .await; assert_eq!(v, [0, 1, 2, 3, 4]) }) } #[test] /// The exact behaviour of the poll function under blocked use fn with_delay() { spin_on(async { let (mut send, mut recv) = channel(0); let mut s = pin::pin!(stream(async |mut cx| { for i in 0..2 { cx.emit(i).await } recv.next().await; for i in 2..5 { cx.emit(i).await } })); let mut log = String::new(); let log = future::poll_fn(|cx| { match s.as_mut().poll_next(cx) { Poll::Ready(Some(r)) => log += &format!("Found {r}\n"), Poll::Ready(None) => return Poll::Ready(format!("{log}Ended")), Poll::Pending => match send.try_send(()) { Ok(()) => log += "Unblocked\n", Err(err) => return Poll::Ready(format!("{log}Unblock err: {err}")), }, } Poll::Pending }) .await; const EXPECTED: &str = "\ Found 0\n\ Found 1\n\ Unblocked\n\ Found 2\n\ Found 3\n\ Found 4\n\ Ended"; assert_eq!(log, EXPECTED) }) } #[test] fn sync_try_all_ok() { spin_on(async { let v = try_stream::<_, ()>(async |mut cx| { for i in 0..5 { cx.emit(i).await } Ok(cx) }) .try_collect::>() .await; assert_eq!(v, Ok(vec![0, 1, 2, 3, 4])) }) } #[test] fn sync_try_err() { spin_on(async { let v = try_stream::<_, ()>(async |mut cx| { for i in 0..5 { cx.emit(i).await } Err(()) }) .try_collect::>() .await; assert_eq!(v, Err(())) }) } }