use rsneat::genome::Genome; // use rsneat::network::Network; use rsneat::population::Population; use rsneat::Neat; #[test] fn xor_avg() { let mut total = 0; let mut count = 0; let mut less_count = 0; let less_limit = 300; for i in 0..100 { println!("Run: {}", i + 1); let result = xor(); if result < less_limit { less_count += 1; } total += result; count += 1; } println!( "Average for {} runs: {}", count, total as f64 / count as f64 ); println!( "Runs with less than {} gens to solve: {}", less_limit, less_count ); } fn xor() -> usize { let mut neat = Neat::new(); let mut founder = Genome::new(3, 1); founder.add_connection(0, 3, 0.0, &mut neat); founder.add_connection(1, 3, 0.0, &mut neat); founder.add_connection(2, 3, 0.0, &mut neat); // println!("{:#?}", founder); let mut pop = Population::clone_from(founder, &neat); loop { let mut correct = true; for (mut n, fitness) in pop.iter_fitness() { correct = true; let data = [ (vec![0.0, 0.0, 1.0], 0.0, false), (vec![0.0, 1.0, 1.0], 1.0, true), (vec![1.0, 0.0, 1.0], 1.0, true), (vec![1.0, 1.0, 1.0], 0.0, false), ]; let mut error = 0.0; *fitness = 0.0; for (input, output, c) in data.iter() { for _ in 0..5 { let _ = n.activate(input.clone().into_iter()); } let result = n.activate(input.clone().into_iter()).next().unwrap(); let result_error = (result - output).abs(); error += result_error; let result = result >= 0.5; if *c != result { correct = false; } n.reset(); } *fitness += 4.0 - error; *fitness *= *fitness; if correct { break; } } if correct { pop.champ().write(&mut std::fs::File::create("champ.neat").unwrap()).unwrap(); println!("Got a champ after {} gens!", pop.gen()); return pop.gen(); } pop.evaluate_generation(&mut neat); } }