138 lines
3.4 KiB
Rust
138 lines
3.4 KiB
Rust
use std::marker::PhantomData;
|
|
|
|
use futures::channel::mpsc;
|
|
use futures::stream::{PollNext, select_with_strategy};
|
|
use futures::{FutureExt, SinkExt, Stream, StreamExt};
|
|
|
|
/// Handle that allows you to emit values on a stream. If you drop
|
|
/// this, the stream will end and you will not be polled again.
|
|
pub struct StreamCtx<'a, T>(mpsc::Sender<T>, PhantomData<&'a ()>);
|
|
impl<T> StreamCtx<'_, T> {
|
|
pub async fn emit(&mut self, value: T) {
|
|
(self.0.send(value).await)
|
|
.expect("Dropped a stream receiver without dropping the driving closure");
|
|
}
|
|
}
|
|
|
|
fn left_strat(_: &mut ()) -> PollNext { PollNext::Left }
|
|
|
|
/// Create a stream from an async function acting as a coroutine
|
|
pub fn stream<'a, T: 'a>(
|
|
f: impl for<'b> AsyncFnOnce(StreamCtx<'b, T>) + 'a,
|
|
) -> impl Stream<Item = T> + 'a {
|
|
let (send, recv) = mpsc::channel::<T>(1);
|
|
let fut = async { f(StreamCtx(send, PhantomData)).await };
|
|
// use options to ensure that the stream is driven to exhaustion
|
|
select_with_strategy(fut.into_stream().map(|()| None), recv.map(|t| Some(t)), left_strat)
|
|
.filter_map(async |opt| opt)
|
|
}
|
|
|
|
/// Create a stream of result from a fallible function.
|
|
pub fn try_stream<'a, T: 'a, E: 'a>(
|
|
f: impl for<'b> AsyncFnOnce(StreamCtx<'b, T>) -> Result<StreamCtx<'b, T>, E> + 'a,
|
|
) -> impl Stream<Item = Result<T, E>> + 'a {
|
|
let (send, recv) = mpsc::channel::<T>(1);
|
|
let fut = async { f(StreamCtx(send, PhantomData)).await };
|
|
select_with_strategy(
|
|
fut.into_stream().map(|res| if let Err(e) = res { Some(Err(e)) } else { None }),
|
|
recv.map(|t| Some(Ok(t))),
|
|
left_strat,
|
|
)
|
|
.filter_map(async |opt| opt)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use std::task::Poll;
|
|
use std::{future, pin};
|
|
|
|
use futures::channel::mpsc::channel;
|
|
use futures::{Stream, StreamExt, TryStreamExt};
|
|
use test_executors::spin_on;
|
|
|
|
use crate::{stream, try_stream};
|
|
|
|
#[test]
|
|
fn sync() {
|
|
spin_on(async {
|
|
let v = stream(async |mut cx| {
|
|
for i in 0..5 {
|
|
cx.emit(i).await
|
|
}
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.await;
|
|
assert_eq!(v, [0, 1, 2, 3, 4])
|
|
})
|
|
}
|
|
|
|
#[test]
|
|
/// The exact behaviour of the poll function under blocked use
|
|
fn with_delay() {
|
|
spin_on(async {
|
|
let (mut send, mut recv) = channel(0);
|
|
let mut s = pin::pin!(stream(async |mut cx| {
|
|
for i in 0..2 {
|
|
cx.emit(i).await
|
|
}
|
|
recv.next().await;
|
|
for i in 2..5 {
|
|
cx.emit(i).await
|
|
}
|
|
}));
|
|
let mut log = String::new();
|
|
let log = future::poll_fn(|cx| {
|
|
match s.as_mut().poll_next(cx) {
|
|
Poll::Ready(Some(r)) => log += &format!("Found {r}\n"),
|
|
Poll::Ready(None) => return Poll::Ready(format!("{log}Ended")),
|
|
Poll::Pending => match send.try_send(()) {
|
|
Ok(()) => log += "Unblocked\n",
|
|
Err(err) => return Poll::Ready(format!("{log}Unblock err: {err}")),
|
|
},
|
|
}
|
|
Poll::Pending
|
|
})
|
|
.await;
|
|
const EXPECTED: &str = "\
|
|
Found 0\n\
|
|
Found 1\n\
|
|
Unblocked\n\
|
|
Found 2\n\
|
|
Found 3\n\
|
|
Found 4\n\
|
|
Ended";
|
|
assert_eq!(log, EXPECTED)
|
|
})
|
|
}
|
|
|
|
#[test]
|
|
fn sync_try_all_ok() {
|
|
spin_on(async {
|
|
let v = try_stream::<_, ()>(async |mut cx| {
|
|
for i in 0..5 {
|
|
cx.emit(i).await
|
|
}
|
|
Ok(cx)
|
|
})
|
|
.try_collect::<Vec<_>>()
|
|
.await;
|
|
assert_eq!(v, Ok(vec![0, 1, 2, 3, 4]))
|
|
})
|
|
}
|
|
|
|
#[test]
|
|
fn sync_try_err() {
|
|
spin_on(async {
|
|
let v = try_stream::<_, ()>(async |mut cx| {
|
|
for i in 0..5 {
|
|
cx.emit(i).await
|
|
}
|
|
Err(())
|
|
})
|
|
.try_collect::<Vec<_>>()
|
|
.await;
|
|
assert_eq!(v, Err(()))
|
|
})
|
|
}
|
|
}
|