use std::marker::PhantomData; use futures::channel::mpsc; use futures::stream::{PollNext, select_with_strategy}; use futures::{FutureExt, SinkExt, Stream, StreamExt}; /// Handle that allows you to emit values on a stream. If you drop /// this, the stream will end and you will not be polled again. pub struct StreamCtx<'a, T>(mpsc::Sender, PhantomData<&'a ()>); impl StreamCtx<'_, T> { pub async fn emit(&mut self, value: T) { (self.0.send(value).await) .expect("Dropped a stream receiver without dropping the driving closure"); } } fn left_strat(_: &mut ()) -> PollNext { PollNext::Left } /// Create a stream from an async function acting as a coroutine pub fn stream<'a, T: 'a>( f: impl for<'b> AsyncFnOnce(StreamCtx<'b, T>) + 'a, ) -> impl Stream + 'a { let (send, recv) = mpsc::channel::(1); let fut = async { f(StreamCtx(send, PhantomData)).await }; // use options to ensure that the stream is driven to exhaustion select_with_strategy(fut.into_stream().map(|()| None), recv.map(|t| Some(t)), left_strat) .filter_map(async |opt| opt) } /// Create a stream of result from a fallible function. pub fn try_stream<'a, T: 'a, E: 'a>( f: impl for<'b> AsyncFnOnce(StreamCtx<'b, T>) -> Result, E> + 'a, ) -> impl Stream> + 'a { let (send, recv) = mpsc::channel::(1); let fut = async { f(StreamCtx(send, PhantomData)).await }; select_with_strategy( fut.into_stream().map(|res| if let Err(e) = res { Some(Err(e)) } else { None }), recv.map(|t| Some(Ok(t))), left_strat, ) .filter_map(async |opt| opt) } #[cfg(test)] mod test { use std::task::Poll; use std::{future, pin}; use futures::channel::mpsc::channel; use futures::{Stream, StreamExt, TryStreamExt}; use test_executors::spin_on; use crate::{stream, try_stream}; #[test] fn sync() { spin_on(async { let v = stream(async |mut cx| { for i in 0..5 { cx.emit(i).await } }) .collect::>() .await; assert_eq!(v, [0, 1, 2, 3, 4]) }) } #[test] /// The exact behaviour of the poll function under blocked use fn with_delay() { spin_on(async { let (mut send, mut recv) = channel(0); let mut s = pin::pin!(stream(async |mut cx| { for i in 0..2 { cx.emit(i).await } recv.next().await; for i in 2..5 { cx.emit(i).await } })); let mut log = String::new(); let log = future::poll_fn(|cx| { match s.as_mut().poll_next(cx) { Poll::Ready(Some(r)) => log += &format!("Found {r}\n"), Poll::Ready(None) => return Poll::Ready(format!("{log}Ended")), Poll::Pending => match send.try_send(()) { Ok(()) => log += "Unblocked\n", Err(err) => return Poll::Ready(format!("{log}Unblock err: {err}")), }, } Poll::Pending }) .await; const EXPECTED: &str = "\ Found 0\n\ Found 1\n\ Unblocked\n\ Found 2\n\ Found 3\n\ Found 4\n\ Ended"; assert_eq!(log, EXPECTED) }) } #[test] fn sync_try_all_ok() { spin_on(async { let v = try_stream::<_, ()>(async |mut cx| { for i in 0..5 { cx.emit(i).await } Ok(cx) }) .try_collect::>() .await; assert_eq!(v, Ok(vec![0, 1, 2, 3, 4])) }) } #[test] fn sync_try_err() { spin_on(async { let v = try_stream::<_, ()>(async |mut cx| { for i in 0..5 { cx.emit(i).await } Err(()) }) .try_collect::>() .await; assert_eq!(v, Err(())) }) } }