/*
* Copyright (C) 2021 jessa0
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*/
#![allow(dead_code)]
use rand_core::{RngCore, SeedableRng};
use rand_chacha::ChaChaRng;
use simple_raft::core::RaftState;
use simple_raft::log::mem::RaftLogMemory;
use simple_raft::message::{LogEntry, RaftMessage, RaftMessageDestination, Rpc, SendableRaftMessage, TermId};
use simple_raft::node::RaftConfig;
use std::cell::RefCell;
use std::collections::{BTreeSet, VecDeque};
pub const CONFIG: RaftConfig = RaftConfig {
election_timeout_ticks: 10,
heartbeat_interval_ticks: 9,
replication_chunk_size: 1024,
};
const RANDOM_SEED: u64 = 0;
const MAX_TICKS: u32 = 100_000;
pub type TestRaft = RaftState;
pub struct TestRaftGroup {
pub nodes: Vec,
pub tick: u32,
pub config: TestRaftGroupConfig,
pub dropped_messages: Vec<(NodeId, SendableRaftMessage)>,
}
#[derive(Clone, Default)]
pub struct TestRaftGroupConfig {
pub drops: BTreeSet<(Option, Option)>,
pub down: BTreeSet,
}
#[derive(Clone, Copy, Debug, derive_more::Display, Eq, derive_more::From, PartialEq, PartialOrd, Ord)]
#[display(fmt = "{:?}", self)]
pub struct NodeId(u64);
pub struct TestLogger;
pub struct TestLoggerContext {
node_id: Option,
tick: Option,
}
pub fn rpc_types() -> [Rpc; 4] {
[
Rpc::VoteRequest(Default::default()),
Rpc::VoteResponse(Default::default()),
Rpc::AppendRequest(Default::default()),
Rpc::AppendResponse(Default::default()),
]
}
pub fn init_random() -> ChaChaRng {
ChaChaRng::seed_from_u64(RANDOM_SEED)
}
pub fn raft(node_id: u64, peers: Vec, log: Option, random: &mut impl RngCore) -> TestRaft {
TestLogger::init();
RaftState::new(
NodeId(node_id),
peers.into_iter().map(NodeId).collect(),
log.unwrap_or_else(|| RaftLogMemory::new_unbounded()),
ChaChaRng::seed_from_u64(random.next_u64()),
CONFIG,
)
}
pub fn config() -> TestRaftGroupConfig {
TestRaftGroupConfig::default()
}
pub fn send(raft: &mut TestRaft, from: u64, term: TermId, rpc: Rpc) -> Option> {
raft.receive(RaftMessage {
term,
rpc: Some(rpc),
}, NodeId(from))
}
pub fn append_entries<'a>(node: &'a mut TestRaft, peers: impl IntoIterator- + 'a) -> impl Iterator
- > + 'a {
let node_id = *node.node_id();
peers.into_iter().flat_map(move |append_to_node_id| {
if append_to_node_id != node_id {
node.append_entries(append_to_node_id)
} else {
None
}
})
}
pub fn run_group<'a>(
nodes: impl Iterator
- + ExactSizeIterator,
initial_messages: impl IntoIterator
- )>,
start_tick: u32,
ticks: Option,
config: &mut TestRaftGroupConfig,
dropped_messages: &mut Vec<(NodeId, SendableRaftMessage)>,
) {
let mut nodes: Vec<_> = nodes.collect();
let node_ids: Vec<_> = nodes.iter().map(|node| *node.node_id()).collect();
let mut messages = VecDeque::with_capacity(nodes.len() * nodes.len());
messages.extend(initial_messages.into_iter());
messages.extend(dropped_messages.drain(..));
for tick in 0..ticks.unwrap_or(1) {
TestLogger::set_tick(Some(start_tick + tick));
if ticks.is_some() {
for node in &mut nodes {
let node_id = *node.node_id();
if !config.is_node_down(node_id) {
TestLogger::set_node_id(Some(node_id));
messages.extend(node.timer_tick().map(|message| (node_id, message)));
messages.extend(append_entries(node, node_ids.iter().cloned()).map(|message| (node_id, message)));
}
}
}
while let Some((from, sendable)) = messages.pop_front() {
let (reply_to_node_id, to_node_count) = match sendable.dest {
RaftMessageDestination::Broadcast => (None, nodes.len().saturating_sub(1)),
RaftMessageDestination::To(to) => (Some(to), 1),
};
let to_nodes = nodes.iter_mut().filter(|node| match &reply_to_node_id {
Some(to_node_id) => node.node_id() == to_node_id,
None => node.node_id() != &from,
});
for (to_node, message) in Iterator::zip(to_nodes, itertools::repeat_n(sendable.message, to_node_count)) {
let to_node_id = *to_node.node_id();
TestLogger::set_node_id(Some(to_node_id));
if !config.should_drop(from, to_node_id) {
log::info!("<- {} {}", from, message);
messages.extend(to_node.receive(message, from).map(|message| (to_node_id, message)));
} else {
log::info!("<- {} DROPPED {}", from, message);
if let Some(reply_to_node_id) = reply_to_node_id {
dropped_messages.push((from, SendableRaftMessage { message, dest: RaftMessageDestination::To(reply_to_node_id) }));
}
}
messages.extend(append_entries(to_node, node_ids.iter().cloned()).map(|message| (to_node_id, message)));
}
}
}
TestLogger::set_tick(None);
TestLogger::set_node_id(None);
}
//
// RaftGroup impls
//
impl TestRaftGroup {
pub fn new(size: u64, random: &mut impl RngCore, config: TestRaftGroupConfig) -> Self {
let nodes: Vec = (0..size).collect();
Self {
nodes: nodes.iter().map(|node_id| raft(*node_id, nodes.clone(), None, random)).collect(),
tick: 0,
config,
dropped_messages: Default::default(),
}
}
pub fn run_until(&mut self, mut until_fun: impl FnMut(&mut Self) -> bool) -> &mut Self {
let mut ticks_remaining = MAX_TICKS;
while !until_fun(self) {
ticks_remaining = ticks_remaining.checked_sub(1).expect("condition failed after maximum simulation length");
self.tick += 1;
run_group(self.nodes.iter_mut(), None, self.tick, Some(1), &mut self.config, &mut self.dropped_messages);
}
self
}
pub fn run_until_commit(&mut self, mut until_fun: impl FnMut(&LogEntry) -> bool) -> &mut Self {
self.run_until(|group| {
let result = group.take_committed().any(|commit| !commit.data.is_empty() && until_fun(&commit));
group.take_committed().for_each(drop);
result
})
}
pub fn run_for(&mut self, ticks: u32) -> &mut Self {
self.run_for_inspect(ticks, |_| ())
}
pub fn run_for_inspect(&mut self, ticks: u32, mut fun: impl FnMut(&mut Self)) -> &mut Self {
let mut ticks_remaining = ticks;
while let Some(new_ticks_remaining) = ticks_remaining.checked_sub(1) {
ticks_remaining = new_ticks_remaining;
self.tick += 1;
run_group(self.nodes.iter_mut(), None, self.tick, Some(1), &mut self.config, &mut self.dropped_messages);
fun(self);
}
self
}
pub fn run_on_all(
&mut self,
mut fun: impl FnMut(&mut TestRaft) -> Option>,
) -> &mut Self {
let messages = self.nodes.iter_mut().flat_map(|node| fun(node).map(|message| (*node.node_id(), message))).collect::>();
run_group(self.nodes.iter_mut(), messages, self.tick, None, &mut self.config, &mut self.dropped_messages);
self
}
pub fn run_on_node(
&mut self,
node_idx: usize,
fun: impl FnOnce(&mut TestRaft) -> Option>,
) -> &mut Self {
let node_id = *self.nodes[node_idx].node_id();
let messages = fun(&mut self.nodes[node_idx]).map(|message| (node_id, message));
run_group(self.nodes.iter_mut(), messages, self.tick, None, &mut self.config, &mut self.dropped_messages);
self
}
pub fn inspect(&mut self, fun: impl FnOnce(&Self)) -> &mut Self {
fun(self);
self
}
pub fn modify(&mut self, fun: impl FnOnce(&mut Self)) -> &mut Self {
fun(self);
self
}
pub fn take_committed(&mut self) -> impl Iterator
- + '_ {
self.nodes.iter_mut().flat_map(|node| node.take_committed())
}
pub fn has_leader(&self) -> bool {
self.nodes.iter().any(|node| node.is_leader())
}
}
//
// TestRaftGroupConfig impls
//
impl TestRaftGroupConfig {
pub fn node_down(mut self, node_id: u64) -> Self {
self.down.insert(NodeId(node_id));
self
}
pub fn isolate(mut self, node_id: u64) -> Self {
self.drops.insert((Some(NodeId(node_id)), None));
self.drops.insert((None, Some(NodeId(node_id))));
self
}
pub fn drop_between(mut self, from: u64, to: u64) -> Self {
self.drops.insert((Some(NodeId(from)), Some(NodeId(to))));
self.drops.insert((Some(NodeId(to)), Some(NodeId(from))));
self
}
pub fn drop_to(mut self, node_id: u64) -> Self {
self.drops.insert((None, Some(NodeId(node_id))));
self
}
pub fn is_node_down(&self, node_id: NodeId) -> bool {
self.down.contains(&node_id)
}
pub fn should_drop(&self, from: NodeId, to: NodeId) -> bool {
self.drops.contains(&(Some(from), Some(to))) ||
self.drops.contains(&(Some(from), None)) ||
self.drops.contains(&(None, Some(to))) ||
self.down.contains(&from) ||
self.down.contains(&to)
}
}
//
// TestLogger impls
//
thread_local! {
static LOGGER_CONTEXT: RefCell = RefCell::new(TestLoggerContext::new());
}
impl TestLogger {
pub fn init() {
let _ignore = log::set_logger(&Self);
log::set_max_level(log::LevelFilter::Debug);
}
pub fn set_node_id(node_id: Option) {
LOGGER_CONTEXT.with(|context| {
context.borrow_mut().node_id = node_id;
});
}
pub fn set_tick(tick: Option) {
LOGGER_CONTEXT.with(|context| {
context.borrow_mut().tick = tick;
});
}
}
impl log::Log for TestLogger {
fn enabled(&self, _metadata: &log::Metadata) -> bool {
true
}
fn log(&self, record: &log::Record) {
LOGGER_CONTEXT.with(|context| {
let context = context.borrow();
if let Some(node_id) = context.node_id {
if let Some(tick) = context.tick {
eprintln!("tick {:03} {} {}", tick, node_id, record.args());
} else {
eprintln!("tick ??? {} {}", node_id, record.args());
}
} else {
eprintln!("{}", record.args());
}
})
}
fn flush(&self) {}
}
//
// TextLoggerContext impls
//
impl TestLoggerContext {
const fn new() -> Self {
Self {
node_id: None,
tick: None,
}
}
}