use std::any::TypeId; use std::borrow::Cow; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::rc::Rc; use futures::future::LocalBoxFuture; use futures::lock::Mutex; use futures::{AsyncWrite, FutureExt}; use itertools::Itertools; use never::Never; use orchid_api_traits::Encode; use orchid_base::clone; use orchid_base::error::OrcRes; use orchid_base::format::{FmtCtx, FmtUnit}; use orchid_base::name::Sym; use trait_set::trait_set; use crate::atom::Atomic; use crate::atom_owned::{DeserializeCtx, OwnedAtom, OwnedVariant}; use crate::context::{SysCtxEntry, ctx, i}; use crate::conv::ToExpr; use crate::coroutine_exec::{ExecHandle, exec}; use crate::expr::Expr; use crate::gen_expr::GExpr; trait_set! { trait FunCB = Fn(Vec) -> LocalBoxFuture<'static, OrcRes> + 'static; } pub trait ExprFunc: Clone + 'static { fn argtyps() -> &'static [TypeId]; fn apply<'a>(&self, hand: ExecHandle<'a>, v: Vec) -> impl Future>; } #[derive(Default)] struct FunsCtx(Mutex>); impl SysCtxEntry for FunsCtx {} #[derive(Clone)] struct FunRecord { argtyps: &'static [TypeId], fun: Rc, } fn process_args>(f: F) -> FunRecord { let argtyps = F::argtyps(); let fun = Rc::new(move |v: Vec| { clone!(f, v mut); exec(async move |mut hand| { let mut norm_args = Vec::with_capacity(v.len()); for (expr, typ) in v.into_iter().zip(argtyps) { if *typ != TypeId::of::() { norm_args.push(hand.exec(expr).await?); } } f.apply(hand, norm_args).await }) .map(Ok) .boxed_local() }); FunRecord { argtyps, fun } } /// An Atom representing a partially applied named native function. These /// partial calls are serialized into the name of the native function and the /// argument list. /// /// See [Lambda] for the non-serializable variant #[derive(Clone)] pub(crate) struct Fun { path: Sym, args: Vec, record: FunRecord, } impl Fun { pub async fn new>(path: Sym, f: F) -> Self { let ctx = ctx(); let funs: &FunsCtx = ctx.get_or_default(); let mut fung = funs.0.lock().await; let record = if let Some(record) = fung.get(&path) { record.clone() } else { let record = process_args(f); fung.insert(path.clone(), record.clone()); record }; Self { args: vec![], path, record } } pub fn arity(&self) -> u8 { self.record.argtyps.len() as u8 } } impl Atomic for Fun { type Data = (); type Variant = OwnedVariant; } impl OwnedAtom for Fun { type Refs = Vec; async fn val(&self) -> Cow<'_, Self::Data> { Cow::Owned(()) } async fn call_ref(&self, arg: Expr) -> GExpr { let new_args = self.args.iter().cloned().chain([arg]).collect_vec(); if new_args.len() == self.record.argtyps.len() { (self.record.fun)(new_args).await.to_gen().await } else { Self { args: new_args, record: self.record.clone(), path: self.path.clone() }.to_gen().await } } async fn call(self, arg: Expr) -> GExpr { self.call_ref(arg).await } async fn serialize(&self, write: Pin<&mut (impl AsyncWrite + ?Sized)>) -> Self::Refs { self.path.to_api().encode(write).await; self.args.clone() } async fn deserialize(mut ds_cx: impl DeserializeCtx, args: Self::Refs) -> Self { let path = Sym::from_api(ds_cx.decode().await, &i()).await; let record = (ctx().get::().0.lock().await.get(&path)) .expect("Function missing during deserialization") .clone(); Self { args, path, record } } async fn print_atom<'a>(&'a self, _: &'a (impl FmtCtx + ?Sized + 'a)) -> FmtUnit { format!("{}:{}/{}", self.path, self.args.len(), self.arity()).into() } } /// An Atom representing a partially applied native lambda. These are not /// serializable. /// /// See [Fun] for the serializable variant #[derive(Clone)] pub struct Lambda { args: Vec, record: FunRecord, } impl Lambda { pub fn new>(f: F) -> Self { Self { args: vec![], record: process_args(f) } } } impl Atomic for Lambda { type Data = (); type Variant = OwnedVariant; } impl OwnedAtom for Lambda { type Refs = Never; async fn val(&self) -> Cow<'_, Self::Data> { Cow::Owned(()) } async fn call_ref(&self, arg: Expr) -> GExpr { let new_args = self.args.iter().cloned().chain([arg]).collect_vec(); if new_args.len() == self.record.argtyps.len() { (self.record.fun)(new_args).await.to_gen().await } else { Self { args: new_args, record: self.record.clone() }.to_gen().await } } async fn call(self, arg: Expr) -> GExpr { self.call_ref(arg).await } } mod expr_func_derives { use std::any::TypeId; use std::sync::OnceLock; use orchid_base::error::OrcRes; use super::ExprFunc; use crate::conv::{ToExpr, TryFromExpr}; use crate::func_atom::{ExecHandle, Expr}; use crate::gen_expr::GExpr; macro_rules! expr_func_derive { ($($t:ident),*) => { pastey::paste!{ impl< $($t: TryFromExpr + 'static, )* Out: ToExpr, Func: AsyncFn($($t,)*) -> Out + Clone + 'static > ExprFunc<($($t,)*), Out> for Func { fn argtyps() -> &'static [TypeId] { static STORE: OnceLock> = OnceLock::new(); &*STORE.get_or_init(|| vec![$(TypeId::of::<$t>()),*]) } async fn apply<'a>(&self, _: ExecHandle<'a>, v: Vec) -> OrcRes { assert_eq!(v.len(), Self::argtyps().len(), "Arity mismatch"); let [$([< $t:lower >],)*] = v.try_into().unwrap_or_else(|_| panic!("Checked above")); Ok(self($($t::try_from_expr([< $t:lower >]).await?,)*).await.to_gen().await) } } } }; } expr_func_derive!(A); expr_func_derive!(A, B); expr_func_derive!(A, B, C); expr_func_derive!(A, B, C, D); expr_func_derive!(A, B, C, D, E); // expr_func_derive!(A, B, C, D, E, F); // expr_func_derive!(A, B, C, D, E, F, G); // expr_func_derive!(A, B, C, D, E, F, G, H); // expr_func_derive!(A, B, C, D, E, F, G, H, I); // expr_func_derive!(A, B, C, D, E, F, G, H, I, J); // expr_func_derive!(A, B, C, D, E, F, G, H, I, J, K); // expr_func_derive!(A, B, C, D, E, F, G, H, I, J, K, L); // expr_func_derive!(A, B, C, D, E, F, G, H, I, J, K, L, M); // expr_func_derive!(A, B, C, D, E, F, G, H, I, J, K, L, M, N); }