Files
orchid/async-fn-stream/src/lib.rs

138 lines
3.4 KiB
Rust

use std::marker::PhantomData;
use futures::channel::mpsc;
use futures::stream::{PollNext, select_with_strategy};
use futures::{FutureExt, SinkExt, Stream, StreamExt};
/// Handle that allows you to emit values on a stream. If you drop
/// this, the stream will end and you will not be polled again.
pub struct StreamCtx<'a, T>(mpsc::Sender<T>, PhantomData<&'a ()>);
impl<T> StreamCtx<'_, T> {
pub async fn emit(&mut self, value: T) {
(self.0.send(value).await)
.expect("Dropped a stream receiver without dropping the driving closure");
}
}
fn left_strat(_: &mut ()) -> PollNext { PollNext::Left }
/// Create a stream from an async function acting as a coroutine
pub fn stream<'a, T: 'a>(
f: impl for<'b> AsyncFnOnce(StreamCtx<'b, T>) + 'a,
) -> impl Stream<Item = T> + 'a {
let (send, recv) = mpsc::channel::<T>(1);
let fut = async { f(StreamCtx(send, PhantomData)).await };
// use options to ensure that the stream is driven to exhaustion
select_with_strategy(fut.into_stream().map(|()| None), recv.map(|t| Some(t)), left_strat)
.filter_map(async |opt| opt)
}
/// Create a stream of result from a fallible function.
pub fn try_stream<'a, T: 'a, E: 'a>(
f: impl for<'b> AsyncFnOnce(StreamCtx<'b, T>) -> Result<StreamCtx<'b, T>, E> + 'a,
) -> impl Stream<Item = Result<T, E>> + 'a {
let (send, recv) = mpsc::channel::<T>(1);
let fut = async { f(StreamCtx(send, PhantomData)).await };
select_with_strategy(
fut.into_stream().map(|res| if let Err(e) = res { Some(Err(e)) } else { None }),
recv.map(|t| Some(Ok(t))),
left_strat,
)
.filter_map(async |opt| opt)
}
#[cfg(test)]
mod test {
use std::task::Poll;
use std::{future, pin};
use futures::channel::mpsc::channel;
use futures::{Stream, StreamExt, TryStreamExt};
use test_executors::spin_on;
use crate::{stream, try_stream};
#[test]
fn sync() {
spin_on(async {
let v = stream(async |mut cx| {
for i in 0..5 {
cx.emit(i).await
}
})
.collect::<Vec<_>>()
.await;
assert_eq!(v, [0, 1, 2, 3, 4])
})
}
#[test]
/// The exact behaviour of the poll function under blocked use
fn with_delay() {
spin_on(async {
let (mut send, mut recv) = channel(0);
let mut s = pin::pin!(stream(async |mut cx| {
for i in 0..2 {
cx.emit(i).await
}
recv.next().await;
for i in 2..5 {
cx.emit(i).await
}
}));
let mut log = String::new();
let log = future::poll_fn(|cx| {
match s.as_mut().poll_next(cx) {
Poll::Ready(Some(r)) => log += &format!("Found {r}\n"),
Poll::Ready(None) => return Poll::Ready(format!("{log}Ended")),
Poll::Pending => match send.try_send(()) {
Ok(()) => log += "Unblocked\n",
Err(err) => return Poll::Ready(format!("{log}Unblock err: {err}")),
},
}
Poll::Pending
})
.await;
const EXPECTED: &str = "\
Found 0\n\
Found 1\n\
Unblocked\n\
Found 2\n\
Found 3\n\
Found 4\n\
Ended";
assert_eq!(log, EXPECTED)
})
}
#[test]
fn sync_try_all_ok() {
spin_on(async {
let v = try_stream::<_, ()>(async |mut cx| {
for i in 0..5 {
cx.emit(i).await
}
Ok(cx)
})
.try_collect::<Vec<_>>()
.await;
assert_eq!(v, Ok(vec![0, 1, 2, 3, 4]))
})
}
#[test]
fn sync_try_err() {
spin_on(async {
let v = try_stream::<_, ()>(async |mut cx| {
for i in 0..5 {
cx.emit(i).await
}
Err(())
})
.try_collect::<Vec<_>>()
.await;
assert_eq!(v, Err(()))
})
}
}