use bigdecimal::BigDecimal; use sqlite3_ext::{function::*, *}; use std::{cmp::Ordering, str::FromStr}; // NULL maps to None // Valid BigDecimal maps to Some(x) // Otherwise Some(0) fn process_value(a: &mut ValueRef) -> Result> { match a.value_type() { ValueType::Null => Ok(None), _ => Ok(Some( BigDecimal::from_str(a.get_str()?).unwrap_or_else(|_| BigDecimal::default()), )), } } fn process_args(args: &mut [&mut ValueRef]) -> Result>> { args.into_iter().map(|x| process_value(*x)).collect() } macro_rules! scalar_method { ($name:ident as ( $a:ident, $b:ident ) -> $ty:ty => $ret:expr) => { #[sqlite3_ext_fn(n_args=2, risk_level=Innocuous, deterministic)] fn $name(ctx: &Context, args: &mut [&mut ValueRef]) -> Result<()> { let mut args = process_args(args)?.into_iter(); let a = args.next().unwrap_or(None); let b = args.next().unwrap_or(None); if let (Some($a), Some($b)) = (a, b) { ctx.set_result($ret).unwrap(); } Ok(()) } }; } scalar_method!(decimal_add as (a, b) -> String => format!("{}", (a + b).normalized())); scalar_method!(decimal_sub as (a, b) -> String => format!("{}", (a - b).normalized())); scalar_method!(decimal_mul as (a, b) -> String => format!("{}", (a * b).normalized())); scalar_method!(decimal_cmp as (a, b) -> i32 => { match a.cmp(&b) { Ordering::Less => -1, Ordering::Equal => 0, Ordering::Greater => 1, } }); #[derive(Default)] #[sqlite3_ext_fn(n_args=1, risk_level=Innocuous, deterministic)] struct Sum { cur: BigDecimal, } impl AggregateFunction<()> for Sum { fn default_value(_: &(), ctx: &Context) -> Result<()> { ctx.set_result(()) } fn step(&mut self, _: &Context, args: &mut [&mut ValueRef]) -> Result<()> { if let Some(x) = process_value(*args.first_mut().unwrap())? { self.cur += x; } Ok(()) } fn value(&self, ctx: &Context) -> Result<()> { ctx.set_result(format!("{}", self.cur.normalized())) } fn inverse(&mut self, _: &Context, args: &mut [&mut ValueRef]) -> Result<()> { if let Some(x) = process_value(*args.first_mut().unwrap())? { self.cur -= x; } Ok(()) } } fn decimal_collation(a: &str, b: &str) -> Ordering { let a = BigDecimal::from_str(a).unwrap_or_else(|_| BigDecimal::default()); let b = BigDecimal::from_str(b).unwrap_or_else(|_| BigDecimal::default()); a.cmp(&b) } #[sqlite3_ext_main] fn init(db: &Connection) -> Result<()> { db.create_scalar_function("decimal_add", &DECIMAL_ADD_OPTS, decimal_add)?; db.create_scalar_function("decimal_sub", &DECIMAL_SUB_OPTS, decimal_sub)?; db.create_scalar_function("decimal_mul", &DECIMAL_MUL_OPTS, decimal_mul)?; db.create_scalar_function("decimal_cmp", &DECIMAL_CMP_OPTS, decimal_cmp)?; db.create_aggregate_function::<_, Sum>("decimal_sum", &SUM_OPTS, ())?; db.create_collation("decimal", decimal_collation)?; Ok(()) } #[cfg(all(test, feature = "static"))] mod test { use super::*; fn setup() -> Result { let conn = Database::open(":memory:")?; init(&conn)?; Ok(conn) } fn case(data: Vec<(&str, Value)>) -> Result<()> { let conn = setup()?; let (sql, expected): (Vec<&str>, Vec) = data.into_iter().unzip(); let sql = format!("SELECT {}", sql.join(", ")); println!("{}", sql); let ret: Vec = conn.query_row(&sql, (), |r| { (0..expected.len()) .map(|i| r[i].to_owned()) .collect::>() })?; assert_eq!(ret, expected); Ok(()) } #[test] fn decimal_add() -> Result<()> { case(vec![ ( "decimal_add('1000000000000000', '0.0000000000000001')", Value::Text("1000000000000000.0000000000000001".to_owned()), ), ("decimal_add(NULL, '0')", Value::Null), ("decimal_add('0', NULL)", Value::Null), ("decimal_add(NULL, NULL)", Value::Null), ("decimal_add('invalid', 2)", Value::Text("2".to_owned())), ]) } #[test] fn decimal_sub() -> Result<()> { case(vec![ ( "decimal_sub('1000000000000000', '0.0000000000000001')", Value::Text("999999999999999.9999999999999999".to_owned()), ), ("decimal_sub(NULL, '0')", Value::Null), ("decimal_sub('0', NULL)", Value::Null), ("decimal_sub(NULL, NULL)", Value::Null), ("decimal_sub('invalid', 2)", Value::Text("-2".to_owned())), ]) } #[test] fn decimal_mul() -> Result<()> { case(vec![ ( "decimal_mul('1000000000000000', '0.0000000000000001')", Value::Text("0.1".to_owned()), ), ("decimal_mul(NULL, '0')", Value::Null), ("decimal_mul('0', NULL)", Value::Null), ("decimal_mul(NULL, NULL)", Value::Null), ("decimal_mul('invalid', 2)", Value::Text("0".to_owned())), ]) } #[test] fn decimal_cmp() -> Result<()> { case(vec![ ("decimal_cmp('1', '-1')", Value::Integer(1)), ("decimal_cmp('-1', '1')", Value::Integer(-1)), ("decimal_cmp('1', '1')", Value::Integer(0)), ("decimal_cmp(NULL, '0')", Value::Null), ("decimal_cmp('0', NULL)", Value::Null), ("decimal_cmp(NULL, NULL)", Value::Null), ]) } fn aggregate_case(expr: &str, data: Vec<&str>, expected: Vec) -> Result<()> { let conn = setup()?; let sql = format!( "SELECT {} FROM ( VALUES {} )", expr, data.iter() .map(|s| format!("({})", s)) .collect::>() .join(", ") ); println!("{}", sql); let ret: Vec = conn .prepare(&sql)? .query(())? .map(|r| r[0].to_owned()) .collect()?; assert_eq!(ret, expected); Ok(()) } #[test] fn decimal_sum() -> Result<()> { aggregate_case( "decimal_sum(column1)", vec!["1000000000000000", "0.0000000000000001", "1"], vec![Value::Text("1000000000000001.0000000000000001".to_owned())], )?; aggregate_case( "decimal_sum(column1)", vec!["1", "NULL"], vec![Value::Text("1".to_owned())], )?; aggregate_case( "decimal_sum(column1)", vec!["NULL"], vec![Value::Text("0".to_owned())], )?; case(vec![("decimal_sum(NULL)", Value::Text("0".to_owned()))])?; case(vec![( "decimal_sum('invalid')", Value::Text("0".to_owned()), )])?; case(vec![("decimal_sum(1) WHERE 1 = 0", Value::Null)])?; aggregate_case( "decimal_sum(column1) OVER ( ROWS 1 PRECEDING )", vec![ "1000000000000000", "0.0000000000000001", "NULL", "NULL", "1", ], vec![ Value::Text("1000000000000000".to_owned()), Value::Text("1000000000000000.0000000000000001".to_owned()), Value::Text("0.0000000000000001".to_owned()), Value::Text("0".to_owned()), Value::Text("1".to_owned()), ], )?; Ok(()) } #[test] fn collation() -> Result<()> { let conn = setup()?; let ret: Vec = conn .prepare( "SELECT column1 FROM ( VALUES (('1')), (('0100')), (('.1')) ) ORDER BY column1 COLLATE decimal", )? .query(())?.map(|row| Ok(row[0].get_str()?.to_owned())) .collect()?; assert_eq!( ret, vec![".1".to_owned(), "1".to_owned(), "0100".to_owned()] ); Ok(()) } }