use std::any::TypeId; use std::borrow::Cow; use std::cell::RefCell; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::rc::Rc; use futures::future::{LocalBoxFuture, join_all}; use futures::{AsyncWrite, FutureExt}; use itertools::Itertools; use never::Never; use orchid_api_traits::Encode; use orchid_base::{FmtCtx, FmtUnit, OrcRes, Pos, Sym, clone}; use task_local::task_local; use trait_set::trait_set; use crate::api; use crate::atom::Atomic; use crate::atom_owned::{DeserializeCtx, OwnedAtom, OwnedVariant}; use crate::conv::ToExpr; use crate::coroutine_exec::{ExecHandle, exec}; use crate::expr::Expr; use crate::gen_expr::{GExpr, new_atom}; use crate::system::sys_id; trait_set! { trait FunCB = Fn(Vec) -> LocalBoxFuture<'static, OrcRes> + 'static; } task_local! { static ARGV: Vec; } pub fn get_arg(idx: usize) -> Expr { ARGV .try_with(|argv| { (argv.get(idx).cloned()) .unwrap_or_else(|| panic!("Cannot read argument ##{idx}, only have {}", argv.len())) }) .expect("get_arg called outside ExprFunc") } pub fn get_argc() -> usize { ARGV.try_with(|argv| argv.len()).expect("get_arg called outside ExprFunc") } pub async fn get_arg_posv(idxes: impl IntoIterator) -> impl Iterator { let args = (ARGV.try_with(|argv| idxes.into_iter().map(|i| &argv[i]).cloned().collect_vec())) .expect("get_arg_posv called outside ExprFunc"); join_all(args.iter().map(|expr| expr.pos())).await.into_iter() } pub trait ExprFunc: Clone + 'static { fn argtyps() -> &'static [TypeId]; fn apply<'a>(&self, hand: ExecHandle<'a>, v: Vec) -> impl Future>; } task_local! { static FUNS_CTX: RefCell>; } pub(crate) fn with_funs_ctx<'a>(fut: LocalBoxFuture<'a, ()>) -> LocalBoxFuture<'a, ()> { Box::pin(FUNS_CTX.scope(RefCell::default(), fut)) } #[derive(Clone)] struct FunRecord { argtyps: &'static [TypeId], fun: Rc, } fn process_args>(f: F) -> FunRecord { let argtyps = F::argtyps(); let fun = Rc::new(move |v: Vec| { clone!(f, v mut); exec(async move |mut hand| { let mut norm_args = Vec::with_capacity(v.len()); for (expr, typ) in v.into_iter().zip(argtyps) { if *typ == TypeId::of::() { norm_args.push(expr); } else { norm_args.push(hand.exec(expr).await?); } } f.apply(hand, norm_args).await }) .map(Ok) .boxed_local() }); FunRecord { argtyps, fun } } /// An Atom representing a partially applied named native function. These /// partial calls are serialized into the name of the native function and the /// argument list. /// /// See [Lambda] for the non-serializable variant #[derive(Clone)] pub(crate) struct Fun { path: Sym, args: Vec, record: FunRecord, } impl Fun { pub async fn new>(path: Sym, f: F) -> Self { FUNS_CTX.with(|cx| { let mut fung = cx.borrow_mut(); let record = if let Some(record) = fung.get(&(sys_id(), path.clone())) { record.clone() } else { let record = process_args(f); fung.insert((sys_id(), path.clone()), record.clone()); record }; Self { args: vec![], path, record } }) } pub fn arity(&self) -> u8 { self.record.argtyps.len() as u8 } } impl Atomic for Fun { type Data = (); type Variant = OwnedVariant; } impl OwnedAtom for Fun { type Refs = Vec; async fn val(&self) -> Cow<'_, Self::Data> { Cow::Owned(()) } async fn call_ref(&self, arg: Expr) -> impl ToExpr { let new_args = self.args.iter().cloned().chain([arg]).collect_vec(); if new_args.len() == self.record.argtyps.len() { (self.record.fun)(new_args).await.to_gen().await } else { new_atom(Self { args: new_args, record: self.record.clone(), path: self.path.clone() }) } } async fn serialize(&self, write: Pin<&mut (impl AsyncWrite + ?Sized)>) -> Self::Refs { self.path.to_api().encode(write).await.unwrap(); self.args.clone() } async fn deserialize(mut ds_cx: impl DeserializeCtx, args: Self::Refs) -> Self { let path = Sym::from_api(ds_cx.decode().await).await; let record = (FUNS_CTX.with(|funs| funs.borrow().get(&(sys_id(), path.clone())).cloned())) .expect("Function missing during deserialization") .clone(); Self { args, path, record } } async fn print_atom<'a>(&'a self, _: &'a (impl FmtCtx + ?Sized + 'a)) -> FmtUnit { format!("{}:{}/{}", self.path, self.args.len(), self.arity()).into() } } /// An Atom representing a partially applied native lambda. These are not /// serializable. /// /// See [Fun] for the serializable variant #[derive(Clone)] pub struct Lambda { args: Vec, record: FunRecord, } impl Lambda { pub fn new>(f: F) -> Self { Self { args: vec![], record: process_args(f) } } } impl Atomic for Lambda { type Data = (); type Variant = OwnedVariant; } impl OwnedAtom for Lambda { type Refs = Never; async fn val(&self) -> Cow<'_, Self::Data> { Cow::Owned(()) } async fn call_ref(&self, arg: Expr) -> impl ToExpr { let new_args = self.args.iter().cloned().chain([arg]).collect_vec(); if new_args.len() == self.record.argtyps.len() { (self.record.fun)(new_args).await.to_gen().await } else { new_atom(Self { args: new_args, record: self.record.clone() }) } } } mod expr_func_derives { use std::any::TypeId; use std::sync::OnceLock; use orchid_base::OrcRes; use super::{ARGV, ExprFunc}; use crate::conv::{ToExpr, TryFromExpr}; use crate::func_atom::{ExecHandle, Expr}; use crate::gen_expr::GExpr; macro_rules! expr_func_derive { ($($t:ident),*) => { pastey::paste!{ impl< $($t: TryFromExpr + 'static, )* Out: ToExpr, Func: AsyncFn($($t,)*) -> Out + Clone + 'static > ExprFunc<($($t,)*), Out> for Func { fn argtyps() -> &'static [TypeId] { static STORE: OnceLock> = OnceLock::new(); &*STORE.get_or_init(|| vec![$(TypeId::of::<$t>()),*]) } async fn apply<'a>(&self, _: ExecHandle<'a>, v: Vec) -> OrcRes { assert_eq!(v.len(), Self::argtyps().len(), "Arity mismatch"); let argv = v.clone(); let [$([< $t:lower >],)*] = v.try_into().unwrap_or_else(|_| panic!("Checked above")); Ok(ARGV.scope(argv, self($($t::try_from_expr([< $t:lower >]).await?,)*)).await.to_gen().await) } } } }; } expr_func_derive!(A); expr_func_derive!(A, B); expr_func_derive!(A, B, C); expr_func_derive!(A, B, C, D); expr_func_derive!(A, B, C, D, E); expr_func_derive!(A, B, C, D, E, F); // expr_func_derive!(A, B, C, D, E, F, G); // expr_func_derive!(A, B, C, D, E, F, G, H); // expr_func_derive!(A, B, C, D, E, F, G, H, I); // expr_func_derive!(A, B, C, D, E, F, G, H, I, J); // expr_func_derive!(A, B, C, D, E, F, G, H, I, J, K); // expr_func_derive!(A, B, C, D, E, F, G, H, I, J, K, L); // expr_func_derive!(A, B, C, D, E, F, G, H, I, J, K, L, M); // expr_func_derive!(A, B, C, D, E, F, G, H, I, J, K, L, M, N); }