//! Bindings around the z3 solver for Raptor IR //! //! The z3 solver only works over certain expression types. Expressions given to //! the z3 solver must not include references or dereferences. use crate::error::*; use crate::ir; use falcon_z3::{Ast, Check, Config, Context, Model, Optimize, Solver}; mod fast; pub use fast::FastSolver; fn return_solver_result( solver: &Solver, context: &Context, ast: &Ast, bits: usize, ) -> Result> { Ok(match solver.check() { Check::Unsat | Check::Unknown => None, Check::Sat => Model::new(context, solver) .and_then(|model| model.get_const_interp(ast)) .and_then(|constant_ast| constant_ast.get_numeral_decimal_string(context)) .and_then(|numeral_dec_str| { ir::Constant::from_decimal_string(&numeral_dec_str, bits).ok() }), }) } fn return_optimize_result( optimize: &Optimize, context: &Context, ast: &Ast, bits: usize, ) -> Result> { Ok(match optimize.check() { Check::Unsat | Check::Unknown => None, Check::Sat => Model::new_optimize(context, optimize) .and_then(|model| model.get_const_interp(ast)) .and_then(|constant_ast| constant_ast.get_numeral_decimal_string(context)) .and_then(|numeral_dec_str| { ir::Constant::from_decimal_string(&numeral_dec_str, bits).ok() }), }) } fn solver_init( solver: &Solver, context: &Context, constraints: &[ir::Expression], ) -> Result<()> { let sort = context.mk_bv_sort(1); let one = context.mk_numeral(1, &sort)?; for constraint in constraints { solver.assert(&context.eq(&one, &expression_to_ast(context, constraint)?)); } Ok(()) } fn optimize_init( optimize: &Optimize, context: &Context, constraints: &[ir::Expression], ) -> Result<()> { let sort = context.mk_bv_sort(1); let one = context.mk_numeral(1, &sort)?; for constraint in constraints { optimize.assert(&context.eq(&one, &expression_to_ast(context, constraint)?)); } Ok(()) } /// Maximize the given value by the given constraints. pub fn maximize( constraints: &[ir::Expression], value: &ir::Expression, ) -> Result> { let config = Config::new().enable_model(); let context = Context::new(config); let optimize = Optimize::new(&context); optimize_init(&optimize, &context, constraints)?; let optimize_result = context.mk_var("OPTIMIZE_RESULT", &context.mk_bv_sort(value.bits()))?; optimize.assert(&context.eq(&optimize_result, &expression_to_ast(&context, value)?)); optimize.maximize(&optimize_result); return_optimize_result(&optimize, &context, &optimize_result, value.bits()) } /// Minimize the given value by the given constraints. pub fn minimize( constraints: &[ir::Expression], value: &ir::Expression, ) -> Result> { let config = Config::new().enable_model(); let context = Context::new(config); let optimize = Optimize::new(&context); optimize_init(&optimize, &context, constraints)?; let optimize_result = context.mk_var("OPTIMIZE_RESULT", &context.mk_bv_sort(value.bits()))?; optimize.assert(&context.eq(&optimize_result, &expression_to_ast(&context, value)?)); optimize.minimize(&optimize_result); return_optimize_result(&optimize, &context, &optimize_result, value.bits()) } /// Solve for a possible solution to the given value with the given constraints. pub fn solve( constraints: &[ir::Expression], value: &ir::Expression, ) -> Result> { let config = Config::new().enable_model(); let context = Context::new(config); let solver = Solver::new(&context); solver_init(&solver, &context, constraints)?; let solver_result = context.mk_var("SOLVER_RESULT", &context.mk_bv_sort(value.bits()))?; solver.assert(&context.eq(&solver_result, &expression_to_ast(&context, value)?)); return_solver_result(&solver, &context, &solver_result, value.bits()) } fn expression_to_ast(context: &Context, expression: &ir::Expression) -> Result { Ok(match expression { ir::Expression::RValue(rvalue) => match rvalue.as_ref() { ir::RValue::Value(constant) => { if let Some(value) = constant.value_u64() { let sort = context.mk_bv_sort(constant.bits()); context.mk_numeral(value, &sort)? } else { let big_uint = constant.value(); let sort = context.mk_bv_sort(8); let bytes = big_uint.to_bytes_le(); let mut v = if bytes.is_empty() { context.mk_numeral(0, &sort)? } else { context.mk_numeral(bytes[0] as u64, &sort)? }; for i in 1..(constant.bits() / 8) { let numeral = if bytes.len() <= i { context.mk_numeral(0, &sort)? } else { context.mk_numeral(bytes[i] as u64, &sort)? }; v = context.concat(&numeral, &v); } v } } ir::RValue::Reference(_) => return Err(ErrorKind::SolverReference.into()), }, ir::Expression::LValue(lvalue) => match lvalue.as_ref() { ir::LValue::Variable(variable) => match variable { ir::Variable::Scalar(scalar) => { let sort = context.mk_bv_sort(scalar.bits()); context.mk_var(scalar.name(), &sort)? } ir::Variable::StackVariable(stack_variable) => { let sort = context.mk_bv_sort(stack_variable.bits()); context.mk_var(stack_variable.name(), &sort)? } }, ir::LValue::Dereference(_) => return Err(ErrorKind::SolverDereference.into()), }, ir::Expression::Add(ref lhs, ref rhs) => context.bvadd( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), ir::Expression::Sub(ref lhs, ref rhs) => context.bvsub( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), ir::Expression::Mul(ref lhs, ref rhs) => context.bvmul( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), ir::Expression::Divu(ref lhs, ref rhs) => context.bvudiv( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), ir::Expression::Modu(ref lhs, ref rhs) => context.bvurem( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), ir::Expression::Divs(ref lhs, ref rhs) => context.bvsdiv( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), ir::Expression::Mods(ref lhs, ref rhs) => context.bvsrem( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), ir::Expression::And(ref lhs, ref rhs) => context.bvand( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), ir::Expression::Or(ref lhs, ref rhs) => context.bvor( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), ir::Expression::Xor(ref lhs, ref rhs) => context.bvxor( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), ir::Expression::Shl(ref lhs, ref rhs) => context.bvshl( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), ir::Expression::Shr(ref lhs, ref rhs) => context.bvlshr( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), ir::Expression::Cmpeq(ref lhs, ref rhs) => { let sort = context.mk_bv_sort(1); context.ite( &context.eq( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), &context.mk_numeral(1, &sort)?, &context.mk_numeral(0, &sort)?, ) } ir::Expression::Cmpneq(ref lhs, ref rhs) => { let sort = context.mk_bv_sort(1); context.ite( &context.eq( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), &context.mk_numeral(0, &sort)?, &context.mk_numeral(1, &sort)?, ) } ir::Expression::Cmplts(ref lhs, ref rhs) => { let sort = context.mk_bv_sort(1); context.ite( &context.bvslt( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), &context.mk_numeral(1, &sort)?, &context.mk_numeral(0, &sort)?, ) } ir::Expression::Cmpltu(ref lhs, ref rhs) => { let sort = context.mk_bv_sort(1); context.ite( &context.bvult( &expression_to_ast(context, lhs)?, &expression_to_ast(context, rhs)?, ), &context.mk_numeral(1, &sort)?, &context.mk_numeral(0, &sort)?, ) } ir::Expression::Zext(bits, ref rhs) => context.zero_ext( (bits - rhs.bits()) as u32, &expression_to_ast(context, rhs)?, ), ir::Expression::Sext(bits, ref rhs) => context.sign_ext( (bits - rhs.bits()) as u32, &expression_to_ast(context, rhs)?, ), ir::Expression::Trun(bits, ref rhs) => { context.extract((bits - 1) as u32, 0, &expression_to_ast(context, rhs)?) } ir::Expression::Ite(ref cond, ref then, ref else_) => context.ite( &context.eq( &expression_to_ast(context, cond)?, &context.mk_numeral(1, &context.mk_bv_sort(1))?, ), &expression_to_ast(context, then)?, &expression_to_ast(context, else_)?, ), }) }