use { itertools::Itertools, lru::LruCache, solana_sdk::pubkey::Pubkey, std::{cmp::Reverse, collections::HashMap}, }; // For each origin, tracks which nodes have sent messages from that origin and // their respective score in terms of timeliness of delivered messages. pub(crate) struct ReceivedCache(LruCache); #[derive(Clone, Default)] struct ReceivedCacheEntry { nodes: HashMap, num_upserts: usize, } impl ReceivedCache { // Minimum number of upserts before a cache entry can be pruned. const MIN_NUM_UPSERTS: usize = 20; pub(crate) fn new(capacity: usize) -> Self { Self(LruCache::new(capacity)) } pub(crate) fn record(&mut self, origin: Pubkey, node: Pubkey, num_dups: usize) { match self.0.get_mut(&origin) { Some(entry) => entry.record(node, num_dups), None => { let mut entry = ReceivedCacheEntry::default(); entry.record(node, num_dups); self.0.put(origin, entry); } } } pub(crate) fn prune( &mut self, pubkey: &Pubkey, // This node. origin: Pubkey, // CRDS value owner. stake_threshold: f64, min_ingress_nodes: usize, stakes: &HashMap, ) -> impl Iterator { match self.0.peek_mut(&origin) { None => None, Some(entry) if entry.num_upserts < Self::MIN_NUM_UPSERTS => None, Some(entry) => Some( std::mem::take(entry) .prune(pubkey, &origin, stake_threshold, min_ingress_nodes, stakes) .filter(move |node| node != &origin), ), } .into_iter() .flatten() } #[cfg(test)] fn mock_clone(&self) -> Self { let mut cache = LruCache::new(self.0.cap()); for (&origin, entry) in self.0.iter().rev() { cache.put(origin, entry.clone()); } Self(cache) } } impl ReceivedCacheEntry { // Limit how big the cache can get if it is spammed // with old messages with random pubkeys. const CAPACITY: usize = 50; // Threshold for the number of duplicates before which a message // is counted as timely towards node's score. const NUM_DUPS_THRESHOLD: usize = 2; fn record(&mut self, node: Pubkey, num_dups: usize) { if num_dups == 0 { self.num_upserts = self.num_upserts.saturating_add(1); } // If the message has been timely enough increment node's score. if num_dups < Self::NUM_DUPS_THRESHOLD { let score = self.nodes.entry(node).or_default(); *score = score.saturating_add(1); } else if self.nodes.len() < Self::CAPACITY { // Ensure that node is inserted into the cache for later pruning. // This intentionally does not negatively impact node's score, in // order to prevent replayed messages with spoofed addresses force // pruning a good node. let _ = self.nodes.entry(node).or_default(); } } fn prune( self, pubkey: &Pubkey, // This node. origin: &Pubkey, // CRDS value owner. stake_threshold: f64, min_ingress_nodes: usize, stakes: &HashMap, ) -> impl Iterator { debug_assert!((0.0..=1.0).contains(&stake_threshold)); debug_assert!(self.num_upserts >= ReceivedCache::MIN_NUM_UPSERTS); // Enforce a minimum aggregate ingress stake; see: // https://github.com/solana-labs/solana/issues/3214 let min_ingress_stake = { let stake = stakes.get(pubkey).min(stakes.get(origin)); (stake.copied().unwrap_or_default() as f64 * stake_threshold) as u64 }; self.nodes .into_iter() .map(|(node, score)| { let stake = stakes.get(&node).copied().unwrap_or_default(); (node, score, stake) }) .sorted_unstable_by_key(|&(_, score, stake)| Reverse((score, stake))) .scan(0u64, |acc, (node, _score, stake)| { let old = *acc; *acc = acc.saturating_add(stake); Some((node, old)) }) .skip(min_ingress_nodes) .skip_while(move |&(_, stake)| stake < min_ingress_stake) .map(|(node, _stake)| node) } } #[cfg(test)] mod tests { use { super::*, std::{collections::HashSet, iter::repeat_with}, }; #[test] fn test_received_cache() { let mut cache = ReceivedCache::new(/*capacity:*/ 100); let pubkey = Pubkey::new_unique(); let origin = Pubkey::new_unique(); let records = vec![ vec![3, 1, 7, 5], vec![7, 6, 5, 2], vec![2, 0, 0, 2], vec![3, 5, 0, 6], vec![6, 2, 6, 2], ]; let nodes: Vec<_> = repeat_with(Pubkey::new_unique) .take(records.len()) .collect(); for (node, records) in nodes.iter().zip(records) { for (num_dups, k) in records.into_iter().enumerate() { for _ in 0..k { cache.record(origin, *node, num_dups); } } } assert_eq!(cache.0.get(&origin).unwrap().num_upserts, 21); let scores: HashMap = [ (nodes[0], 4), (nodes[1], 13), (nodes[2], 2), (nodes[3], 8), (nodes[4], 8), ] .into_iter() .collect(); assert_eq!(cache.0.get(&origin).unwrap().nodes, scores); let stakes = [ (nodes[0], 6), (nodes[1], 1), (nodes[2], 5), (nodes[3], 3), (nodes[4], 7), (pubkey, 9), (origin, 9), ] .into_iter() .collect(); let prunes: HashSet = [nodes[0], nodes[2], nodes[3]].into_iter().collect(); assert_eq!( cache .mock_clone() .prune(&pubkey, origin, 0.5, 2, &stakes) .collect::>(), prunes ); let prunes: HashSet = [nodes[0], nodes[2]].into_iter().collect(); assert_eq!( cache .prune(&pubkey, origin, 1.0, 0, &stakes) .collect::>(), prunes ); } }