use alloc::vec::Vec; use core::cmp::Reverse; use core::marker::PhantomData; use itertools::Itertools; use p3_commit::Mmcs; use p3_field::{PackedField, PackedValue}; use p3_matrix::{Dimensions, Matrix}; use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction}; use p3_util::log2_ceil_usize; use serde::{Deserialize, Serialize}; use crate::FieldMerkleTree; /// A vector commitment scheme backed by a `FieldMerkleTree`. /// /// Generics: /// - `P`: a leaf value TODO /// - `H`: the leaf hasher /// - `C`: the digest compression function #[derive(Copy, Clone, Debug)] pub struct FieldMerkleTreeMmcs { hash: H, compress: C, _phantom: PhantomData<(P, PW)>, } impl FieldMerkleTreeMmcs { pub const fn new(hash: H, compress: C) -> Self { Self { hash, compress, _phantom: PhantomData, } } } impl Mmcs for FieldMerkleTreeMmcs where P: PackedField, PW: PackedValue, H: CryptographicHasher, H: CryptographicHasher, H: Sync, C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>, C: PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>, C: Sync, PW::Value: Eq, [PW::Value; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>, { type Commitment = Hash; type Proof = Vec<[PW::Value; DIGEST_ELEMS]>; type Error = (); type ProverData = FieldMerkleTree; fn commit>( &self, inputs: Vec, ) -> (Self::Commitment, Self::ProverData) { let tree = FieldMerkleTree::new::(&self.hash, &self.compress, inputs); let root = tree.root(); (root, tree) } fn open_batch>( &self, index: usize, prover_data: &FieldMerkleTree, ) -> (Vec>, Vec<[PW::Value; DIGEST_ELEMS]>) { let max_height = self.get_max_height(prover_data); let log_max_height = log2_ceil_usize(max_height); let openings = prover_data .leaves .iter() .map(|matrix| { let log2_height = log2_ceil_usize(matrix.height()); let bits_reduced = log_max_height - log2_height; let reduced_index = index >> bits_reduced; matrix.row(reduced_index).collect() }) .collect_vec(); let proof = (0..log_max_height) .map(|i| prover_data.digest_layers[i][(index >> i) ^ 1]) .collect(); (openings, proof) } fn get_matrices<'a, M: Matrix>( &self, prover_data: &'a Self::ProverData, ) -> Vec<&'a M> { prover_data.leaves.iter().collect() } fn verify_batch( &self, commit: &Self::Commitment, dimensions: &[Dimensions], mut index: usize, opened_values: &[Vec], proof: &Self::Proof, ) -> Result<(), Self::Error> { let mut heights_tallest_first = dimensions .iter() .enumerate() .sorted_by_key(|(_, dims)| Reverse(dims.height)) .peekable(); let mut curr_height_padded = heights_tallest_first .peek() .unwrap() .1 .height .next_power_of_two(); let mut root = self.hash.hash_iter_slices( heights_tallest_first .peeking_take_while(|(_, dims)| { dims.height.next_power_of_two() == curr_height_padded }) .map(|(i, _)| opened_values[i].as_slice()), ); for &sibling in proof.iter() { let (left, right) = if index & 1 == 0 { (root, sibling) } else { (sibling, root) }; root = self.compress.compress([left, right]); index >>= 1; curr_height_padded >>= 1; let next_height = heights_tallest_first .peek() .map(|(_, dims)| dims.height) .filter(|h| h.next_power_of_two() == curr_height_padded); if let Some(next_height) = next_height { let next_height_openings_digest = self.hash.hash_iter_slices( heights_tallest_first .peeking_take_while(|(_, dims)| dims.height == next_height) .map(|(i, _)| opened_values[i].as_slice()), ); root = self.compress.compress([root, next_height_openings_digest]); } } if commit == &root { Ok(()) } else { Err(()) } } } #[cfg(test)] mod tests { use alloc::vec; use itertools::Itertools; use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear}; use p3_commit::Mmcs; use p3_field::{AbstractField, Field}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::{Dimensions, Matrix}; use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral}; use p3_symmetric::{ CryptographicHasher, PaddingFreeSponge, PseudoCompressionFunction, TruncatedPermutation, }; use rand::thread_rng; use super::FieldMerkleTreeMmcs; type F = BabyBear; type Perm = Poseidon2; type MyHash = PaddingFreeSponge; type MyCompress = TruncatedPermutation; type MyMmcs = FieldMerkleTreeMmcs<::Packing, ::Packing, MyHash, MyCompress, 8>; #[test] fn commit_single_1x8() { let perm = Perm::new_from_rng_128( Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, &mut thread_rng(), ); let hash = MyHash::new(perm.clone()); let compress = MyCompress::new(perm); let mmcs = MyMmcs::new(hash.clone(), compress.clone()); // v = [2, 1, 2, 2, 0, 0, 1, 0] let v = vec![ F::two(), F::one(), F::two(), F::two(), F::zero(), F::zero(), F::one(), F::zero(), ]; let (commit, _) = mmcs.commit_vec(v.clone()); let expected_result = compress.compress([ compress.compress([ compress.compress([hash.hash_item(v[0]), hash.hash_item(v[1])]), compress.compress([hash.hash_item(v[2]), hash.hash_item(v[3])]), ]), compress.compress([ compress.compress([hash.hash_item(v[4]), hash.hash_item(v[5])]), compress.compress([hash.hash_item(v[6]), hash.hash_item(v[7])]), ]), ]); assert_eq!(commit, expected_result); } #[test] fn commit_single_2x2() { let perm = Perm::new_from_rng_128( Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, &mut thread_rng(), ); let hash = MyHash::new(perm.clone()); let compress = MyCompress::new(perm); let mmcs = MyMmcs::new(hash.clone(), compress.clone()); // mat = [ // 0 1 // 2 1 // ] let mat = RowMajorMatrix::new(vec![F::zero(), F::one(), F::two(), F::one()], 2); let (commit, _) = mmcs.commit(vec![mat]); let expected_result = compress.compress([ hash.hash_slice(&[F::zero(), F::one()]), hash.hash_slice(&[F::two(), F::one()]), ]); assert_eq!(commit, expected_result); } #[test] fn commit_single_2x3() { let perm = Perm::new_from_rng_128( Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, &mut thread_rng(), ); let hash = MyHash::new(perm.clone()); let compress = MyCompress::new(perm); let mmcs = MyMmcs::new(hash.clone(), compress.clone()); let default_digest = [F::zero(); 8]; // mat = [ // 0 1 // 2 1 // 2 2 // ] let mat = RowMajorMatrix::new( vec![F::zero(), F::one(), F::two(), F::one(), F::two(), F::two()], 2, ); let (commit, _) = mmcs.commit(vec![mat]); let expected_result = compress.compress([ compress.compress([ hash.hash_slice(&[F::zero(), F::one()]), hash.hash_slice(&[F::two(), F::one()]), ]), compress.compress([hash.hash_slice(&[F::two(), F::two()]), default_digest]), ]); assert_eq!(commit, expected_result); } #[test] fn commit_mixed() { let perm = Perm::new_from_rng_128( Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, &mut thread_rng(), ); let hash = MyHash::new(perm.clone()); let compress = MyCompress::new(perm); let mmcs = MyMmcs::new(hash.clone(), compress.clone()); let default_digest = [F::zero(); 8]; // mat_1 = [ // 0 1 // 2 1 // 2 2 // ] let mat_1 = RowMajorMatrix::new( vec![F::zero(), F::one(), F::two(), F::one(), F::two(), F::two()], 2, ); // mat_2 = [ // 1 2 1 // 0 2 2 // ] let mat_2 = RowMajorMatrix::new( vec![F::one(), F::two(), F::one(), F::zero(), F::two(), F::two()], 3, ); let (commit, prover_data) = mmcs.commit(vec![mat_1, mat_2]); let mat_1_leaf_hashes = [ hash.hash_slice(&[F::zero(), F::one()]), hash.hash_slice(&[F::two(), F::one()]), hash.hash_slice(&[F::two(), F::two()]), ]; let mat_2_leaf_hashes = [ hash.hash_slice(&[F::one(), F::two(), F::one()]), hash.hash_slice(&[F::zero(), F::two(), F::two()]), ]; let expected_result = compress.compress([ compress.compress([ compress.compress([mat_1_leaf_hashes[0], mat_1_leaf_hashes[1]]), mat_2_leaf_hashes[0], ]), compress.compress([ compress.compress([mat_1_leaf_hashes[2], default_digest]), mat_2_leaf_hashes[1], ]), ]); assert_eq!(commit, expected_result); let (opened_values, _proof) = mmcs.open_batch(2, &prover_data); assert_eq!( opened_values, vec![ vec![F::two(), F::two()], vec![F::zero(), F::two(), F::two()] ] ); } #[test] fn commit_either_order() { let mut rng = thread_rng(); let perm = Perm::new_from_rng_128( Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, &mut rng, ); let hash = MyHash::new(perm.clone()); let compress = MyCompress::new(perm); let mmcs = MyMmcs::new(hash, compress); let input_1 = RowMajorMatrix::::rand(&mut rng, 5, 8); let input_2 = RowMajorMatrix::::rand(&mut rng, 3, 16); let (commit_1_2, _) = mmcs.commit(vec![input_1.clone(), input_2.clone()]); let (commit_2_1, _) = mmcs.commit(vec![input_2, input_1]); assert_eq!(commit_1_2, commit_2_1); } #[test] #[should_panic] fn mismatched_heights() { let mut rng = thread_rng(); let perm = Perm::new_from_rng_128( Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, &mut rng, ); let hash = MyHash::new(perm.clone()); let compress = MyCompress::new(perm); let mmcs = MyMmcs::new(hash, compress); // attempt to commit to a mat with 8 rows and a mat with 7 rows. this should panic. let large_mat = RowMajorMatrix::new( [1, 2, 3, 4, 5, 6, 7, 8].map(F::from_canonical_u8).to_vec(), 1, ); let small_mat = RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7].map(F::from_canonical_u8).to_vec(), 1); let _ = mmcs.commit(vec![large_mat, small_mat]); } #[test] fn verify_tampered_proof_fails() { let mut rng = thread_rng(); let perm = Perm::new_from_rng_128( Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, &mut rng, ); let hash = MyHash::new(perm.clone()); let compress = MyCompress::new(perm); let mmcs = MyMmcs::new(hash, compress); // 4 8x1 matrixes, 4 8x2 matrixes let large_mats = (0..4).map(|_| RowMajorMatrix::::rand(&mut thread_rng(), 8, 1)); let large_mat_dims = (0..4).map(|_| Dimensions { height: 8, width: 1, }); let small_mats = (0..4).map(|_| RowMajorMatrix::::rand(&mut thread_rng(), 8, 2)); let small_mat_dims = (0..4).map(|_| Dimensions { height: 8, width: 2, }); let (commit, prover_data) = mmcs.commit(large_mats.chain(small_mats).collect_vec()); // open the 3rd row of each matrix, mess with proof, and verify let (opened_values, mut proof) = mmcs.open_batch(3, &prover_data); proof[0][0] += F::one(); mmcs.verify_batch( &commit, &large_mat_dims.chain(small_mat_dims).collect_vec(), 3, &opened_values, &proof, ) .expect_err("expected verification to fail"); } #[test] fn size_gaps() { let mut rng = thread_rng(); let perm = Perm::new_from_rng_128( Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, &mut rng, ); let hash = MyHash::new(perm.clone()); let compress = MyCompress::new(perm); let mmcs = MyMmcs::new(hash, compress); // 4 mats with 1000 rows, 8 columns let large_mats = (0..4).map(|_| RowMajorMatrix::::rand(&mut thread_rng(), 1000, 8)); let large_mat_dims = (0..4).map(|_| Dimensions { height: 1000, width: 8, }); // 5 mats with 70 rows, 8 columns let medium_mats = (0..5).map(|_| RowMajorMatrix::::rand(&mut thread_rng(), 70, 8)); let medium_mat_dims = (0..5).map(|_| Dimensions { height: 70, width: 8, }); // 6 mats with 8 rows, 8 columns let small_mats = (0..6).map(|_| RowMajorMatrix::::rand(&mut thread_rng(), 8, 8)); let small_mat_dims = (0..6).map(|_| Dimensions { height: 8, width: 8, }); let (commit, prover_data) = mmcs.commit( large_mats .chain(medium_mats) .chain(small_mats) .collect_vec(), ); // open the 6th row of each matrix and verify let (opened_values, proof) = mmcs.open_batch(6, &prover_data); mmcs.verify_batch( &commit, &large_mat_dims .chain(medium_mat_dims) .chain(small_mat_dims) .collect_vec(), 6, &opened_values, &proof, ) .expect("expected verification to succeed"); } #[test] fn different_widths() { let mut rng = thread_rng(); let perm = Perm::new_from_rng_128( Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, &mut rng, ); let hash = MyHash::new(perm.clone()); let compress = MyCompress::new(perm); let mmcs = MyMmcs::new(hash, compress); // 10 mats with 32 rows where the ith mat has i + 1 cols let mats = (0..10) .map(|i| RowMajorMatrix::::rand(&mut thread_rng(), 32, i + 1)) .collect_vec(); let dims = mats.iter().map(|m| m.dimensions()).collect_vec(); let (commit, prover_data) = mmcs.commit(mats); let (opened_values, proof) = mmcs.open_batch(17, &prover_data); mmcs.verify_batch(&commit, &dims, 17, &opened_values, &proof) .expect("expected verification to succeed"); } }