From 06debb363656007fa2e18d99dfcd538c31f99b0f Mon Sep 17 00:00:00 2001 From: Lawrence Bethlenfalvy Date: Tue, 16 Dec 2025 00:02:45 +0100 Subject: [PATCH] Tests pass for reqnot --- Cargo.lock | 2 +- orchid-api-traits/src/relations.rs | 10 ++ orchid-base/Cargo.toml | 5 +- orchid-base/src/reqnot.rs | 179 +++++++++++++++-------------- 4 files changed, 105 insertions(+), 91 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5a036dc..4c7efee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1058,7 +1058,7 @@ dependencies = [ "rust-embed", "some_executor", "substack", - "test_executors 0.3.5", + "test_executors 0.4.0", "trait-set", "unsync-pipe", ] diff --git a/orchid-api-traits/src/relations.rs b/orchid-api-traits/src/relations.rs index 8cb925e..0c6e5e5 100644 --- a/orchid-api-traits/src/relations.rs +++ b/orchid-api-traits/src/relations.rs @@ -1,5 +1,7 @@ use core::fmt; +use never::Never; + use super::coding::Coding; use crate::helpers::enc_vec; @@ -13,8 +15,16 @@ pub trait Channel: 'static { type Req: Coding + Sized + 'static; type Notif: Coding + Sized + 'static; } +impl Channel for Never { + type Notif = Never; + type Req = Never; +} pub trait MsgSet: Sync + 'static { type In: Channel; type Out: Channel; } +impl MsgSet for Never { + type In = Never; + type Out = Never; +} diff --git a/orchid-base/Cargo.toml b/orchid-base/Cargo.toml index 8311b59..ab3bea7 100644 --- a/orchid-base/Cargo.toml +++ b/orchid-base/Cargo.toml @@ -26,5 +26,8 @@ regex = "1.11.2" rust-embed = "8.7.2" some_executor = "0.6.1" substack = "1.1.1" -test_executors = "0.3.5" trait-set = "0.3.0" + +[dev-dependencies] +futures = "0.3.31" +test_executors = "0.4.0" diff --git a/orchid-base/src/reqnot.rs b/orchid-base/src/reqnot.rs index 5dd9b9e..38a9b06 100644 --- a/orchid-base/src/reqnot.rs +++ b/orchid-base/src/reqnot.rs @@ -1,4 +1,3 @@ -use std::any::TypeId; use std::cell::RefCell; use std::collections::VecDeque; use std::future::Future; @@ -14,9 +13,11 @@ use futures::channel::mpsc::{self, Receiver, Sender, channel}; use futures::channel::oneshot; use futures::future::LocalBoxFuture; use futures::lock::{Mutex, MutexGuard}; -use futures::{AsyncRead, AsyncWrite, SinkExt, Stream, StreamExt, stream, stream_select}; +use futures::{ + AsyncRead, AsyncWrite, AsyncWriteExt, SinkExt, Stream, StreamExt, stream, stream_select, +}; use hashbrown::HashMap; -use orchid_api_traits::{Decode, Encode, Request, UnderRoot}; +use orchid_api_traits::{Channel, Decode, Encode, Request, UnderRoot}; #[must_use = "Receipts indicate that a required action has been performed within a function. \ Most likely this should be returned somewhere."] @@ -59,8 +60,6 @@ pub trait MsgWriter { /// For initiating outbound requests and notifications pub trait Client { - fn root_req_tid(&self) -> TypeId; - fn root_notif_tid(&self) -> TypeId; fn start_request(&self) -> LocalBoxFuture<'_, Box>; fn start_notif(&self) -> LocalBoxFuture<'_, Box>; } @@ -68,9 +67,8 @@ pub trait Client { /// Extension trait with convenience methods that handle outbound request and /// notif lifecycle and typing #[allow(async_fn_in_trait)] -pub trait ClientExt: Client { - async fn request>(&self, t: T) -> T::Response { - assert_eq!(TypeId::of::<::Root>(), self.root_req_tid()); +pub trait ClientExt: Client { + async fn request>(&self, t: T) -> T::Response { let mut req = self.start_request().await; t.into_root().encode(req.writer().as_mut()).await; let mut rep = req.send().await; @@ -78,14 +76,13 @@ pub trait ClientExt: Client { rep.finish().await; response } - async fn notify + 'static>(&self, t: T) { - assert_eq!(TypeId::of::<::Root>(), self.root_notif_tid()); + async fn notify>(&self, t: T) { let mut notif = self.start_notif().await; t.into_root().encode(notif.writer().as_mut()).await; notif.finish().await; } } -impl ClientExt for T {} +impl ClientExt for T {} /// A form of [Evidence] that doesn't require the value to be kept around pub struct Witness(PhantomData); @@ -110,7 +107,7 @@ type IoGuard = Bound>>, IoLock>; /// An incoming request. This holds a lock on the ingress channel. pub struct ReqReader<'a> { id: u64, - read: MutexGuard<'a, IoRef>, + read: IoGuard, write: &'a Mutex>, } impl<'a> ReqReader<'a> { @@ -148,11 +145,15 @@ pub struct RepWriter<'a> { } impl<'a> RepWriter<'a> { pub fn writer(&mut self) -> Pin<&mut dyn AsyncWrite> { self.write.as_mut() } - pub async fn send(self) -> Receipt<'a> { Receipt(PhantomData) } + pub async fn send(mut self) -> Receipt<'a> { + self.writer().flush().await.unwrap(); + Receipt(PhantomData) + } } pub struct NotifReader<'a> { - read: MutexGuard<'a, IoRef>, + _pd: PhantomData<&'a mut ()>, + read: IoGuard, } impl<'a> NotifReader<'a> { pub fn reader(&mut self) -> Pin<&mut dyn AsyncRead> { self.read.as_mut() } @@ -164,6 +165,7 @@ impl<'a> NotifReader<'a> { pub async fn release(self) {} } +#[derive(Debug)] struct ReplySub { id: u64, ack: oneshot::Sender<()>, @@ -174,27 +176,17 @@ struct IoClient { output: IoLock, id: Rc>, subscribe: Rc>, - req_tid: TypeId, - notif_tid: TypeId, } impl IoClient { - fn new(output: IoLock) -> (Receiver, Self) { + fn new(output: IoLock) -> (Receiver, Self) { let (req, rep) = mpsc::channel(0); - (rep, Self { - output, - id: Rc::new(RefCell::new(0)), - req_tid: TypeId::of::(), - notif_tid: TypeId::of::(), - subscribe: Rc::new(req), - }) + (rep, Self { output, id: Rc::new(RefCell::new(0)), subscribe: Rc::new(req) }) } async fn lock_out(&self) -> IoGuard { Bound::async_new(self.output.clone(), async |o| o.lock().await).await } } impl Client for IoClient { - fn root_notif_tid(&self) -> TypeId { self.notif_tid } - fn root_req_tid(&self) -> TypeId { self.req_tid } fn start_notif(&self) -> LocalBoxFuture<'_, Box> { Box::pin(async { let mut o = self.lock_out().await; @@ -256,37 +248,41 @@ impl MsgWriter for IoNotifWriter { } pub struct CommCtx { - quit: Sender<()>, + exit: Sender<()>, } impl CommCtx { - pub async fn quit(self) { self.quit.clone().send(()).await.expect("quit channel dropped"); } + pub async fn exit(self) { self.exit.clone().send(()).await.expect("quit channel dropped"); } } /// Establish bidirectional request-notification communication over a duplex /// channel. The returned [IoClient] can be used for notifications immediately, /// but requests can only be received while the future is running. The future -/// will only resolve when [CommCtx::quit] is called. -pub fn io_comm( +/// will only resolve when [CommCtx::quit] is called. The generic type +/// parameters are associated with the client and serve to ensure with a runtime +/// check that the correct message families are sent in the correct directions +/// across the channel. +pub fn io_comm( o: Rc>>>, i: Mutex>>, - notif: impl for<'a> AsyncFn(&mut CommCtx, NotifReader<'a>), - req: impl for<'a> AsyncFn(&mut CommCtx, ReqReader<'a>) -> Receipt<'a>, -) -> (impl Client, impl Future) { + notif: impl for<'a> AsyncFn(NotifReader<'a>), + req: impl for<'a> AsyncFn(ReqReader<'a>) -> Receipt<'a>, +) -> (impl ClientExt, CommCtx, impl Future) { let i = Rc::new(i); - let (onsub, client) = IoClient::new::(o.clone()); - (client, async move { - let (exit, onexit) = channel(1); + let (onsub, client) = IoClient::new(o.clone()); + let (exit, onexit) = channel(1); + (client, CommCtx { exit }, async move { enum Event { - Input(u64), + Input(u64, IoGuard), Sub(ReplySub), Exit, } let exiting = RefCell::new(false); let input_stream = stream(async |mut h| { loop { - let id = u64::decode(i.lock().await.as_mut()).await; - h.emit(Event::Input(id)).await; + let mut g = Bound::async_new(i.clone(), async |i| i.lock().await).await; + let id = u64::decode(g.as_mut()).await; + h.emit(Event::Input(id, g)).await; } }); let pending_reqs = RefCell::new(VecDeque::>::new()); @@ -320,26 +316,22 @@ pub fn io_comm( pending_replies.insert(id, cb); ack.send(()).unwrap(); }, - Event::Input(0) => { - let (i, notif, exit) = (i.clone(), ¬if, exit.clone()); + Event::Input(0, read) => { + let notif = ¬if; pending_reqs.borrow_mut().push_back(Box::pin(async move { - let g = i.lock().await; - notif(&mut CommCtx { quit: exit.clone() }, NotifReader { read: g }).await + notif(NotifReader { _pd: PhantomData, read }).await })); }, // id.msb == 0 is a request, !id where id.msb == 1 is the equivalent response - Event::Input(id) => + Event::Input(id, read) => if (id & (1 << (u64::BITS - 1))) == 0 { - let (i, o, req, exit) = (i.clone(), o.clone(), &req, exit.clone()); + let (o, req) = (o.clone(), &req); pending_reqs.borrow_mut().push_back(Box::pin(async move { - let g = i.lock().await; - let _ = - req(&mut CommCtx { quit: exit.clone() }, ReqReader { id, read: g, write: &o }) - .await; + let _ = req(ReqReader { id, read, write: &o }).await; }) as LocalBoxFuture<()>); } else { let cb = pending_replies.remove(&!id).expect("Reply to unrecognized request"); - let _ = cb.send(Bound::async_new(i.clone(), |i| i.lock()).await); + cb.send(read).unwrap_or_else(|_| panic!("Failed to send reply")); }, } } @@ -350,14 +342,14 @@ pub fn io_comm( #[cfg(test)] mod test { - use std::cell::RefCell; use std::rc::Rc; - use futures::join; + use futures::channel::mpsc; use futures::lock::Mutex; + use futures::{SinkExt, StreamExt, join}; use never::Never; use orchid_api_derive::{Coding, Hierarchy}; - use orchid_api_traits::Request; + use orchid_api_traits::{Channel, Request}; use test_executors::spin_on; use unsync_pipe::pipe; @@ -367,35 +359,6 @@ mod test { #[extendable] struct TestNotif(u64); - #[test] - fn notification() { - spin_on(async { - let (in1, out2) = pipe(1024); - let (in2, out1) = pipe(1024); - let received = RefCell::new(None); - let (_, run_receiver) = io_comm::( - Rc::new(Mutex::new(Box::pin(in2))), - Mutex::new(Box::pin(out2)), - async |_, notif: NotifReader| { - *received.borrow_mut() = Some(notif.read::().await) - }, - async |_, _| panic!("Should receive notif, not request"), - ); - let (sender, _) = io_comm::( - Rc::new(Mutex::new(Box::pin(in1))), - Mutex::new(Box::pin(out1)), - async |_, _| panic!("Should not receive notif"), - async |_, _| panic!("Should not receive request"), - ); - join!(run_receiver, async { - sender.notify(TestNotif(3)).await; - assert_eq!(*received.borrow(), Some(TestNotif(3))); - sender.notify(TestNotif(4)).await; - assert_eq!(*received.borrow(), Some(TestNotif(4))); - }); - }) - } - #[derive(Clone, Debug, Coding, Hierarchy)] #[extendable] struct DummyRequest(u64); @@ -403,29 +366,67 @@ mod test { type Response = u64; } + struct TestChannel; + impl Channel for TestChannel { + type Notif = TestNotif; + type Req = DummyRequest; + } + + #[test] + fn notification() { + spin_on(async { + let (in1, out2) = pipe(1024); + let (in2, out1) = pipe(1024); + let (received, mut on_receive) = mpsc::channel(2); + let (_, recv_ctx, run_recv) = io_comm::( + Rc::new(Mutex::new(Box::pin(in2))), + Mutex::new(Box::pin(out2)), + async |notif: NotifReader| { + received.clone().send(notif.read::().await).await.unwrap(); + }, + async |_| panic!("Should receive notif, not request"), + ); + let (sender, ..) = io_comm::( + Rc::new(Mutex::new(Box::pin(in1))), + Mutex::new(Box::pin(out1)), + async |_| panic!("Should not receive notif"), + async |_| panic!("Should not receive request"), + ); + join!(run_recv, async { + sender.notify(TestNotif(3)).await; + assert_eq!(on_receive.next().await, Some(TestNotif(3))); + sender.notify(TestNotif(4)).await; + assert_eq!(on_receive.next().await, Some(TestNotif(4))); + recv_ctx.exit().await; + }); + }) + } + #[test] fn request() { spin_on(async { let (in1, out2) = pipe(1024); let (in2, out1) = pipe(1024); - let (_, run_server) = io_comm::( + let (_, srv_ctx, run_srv) = io_comm::( Rc::new(Mutex::new(Box::pin(in2))), Mutex::new(Box::pin(out2)), - async |_, _| panic!("No notifs expected"), - async |_, mut req| { + async |_| panic!("No notifs expected"), + async |mut req| { let val = req.read_req::().await; req.reply(&val, &(val.0 + 1)).await }, ); - let (client, run_client) = io_comm::( + let (client, client_ctx, run_client) = io_comm::( Rc::new(Mutex::new(Box::pin(in1))), Mutex::new(Box::pin(out1)), - async |_, _| panic!("Not expecting ingress notif"), - async |_, _| panic!("Not expecting ingress req"), + async |_| panic!("Not expecting ingress notif"), + async |_| panic!("Not expecting ingress req"), ); - join!(run_server, run_client, async { + join!(run_srv, run_client, async { let response = client.request(DummyRequest(5)).await; assert_eq!(response, 6); + srv_ctx.exit().await; + client_ctx.exit().await; }); }) }