use bitflags::bitflags; use clap::{value_t, value_t_or_exit, App, Arg}; use forrustts::ForrusttsError; use forrustts::IdType; use forrustts::Position; use forrustts::Time; use rand::rngs::StdRng; use rand::Rng; use rand::SeedableRng; use rand_distr::{Exp, Uniform}; // Some of the material below seems like a candidate for a public API, // but we need to decide here if this package should provide that. // If so, then many of these types should not be here, as they have nothing // to do with Wright-Fisher itself, and are instead more general. // Even though Position is an integer, we will use // an exponential distribution to get the distance to // the next crossover position. The reason for this is // that rand_distr::Geometric has really poor performance. type BreakpointFunction = Option>; #[derive(Copy, Clone)] struct Parent { index: usize, node0: IdType, node1: IdType, } struct Birth { index: usize, parent0: Parent, parent1: Parent, } type VecParent = Vec; type VecBirth = Vec; struct PopulationState { pub parents: VecParent, pub births: VecBirth, pub edge_buffer: forrustts::EdgeBuffer, pub tables: forrustts::TableCollection, } impl PopulationState { pub fn new(genome_length: Position) -> Self { PopulationState { parents: vec![], births: vec![], edge_buffer: forrustts::EdgeBuffer::new(), tables: forrustts::TableCollection::new(genome_length).unwrap(), } } } fn deaths_and_parents(psurvival: f64, rng: &mut StdRng, pop: &mut PopulationState) { pop.births.clear(); let random_parents = Uniform::new(0_usize, pop.parents.len() as usize); for i in 0..pop.parents.len() { let x: f64 = rng.gen(); match x.partial_cmp(&psurvival) { Some(std::cmp::Ordering::Greater) => { let parent0 = pop.parents[rng.sample(random_parents)]; let parent1 = pop.parents[rng.sample(random_parents)]; pop.births.push(Birth { index: i, parent0, parent1, }); } Some(_) => (), None => (), } } } fn mendel(pnodes: &mut (tskit::tsk_id_t, tskit::tsk_id_t), rng: &mut StdRng) { let x: f64 = rng.gen(); match x.partial_cmp(&0.5) { Some(std::cmp::Ordering::Less) => { std::mem::swap(&mut pnodes.0, &mut pnodes.1); } Some(_) => (), None => panic!("Unexpected None"), } } fn crossover_and_record_edges( parent: Parent, child: IdType, breakpoint: BreakpointFunction, recorder: &impl Fn( IdType, IdType, (Position, Position), &mut forrustts::TableCollection, &mut forrustts::EdgeBuffer, ), rng: &mut StdRng, tables: &mut forrustts::TableCollection, edge_buffer: &mut forrustts::EdgeBuffer, ) { let mut pnodes = (parent.node0, parent.node1); mendel(&mut pnodes, rng); let mut p0 = parent.node0; let mut p1 = parent.node1; if let Some(exp) = breakpoint { let mut current_pos: Position = 0; loop { // TODO: gotta justify the next line... let next_length = (rng.sample(exp) as Position) + 1; assert!(next_length > 0); if current_pos + next_length < tables.genome_length() { recorder( p0, child, (current_pos, current_pos + next_length), tables, edge_buffer, ); current_pos += next_length; std::mem::swap(&mut p0, &mut p1); } else { recorder( p0, child, (current_pos, tables.genome_length()), tables, edge_buffer, ); break; } } } else { recorder(p0, child, (0, tables.genome_length()), tables, edge_buffer); } } fn generate_births( breakpoint: BreakpointFunction, birth_time: Time, rng: &mut StdRng, pop: &mut PopulationState, recorder: &impl Fn( IdType, IdType, (Position, Position), &mut forrustts::TableCollection, &mut forrustts::EdgeBuffer, ), ) { for b in &pop.births { // Record 2 new nodes let new_node_0: IdType = pop.tables.add_node(birth_time, 0).unwrap(); let new_node_1: IdType = pop.tables.add_node(birth_time, 0).unwrap(); crossover_and_record_edges( b.parent0, new_node_0, breakpoint, recorder, rng, &mut pop.tables, &mut pop.edge_buffer, ); crossover_and_record_edges( b.parent1, new_node_1, breakpoint, recorder, rng, &mut pop.tables, &mut pop.edge_buffer, ); pop.parents[b.index].index = b.index; pop.parents[b.index].node0 = new_node_0; pop.parents[b.index].node1 = new_node_1; } } fn buffer_edges( parent: IdType, child: IdType, span: (Position, Position), _: &mut forrustts::TableCollection, buffer: &mut forrustts::EdgeBuffer, ) { buffer .extend(parent, forrustts::Segment::new(span.0, span.1, child)) .unwrap(); } fn record_edges( parent: IdType, child: IdType, span: (Position, Position), tables: &mut forrustts::TableCollection, _: &mut forrustts::EdgeBuffer, ) { tables.add_edge(span.0, span.1, parent, child).unwrap(); } fn fill_samples(parents: &[Parent], samples: &mut forrustts::SamplesInfo) { samples.samples.clear(); for p in parents { samples.samples.push(p.node0); samples.samples.push(p.node1); } } fn sort_and_simplify( flags: SimulationFlags, simplification_flags: forrustts::SimplificationFlags, samples: &forrustts::SamplesInfo, state: &mut forrustts::SimplificationBuffers, pop: &mut PopulationState, output: &mut forrustts::SimplificationOutput, ) { if !flags.contains(SimulationFlags::BUFFER_EDGES) { pop.tables .sort_tables(forrustts::TableSortingFlags::empty()); if flags.contains(SimulationFlags::USE_STATE) { forrustts::simplify_tables( samples, simplification_flags, state, &mut pop.tables, output, ) .unwrap(); } else { forrustts::simplify_tables_without_state( samples, simplification_flags, &mut pop.tables, output, ) .unwrap(); } } else { forrustts::simplify_from_edge_buffer( samples, simplification_flags, state, &mut pop.edge_buffer, &mut pop.tables, output, ) .unwrap(); } } fn simplify_and_remap_nodes( flags: SimulationFlags, simplification_flags: forrustts::SimplificationFlags, samples: &mut forrustts::SamplesInfo, state: &mut forrustts::SimplificationBuffers, pop: &mut PopulationState, output: &mut forrustts::SimplificationOutput, ) { fill_samples(&pop.parents, samples); sort_and_simplify(flags, simplification_flags, samples, state, pop, output); for p in &mut pop.parents { p.node0 = output.idmap[p.node0 as usize]; p.node1 = output.idmap[p.node1 as usize]; } if flags.contains(SimulationFlags::BUFFER_EDGES) { samples.edge_buffer_founder_nodes.clear(); for p in &pop.parents { samples.edge_buffer_founder_nodes.push(p.node0); samples.edge_buffer_founder_nodes.push(p.node1); } } } fn validate_simplification_interval(x: Time) -> Time { if x < 1 { panic!("simplification_interval must be None or >= 1"); } x } pub struct PopulationParams { pub size: u32, pub genome_length: Position, pub breakpoint: BreakpointFunction, pub psurvival: f64, } impl PopulationParams { pub fn new(size: u32, genome_length: Position, xovers: f64, psurvival: f64) -> Self { PopulationParams { size, genome_length, breakpoint: match xovers.partial_cmp(&0.0) { Some(std::cmp::Ordering::Greater) => { Some(Exp::new(xovers / genome_length as f64).unwrap()) } Some(_) => None, None => None, }, psurvival, } } } bitflags! { #[derive(Default)] pub struct SimulationFlags: u32 { // If set, and BUFFER_EDGES is not set, // then simplification will use a reusable set // of buffers for each call. Otherwise, // these buffers will be allocated each time // simplification happens. const USE_STATE = 1 << 0; // If set, edge buffering will be used. // If not set, then the standard "record // and sort" method will be used. const BUFFER_EDGES = 1 << 1; } } pub struct SimulationParams { pub simplification_interval: Option