Added unsync-pipe with some tests

This commit is contained in:
2025-12-13 02:28:10 +01:00
parent 0f89cde246
commit 224c4ecca2
5 changed files with 546 additions and 330 deletions

13
unsync-pipe/Cargo.toml Normal file
View File

@@ -0,0 +1,13 @@
[package]
name = "unsync-pipe"
version = "0.1.0"
edition = "2024"
[dev-dependencies]
itertools = "0.14.0"
rand = "0.9.2"
rand_chacha = "0.9.0"
test_executors = "0.4.0"
[dependencies]
futures = "0.3.31"

374
unsync-pipe/src/lib.rs Normal file
View File

@@ -0,0 +1,374 @@
use std::alloc::{Layout, alloc, dealloc};
use std::pin::Pin;
use std::process::abort;
use std::ptr::{null, null_mut, slice_from_raw_parts};
use std::task::{Context, Poll, Waker};
use std::{io, mem};
use futures::{AsyncRead, AsyncWrite};
fn pipe_layout(bs: usize) -> Layout { Layout::from_size_align(bs, 1).expect("1-align is trivial") }
pub fn pipe(size: usize) -> (Writer, Reader) {
assert!(0 < size, "cannot create async pipe without buffer");
// SAFETY: the
let start = unsafe { alloc(pipe_layout(size)) };
extern "C" fn drop(val: *const ()) {
let AsyncRingbuffer {
start,
size,
mut read_waker,
mut write_waker,
reader_dropped,
writer_dropped,
// irrelevant if correctly dropped
read_idx: _,
write_idx: _,
// data used to make this call
drop: _,
state: _,
} = *unsafe { Box::from_raw(val as *mut AsyncRingbuffer) };
if !writer_dropped || !reader_dropped {
eprintln!("Pipe dropped in err before reader or writer");
abort()
}
read_waker.drop();
write_waker.drop();
unsafe { dealloc(start, pipe_layout(size)) }
}
let state = Box::into_raw(Box::new(AsyncRingbuffer {
start,
size,
state: null(),
read_idx: 0,
write_idx: 0,
read_waker: Trigger::empty(),
write_waker: Trigger::empty(),
reader_dropped: false,
writer_dropped: false,
drop,
}));
let state_mut = unsafe { state.as_mut().unwrap() };
state_mut.state = state as *const ();
(Writer(state_mut as *mut _), Reader(state_mut as *mut _))
}
/// A single-fire empty event, to be distributed by value. Either one of the
/// functions can be called exactly once.
#[repr(C)]
struct Trigger {
state: *const (),
invoke: extern "C" fn(*const ()),
drop: extern "C" fn(*const ()),
}
impl Trigger {
fn new(waker: Waker) -> Self {
let state = Box::into_raw(Box::new(waker)) as *const ();
extern "C" fn drop(state: *const ()) {
unsafe { mem::drop(Box::from_raw(state as *mut Waker)) };
}
extern "C" fn invoke(state: *const ()) { unsafe { Box::from_raw(state as *mut Waker) }.wake(); }
Self { state, invoke, drop }
}
fn empty() -> Self {
extern "C" fn empty_fn_ptr(_: *const ()) { abort() }
Self { state: null(), drop: empty_fn_ptr, invoke: empty_fn_ptr }
}
fn is_empty(&self) -> bool { self.state.is_null() }
fn invoke(&mut self) {
if let Some(this) = self.take() {
(this.invoke)(this.state)
}
}
fn drop(&mut self) {
if let Some(this) = self.take() {
(this.drop)(this.state)
}
}
fn take(&mut self) -> Option<Self> {
(!self.is_empty()).then(|| std::mem::replace(self, Self::empty()))
}
}
/// A ringbuffer for single-threaded synchronized communication.
#[repr(C)]
struct AsyncRingbuffer {
state: *const (),
start: *mut u8,
size: usize,
read_idx: usize,
write_idx: usize,
read_waker: Trigger,
write_waker: Trigger,
reader_dropped: bool,
writer_dropped: bool,
drop: extern "C" fn(*const ()),
}
impl AsyncRingbuffer {
fn drop_writer(&mut self) {
self.writer_dropped = true;
if self.reader_dropped {
(self.drop)(self.state)
}
}
fn drop_reader(&mut self) {
self.reader_dropped = true;
if self.writer_dropped {
(self.drop)(self.state)
}
}
fn writer_wait<T>(&mut self, waker: &Waker) -> Poll<io::Result<T>> {
if self.reader_dropped {
return Poll::Ready(Err(broken_pipe_error()));
}
self.read_waker.invoke();
self.write_waker.drop();
self.write_waker = Trigger::new(waker.clone());
Poll::Pending
}
fn reader_wait(&mut self, waker: &Waker) -> Poll<io::Result<usize>> {
if self.writer_dropped {
return Poll::Ready(Err(broken_pipe_error()));
}
self.write_waker.invoke();
self.read_waker.drop();
self.read_waker = Trigger::new(waker.clone());
Poll::Pending
}
unsafe fn non_wrapping_write_unchecked(&mut self, buf: &[u8]) {
let write_ptr = unsafe { self.start.add(self.write_idx) };
let slc = slice_from_raw_parts(write_ptr, buf.len()).cast_mut();
unsafe { &mut *slc }.copy_from_slice(buf);
self.write_idx = (self.write_idx + buf.len()) % self.size;
}
unsafe fn non_wrapping_read_unchecked(&mut self, buf: &mut [u8]) {
let read_ptr = unsafe { self.start.add(self.read_idx) };
let slc = slice_from_raw_parts(read_ptr, buf.len()).cast_mut();
buf.copy_from_slice(unsafe { &*slc });
self.read_idx = (self.read_idx + buf.len()) % self.size;
}
fn is_full(&self) -> bool { (self.write_idx + 1) % self.size == self.read_idx }
fn is_empty(&self) -> bool { self.write_idx == self.read_idx }
}
fn already_closed_error() -> io::Error {
io::Error::new(io::ErrorKind::BrokenPipe, "Pipe already closed from this end")
}
fn broken_pipe_error() -> io::Error {
io::Error::new(io::ErrorKind::BrokenPipe, "Pipe already closed from other end")
}
#[repr(C)]
pub struct Writer(*mut AsyncRingbuffer);
impl Writer {
unsafe fn get_state(self: Pin<&mut Self>) -> io::Result<&mut AsyncRingbuffer> {
match unsafe { self.0.as_mut() } {
Some(data) => Ok(data),
None => Err(already_closed_error()),
}
}
}
impl AsyncWrite for Writer {
fn poll_close(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
unsafe {
match self.as_mut().get_state() {
Err(e) => return Poll::Ready(Err(e)),
Ok(data) => {
data.drop_writer();
},
}
}
self.0 = null_mut();
Poll::Ready(Ok(()))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unsafe {
let data = self.as_mut().get_state()?;
if data.is_empty() { Poll::Ready(Ok(())) } else { data.writer_wait(cx.waker()) }
}
}
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let w = unsafe {
let data = self.as_mut().get_state()?;
let AsyncRingbuffer { write_idx, read_idx, size, .. } = *data;
if !buf.is_empty() && data.is_empty() {
eprintln!("Wake reader");
data.read_waker.invoke();
}
if !buf.is_empty() && data.is_full() {
eprintln!("Writer is blocked, waiting");
data.writer_wait(cx.waker())
} else if write_idx < read_idx {
eprintln!("Non-wrapping backside write w={write_idx} < r={read_idx} <= s={size}");
let count = buf.len().min(read_idx - write_idx - 1);
data.non_wrapping_write_unchecked(&buf[0..count]);
Poll::Ready(Ok(count))
} else if data.write_idx + buf.len() < size {
eprintln!(
"Non-wrapping frontside write r={read_idx} <= w={write_idx} + b={} < s={size}",
buf.len()
);
data.non_wrapping_write_unchecked(&buf[0..buf.len()]);
Poll::Ready(Ok(buf.len()))
} else if read_idx == 0 {
eprintln!("Frontside write up to origin r=0 < s={size} < w={write_idx} + b={}", buf.len());
data.non_wrapping_write_unchecked(&buf[0..size - write_idx - 1]);
Poll::Ready(Ok(size - write_idx - 1))
} else {
let (end, start) = buf.split_at(size - write_idx);
eprintln!("Wrapping write r={read_idx} < s={size} < w={write_idx} + b={}", buf.len());
data.non_wrapping_write_unchecked(end);
let start_count = start.len().min(read_idx - 1);
data.non_wrapping_write_unchecked(&start[0..start_count]);
Poll::Ready(Ok(end.len() + start_count))
}
};
if let Poll::Ready(Ok(w)) = &w {
eprintln!("Wrote {w}")
}
w
}
}
impl Drop for Writer {
fn drop(&mut self) {
eprintln!("Dropping writer");
unsafe {
if let Some(data) = self.0.as_mut() {
data.drop_writer();
}
}
}
}
#[repr(C)]
pub struct Reader(*mut AsyncRingbuffer);
impl AsyncRead for Reader {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
eprintln!("Beginning read of {}", buf.len());
let r = unsafe {
let data = self.0.as_mut().expect("Cannot be null");
let AsyncRingbuffer { read_idx, write_idx, size, .. } = *data;
if !buf.is_empty() && data.is_full() {
eprintln!("Wake writer");
data.write_waker.invoke();
}
if !buf.is_empty() && data.is_empty() {
eprintln!("Nothing to read, waiting...");
data.reader_wait(cx.waker())
} else if read_idx < write_idx {
eprintln!("frontside non-wrapping read");
let count = buf.len().min(write_idx - read_idx);
data.non_wrapping_read_unchecked(&mut buf[0..count]);
Poll::Ready(Ok(count))
} else if read_idx + buf.len() < size {
eprintln!("backside non-wrapping read");
data.non_wrapping_read_unchecked(buf);
Poll::Ready(Ok(buf.len()))
} else {
eprintln!("wrapping read");
let (end, start) = buf.split_at_mut(size - read_idx);
data.non_wrapping_read_unchecked(end);
let start_count = start.len().min(write_idx);
data.non_wrapping_read_unchecked(&mut start[0..start_count]);
Poll::Ready(Ok(end.len() + start_count))
}
};
if let Poll::Ready(Ok(r)) = &r {
eprintln!("Read {r}")
}
r
}
}
impl Drop for Reader {
fn drop(&mut self) {
eprintln!("Dropping reader");
unsafe {
if let Some(data) = self.0.as_mut() {
data.drop_reader();
}
}
}
}
#[cfg(test)]
mod tests {
use std::io::Write;
use std::pin::pin;
use futures::future::join;
use futures::{AsyncReadExt, AsyncWriteExt};
use itertools::Itertools;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use test_executors::spin_on;
use super::*;
#[test]
fn basic_io() {
let mut w_rng = ChaCha8Rng::seed_from_u64(2);
let mut r_rng = ChaCha8Rng::seed_from_u64(1);
println!("Output check");
spin_on(async {
let (w, r) = pipe(1024);
let test_length = 100_000;
let data = (0u32..test_length).flat_map(|num| num.to_be_bytes());
let write_fut = async {
let mut w = pin!(w);
let mut source = data.clone();
let mut tally = 0;
while tally < test_length * 8 {
let values = source.by_ref().take(w_rng.random_range(0..200)).collect::<Vec<_>>();
tally += values.len() as u32;
w.write_all(&values).await.unwrap();
}
w.flush().await.unwrap();
};
let read_fut = async {
let mut r = pin!(r);
let mut expected = data.clone();
let mut tally = 0;
let mut aggregate = Vec::new();
let mut expected_aggregate = Vec::new();
let mut percentage = 0;
while tally < test_length * 8 {
let next_percentage = tally * 100 / (test_length * 8);
if percentage < next_percentage {
percentage = next_percentage;
println!("{percentage}%");
io::stdout().flush().unwrap();
}
let expected_values =
expected.by_ref().take(r_rng.random_range(0..200)).collect::<Vec<_>>();
tally += expected_values.len() as u32;
let mut values = vec![0; expected_values.len()];
r.read_exact(&mut values[..]).await.unwrap_or_else(|e| panic!("At {tally} bytes: {e}"));
aggregate.extend_from_slice(&values);
expected_aggregate.extend_from_slice(&expected_values);
if values != expected_values {
fn print_bytes(bytes: &[u8]) -> String {
(bytes.iter().map(|s| format!("{s:>2x}")).chunks(32).into_iter())
.map(|c| c.into_iter().join(" "))
.join("\n")
}
panic!(
"Difference in generated numbers\n{}\n{}",
print_bytes(&aggregate),
print_bytes(&expected_aggregate),
)
}
}
eprintln!("Read {tally} correct bytes")
};
join(write_fut, read_fut).await;
})
}
}