use std::time::Instant; use criterion::{criterion_group, criterion_main, Criterion}; use sunscreen::{ types::zkp::{ConstrainCmp, Field, FieldSpec, IntoProgramNode, ProgramNode}, *, }; use sunscreen_zkp_backend::{bulletproofs::BulletproofsBackend, BigInt}; type BPField = Field<::Field>; fn to_field_element( bits: &[ProgramNode>], twos_complement: bool, ) -> ProgramNode> { let powers = (0..bits.len()) .map(|x| { let power = Field::::from(BigInt::from(BigInt::ONE.shl_vartime(x))); let msb = bits.len() - 1; if twos_complement && x == msb { -(power.into_program_node()) } else { power.into_program_node() } }) .collect::>(); let mut val = Field::from(0u8).into_program_node(); for (i, bit) in bits.iter().enumerate() { val = val + *bit * powers[i]; } val } fn get_coeffs(x: &[[ProgramNode>; 8]]) -> Vec>> { x.iter().map(|x| to_field_element(x, true)).collect() } /** * Takes an [`i8`] and returns the value encoded as 8 binary [`NativeField`] elements. */ fn encode(val: i8) -> [BPField; 8] { let as_u8 = val.to_le_bytes()[0]; (0..8) .map(|x| Field::from((as_u8 >> x) & 0x1)) .collect::>() .try_into() .unwrap() } fn make_fractional_value(bits: &[i8]) -> [[BPField; 8]; 64] { assert!(bits.len() <= 64); let remain = 64 - bits.len(); let a = bits .iter() .cloned() .chain((0..remain).map(|_| 0i8)) .collect::>(); a.iter() .map(|x| encode(*x)) .collect::>() .try_into() .unwrap() } /// In this scenario, we have an unshielded tx coming from an encrypted /// balance. We need to prove that the public shielded value is less than /// equal your account balance. /// /// # Remarks /// Doing this for real, we would need to scale the public tx amount by /// the number of decimal places in the fractional amount. This is /// basically free, so we don't need to time it here. fn unshield_tx_fractional_range_proof(_c: &mut Criterion) { #[zkp_program] /** * Proves the 0 < a <= b and a == c */ fn in_range(balance: [[Field; 8]; 64], #[constant] unshielded: Field) { println!("Running unshield proof..."); let balance_coeffs = get_coeffs(&balance); let balance_val = to_field_element(&balance_coeffs, false); unshielded.constrain_le_bounded(balance_val, 8); } let app = Compiler::new() .zkp_backend::() .zkp_program(in_range) .compile() .unwrap(); let prog = app.get_zkp_program(in_range).unwrap(); // Create a carryless binary value. // a is 3 in the 1s place, 2 in the 2s place. // 3 * 1 + 2 * 2 = 7 let balance = make_fractional_value(&[3, 2]); let runtime = Runtime::new_zkp(BulletproofsBackend::new()).unwrap(); let prover_time = Instant::now(); let tx_input: Vec = vec![BPField::from(4).into()]; let balance_input: Vec = vec![balance.into()]; let proof = runtime .prove(prog, balance_input, vec![], tx_input.clone()) .unwrap(); println!("Prover time {}s", prover_time.elapsed().as_secs_f64()); let verifier_time = Instant::now(); runtime.verify(prog, &proof, vec![], tx_input).unwrap(); println!("Verifier time {}s", verifier_time.elapsed().as_secs_f64()); let proof_ser = bincode::serialize(&proof).unwrap(); println!("Proof size {}B", proof_ser.len()); } /// In this scenario, we have parts of 3 messages from encrypted ciphertexts. These /// are the upper and lower 32 coefficients of 3 /// [`Fractional`](sunscreen::types::bfv::Fractional) encoded values. We assume we've /// revealed all the other digits (i.e the low order fractional parts) to be zero /// in the short discrete log proof, so they don't contribute to the value. /// /// The short discrete log proof gives us the message coefficients in 2's complement /// binary, from which we reconstruct the coefficients. We then multiply the /// coefficients by powers of 2 (treating the entire fractional value as an integer) /// and prove that 0 < a <= b and a == c, where /// * a is the tx amount in the message encrypted under Alice's key /// * b is Alice's balance message encrypted under her key. /// * c is the tx amount in the message encrypted under Bob's key. /// /// # Remarks /// Since we require c be binary (to prevent overflowing another user's balance), we /// *could* reveal all but the least significant bit of each coefficient to be zero /// in the SDLP, which reduces the number of circuit inputs. However, this proof is /// orders of magnitude faster than the SDLP so 🤷‍♀️. fn private_tx_fractional_range_proof(_c: &mut Criterion) { #[zkp_program] /** * Proves the 0 < a <= b and a == c */ fn in_range( a: [[Field; 8]; 64], b: [[Field; 8]; 64], c: [[Field; 8]; 64], ) { println!("Running private_tx_fractional_range_proof..."); let a_coeffs = get_coeffs(&a); let b_coeffs = get_coeffs(&b); let c_coeffs = get_coeffs(&c); let a_val = to_field_element(&a_coeffs, false); let b_val = to_field_element(&b_coeffs, false); let c_val = to_field_element(&c_coeffs, false); a_val.constrain_gt_bounded(Field::::from(0).into_program_node(), 8); a_val.constrain_le_bounded(b_val, 8); a_val.constrain_eq(c_val); } let app = Compiler::new() .zkp_backend::() .zkp_program(in_range) .compile() .unwrap(); let prog = app.get_zkp_program(in_range).unwrap(); // Create a carryless binary value. // a is 3 in the 1s place, 2 in the 2s place. // 3 * 1 + 2 * 2 = 7 let a = make_fractional_value(&[3, 2]); // 4 * 1 + 16 * 2 = 36 let b = make_fractional_value(&[4, 16]); // This value should equal a. But we're going to test equality with a // different representation of 7. // 1 * 1 + 1 * 2 + 1 * 4 = 7 let c = make_fractional_value(&[1, 1, 1]); let runtime = Runtime::new_zkp(BulletproofsBackend::new()).unwrap(); let prover_time = Instant::now(); let proof = runtime.prove(prog, vec![a, b, c], vec![], vec![]).unwrap(); println!("Prover time {}s", prover_time.elapsed().as_secs_f64()); let verifier_time = Instant::now(); runtime .verify(prog, &proof, Vec::::new(), vec![]) .unwrap(); println!("Verifier time {}s", verifier_time.elapsed().as_secs_f64()); let proof_ser = bincode::serialize(&proof).unwrap(); println!("Proof size {}B", proof_ser.len()); } /// Imagine a multi-party computation where each user submits a value between /// zero and a maximum encrypted under their key. The MPC computes the mean and /// variance of their inputs. /// /// This proof is similar to [`private_tx_fractional_range_proof`] except we're only /// proving 0 < a <= b where /// /// * a is the submitted value under a given user's key. /// * b is the maximum value encrypted under the same user's key. fn mean_variance_fractional_range_proof(_c: &mut Criterion) { #[zkp_program] /** * Proves the 0 < a <= b and a == c */ fn in_range(a: [[Field; 8]; 64], b: [[Field; 8]; 64]) { println!("Running mean_variance_fractional_range_proof..."); let a_coeffs = get_coeffs(&a); let b_coeffs = get_coeffs(&b); let a_val = to_field_element(&a_coeffs, false); let b_val = to_field_element(&b_coeffs, false); a_val.constrain_ge_bounded(Field::::from(0).into_program_node(), 8); a_val.constrain_le_bounded(b_val, 8); } let app = Compiler::new() .zkp_backend::() .zkp_program(in_range) .compile() .unwrap(); let prog = app.get_zkp_program(in_range).unwrap(); // Create a carryless binary value. // a is 3 in the 1s place, 2 in the 2s place. // 3 * 1 + 2 * 2 = 7 let a = make_fractional_value(&[3, 2]); // 4 * 1 + 16 * 2 = 36 let b = make_fractional_value(&[4, 16]); let runtime = Runtime::new_zkp(BulletproofsBackend::new()).unwrap(); let prover_time = Instant::now(); let proof = runtime.prove(prog, vec![a, b], vec![], vec![]).unwrap(); println!("Prover time {}s", prover_time.elapsed().as_secs_f64()); let verifier_time = Instant::now(); runtime .verify(prog, &proof, Vec::::new(), vec![]) .unwrap(); println!("Verifier time {}s", verifier_time.elapsed().as_secs_f64()); let proof_ser = bincode::serialize(&proof).unwrap(); println!("Proof size {}B", proof_ser.len()); } /// Suppose we're adding ZKPs on the inputs to our chi squared example (see /// examples/chi_sq/main.rs). We want to prove our ciphertexts are well-formed and /// that the contained [`Signed`](sunscreen::types::bfv::Signed) messages are between /// 0 and 12. /// /// This proof is similar to [`mean_variance_fractional_range_proof`] except we're /// proving 3 values are between 0 <= a_i <= 12 where /// /// * a_0, a_1, a_2 are submitted value under a given user's key. fn chi_sq_fractional_range_proof(_c: &mut Criterion) { #[zkp_program] /** * Proves the 0 < a <= b and a == c */ fn in_range( a_0: [[Field; 8]; 64], a_1: [[Field; 8]; 64], a_2: [[Field; 8]; 64], #[constant] n: Field, ) { println!("Running chi_sq_fractional_range_proof..."); let a_0_coeffs = get_coeffs(&a_0); let a_1_coeffs = get_coeffs(&a_1); let a_2_coeffs = get_coeffs(&a_2); let a_0_val = to_field_element(&a_0_coeffs, false); let a_1_val = to_field_element(&a_1_coeffs, false); let a_2_val = to_field_element(&a_2_coeffs, false); (a_0_val + a_1_val + a_2_val).constrain_eq(n); } let app = Compiler::new() .zkp_backend::() .zkp_program(in_range) .compile() .unwrap(); let prog = app.get_zkp_program(in_range).unwrap(); // Create a carryless binary value. // a is 3 in the 1s place, 2 in the 2s place. // 3 * 1 + 2 * 2 = 7 let a_0 = make_fractional_value(&[3, 2]); // 4 * 1 + 1 * 2 = 6 let a_1 = make_fractional_value(&[4, 1]); // 3 * 1 + 2 * 2 = 7 let a_2 = make_fractional_value(&[3, 2]); let n = BPField::from(7 + 6 + 7); let priv_inputs: Vec = vec![a_0.into(), a_1.into(), a_2.into()]; let const_inputs: Vec = vec![n.into()]; let runtime = Runtime::new_zkp(BulletproofsBackend::new()).unwrap(); let prover_time = Instant::now(); let proof = runtime .prove(prog, priv_inputs, vec![], const_inputs.clone()) .unwrap(); println!("Prover time {}s", prover_time.elapsed().as_secs_f64()); let verifier_time = Instant::now(); runtime.verify(prog, &proof, vec![], const_inputs).unwrap(); println!("Verifier time {}s", verifier_time.elapsed().as_secs_f64()); let proof_ser = bincode::serialize(&proof).unwrap(); println!("Proof size {}B", proof_ser.len()); } criterion_group!( benches, private_tx_fractional_range_proof, unshield_tx_fractional_range_proof, mean_variance_fractional_range_proof, chi_sq_fractional_range_proof ); criterion_main!(benches);