diff --git a/.cargo/config.toml b/.cargo/config.toml index 063a563..5dcce32 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -5,7 +5,7 @@ orcxdb = "xtask orcxdb" [env] CARGO_WORKSPACE_DIR = { value = "", relative = true } -ORCHID_EXTENSIONS = "target/debug/orchid-std-dbg" +ORCHID_EXTENSIONS = "target/debug/orchid_std" ORCHID_DEFAULT_SYSTEMS = "orchid::std;orchid::macros" ORCHID_LOG_BUFFERS = "true" RUST_BACKTRACE = "1" diff --git a/Cargo.lock b/Cargo.lock index 91c655c..7d35fff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -891,7 +891,6 @@ dependencies = [ "rust-embed", "substack", "task-local", - "test_executors", "trait-set", "unsync-pipe", ] diff --git a/orchid-base/Cargo.toml b/orchid-base/Cargo.toml index f777d6b..6385954 100644 --- a/orchid-base/Cargo.toml +++ b/orchid-base/Cargo.toml @@ -30,4 +30,3 @@ task-local = "0.1.0" [dev-dependencies] futures = "0.3.31" -test_executors = "0.4.1" diff --git a/orchid-base/src/binary.rs b/orchid-base/src/binary.rs index d3e1a62..3474201 100644 --- a/orchid-base/src/binary.rs +++ b/orchid-base/src/binary.rs @@ -3,13 +3,13 @@ use std::pin::Pin; use std::rc::Rc; use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; -use orchid_api::binary::{FutureBin, FutureContextBin, OwnedWakerBin, UnitPoll}; +use crate::api; type WideBox = Box>; static OWNED_VTABLE: RawWakerVTable = RawWakerVTable::new( |data| { - let data = unsafe { Rc::::from_raw(data as *const _) }; + let data = unsafe { Rc::::from_raw(data as *const _) }; let val = RawWaker::new(Rc::into_raw(data.clone()) as *const (), &OWNED_VTABLE); // Clone must create a duplicate of the Rc, so it has to be un-leaked, cloned, // then leaked again. @@ -19,19 +19,19 @@ static OWNED_VTABLE: RawWakerVTable = RawWakerVTable::new( |data| { // Wake must awaken the task and then clean up the state, so the waker must be // un-leaked - let data = unsafe { Rc::::from_raw(data as *const _) }; + let data = unsafe { Rc::::from_raw(data as *const _) }; (data.wake)(data.data); mem::drop(data); }, |data| { // Wake-by-ref must awaken the task while preserving the future, so the Rc is // untouched - let data = unsafe { (data as *const OwnedWakerBin).as_ref() }.unwrap(); + let data = unsafe { (data as *const api::binary::OwnedWakerBin).as_ref() }.unwrap(); (data.wake_ref)(data.data); }, |data| { // Drop must clean up the state, so the waker must be un-leaked - let data = unsafe { Rc::::from_raw(data as *const _) }; + let data = unsafe { Rc::::from_raw(data as *const _) }; (data.drop)(data.data); mem::drop(data); }, @@ -39,12 +39,12 @@ static OWNED_VTABLE: RawWakerVTable = RawWakerVTable::new( struct BorrowedWakerData<'a> { go_around: &'a mut bool, - cx: FutureContextBin, + cx: api::binary::FutureContextBin, } static BORROWED_VTABLE: RawWakerVTable = RawWakerVTable::new( |data| { let data = unsafe { (data as *mut BorrowedWakerData).as_mut() }.unwrap(); - let owned_data = Rc::::new((data.cx.waker)(data.cx.data)); + let owned_data = Rc::::new((data.cx.waker)(data.cx.data)); RawWaker::new(Rc::into_raw(owned_data) as *const (), &OWNED_VTABLE) }, |data| *unsafe { (data as *mut BorrowedWakerData).as_mut() }.unwrap().go_around = true, @@ -54,13 +54,13 @@ static BORROWED_VTABLE: RawWakerVTable = RawWakerVTable::new( /// Convert a future to a binary-compatible format that can be sent across /// dynamic library boundaries -pub fn future_to_vt + 'static>(fut: Fut) -> FutureBin { +pub fn future_to_vt + 'static>(fut: Fut) -> api::binary::FutureBin { let wide_box = Box::new(fut) as WideBox; let data = Box::into_raw(Box::new(wide_box)); extern "C" fn drop(raw: *const ()) { mem::drop(unsafe { Box::::from_raw(raw as *mut _) }) } - extern "C" fn poll(raw: *const (), cx: FutureContextBin) -> UnitPoll { + extern "C" fn poll(raw: *const (), cx: api::binary::FutureContextBin) -> api::binary::UnitPoll { let mut this = unsafe { Pin::new_unchecked(&mut **(raw as *mut WideBox).as_mut().unwrap()) }; loop { let mut go_around = false; @@ -73,27 +73,27 @@ pub fn future_to_vt + 'static>(fut: Fut) -> FutureBin { let mut ctx = Context::from_waker(&borrowed_waker); let result = this.as_mut().poll(&mut ctx); if matches!(result, Poll::Ready(())) { - break UnitPoll::Ready; + break api::binary::UnitPoll::Ready; } if !go_around { - break UnitPoll::Pending; + break api::binary::UnitPoll::Pending; } } } - FutureBin { data: data as *const _, drop, poll } + api::binary::FutureBin { data: data as *const _, drop, poll } } struct VirtualFuture { - vt: FutureBin, + vt: api::binary::FutureBin, } impl Unpin for VirtualFuture {} impl Future for VirtualFuture { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - extern "C" fn waker(raw: *const ()) -> OwnedWakerBin { + extern "C" fn waker(raw: *const ()) -> api::binary::OwnedWakerBin { let waker = unsafe { (raw as *mut Context).as_mut() }.unwrap().waker().clone(); let data = Box::into_raw(Box::::new(waker)) as *const (); - return OwnedWakerBin { data, drop, wake, wake_ref }; + return api::binary::OwnedWakerBin { data, drop, wake, wake_ref }; extern "C" fn drop(raw: *const ()) { mem::drop(unsafe { Box::::from_raw(raw as *mut Waker) }) } @@ -104,11 +104,11 @@ impl Future for VirtualFuture { unsafe { (raw as *mut Waker).as_mut() }.unwrap().wake_by_ref(); } } - let cx = FutureContextBin { data: cx as *mut Context as *const (), waker }; + let cx = api::binary::FutureContextBin { data: cx as *mut Context as *const (), waker }; let result = (self.vt.poll)(self.vt.data, cx); match result { - UnitPoll::Pending => Poll::Pending, - UnitPoll::Ready => Poll::Ready(()), + api::binary::UnitPoll::Pending => Poll::Pending, + api::binary::UnitPoll::Ready => Poll::Ready(()), } } } @@ -118,4 +118,4 @@ impl Drop for VirtualFuture { /// Receive a future sent across dynamic library boundaries and convert it into /// an owned object -pub fn vt_to_future(vt: FutureBin) -> impl Future { VirtualFuture { vt } } +pub fn vt_to_future(vt: api::binary::FutureBin) -> impl Future { VirtualFuture { vt } } diff --git a/orchid-base/src/future_debug.rs b/orchid-base/src/future_debug.rs new file mode 100644 index 0000000..a5f27f9 --- /dev/null +++ b/orchid-base/src/future_debug.rs @@ -0,0 +1,157 @@ +use std::cell::RefCell; +use std::fmt::Display; +use std::pin::pin; +use std::rc::Rc; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::{Context, Poll, Wake, Waker}; + +use futures::Stream; +use itertools::Itertools; +use task_local::task_local; + +struct OnPollWaker(Waker, F); +impl Wake for OnPollWaker { + fn wake(self: Arc) { + (self.1)(); + self.0.wake_by_ref() + } +} + +/// Attach a callback to the [Future] protocol for testing and debugging. Note +/// that this function is safe and simple in order to facilitate debugging +/// without adding more points of failure, but it's not fast; it performs a heap +/// allocation on each poll of the returned future. +pub async fn on_wake( + f: F, + wake: impl Fn() + Clone + Send + Sync + 'static, +) -> F::Output { + let mut f = pin!(f); + futures::future::poll_fn(|cx| { + let waker = Arc::new(OnPollWaker(cx.waker().clone(), wake.clone())).into(); + f.as_mut().poll(&mut Context::from_waker(&waker)) + }) + .await +} + +/// Respond to [Future::poll] with a callback. For maximum flexibility and state +/// control, your callback receives the actual poll job as a callback function. +/// Failure to call this function will result in an immediate panic. +pub async fn wrap_poll( + f: Fut, + mut cb: impl FnMut(Box bool + '_>), +) -> Fut::Output { + let mut f = pin!(f); + futures::future::poll_fn(|cx| { + let poll = RefCell::new(None); + cb(Box::new(|| { + let poll1 = f.as_mut().poll(cx); + let ret = poll1.is_ready(); + *poll.borrow_mut() = Some(poll1); + ret + })); + poll.into_inner().expect("Callback to on_poll failed to call its argument") + }) + .await +} + +pub fn wrap_poll_next<'a, S: Stream + 'a>( + s: S, + mut cb: impl FnMut(Box bool + '_>) + 'a, +) -> impl Stream + 'a { + let mut s = Box::pin(s); + futures::stream::poll_fn(move |cx| { + let poll = RefCell::new(None); + cb(Box::new(|| { + let poll1 = s.as_mut().poll_next(cx); + let ret = poll1.is_ready(); + *poll.borrow_mut() = Some(poll1); + ret + })); + poll.into_inner().expect("Callback to on_poll failed to call its argument") + }) +} + +pub fn on_stream_wake<'a, S: Stream + 'a>( + s: S, + wake: impl Fn() + Clone + Send + Sync + 'static, +) -> impl Stream { + let mut s = Box::pin(s); + futures::stream::poll_fn(move |cx| { + let waker = Arc::new(OnPollWaker(cx.waker().clone(), wake.clone())).into(); + s.as_mut().poll_next(&mut Context::from_waker(&waker)) + }) +} + +task_local! { + static LABEL_STATE: Vec> +} + +pub async fn with_label(label: &str, f: Fut) -> Fut::Output { + let mut new_lbl = LABEL_STATE.try_with(|lbl| lbl.clone()).unwrap_or_default(); + new_lbl.push(Rc::new(label.to_string())); + LABEL_STATE.scope(new_lbl, f).await +} + +pub fn label() -> impl Display + Clone + Send + Sync + 'static { + LABEL_STATE.try_with(|lbl| lbl.iter().join("/")).unwrap_or("".to_string()) +} + +pub struct Label; +impl Display for Label { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", label()) } +} + +pub async fn eprint_events(note: &str, f: Fut) -> Fut::Output { + let label = label(); + let note1 = note.to_string(); + on_wake( + wrap_poll(f, |cb| { + eprintln!("{Label} polling {note}"); + eprintln!("{Label} polled {note} (ready? {})", cb()) + }), + move || eprintln!("{label} woke {note1}"), + ) + .await +} + +pub fn eprint_stream_events<'a, S: Stream + 'a>( + note: &'a str, + s: S, +) -> impl Stream + 'a { + let label = label(); + let note1 = note.to_string(); + on_stream_wake( + wrap_poll_next(s, move |cb| { + eprintln!("{Label} polling {note}"); + eprintln!("{Label} polled {note} (ready? {})", cb()) + }), + move || eprintln!("{label} woke {note1}"), + ) +} + +struct SpinWaker(AtomicBool); +impl Wake for SpinWaker { + fn wake(self: Arc) { self.0.store(true, Ordering::Relaxed); } +} + +/// A dumb executor that keeps synchronously re-running the future as long as it +/// keeps synchronously waking itself. +/// +/// # Panics +/// +/// If the future doesn't wake itself and doesn't settle. This is useful for +/// deterministic tests that don't contain side effects or threading. +pub fn spin_on(f: Fut) -> Fut::Output { + let repeat = Arc::new(SpinWaker(AtomicBool::new(false))); + let mut f = pin!(f); + let waker = repeat.clone().into(); + let mut cx = Context::from_waker(&waker); + loop { + match f.as_mut().poll(&mut cx) { + Poll::Ready(t) => break t, + Poll::Pending if repeat.0.swap(false, Ordering::Relaxed) => (), + Poll::Pending => panic!("The future did not exit and did not call its waker."), + } + } +} diff --git a/orchid-base/src/lib.rs b/orchid-base/src/lib.rs index ec6c970..44755c5 100644 --- a/orchid-base/src/lib.rs +++ b/orchid-base/src/lib.rs @@ -10,6 +10,7 @@ pub mod combine; pub mod error; pub mod event; pub mod format; +pub mod future_debug; pub mod id_store; pub mod interner; pub mod iter_utils; diff --git a/orchid-base/src/localset.rs b/orchid-base/src/localset.rs index ca0f238..e86a2d6 100644 --- a/orchid-base/src/localset.rs +++ b/orchid-base/src/localset.rs @@ -23,12 +23,12 @@ impl Future for LocalSet<'_, E> { let mut any_pending = false; loop { match this.receiver.poll_next_unpin(cx) { + Poll::Ready(Some(fut)) => this.pending.push_back(fut), + Poll::Ready(None) => break, Poll::Pending => { any_pending = true; break; }, - Poll::Ready(None) => break, - Poll::Ready(Some(fut)) => this.pending.push_back(fut), } } let count = this.pending.len(); diff --git a/orchid-base/src/logging.rs b/orchid-base/src/logging.rs index f1ba853..334234f 100644 --- a/orchid-base/src/logging.rs +++ b/orchid-base/src/logging.rs @@ -71,4 +71,7 @@ pub mod test { Self(Rc::new(move |s| clone!(f; Box::pin(async move { f(s).await })))) } } + impl Default for TestLogger { + fn default() -> Self { TestLogger::new(async |s| eprint!("{s}")) } + } } diff --git a/orchid-base/src/reqnot.rs b/orchid-base/src/reqnot.rs index 1c73c71..97e3b3d 100644 --- a/orchid-base/src/reqnot.rs +++ b/orchid-base/src/reqnot.rs @@ -299,12 +299,10 @@ impl<'a> MsgWriter<'a> for IoNotifWriter { pub struct CommCtx { exit: Sender<()>, - o: Rc>>>, } impl CommCtx { pub async fn exit(self) -> io::Result<()> { - self.o.lock().await.as_mut().close().await?; self.exit.clone().send(()).await.expect("quit channel dropped"); Ok(()) } @@ -325,7 +323,7 @@ pub fn io_comm( let o = Rc::new(Mutex::new(o)); let (onsub, client) = IoClient::new(o.clone()); let (exit, onexit) = channel(1); - (client, CommCtx { exit, o: o.clone() }, IoCommServer { o, i, onsub, onexit }) + (client, CommCtx { exit }, IoCommServer { o, i, onsub, onexit }) } pub struct IoCommServer { o: Rc>>>, @@ -345,7 +343,6 @@ impl IoCommServer { Sub(ReplySub), Exit, } - let exiting = RefCell::new(false); let input_stream = try_stream(async |mut h| { loop { let mut g = Bound::async_new(i.clone(), async |i| i.lock().await).await; @@ -361,27 +358,21 @@ impl IoCommServer { } }); let (mut add_pending_req, fork_future) = LocalSet::new(); - let mut fork_stream = pin!(fork_future.fuse().into_stream()); + let mut fork_stream = pin!(fork_future.into_stream()); let mut pending_replies = HashMap::new(); 'body: { - let mut shared = pin!(stream_select!( + let mut shared = stream_select! { pin!(input_stream) as Pin<&mut dyn Stream>>, onsub.map(|sub| Ok(Event::Sub(sub))), fork_stream.as_mut().map(|res| { - res.map(|()| panic!("this substream cannot exit while the loop is running")) + res.map(|()| panic!("this substream cannot exit while the loop is running") as Event) }), onexit.map(|()| Ok(Event::Exit)), - )); + }; while let Some(next) = shared.next().await { match next { Err(e) => break 'body Err(e), - Ok(Event::Exit) => { - *exiting.borrow_mut() = true; - let mut out = o.lock().await; - out.as_mut().flush().await?; - out.as_mut().close().await?; - break; - }, + Ok(Event::Exit) => break, Ok(Event::Sub(ReplySub { id, ack, cb })) => { pending_replies.insert(id, cb); ack.send(()).unwrap(); @@ -415,6 +406,9 @@ impl IoCommServer { while let Some(next) = fork_stream.next().await { next? } + let mut out = o.lock().await; + out.as_mut().flush().await?; + out.as_mut().close().await?; Ok(()) } } @@ -427,11 +421,9 @@ mod test { use futures::{SinkExt, StreamExt, join}; use orchid_api_derive::{Coding, Hierarchy}; use orchid_api_traits::Request; - use test_executors::spin_on; use unsync_pipe::pipe; - use crate::logging::test::TestLogger; - use crate::logging::with_logger; + use crate::future_debug::spin_on; use crate::reqnot::{ClientExt, MsgReaderExt, ReqReaderExt, io_comm}; #[derive(Clone, Debug, PartialEq, Coding, Hierarchy)] @@ -440,8 +432,7 @@ mod test { #[test] fn notification() { - let logger = TestLogger::new(async |s| eprint!("{s}")); - spin_on(with_logger(logger, async { + spin_on(async { let (in1, out2) = pipe(1024); let (in2, out1) = pipe(1024); let (received, mut on_receive) = mpsc::channel(2); @@ -468,7 +459,7 @@ mod test { recv_ctx.exit().await.unwrap(); } ); - })) + }) } #[derive(Clone, Debug, Coding, Hierarchy)] @@ -480,8 +471,7 @@ mod test { #[test] fn request() { - let logger = TestLogger::new(async |s| eprint!("{s}")); - spin_on(with_logger(logger, async { + spin_on(async { let (in1, out2) = pipe(1024); let (in2, out1) = pipe(1024); let (_, srv_ctx, srv) = io_comm(Box::pin(in2), Box::pin(out2)); @@ -515,13 +505,12 @@ mod test { client_ctx.exit().await.unwrap(); } ); - })) + }) } #[test] fn exit() { - let logger = TestLogger::new(async |s| eprint!("{s}")); - spin_on(with_logger(logger, async { + spin_on(async { let (input1, output1) = pipe(1024); let (input2, output2) = pipe(1024); let (reply_client, reply_context, reply_server) = @@ -565,6 +554,6 @@ mod test { onexit.await.unwrap(); } ) - })); + }); } } diff --git a/orchid-extension/src/binary.rs b/orchid-extension/src/binary.rs index 416b1bf..786cbf4 100644 --- a/orchid-extension/src/binary.rs +++ b/orchid-extension/src/binary.rs @@ -34,6 +34,10 @@ pub fn orchid_extension_main_body(cx: ExtCx, builder: ExtensionBuilder) { /// # Usage /// /// ``` +/// #[macro_use] +/// use orchid_extension::dylib_main; +/// use orchid_extension::entrypoint::ExtensionBuilder; +/// /// dylib_main! { /// ExtensionBuilder::new("orchid-std::main") /// } diff --git a/orchid-extension/src/entrypoint.rs b/orchid-extension/src/entrypoint.rs index ccb5ccf..16d083c 100644 --- a/orchid-extension/src/entrypoint.rs +++ b/orchid-extension/src/entrypoint.rs @@ -9,7 +9,6 @@ use futures::future::{LocalBoxFuture, join_all}; use futures::{AsyncRead, AsyncWrite, AsyncWriteExt, StreamExt, stream}; use hashbrown::HashMap; use itertools::Itertools; -use orchid_api::{ExtHostNotif, ExtHostReq}; use orchid_api_traits::{Decode, Encode, Request, UnderRoot, enc_vec}; use orchid_base::char_filter::{char_filter_match, char_filter_union, mk_char_filter}; use orchid_base::error::try_with_reporter; @@ -63,7 +62,7 @@ task_local! { } /// Send a request through the global client's [ClientExt::request] -pub async fn request>(t: T) -> T::Response { +pub async fn request>(t: T) -> T::Response { let response = get_client().request(t).await.unwrap(); if MUTE_REPLY.try_with(|b| *b).is_err() { writeln!(log("msg"), "Got response {response:?}").await; @@ -72,7 +71,7 @@ pub async fn request>(t: T) -> T::Resp } /// Send a notification through the global client's [ClientExt::notify] -pub async fn notify>(t: T) { +pub async fn notify>(t: T) { get_client().notify(t).await.unwrap() } diff --git a/orchid-std/src/macros/match_macros.rs b/orchid-std/src/macros/match_macros.rs index 73105d4..169e535 100644 --- a/orchid-std/src/macros/match_macros.rs +++ b/orchid-std/src/macros/match_macros.rs @@ -4,7 +4,6 @@ use async_fn_stream::stream; use futures::future::join_all; use futures::{Stream, StreamExt, stream}; use never::Never; -use orchid_api::ExprTicket; use orchid_api_derive::Coding; use orchid_base::error::{OrcRes, mk_errv}; use orchid_base::format::fmt; @@ -28,7 +27,7 @@ use crate::{HomoTpl, MacTok, MacTree, OrcOpt, Tpl, UntypedTuple, api}; #[derive(Clone, Coding)] pub struct MatcherData { keys: Vec, - matcher: ExprTicket, + matcher: api::ExprTicket, } impl MatcherData { async fn matcher(&self) -> Expr { Expr::from_handle(ExprHandle::from_ticket(self.matcher).await) } diff --git a/orchid-std/src/std/reflection/sym_atom.rs b/orchid-std/src/std/reflection/sym_atom.rs index 9bb8c4e..ee31e00 100644 --- a/orchid-std/src/std/reflection/sym_atom.rs +++ b/orchid-std/src/std/reflection/sym_atom.rs @@ -1,6 +1,5 @@ use std::borrow::Cow; -use orchid_api::TStrv; use orchid_api_derive::{Coding, Hierarchy}; use orchid_api_traits::Request; use orchid_base::error::mk_errv; @@ -37,7 +36,7 @@ impl Supports for SymAtom { #[derive(Clone, Debug, Coding, Hierarchy)] #[extends(StdReq)] -pub struct CreateSymAtom(pub TStrv); +pub struct CreateSymAtom(pub api::TStrv); impl Request for CreateSymAtom { type Response = api::ExprTicket; } diff --git a/orchid-std/src/std/tuple.rs b/orchid-std/src/std/tuple.rs index 1a9a90d..f6e21ca 100644 --- a/orchid-std/src/std/tuple.rs +++ b/orchid-std/src/std/tuple.rs @@ -6,7 +6,6 @@ use std::rc::Rc; use futures::AsyncWrite; use futures::future::join_all; use never::Never; -use orchid_api::ExprTicket; use orchid_api_derive::{Coding, Hierarchy}; use orchid_api_traits::Request; use orchid_base::error::{OrcRes, mk_errv}; @@ -27,7 +26,7 @@ use crate::{Int, StdSystem, api}; pub struct Tuple(pub(super) Rc>); impl Atomic for Tuple { - type Data = Vec; + type Data = Vec; type Variant = OwnedVariant; } diff --git a/unsync-pipe/src/lib.rs b/unsync-pipe/src/lib.rs index 9594893..2b8224e 100644 --- a/unsync-pipe/src/lib.rs +++ b/unsync-pipe/src/lib.rs @@ -27,7 +27,6 @@ pub fn pipe(size: usize) -> (Writer, Reader) { size, mut read_waker, mut write_waker, - mut flush_waker, reader_dropped, writer_dropped, // irrelevant if correctly dropped @@ -42,7 +41,6 @@ pub fn pipe(size: usize) -> (Writer, Reader) { } read_waker.drop(); write_waker.drop(); - flush_waker.drop(); unsafe { dealloc(start, pipe_layout(size)) } } let state = Box::into_raw(Box::new(AsyncRingbuffer { @@ -53,7 +51,6 @@ pub fn pipe(size: usize) -> (Writer, Reader) { write_idx: 0, read_waker: Trigger::empty(), write_waker: Trigger::empty(), - flush_waker: Trigger::empty(), reader_dropped: false, writer_dropped: false, drop, @@ -110,21 +107,22 @@ struct AsyncRingbuffer { write_idx: usize, read_waker: Trigger, write_waker: Trigger, - flush_waker: Trigger, reader_dropped: bool, writer_dropped: bool, drop: extern "C" fn(*const ()), } impl AsyncRingbuffer { + fn wake_reader(&mut self) { self.read_waker.invoke() } + fn wake_writer(&mut self) { self.write_waker.invoke() } fn drop_writer(&mut self) { - self.read_waker.invoke(); + self.wake_reader(); self.writer_dropped = true; if self.reader_dropped { (self.drop)(self.state) } } fn drop_reader(&mut self) { - self.write_waker.invoke(); + self.wake_writer(); self.reader_dropped = true; if self.writer_dropped { (self.drop)(self.state) @@ -134,25 +132,14 @@ impl AsyncRingbuffer { if self.reader_dropped { return Poll::Ready(Err(broken_pipe_error())); } - self.read_waker.invoke(); self.write_waker.drop(); self.write_waker = Trigger::new(waker.clone()); Poll::Pending } - fn flush_wait(&mut self, waker: &Waker) -> Poll> { - if self.reader_dropped { - return Poll::Ready(Err(broken_pipe_error())); - } - self.read_waker.invoke(); - self.flush_waker.drop(); - self.flush_waker = Trigger::new(waker.clone()); - Poll::Pending - } fn reader_wait(&mut self, waker: &Waker) -> Poll> { if self.writer_dropped { return Poll::Ready(Err(broken_pipe_error())); } - self.write_waker.invoke(); self.read_waker.drop(); self.read_waker = Trigger::new(waker.clone()); Poll::Pending @@ -234,7 +221,7 @@ impl Writer { return Err(SyncWriteError::BufferFull); } state.wrapping_write_unchecked(data); - state.write_waker.invoke(); + state.wake_reader(); Ok(()) } } @@ -255,7 +242,7 @@ impl AsyncWrite for Writer { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { unsafe { let data = self.as_mut().get_state()?; - if data.is_empty() { Poll::Ready(Ok(())) } else { data.flush_wait(cx.waker()) } + if data.is_empty() { Poll::Ready(Ok(())) } else { data.writer_wait(cx.waker()) } } } fn poll_write( @@ -263,12 +250,15 @@ impl AsyncWrite for Writer { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } unsafe { let data = self.as_mut().get_state()?; - if !buf.is_empty() && data.is_empty() { - data.read_waker.invoke(); + if data.is_empty() { + data.wake_reader(); } - if !buf.is_empty() && data.is_full() { + if data.is_full() { data.writer_wait(cx.waker()) } else { Poll::Ready(Ok(data.wrapping_write_unchecked(buf))) @@ -300,7 +290,7 @@ impl AsyncRead for Reader { let data = self.0.as_mut().expect("Cannot be null"); let AsyncRingbuffer { read_idx, write_idx, size, .. } = *data; if !buf.is_empty() && data.is_full() { - data.write_waker.invoke(); + data.wake_writer(); } let poll = if !buf.is_empty() && data.is_empty() { // Nothing to read, waiting... @@ -322,8 +312,8 @@ impl AsyncRead for Reader { data.non_wrapping_read_unchecked(&mut start[0..start_count]); Poll::Ready(Ok(end.len() + start_count)) }; - if !buf.is_empty() && data.is_empty() { - data.flush_waker.invoke(); + if data.is_empty() { + data.wake_writer(); } poll }