use std::cell::RefCell; use std::future::Future; use std::io; use std::num::NonZeroU64; use std::pin::Pin; use std::rc::{Rc, Weak}; use async_fn_stream::stream; use derive_destructure::destructure; use futures::channel::mpsc::{Sender, channel}; use futures::future::{join, join_all}; use futures::lock::Mutex; use futures::{AsyncRead, AsyncWrite, AsyncWriteExt, SinkExt, StreamExt}; use hashbrown::{HashMap, HashSet}; use itertools::Itertools; use orchid_api_traits::{Decode, Encode, Request}; use orchid_base::clone; use orchid_base::format::{FmtCtxImpl, Format}; use orchid_base::interner::{IStr, IStrv, es, ev, is, iv}; use orchid_base::location::Pos; use orchid_base::logging::logger; use orchid_base::name::Sym; use orchid_base::reqnot::{Client, ClientExt, MsgReaderExt, ReqHandleExt, ReqReaderExt, io_comm}; use orchid_base::stash::{stash, with_stash}; use orchid_base::tree::AtomRepr; use crate::api; use crate::atom::AtomHand; use crate::ctx::{Ctx, JoinHandle}; use crate::dealias::{ChildError, ChildErrorKind, walk}; use crate::expr::{Expr, PathSetBuilder}; use crate::system::SystemCtor; use crate::tree::MemberKind; pub struct ExtPort { pub input: Pin>, pub output: Pin>, } pub struct ReqPair(R, Sender); /// Data held about an Extension. This is refcounted within [Extension]. It's /// important to only ever access parts of this struct through the [Arc] because /// the components reference each other through [Weak]s of it, and will panic if /// upgrading fails. #[derive(destructure)] pub struct ExtensionData { name: String, ctx: Ctx, join_ext: Option>, client: Rc, systems: Vec, next_pars: RefCell, lex_recur: Mutex>>>, strings: RefCell>, string_vecs: RefCell>, } impl Drop for ExtensionData { fn drop(&mut self) { let client = self.client.clone(); let join_ext = self.join_ext.take().expect("Only called once in Drop"); stash(async move { client.notify(api::HostExtNotif::Exit).await.unwrap(); join_ext.join().await; }) } } #[derive(Clone)] pub struct Extension(Rc); impl Extension { pub async fn new(mut init: ExtPort, ctx: Ctx) -> io::Result { api::HostHeader { log_strategy: logger().strat(), msg_logs: ctx.msg_logs.strat() } .encode(init.input.as_mut()) .await .unwrap(); init.input.flush().await.unwrap(); let header = api::ExtensionHeader::decode(init.output.as_mut()).await.unwrap(); Ok(Self(Rc::new_cyclic(|weak: &Weak| { // context not needed because exit is extension-initiated let (client, _, future) = io_comm( Rc::new(Mutex::new(init.input)), Mutex::new(init.output), clone!(weak; async move |reader| { with_stash(async { let this = Extension(weak.upgrade().unwrap()); let notif = reader.read::().await.unwrap(); if !matches!(notif, api::ExtHostNotif::Log(_)) { writeln!(logger(), "Host received notif {notif:?}"); } match notif { api::ExtHostNotif::ExprNotif(api::ExprNotif::Acquire(acq)) => { let target = this.0.ctx.exprs.get_expr(acq.1).expect("Invalid ticket"); this.0.ctx.exprs.give_expr(target) } api::ExtHostNotif::ExprNotif(api::ExprNotif::Release(rel)) => { if this.is_own_sys(rel.0).await { this.0.ctx.exprs.take_expr(rel.1); } else { writeln!(this.0.ctx.msg_logs, "Not our system {:?}", rel.0) } }, api::ExtHostNotif::Log(api::Log(str)) => logger().log(str), api::ExtHostNotif::Sweeped(data) => { for i in join_all(data.strings.into_iter().map(es)).await { this.0.strings.borrow_mut().remove(&i); } for i in join_all(data.vecs.into_iter().map(ev)).await { this.0.string_vecs.borrow_mut().remove(&i); } }, } Ok(()) }).await }), { clone!(weak, ctx); async move |mut reader| { with_stash(async { let this = Self(weak.upgrade().unwrap()); let req = reader.read_req::().await.unwrap(); let handle = reader.finish().await; if !matches!(req, api::ExtHostReq::ExtAtomPrint(_)) { writeln!(logger(), "Host received request {req:?}"); } match req { api::ExtHostReq::Ping(ping) => handle.reply(&ping, &()).await, api::ExtHostReq::IntReq(intreq) => match intreq { api::IntReq::InternStr(s) => { let i = is(&s.0).await; this.0.strings.borrow_mut().insert(i.clone()); handle.reply(&s, &i.to_api()).await }, api::IntReq::InternStrv(v) => { let tokens = join_all(v.0.iter().map(|m| es(*m))).await; this.0.strings.borrow_mut().extend(tokens.iter().cloned()); let i = iv(&tokens).await; this.0.string_vecs.borrow_mut().insert(i.clone()); handle.reply(&v, &i.to_api()).await }, api::IntReq::ExternStr(si) => { let i = es(si.0).await; this.0.strings.borrow_mut().insert(i.clone()); handle.reply(&si, &i.to_string()).await }, api::IntReq::ExternStrv(vi) => { let i = ev(vi.0).await; this.0.strings.borrow_mut().extend(i.iter().cloned()); this.0.string_vecs.borrow_mut().insert(i.clone()); let markerv = i.iter().map(|t| t.to_api()).collect_vec(); handle.reply(&vi, &markerv).await }, }, api::ExtHostReq::Fwd(ref fw @ api::Fwd(ref atom, ref key, ref body)) => { let sys = ctx.system_inst(atom.owner).await.expect("owner of live atom dropped"); let client = sys.client(); let reply = client.request(api::Fwded(fw.0.clone(), *key, body.clone())).await.unwrap(); handle.reply(fw, &reply).await }, api::ExtHostReq::SysFwd(ref fw @ api::SysFwd(id, ref body)) => { let sys = ctx.system_inst(id).await.unwrap(); handle.reply(fw, &sys.request(body.clone()).await).await }, api::ExtHostReq::SubLex(sl) => { let (rep_in, mut rep_out) = channel(0); { let lex_g = this.0.lex_recur.lock().await; let mut req_in = lex_g.get(&sl.id).cloned().expect("Sublex for nonexistent lexid"); req_in.send(ReqPair(sl.clone(), rep_in)).await.unwrap(); } handle.reply(&sl, &rep_out.next().await.unwrap()).await }, api::ExtHostReq::ExprReq(expr_req) => match expr_req { api::ExprReq::Inspect(ins @ api::Inspect { target }) => { let expr = ctx.exprs.get_expr(target).expect("Invalid ticket"); handle .reply(&ins, &api::Inspected { refcount: expr.strong_count() as u32, location: expr.pos().to_api(), kind: expr.to_api().await, }) .await }, api::ExprReq::Create(ref cre @ api::Create(ref expr)) => { let expr = Expr::from_api(expr, PathSetBuilder::new(), ctx.clone()).await; let expr_id = expr.id(); ctx.exprs.give_expr(expr); handle.reply(cre, &expr_id).await }, }, api::ExtHostReq::LsModule(ref ls @ api::LsModule(_sys, path)) => { let reply: ::Response = 'reply: { let path = ev(path).await; let root = (ctx.root.read().await.upgrade()) .expect("LSModule called when root isn't in context"); let root_data = &*root.0.read().await; let mut walk_ctx = (ctx.clone(), &root_data.consts); let module = match walk(&root_data.root, false, path.iter().cloned(), &mut walk_ctx).await { Ok(module) => module, Err(ChildError { kind, .. }) => break 'reply Err(match kind { ChildErrorKind::Private => panic!("Access checking was disabled"), ChildErrorKind::Constant => api::LsModuleError::IsConstant, ChildErrorKind::Missing => api::LsModuleError::InvalidPath, }), }; let mut members = std::collections::HashMap::new(); for (k, v) in &module.members { let kind = match v.kind(ctx.clone(), &root_data.consts).await { MemberKind::Const => api::MemberInfoKind::Constant, MemberKind::Module(_) => api::MemberInfoKind::Module, }; members.insert(k.to_api(), api::MemberInfo { public: v.public, kind }); } Ok(api::ModuleInfo { members }) }; handle.reply(ls, &reply).await }, api::ExtHostReq::ResolveNames(ref rn) => { let api::ResolveNames { constid, names, sys } = rn; let mut resolver = { let systems = ctx.systems.read().await; let weak_sys = systems.get(sys).expect("ResolveNames for invalid sys"); let sys = weak_sys.upgrade().expect("ResolveNames after sys drop"); sys.name_resolver(*constid).await }; let responses = stream(async |mut cx| { for name in names { cx.emit(match resolver(&ev(*name).await[..]).await { Ok(abs) => Ok(abs.to_sym().await.to_api()), Err(e) => Err(e.to_api()), }) .await } }) .collect() .await; handle.reply(rn, &responses).await }, api::ExtHostReq::ExtAtomPrint(ref eap @ api::ExtAtomPrint(ref atom)) => { let atom = AtomHand::from_api(atom, Pos::None, &mut ctx.clone()).await; let unit = atom.print(&FmtCtxImpl::default()).await; handle.reply(eap, &unit.to_api()).await }, } }) .await } }, ); let join_ext = ctx.spawn(async { future.await.unwrap(); // extension exited successfully }); ExtensionData { name: header.name.clone(), ctx: ctx.clone(), systems: (header.systems.iter().cloned()) .map(|decl| SystemCtor { decl, ext: WeakExtension(weak.clone()) }) .collect(), join_ext: Some(join_ext), next_pars: RefCell::new(NonZeroU64::new(1).unwrap()), lex_recur: Mutex::default(), client: Rc::new(client), strings: RefCell::default(), string_vecs: RefCell::default(), } }))) } pub fn name(&self) -> &String { &self.0.name } #[must_use] pub fn client(&self) -> &dyn Client { &*self.0.client } #[must_use] pub fn ctx(&self) -> &Ctx { &self.0.ctx } pub fn system_ctors(&self) -> impl Iterator { self.0.systems.iter() } #[must_use] pub async fn is_own_sys(&self, id: api::SysId) -> bool { let Some(sys) = self.ctx().system_inst(id).await else { writeln!(logger(), "Invalid system ID {id:?}"); return false; }; Rc::ptr_eq(&self.0, &sys.ext().0) } #[must_use] pub fn next_pars(&self) -> NonZeroU64 { let mut next_pars = self.0.next_pars.borrow_mut(); *next_pars = next_pars.checked_add(1).unwrap_or(NonZeroU64::new(1).unwrap()); *next_pars } pub(crate) async fn lex_req>>( &self, source: IStr, src: Sym, pos: u32, sys: api::SysId, mut r: impl FnMut(u32) -> F, ) -> api::OrcResult> { // get unique lex ID let id = api::ParsId(self.next_pars()); // create and register channel let (req_in, mut req_out) = channel(0); self.0.lex_recur.lock().await.insert(id, req_in); // lex_recur released let (ret, ()) = join( async { let res = (self.client()) .request(api::LexExpr { id, pos, sys, src: src.to_api(), text: source.to_api() }) .await .unwrap(); // collect sender to unblock recursion handler branch before returning self.0.lex_recur.lock().await.remove(&id); res }, async { while let Some(ReqPair(sublex, mut rep_in)) = req_out.next().await { (rep_in.send(r(sublex.pos).await).await) .expect("Response channel dropped while request pending") } }, ) .await; ret.transpose() } pub fn system_drop(&self, id: api::SysId) { let rc = self.clone(); let _ = self.ctx().spawn(with_stash(async move { rc.client().request(api::SystemDrop(id)).await.unwrap(); rc.ctx().systems.write().await.remove(&id); })); } #[must_use] pub fn downgrade(&self) -> WeakExtension { WeakExtension(Rc::downgrade(&self.0)) } } #[derive(Clone)] pub struct WeakExtension(Weak); impl WeakExtension { #[must_use] pub fn upgrade(&self) -> Option { self.0.upgrade().map(Extension) } }