#[cfg(feature = "value")] use exmex::{exerr, ExResult, Express, FlatExVal, Val}; #[cfg(feature = "value")] mod utils; #[test] #[cfg(feature = "value")] fn test_vars() -> ExResult<()> { let expr = exmex::parse_val::("x+5.3")?; utils::assert_float_eq_f64(expr.eval(&[Val::Float(3.4)])?.to_float()?, 8.7); let expr = exmex::parse_val::("-(x1 if x0 else x2)+5.3")?; let res = expr .eval(&[Val::Bool(true), Val::Float(3.4), Val::Int(3)])? .to_float()?; utils::assert_float_eq_f64(res, 1.9); let expr = exmex::parse_val::("-sin(x)+5.3")?; utils::assert_float_eq::( expr.eval(&[Val::Float(2.2)])?.to_float()?, -2.2f32.sin() + 5.3, 1e-6, 0.0, "", ); let expr = exmex::parse_val::("-sin(x) if y > 0 else z + 3")?; utils::assert_float_eq::( expr.eval(&[Val::Float(1.0), Val::Int(2), Val::Int(3)])? .to_float()?, -1f32.sin(), 1e-6, 0.0, "", ); assert_eq!( expr.eval(&[Val::Float(1.0), Val::Int(-1), Val::Int(3)])? .to_int()?, 6, ); let expr = exmex::parse_val::("z if false else 2")?; println!("{:#?}", expr); assert_eq!(expr.eval(&[Val::Int(-3)])?.to_int()?, 2,); Ok(()) } #[test] #[cfg(feature = "value")] #[cfg(feature = "partial")] fn test_value_partial() -> ExResult<()> { use exmex::Differentiate; use crate::utils::assert_float_eq_f64; let expr = exmex::parse_val::("x if x > 0 else (2*x if x >= -1 else -x)")?; assert_float_eq_f64(1.0, expr.eval(&[Val::Float(1.0)])?.to_float()?); assert_float_eq_f64(0.0, expr.eval(&[Val::Float(0.0)])?.to_float()?); assert_float_eq_f64(2.0, expr.eval(&[Val::Float(-2.0)])?.to_float()?); assert_float_eq_f64(-2.0, expr.eval(&[Val::Float(-1.0)])?.to_float()?); println!("{expr}"); let deri = expr.partial(0).unwrap(); for x in [-2.0, -1.5] { let res = deri.eval(&[Val::Float(x)])?.to_float()?; assert_float_eq_f64(res, -1.0); } for x in [1.0, 0.5, 3.0] { let res = deri.eval(&[Val::Float(x)])?.to_float()?; assert_float_eq_f64(res, 1.0); } for x in [-1.0, -0.5, 0.0] { let res = deri.eval(&[Val::Float(x)])?.to_float()?; assert_float_eq_f64(res, 2.0); } let sin = exmex::parse_val::("sin(x)")?; let cos = sin.partial(0).unwrap(); let res = cos.eval(&[Val::Float(34.0)])?.to_float()?; assert_float_eq_f64(res, 34.0f64.cos()); let sin = exmex::parse_val::("sin(x) if x < 0 else sin(x)")?; let cos = sin.partial(0).unwrap(); for x in [-1.0, -0.5, 0.0, 0.5, 1.0] { let res = cos.eval(&[Val::Float(x)])?.to_float()?; assert_float_eq_f64(res, x.cos()); } Ok(()) } #[test] #[cfg(feature = "value")] fn test_readme() -> ExResult<()> { let expr = exmex::parse_val::("0 if b < c else 1.2")?; let res = expr.eval(&[Val::Float(34.0), Val::Int(21)])?.to_float()?; assert!((res - 1.2).abs() < 1e-12); #[cfg(feature = "partial")] { use exmex::Differentiate; let expr = exmex::parse_val::("3*x if x > 1 else x^2")?; let deri = expr.partial(0)?; let res = deri.eval(&[Val::Float(1.0)])?.to_float()?; assert!((res - 2.0).abs() < 1e-12); let res = deri.eval(&[Val::Float(7.0)])?.to_float()?; assert!((res - 3.0).abs() < 1e-12); } Ok(()) } #[test] #[cfg(feature = "serde")] #[cfg(feature = "value")] fn test_serde_public() { use exmex::FlatExVal; let s = "{x}^3.0 if z < 0 else y"; // flatex let flatex = FlatExVal::::parse(s).unwrap(); let serialized = serde_json::to_string(&flatex).unwrap(); let deserialized = serde_json::from_str::>(serialized.as_str()).unwrap(); assert_eq!(deserialized.var_names().len(), 3); let res = deserialized .eval(&[Val::Float(2.0), Val::Bool(false), Val::Float(1.0)]) .unwrap(); assert_eq!(res.to_bool().unwrap(), false); let res = deserialized .eval(&[Val::Float(2.0), Val::Float(1.0), Val::Int(-1)]) .unwrap(); utils::assert_float_eq_f64(res.to_float().unwrap(), 8.0); assert_eq!(s, format!("{}", deserialized)); let s = "min({x}^3.0 if z < 0 else y, 1.0)"; // flatex let flatex = FlatExVal::::parse(s).unwrap(); let serialized = serde_json::to_string(&flatex).unwrap(); let deserialized = serde_json::from_str::>(serialized.as_str()).unwrap(); assert_eq!(deserialized.var_names().len(), 3); let res = deserialized .eval(&[Val::Float(2.0), Val::Float(1.0), Val::Float(-1.0)]) .unwrap(); assert_eq!(res.to_float().unwrap(), 1.0); assert_eq!(s, format!("{}", deserialized)); } #[cfg(feature = "value")] #[test] fn test_to() -> ExResult<()> { utils::assert_float_eq_f64( Val::::Float(std::f64::consts::TAU).to_float()?, std::f64::consts::TAU, ); assert_eq!(Val::::Int(123).to_int()?, 123); assert!(Val::::Bool(true).to_bool()?); assert_eq!(Val::::Bool(false).to_int()?, 0); assert_eq!(Val::::Bool(true).to_float()?, 1.0); utils::assert_float_eq_f64(Val::::Float(3.4).to_float()?, 3.4); assert_eq!(Val::::Int(34).to_int()?, 34); assert!(!Val::::Bool(false).to_bool()?); Ok(()) } #[cfg(feature = "value")] #[cfg(test)] use exmex::{DeepEx, ValMatcher, ValOpsFactory}; #[cfg(feature = "value")] #[cfg(test)] type Fx = FlatExVal; #[cfg(feature = "value")] #[cfg(test)] type Dx<'a> = DeepEx<'a, Val, ValOpsFactory, ValMatcher>; #[cfg(feature = "value")] #[test] fn test_no_vars() -> ExResult<()> { fn test_int(s: &str, reference: i32) -> ExResult<()> { fn test_<'a, EX>(s: &'a str, reference: i32) -> ExResult<()> where EX: Express<'a, Val>, { println!("=== testing\n{}", s); let res = exmex::parse_val::(s)?.eval(&[])?.to_int(); match res { Ok(i) => { assert_eq!(reference, i); } Err(e) => { println!("{:?}", e); unreachable!(); } } Ok(()) } test_::(s, reference)?; test_::(s, reference) } fn test_float(s: &str, reference: f64) -> ExResult<()> { fn test_<'a, EX>(s: &'a str, reference: f64) -> ExResult<()> where EX: Express<'a, Val>, { println!("=== testing\n{}", s); let expr = FlatExVal::::parse(s)?; utils::assert_float_eq_f64(reference, expr.eval(&[])?.to_float()?); Ok(()) } test_::(s, reference)?; test_::(s, reference) } fn test_bool(s: &str, reference: bool) -> ExResult<()> { println!("=== testing\n{}", s); fn test_<'a, EX>(s: &'a str, reference: bool) -> ExResult<()> where EX: Express<'a, Val>, { let expr = EX::parse(s)?; assert_eq!(reference, expr.eval(&[])?.to_bool()?); Ok(()) } test_::(s, reference)?; test_::(s, reference) } fn test_error(s: &str) -> ExResult<()> { fn test_<'a, EX>(s: &'a str) -> ExResult<()> where EX: Express<'a, Val>, { let expr = EX::parse(s); match expr { Ok(exp) => { let v = exp.eval(&[])?; match v { Val::Error(e) => { println!("found expected error {:?}", e); Ok(()) } _ => Err(exerr!("'{}' should fail but didn't", s)), } } Err(e) => { println!("found expected error {:?}", e); Ok(()) } } } test_::(&s)?; test_::(&s) } fn test_none(s: &str) -> ExResult<()> { let expr = FlatExVal::::parse(s)?; match expr.eval(&[])? { Val::None => Ok(()), _ => Err(exerr!("'{}' should return none but didn't", s)), } } test_error("if true else 2")?; test_int("1+2 if 1 > 0 else 2+4", 3)?; test_int("1+2 if 1 < 0 else 2+4", 6)?; test_error("929<<92")?; test_error("929<<32")?; test_error("929>>32")?; test_int("928<<31", 0)?; test_int("929>>31", 0)?; test_float("2.0^2", 4.0)?; test_int("2^4", 16)?; test_error("2^-4")?; test_int("2+4", 6)?; test_int("9+4", 13)?; test_int("9+4^2", 25)?; test_float("τ/TAU", 1.0)?; test_int("9/4", 2)?; test_int("9%4", 1)?; test_float("2.5+4.0^2", 18.5)?; test_float("2.5*4.0^2", 2.5 * 4.0 * 4.0)?; test_float("2.5-4.0^-2", 2.5 - 4.0f64.powi(-2))?; test_float("9.0/4.0", 9.0 / 4.0)?; test_float("sin(9.0)", 9.0f64.sin())?; test_float("cos(91.0)", 91.0f64.cos())?; test_float("ln(91.0)", 91.0f64.ln())?; test_float("log(91.0)", 91.0f64.ln())?; test_float("tan(913.0)", 913.0f64.tan())?; test_float("min(tan(913.0), 1)", 913.0f64.tan())?; test_float("max(tan(913.0), 1.0)", 1.0)?; test_float("max(tan(913.0), 1)", 1.0)?; test_float("sin(-π)", 0.0)?; test_float("sin(π)", 0.0)?; test_float("τ", std::f64::consts::PI * 2.0)?; test_float("sin(-τ)", 0.0)?; test_float("round(π)", 3.0)?; test_float("cos(π)", -1.0)?; test_float("cos(TAU)", 1.0)?; test_float("sin (1 if false else 2.0)", 2.0f64.sin())?; test_float("cbrt(27.0)", 3.0)?; test_int("1 if true else 2.0", 1)?; test_float("(9.0 if true else 2.0)", 9.0)?; test_int("1<<4-2", 4)?; test_int("4>>2", 1)?; test_int("signum(4>>1)", 1)?; test_float("signum(-123.12)", -1.0)?; test_float("abs(-123.12)", 123.12)?; test_int("fact(4)", 2 * 3 * 4)?; test_int("fact(0)", 1)?; test_error("fact(-1)")?; test_bool("1>2", false)?; test_bool("1<2", true)?; test_bool("1.4>=1.4", true)?; test_bool("true==true", true)?; test_bool("false==true", false)?; test_bool("1.5 != 1.5 + 2.0", true)?; test_float("1 + 1.0", 2.0)?; test_bool("1.0 == 1", true)?; test_bool("1 == 1", true)?; test_bool("2 == true", false)?; test_bool("1.5 < 1", false)?; test_bool("true == true", true)?; test_bool("false != true", true)?; test_bool("false != false", false)?; test_bool("1 > 0.5", true)?; test_error("to_float(10000000000000)")?; test_bool("true == 1", false)?; test_bool("true else 2", true)?; test_int("1 else 2", 1)?; test_none("2 if false")?; test_int("to_int(1)", 1)?; test_int("to_int(3.5)", 3)?; test_float("to_float(2)", 2.0)?; test_float("to_float(3.5)", 3.5)?; test_float("to_float(true)", 1.0)?; test_float("to_float(false)", 0.0)?; test_int("to_int(true)", 1)?; test_int("to_int(false)", 0)?; test_error("to_int(fact(-1))")?; test_error("to_float(5 if false)")?; test_error("0/0")?; test_bool("(5 if false) == (5 if false)", false)?; test_error("2^40")?; test_error("1000000000*1000000000")?; test_error("1500000000+1500000000")?; test_error("-1500000000-1500000000")?; test_error("0%0")?; test_int("1&&2", 1)?; test_bool("true&&false", false)?; test_bool("false || true", true)?; test_int("1&&2.0", 1)?; test_float("1||2.0", 2.0)?; test_float( "atanh(0.5)/asinh(-7.5)*acosh(2.3)", 0.5f64.atanh() / (-7.5f64).asinh() * 2.3f64.acosh(), )?; test_float("sin(atan2(1, 1.0 / 2.0))", (1.0f64.atan2(0.5)).sin())?; Ok(()) } #[cfg(feature = "value")] #[test] fn test() { use smallvec::smallvec; fn assert_arr(arr: Val, reference: &[f64]) { let arr = arr.to_array().unwrap(); for (a, b) in arr.iter().zip(reference.iter()) { utils::assert_float_eq_f64(*b, *a); } } // dot product let s = "dot(-[1.0, 2.0, -3.0], [0, 1, 0])"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[]).unwrap(); utils::assert_float_eq_f64(x.to_float().unwrap(), -2.0f64); let s = "dot(-[1.0, 2.0, -3.0], [0, 2, 0] - 1)"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[]).unwrap(); utils::assert_float_eq_f64(x.to_float().unwrap(), -4.0f64); let a1 = Val::Array(smallvec![0.0, 0.0, -3.0]); let a2 = Val::Array(smallvec![0.0, 3.0, 0.0]); let s = "dot(a1, a2)"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[a1, a2]).unwrap(); utils::assert_float_eq_f64(x.to_float().unwrap(), 0.0); // access components with .0, .1, .2 let s = "-[1.0, 2.0, -3.0].0"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[]).unwrap(); utils::assert_float_eq_f64(x.to_float().unwrap(), -1.0f64); let s = "-[1.0, 2.0, -3.0].1"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[]).unwrap(); utils::assert_float_eq_f64(x.to_float().unwrap(), -2.0f64); let s = "-[1.0, 2.0, -3.0].2"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[]).unwrap(); utils::assert_float_eq_f64(x.to_float().unwrap(), 3.0f64); // negate let s = "-[1.0, 2.0, 3.0]"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[]).unwrap(); let reference = [-1.0, -2.0, -3.0]; assert_arr(x, &reference); // compute with scalars let s = "-[1.0, 2.0, 3.0] * 1"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[]).unwrap(); let reference = [-1.0, -2.0, -3.0]; assert_arr(x, &reference); let s = "-[1.0, 2.0, 3.0] + 0"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[]).unwrap(); let reference = [-1.0, -2.0, -3.0]; assert_arr(x, &reference); let s = "-[1.0, 2.0, 3.0] + 1.5"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[]).unwrap(); let reference = [0.5, -0.5, -1.5]; assert_arr(x, &reference); let s = "-[1.0, 2.0, 3.0] / 2"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[]).unwrap(); let reference = [-0.5, -1.0, -1.5]; assert_arr(x, &reference); let s = "-[1.0, 2.0, 3.0] * 0"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[]).unwrap(); let reference = [0.0, 0.0, 0.0]; assert_arr(x, &reference); let a = Val::Array(smallvec![-1.0, -2.0, -3.0]); let s = "a - a"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[a]).unwrap(); let reference = [0.0, 0.0, 0.0]; assert_arr(x, &reference); // length let s = "length(-[1.0, 2.0, -3.0])"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[]).unwrap(); utils::assert_float_eq_f64(x.to_float().unwrap(), (14.0f64).sqrt()); let a = Val::Array(smallvec![4.0, -3.0]); let s = "length(a)"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[a]).unwrap(); utils::assert_float_eq_f64(x.to_float().unwrap(), 5.0); // cross let a1 = Val::Array(smallvec![0.0, 0.0, 1.0]); let a2 = Val::Array(smallvec![0.0, 1.0, 0.0]); let s = "cross(a1, a2)"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[a1, a2]).unwrap(); let reference = [-1.0, 0.0, 0.0]; assert_arr(x, &reference); let a1 = Val::Array(smallvec![0.0, 1.0, 0.0]); let a2 = Val::Array(smallvec![0.0, 0.0, 1.0]); let s = "cross(a1, a2)"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[a1, a2]).unwrap(); let reference = [1.0, 0.0, 0.0]; assert_arr(x, &reference); // min/max let a1 = Val::Array(smallvec![0.0, 0.0, 1.0]); let a2 = Val::Array(smallvec![0.0, 1.0, 0.0]); let s = "min(a1, a2)"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[a1, a2]).unwrap(); let reference = [0.0, 0.0, 0.0]; assert_arr(x, &reference); let a1 = Val::Array(smallvec![0.0, 1.0, 0.0]); let a2 = Val::Float(0.5); let s = "min(a1, a2)"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[a1, a2]).unwrap(); let reference = [0.0, 0.5, 0.0]; assert_arr(x, &reference); let a1 = Val::Array(smallvec![0.0, 1.0, 0.0]); let a2 = Val::Array(smallvec![0.0, 0.0, 1.0]); let s = "max(a1, a2)"; let expr = FlatExVal::::parse(s).unwrap(); let x = expr.eval(&[a1, a2]).unwrap(); let reference = [0.0, 1.0, 1.0]; assert_arr(x, &reference); } #[cfg(feature = "value")] #[cfg(feature = "serde")] #[test] fn test_serde() { use serde::{Deserialize, Serialize}; let s = "-1200 if (cb / ib) < 1 else -2400"; let expr = FlatExVal::::parse(s).unwrap(); #[derive(Serialize, Deserialize)] struct Tmp { expr: FlatExVal, } let tmp = Tmp { expr }; let ser = serde_json::to_string_pretty(&tmp) .unwrap() .replace("/", "\\/"); let _deser: Tmp = serde_json::from_str(&ser).unwrap(); } #[cfg(feature = "value")] #[test] fn test_fuzz() { let s = "ata---n-----0>>220>22--ata---n-----0>>220>22-------------tanh-------------------tanh--------6/π"; let expr = FlatExVal::::parse(s).unwrap(); let res = expr.eval(&[Val::Int(2), Val::Int(3)]).unwrap(); assert!(!res.to_bool().unwrap()); let s = "fact+82"; let expr = FlatExVal::::parse(s).unwrap(); assert!(expr.eval(&[]).unwrap().to_int().is_err()); assert!(expr.eval(&[]).unwrap().to_float().is_err()); assert!(expr.eval(&[]).unwrap().to_bool().is_err()); }