use std::any::Any; use std::borrow::Cow; use std::cell::RefCell; use std::error::Error; use thiserror::Error; use partiql_catalog::call_defs::{CallDef, CallSpec, CallSpecArg}; use partiql_catalog::catalog::{Catalog, PartiqlCatalog}; use partiql_catalog::context::{SessionContext, SystemContext}; use partiql_catalog::extension::{Extension, ExtensionResultError}; use partiql_catalog::table_fn::{ BaseTableExpr, BaseTableExprResult, BaseTableFunctionInfo, TableFunction, }; use partiql_eval::env::basic::MapBindings; use partiql_eval::eval::BasicContext; use partiql_eval::plan::EvaluationMode; use partiql_parser::{Parsed, ParserResult}; use partiql_value::{bag, tuple, DateTime, Value}; use partiql_logical as logical; #[derive(Debug)] pub struct UserCtxTestExtension {} impl partiql_catalog::extension::Extension for UserCtxTestExtension { fn name(&self) -> String { "test_extension".into() } fn load(&self, catalog: &mut dyn Catalog) -> Result<(), Box> { match catalog .add_table_function(TableFunction::new(Box::new(TestUserContextFunction::new()))) { Ok(_) => Ok(()), Err(e) => Err(Box::new(e) as Box), } } } #[derive(Debug)] pub(crate) struct TestUserContextFunction { call_def: CallDef, } impl TestUserContextFunction { pub fn new() -> Self { TestUserContextFunction { call_def: CallDef { names: vec!["test_user_context"], overloads: vec![CallSpec { input: vec![CallSpecArg::Positional], output: Box::new(|args| { logical::ValueExpr::Call(logical::CallExpr { name: logical::CallName::ByName("test_user_context".to_string()), arguments: args, }) }), }], }, } } } impl BaseTableFunctionInfo for TestUserContextFunction { fn call_def(&self) -> &CallDef { &self.call_def } fn plan_eval(&self) -> Box { Box::new(EvalTestCtxTable {}) } } #[derive(Error, Debug)] #[non_exhaustive] pub enum UserCtxError { #[error("unknown error")] Unknown, } #[derive(Debug)] pub(crate) struct EvalTestCtxTable {} impl BaseTableExpr for EvalTestCtxTable { fn evaluate<'c>( &self, args: &[Cow<'_, Value>], ctx: &'c dyn SessionContext<'c>, ) -> BaseTableExprResult<'c> { if let Some(arg1) = args.first() { match arg1.as_ref() { Value::String(name) => generated_data(name.to_string(), ctx), _ => { let error = UserCtxError::Unknown; Err(Box::new(error) as ExtensionResultError) } } } else { let error = UserCtxError::Unknown; Err(Box::new(error) as ExtensionResultError) } } } struct TestDataGen<'a> { ctx: &'a dyn SessionContext<'a>, name: String, } impl<'a> Iterator for TestDataGen<'a> { type Item = Result; fn next(&mut self) -> Option { if let Some(cv) = self.ctx.user_context(&self.name) { if let Some(counter) = cv.downcast_ref::() { let mut n = counter.data.borrow_mut(); if *n > 0 { *n -= 1; let idx: u8 = (5 - *n) as u8; let id = format!("id_{idx}"); let m = idx % 2; return Some(Ok(tuple![("foo", m), ("bar", id)].into())); } } } None } } fn generated_data<'a>(name: String, ctx: &'a dyn SessionContext<'a>) -> BaseTableExprResult<'a> { Ok(Box::new(TestDataGen { ctx, name })) } #[derive(Debug)] pub struct Counter { data: RefCell, } #[track_caller] #[inline] pub(crate) fn parse(statement: &str) -> ParserResult { partiql_parser::Parser::default().parse(statement) } #[track_caller] #[inline] pub(crate) fn lower( catalog: &dyn Catalog, parsed: &Parsed<'_>, ) -> partiql_logical::LogicalPlan { let planner = partiql_logical_planner::LogicalPlanner::new(catalog); planner.lower(parsed).expect("lower") } #[track_caller] #[inline] pub(crate) fn evaluate( catalog: &dyn Catalog, logical: partiql_logical::LogicalPlan, bindings: MapBindings, ctx_vals: &[(String, &(dyn Any))], ) -> Value { let mut planner = partiql_eval::plan::EvaluatorPlanner::new(EvaluationMode::Permissive, catalog); let mut plan = planner.compile(&logical).expect("Expect no plan error"); let sys = SystemContext { now: DateTime::from_system_now_utc(), }; let mut ctx = BasicContext::new(bindings, sys); for (k, v) in ctx_vals { ctx.user.insert(k.as_str().into(), *v); } if let Ok(out) = plan.execute_mut(&ctx) { out.result } else { Value::Missing } } #[test] fn test_context() { let expected: Value = bag![ tuple![("foo", 1), ("bar", "id_1")], tuple![("foo", 0), ("bar", "id_2")], tuple![("foo", 1), ("bar", "id_3")], tuple![("foo", 0), ("bar", "id_4")], tuple![("foo", 1), ("bar", "id_5")], ] .into(); let query = "SELECT foo, bar from test_user_context('counter') as data"; let mut catalog = PartiqlCatalog::default(); let ext = UserCtxTestExtension {}; ext.load(&mut catalog).expect("extension load to succeed"); let parsed = parse(query); let lowered = lower(&catalog, &parsed.expect("parse")); let bindings = Default::default(); let counter = Counter { data: RefCell::new(5), }; let ctx: [(String, &dyn Any); 1] = [("counter".to_string(), &counter)]; let out = evaluate(&catalog, lowered, bindings, &ctx); assert!(out.is_bag()); assert_eq!(&out, &expected); assert_eq!(*counter.data.borrow(), 0); }