// Copyright (c) Aptos // SPDX-License-Identifier: Apache-2.0 use anyhow::Result; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use rocksdb::DEFAULT_COLUMN_FAMILY_NAME; use sov_schema_db::schema::{KeyDecoder, KeyEncoder, Schema, ValueCodec}; use sov_schema_db::{define_schema, CodecError, SchemaIterator, SeekKeyEncoder, DB}; use tempfile::TempDir; define_schema!(TestSchema, TestKey, TestValue, "TestCF"); #[derive(Debug, Eq, PartialEq)] pub(crate) struct TestKey(u32, u32, u32); #[derive(Debug, Eq, PartialEq)] pub(crate) struct TestValue(u32); impl KeyEncoder for TestKey { fn encode_key(&self) -> Result, CodecError> { let mut bytes = vec![]; bytes .write_u32::(self.0) .map_err(|e| CodecError::Wrapped(e.into()))?; bytes .write_u32::(self.1) .map_err(|e| CodecError::Wrapped(e.into()))?; bytes .write_u32::(self.2) .map_err(|e| CodecError::Wrapped(e.into()))?; Ok(bytes) } } impl KeyDecoder for TestKey { fn decode_key(data: &[u8]) -> Result { let mut reader = std::io::Cursor::new(data); Ok(TestKey( reader .read_u32::() .map_err(|e| CodecError::Wrapped(e.into()))?, reader .read_u32::() .map_err(|e| CodecError::Wrapped(e.into()))?, reader .read_u32::() .map_err(|e| CodecError::Wrapped(e.into()))?, )) } } impl SeekKeyEncoder for TestKey { fn encode_seek_key(&self) -> sov_schema_db::schema::Result> { self.encode_key() } } impl ValueCodec for TestValue { fn encode_value(&self) -> Result, CodecError> { Ok(self.0.to_be_bytes().to_vec()) } fn decode_value(data: &[u8]) -> Result { let mut reader = std::io::Cursor::new(data); Ok(TestValue( reader .read_u32::() .map_err(|e| CodecError::Wrapped(e.into()))?, )) } } pub struct KeyPrefix1(u32); impl SeekKeyEncoder for KeyPrefix1 { fn encode_seek_key(&self) -> Result, CodecError> { Ok(self.0.to_be_bytes().to_vec()) } } pub struct KeyPrefix2(u32, u32); impl SeekKeyEncoder for KeyPrefix2 { fn encode_seek_key(&self) -> Result, CodecError> { let mut bytes = vec![]; bytes .write_u32::(self.0) .map_err(|e| CodecError::Wrapped(e.into()))?; bytes .write_u32::(self.1) .map_err(|e| CodecError::Wrapped(e.into()))?; Ok(bytes) } } fn collect_values(iter: SchemaIterator) -> Vec { iter.map(|row| (row.unwrap().1).0).collect() } struct TestDB { _tmpdir: TempDir, db: DB, } impl TestDB { fn new() -> Self { let tmpdir = tempfile::tempdir().unwrap(); let column_families = vec![DEFAULT_COLUMN_FAMILY_NAME, TestSchema::COLUMN_FAMILY_NAME]; let mut db_opts = rocksdb::Options::default(); db_opts.create_if_missing(true); db_opts.create_missing_column_families(true); let db = DB::open(tmpdir.path(), "test", column_families, &db_opts).unwrap(); db.put::(&TestKey(1, 0, 0), &TestValue(100)) .unwrap(); db.put::(&TestKey(1, 0, 2), &TestValue(102)) .unwrap(); db.put::(&TestKey(1, 0, 4), &TestValue(104)) .unwrap(); db.put::(&TestKey(1, 1, 0), &TestValue(110)) .unwrap(); db.put::(&TestKey(1, 1, 2), &TestValue(112)) .unwrap(); db.put::(&TestKey(1, 1, 4), &TestValue(114)) .unwrap(); db.put::(&TestKey(2, 0, 0), &TestValue(200)) .unwrap(); db.put::(&TestKey(2, 0, 2), &TestValue(202)) .unwrap(); TestDB { _tmpdir: tmpdir, db, } } } impl TestDB { fn iter(&self) -> SchemaIterator { self.db.iter().expect("Failed to create iterator.") } fn rev_iter(&self) -> SchemaIterator { self.db.rev_iter().expect("Failed to create iterator.") } } impl std::ops::Deref for TestDB { type Target = DB; fn deref(&self) -> &Self::Target { &self.db } } #[test] fn test_seek_to_first() { let db = TestDB::new(); let mut iter = db.iter(); iter.seek_to_first(); assert_eq!( collect_values(iter), [100, 102, 104, 110, 112, 114, 200, 202] ); let mut iter = db.rev_iter(); iter.seek_to_first(); assert_eq!(collect_values(iter), [100]); } #[test] fn test_seek_to_last() { let db = TestDB::new(); let mut iter = db.iter(); iter.seek_to_last(); assert_eq!(collect_values(iter), [202]); let mut iter = db.rev_iter(); iter.seek_to_last(); assert_eq!( collect_values(iter), [202, 200, 114, 112, 110, 104, 102, 100] ); } #[test] fn test_seek_by_existing_key() { let db = TestDB::new(); let mut iter = db.iter(); iter.seek(&TestKey(1, 1, 0)).unwrap(); assert_eq!(collect_values(iter), [110, 112, 114, 200, 202]); let mut iter = db.rev_iter(); iter.seek(&TestKey(1, 1, 0)).unwrap(); assert_eq!(collect_values(iter), [110, 104, 102, 100]); } #[test] fn test_seek_by_nonexistent_key() { let db = TestDB::new(); let mut iter = db.iter(); iter.seek(&TestKey(1, 1, 1)).unwrap(); assert_eq!(collect_values(iter), [112, 114, 200, 202]); let mut iter = db.rev_iter(); iter.seek(&TestKey(1, 1, 1)).unwrap(); assert_eq!(collect_values(iter), [112, 110, 104, 102, 100]); } #[test] fn test_seek_for_prev_by_existing_key() { let db = TestDB::new(); let mut iter = db.iter(); iter.seek_for_prev(&TestKey(1, 1, 0)).unwrap(); assert_eq!(collect_values(iter), [110, 112, 114, 200, 202]); let mut iter = db.rev_iter(); iter.seek_for_prev(&TestKey(1, 1, 0)).unwrap(); assert_eq!(collect_values(iter), [110, 104, 102, 100]); } #[test] fn test_seek_for_prev_by_nonexistent_key() { let db = TestDB::new(); let mut iter = db.iter(); iter.seek_for_prev(&TestKey(1, 1, 1)).unwrap(); assert_eq!(collect_values(iter), [110, 112, 114, 200, 202]); let mut iter = db.rev_iter(); iter.seek_for_prev(&TestKey(1, 1, 1)).unwrap(); assert_eq!(collect_values(iter), [110, 104, 102, 100]); } #[test] fn test_seek_by_1prefix() { let db = TestDB::new(); let mut iter = db.iter(); iter.seek(&KeyPrefix1(2)).unwrap(); assert_eq!(collect_values(iter), [200, 202]); let mut iter = db.rev_iter(); iter.seek(&KeyPrefix1(2)).unwrap(); assert_eq!(collect_values(iter), [200, 114, 112, 110, 104, 102, 100]); } #[test] fn test_seek_for_prev_by_1prefix() { let db = TestDB::new(); let mut iter = db.iter(); iter.seek_for_prev(&KeyPrefix1(2)).unwrap(); assert_eq!(collect_values(iter), [114, 200, 202]); let mut iter = db.rev_iter(); iter.seek_for_prev(&KeyPrefix1(2)).unwrap(); assert_eq!(collect_values(iter), [114, 112, 110, 104, 102, 100]); } #[test] fn test_seek_by_2prefix() { let db = TestDB::new(); let mut iter = db.iter(); iter.seek(&KeyPrefix2(2, 0)).unwrap(); assert_eq!(collect_values(iter), [200, 202]); let mut iter = db.rev_iter(); iter.seek(&KeyPrefix2(2, 0)).unwrap(); assert_eq!(collect_values(iter), [200, 114, 112, 110, 104, 102, 100]); } #[test] fn test_seek_for_prev_by_2prefix() { let db = TestDB::new(); let mut iter = db.iter(); iter.seek_for_prev(&KeyPrefix2(2, 0)).unwrap(); assert_eq!(collect_values(iter), [114, 200, 202]); let mut iter = db.rev_iter(); iter.seek_for_prev(&KeyPrefix2(2, 0)).unwrap(); assert_eq!(collect_values(iter), [114, 112, 110, 104, 102, 100]); }