//! Single-threaded binary safe AsyncWrite/AsyncRead pair //! //! The main entry point is [pipe]. [Writer] and [Reader] can just be sent //! across binary boundaries. A change to the ABI constitutes a major version //! break. use std::alloc::{Layout, alloc, dealloc}; use std::pin::Pin; use std::process::abort; use std::ptr::{null, null_mut, slice_from_raw_parts}; use std::task::{Context, Poll, Waker}; use std::{io, mem}; use futures_io::{AsyncRead, AsyncWrite}; fn pipe_layout(bs: usize) -> Layout { Layout::from_size_align(bs, 1).expect("1-align is trivial") } /// Create a ringbuffer with the specified byte capacity. Once the buffer is /// exhausted, the writer will block. pub fn pipe(size: usize) -> (Writer, Reader) { assert!(0 < size, "cannot create async pipe without buffer"); // SAFETY: the let start = unsafe { alloc(pipe_layout(size)) }; extern "C" fn drop(val: *const ()) { let AsyncRingbuffer { start, size, mut read_waker, mut write_waker, reader_dropped, writer_dropped, // irrelevant if correctly dropped read_idx: _, write_idx: _, // data used to make this call drop: _, state: _, } = *unsafe { Box::from_raw(val as *mut AsyncRingbuffer) }; if !writer_dropped || !reader_dropped { eprintln!("Pipe dropped in err before reader or writer"); abort() } read_waker.drop(); write_waker.drop(); unsafe { dealloc(start, pipe_layout(size)) } } let state = Box::into_raw(Box::new(AsyncRingbuffer { start, size, state: null(), read_idx: 0, write_idx: 0, read_waker: Trigger::empty(), write_waker: Trigger::empty(), reader_dropped: false, writer_dropped: false, drop, })); let state_mut = unsafe { state.as_mut().unwrap() }; state_mut.state = state as *const (); (Writer(state_mut as *mut _), Reader(state_mut as *mut _)) } /// A single-fire empty event, to be distributed by value. Either one of the /// functions can be called exactly once. #[repr(C)] struct Trigger { state: *const (), invoke: extern "C" fn(*const ()), drop: extern "C" fn(*const ()), } impl Trigger { fn new(waker: Waker) -> Self { let state = Box::into_raw(Box::new(waker)) as *const (); extern "C" fn drop(state: *const ()) { unsafe { mem::drop(Box::from_raw(state as *mut Waker)) }; } extern "C" fn invoke(state: *const ()) { unsafe { Box::from_raw(state as *mut Waker) }.wake(); } Self { state, invoke, drop } } fn empty() -> Self { extern "C" fn empty_fn_ptr(_: *const ()) { abort() } Self { state: null(), drop: empty_fn_ptr, invoke: empty_fn_ptr } } fn is_empty(&self) -> bool { self.state.is_null() } fn invoke(&mut self) { if let Some(this) = self.take() { (this.invoke)(this.state) } } fn drop(&mut self) { if let Some(this) = self.take() { (this.drop)(this.state) } } fn take(&mut self) -> Option { (!self.is_empty()).then(|| std::mem::replace(self, Self::empty())) } } /// A ringbuffer for single-threaded synchronized communication. #[repr(C)] struct AsyncRingbuffer { state: *const (), start: *mut u8, size: usize, read_idx: usize, write_idx: usize, read_waker: Trigger, write_waker: Trigger, reader_dropped: bool, writer_dropped: bool, drop: extern "C" fn(*const ()), } impl AsyncRingbuffer { fn drop_writer(&mut self) { self.writer_dropped = true; if self.reader_dropped { (self.drop)(self.state) } } fn drop_reader(&mut self) { self.reader_dropped = true; if self.writer_dropped { (self.drop)(self.state) } } fn writer_wait(&mut self, waker: &Waker) -> Poll> { if self.reader_dropped { return Poll::Ready(Err(broken_pipe_error())); } self.read_waker.invoke(); self.write_waker.drop(); self.write_waker = Trigger::new(waker.clone()); Poll::Pending } fn reader_wait(&mut self, waker: &Waker) -> Poll> { if self.writer_dropped { return Poll::Ready(Err(broken_pipe_error())); } self.write_waker.invoke(); self.read_waker.drop(); self.read_waker = Trigger::new(waker.clone()); Poll::Pending } unsafe fn non_wrapping_write_unchecked(&mut self, buf: &[u8]) { let write_ptr = unsafe { self.start.add(self.write_idx) }; let slc = slice_from_raw_parts(write_ptr, buf.len()).cast_mut(); unsafe { &mut *slc }.copy_from_slice(buf); self.write_idx = (self.write_idx + buf.len()) % self.size; } unsafe fn non_wrapping_read_unchecked(&mut self, buf: &mut [u8]) { let read_ptr = unsafe { self.start.add(self.read_idx) }; let slc = slice_from_raw_parts(read_ptr, buf.len()).cast_mut(); buf.copy_from_slice(unsafe { &*slc }); self.read_idx = (self.read_idx + buf.len()) % self.size; } fn is_full(&self) -> bool { (self.write_idx + 1) % self.size == self.read_idx } fn is_empty(&self) -> bool { self.write_idx == self.read_idx } } fn already_closed_error() -> io::Error { io::Error::new(io::ErrorKind::BrokenPipe, "Pipe already closed from this end") } fn broken_pipe_error() -> io::Error { io::Error::new(io::ErrorKind::BrokenPipe, "Pipe already closed from other end") } /// A binary safe [AsyncWrite] implementor writing to a ringbuffer created by /// [pipe]. #[repr(C)] pub struct Writer(*mut AsyncRingbuffer); impl Writer { unsafe fn get_state(self: Pin<&mut Self>) -> io::Result<&mut AsyncRingbuffer> { match unsafe { self.0.as_mut() } { Some(data) => Ok(data), None => Err(already_closed_error()), } } } impl AsyncWrite for Writer { fn poll_close(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { unsafe { match self.as_mut().get_state() { Err(e) => return Poll::Ready(Err(e)), Ok(data) => { data.drop_writer(); }, } } self.0 = null_mut(); Poll::Ready(Ok(())) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { unsafe { let data = self.as_mut().get_state()?; if data.is_empty() { Poll::Ready(Ok(())) } else { data.writer_wait(cx.waker()) } } } fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { unsafe { let data = self.as_mut().get_state()?; let AsyncRingbuffer { write_idx, read_idx, size, .. } = *data; if !buf.is_empty() && data.is_empty() { data.read_waker.invoke(); } if !buf.is_empty() && data.is_full() { // Writer is blocked data.writer_wait(cx.waker()) } else if write_idx < read_idx { // Non-wrapping backside write w < r <= s let count = buf.len().min(read_idx - write_idx - 1); data.non_wrapping_write_unchecked(&buf[0..count]); Poll::Ready(Ok(count)) } else if data.write_idx + buf.len() < size { // Non-wrapping frontside write r <= w + b < s data.non_wrapping_write_unchecked(&buf[0..buf.len()]); Poll::Ready(Ok(buf.len())) } else if read_idx == 0 { // Frontside write up to origin r=0 < s < w + b data.non_wrapping_write_unchecked(&buf[0..size - write_idx - 1]); Poll::Ready(Ok(size - write_idx - 1)) } else { let (end, start) = buf.split_at(size - write_idx); // Wrapping write r < s < w + b data.non_wrapping_write_unchecked(end); let start_count = start.len().min(read_idx - 1); data.non_wrapping_write_unchecked(&start[0..start_count]); Poll::Ready(Ok(end.len() + start_count)) } } } } impl Drop for Writer { fn drop(&mut self) { unsafe { if let Some(data) = self.0.as_mut() { data.drop_writer(); } } } } /// A binary safe [AsyncRead] implementor reading from a ringbuffer created by /// [pipe] #[repr(C)] pub struct Reader(*mut AsyncRingbuffer); impl AsyncRead for Reader { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { unsafe { let data = self.0.as_mut().expect("Cannot be null"); let AsyncRingbuffer { read_idx, write_idx, size, .. } = *data; if !buf.is_empty() && data.is_full() { data.write_waker.invoke(); } if !buf.is_empty() && data.is_empty() { // Nothing to read, waiting... data.reader_wait(cx.waker()) } else if read_idx < write_idx { // Frontside non-wrapping read let count = buf.len().min(write_idx - read_idx); data.non_wrapping_read_unchecked(&mut buf[0..count]); Poll::Ready(Ok(count)) } else if read_idx + buf.len() < size { // Backside non-wrapping read data.non_wrapping_read_unchecked(buf); Poll::Ready(Ok(buf.len())) } else { // Wrapping read let (end, start) = buf.split_at_mut(size - read_idx); data.non_wrapping_read_unchecked(end); let start_count = start.len().min(write_idx); data.non_wrapping_read_unchecked(&mut start[0..start_count]); Poll::Ready(Ok(end.len() + start_count)) } } } } impl Drop for Reader { fn drop(&mut self) { unsafe { if let Some(data) = self.0.as_mut() { data.drop_reader(); } } } } #[cfg(test)] mod tests { use std::pin::pin; use futures::future::join; use futures::{AsyncReadExt, AsyncWriteExt}; use itertools::Itertools; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; use test_executors::spin_on; use super::*; #[test] fn basic_io() { let mut w_rng = ChaCha8Rng::seed_from_u64(2); let mut r_rng = ChaCha8Rng::seed_from_u64(1); spin_on(async { let (w, r) = pipe(1024); let test_length = 10_000_000; let data = (0u32..test_length).flat_map(|num| num.to_be_bytes()); let write_fut = async { let mut w = pin!(w); let mut source = data.clone(); let mut tally = 0; while tally < test_length * 4 { let values = source.by_ref().take(w_rng.random_range(0..200)).collect::>(); tally += values.len() as u32; w.write_all(&values).await.unwrap(); } w.flush().await.unwrap(); }; let read_fut = async { let mut r = pin!(r); let mut expected = data.clone(); let mut tally = 0; while tally < test_length * 4 { let expected_values = expected.by_ref().take(r_rng.random_range(0..200)).collect::>(); tally += expected_values.len() as u32; let mut values = vec![0; expected_values.len()]; r.read_exact(&mut values[..]).await.unwrap_or_else(|e| panic!("At {tally} bytes: {e}")); if values != expected_values { fn print_bytes(bytes: &[u8]) -> String { (bytes.iter().map(|s| format!("{s:>2x}")).chunks(32).into_iter()) .map(|c| c.into_iter().join(" ")) .join("\n") } panic!( "Difference in generated numbers\n{}\n{}", print_bytes(&values), print_bytes(&expected_values), ) } } }; join(write_fut, read_fut).await; }) } }