use clap::Parser; use rand::prelude::*; use triadic_memory::{ encoders::{ScalarEncoder, ScalarFloatEncoder}, sdr, sequence_mem::SequenceMemory, }; #[derive(Debug, Parser)] #[clap(author, version, about)] struct Cli { /// Size of the SDRs #[clap(long, short = 'N', default_value_t = sdr::SIZE)] size: usize, /// Number of active connections in the SDRs #[clap(long, short = 'P', default_value_t = sdr::POP)] pop: usize, /// Length of stored sequence #[clap(long, short, default_value_t = 100)] length: usize, /// How many items to feed each sequence in before trying to recall #[clap(long, short = 'p', default_value_t = 2)] pre_roll: usize, /// Check recall on the first ten SDRs at the end of the test. #[clap(long, action)] nostalgia: bool, /// If set, encode a linear sequence of values apart (cannot be used with '--min' or '--max') #[clap(long, action)] step: Option, /// Resolution of the encoder #[clap(long, short, default_value_t = 0.1)] resolution: f64, /// Lower bound for random values in a sequence (cannot be used with '--step') #[clap(long, conflicts_with = "step", default_value_t = 0.0)] min: f64, /// Upper bound for random values in a sequence (cannot be used with '--step') #[clap(long, conflicts_with = "step", default_value_t = 10.0)] max: f64, } fn main() { let cli = Cli::parse(); let size = cli.size; let pop = cli.pop; let len = cli.length; let pre_roll = cli.pre_roll; let resolution = cli.resolution; let mut rng = StdRng::from_entropy(); let mut mem = SequenceMemory::new(size, pop); let enc = ScalarFloatEncoder::new(size, pop, resolution); let mut data = Vec::with_capacity(len); if let Some(step) = cli.step { for i in 0..len { data.push(i as f64 * step); } } else { for _ in 0..len { let v = rng.gen_range(cli.min..cli.max); data.push(v); } } let mut sdrs = Vec::with_capacity(len); for p in data.iter() { let s = enc.encode(*p); sdrs.push(s); } let data = &data; let sdrs = &sdrs; mem.new_seq(); for sdr in sdrs.iter() { // first, add the sequence mem.add_to_sequence(sdr); } // now, check recall mem.new_seq(); for sdr in sdrs[0..pre_roll].iter() { mem.predict_next(sdr); } let mut prev = sdrs[pre_roll].clone(); for i in pre_roll..len { let y = data[i]; let original = &sdrs[i]; let overlap = original.overlap(&prev); let error = pop - overlap; prev = mem.predict_next(&prev); println!("{i} {y} {error}"); } if cli.nostalgia { mem.new_seq(); let len = (len / 10).max(10); for sdr in sdrs[0..pre_roll].iter() { mem.predict_next(sdr); } let mut prev = sdrs[pre_roll].clone(); for (i, sdr) in sdrs[pre_roll..len].iter().enumerate() { let overlap = sdr.overlap(&prev); let error = overlap - pop; prev = mem.predict_next(&prev); println!("NOSTALGIA SDR {i}: error: {error}"); } } }