From b95084dd4d615c3a685eebca83df4c16e7133910 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:53:59 +0000 Subject: [PATCH 1/2] create custom activations example --- examples/custom_activation.rs | 92 +++++++++++++++++++++++++++++++++++ src/topology/activation.rs | 4 +- 2 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 examples/custom_activation.rs diff --git a/examples/custom_activation.rs b/examples/custom_activation.rs new file mode 100644 index 0000000..f52882b --- /dev/null +++ b/examples/custom_activation.rs @@ -0,0 +1,92 @@ +//! An example implementation of a custom activation function. + +use neat::*; +use rand::prelude::*; + +#[derive(DivisionReproduction, RandomlyMutable, Clone)] +struct AgentDNA { + network: NeuralNetworkTopology<2, 2>, +} + +impl Prunable for AgentDNA {} + +impl GenerateRandom for AgentDNA { + fn gen_random(rng: &mut impl Rng) -> Self { + Self { + network: NeuralNetworkTopology::new(0.01, 3, rng), + } + } +} + +fn fitness(g: &AgentDNA) -> f32 { + let network: NeuralNetwork<2, 2> = NeuralNetwork::from(&g.network); + let mut fitness = 0.; + let mut rng = rand::thread_rng(); + + for _ in 0..50 { + let n = rng.gen::(); + let n2 = rng.gen::(); + + let expected = if (n + n2) / 2. >= 0.5 { + 0 + } else { + 1 + }; + + let result = network.predict([n, n2]); + network.flush_state(); + + // partial_cmp chance of returning None in this smh + let result = result.iter().max_index(); + + if result == expected { + fitness += 1.; + } else { + fitness -= 1.; + } + } + + fitness +} + +#[cfg(feature = "serde")] +fn serde_nextgen(rewards: Vec<(AgentDNA, f32)>) -> Vec { + let max = rewards + .iter() + .max_by(|(_, ra), (_, rb)| ra.total_cmp(rb)) + .unwrap(); + + let ser = NNTSerde::from(&max.0.network); + let data = serde_json::to_string_pretty(&ser).unwrap(); + std::fs::write("best-agent.json", data).expect("Failed to write to file"); + + division_pruning_nextgen(rewards) +} + +fn main() { + let log_activation = activation_fn!(f32::log10); + register_activation(log_activation); + + #[cfg(not(feature = "rayon"))] + let mut rng = rand::thread_rng(); + + let mut sim = GeneticSim::new( + #[cfg(not(feature = "rayon"))] + Vec::gen_random(&mut rng, 100), + + #[cfg(feature = "rayon")] + Vec::gen_random(100), + + fitness, + + #[cfg(not(feature = "serde"))] + division_pruning_nextgen, + + #[cfg(feature = "serde")] + serde_nextgen, + ); + + for _ in 0..200 { + sim.next_generation(); + } +} \ No newline at end of file diff --git a/src/topology/activation.rs b/src/topology/activation.rs index a711851..5bf9540 100644 --- a/src/topology/activation.rs +++ b/src/topology/activation.rs @@ -15,11 +15,11 @@ use crate::NeuronLocation; #[macro_export] macro_rules! activation_fn { ($F: path) => { - ActivationFn::new(Arc::new($F), ActivationScope::default(), stringify!($F).into()) + ActivationFn::new(std::sync::Arc::new($F), ActivationScope::default(), stringify!($F).into()) }; ($F: path, $S: expr) => { - ActivationFn::new(Arc::new($F), $S, stringify!($F).into()) + ActivationFn::new(std::sync::Arc::new($F), $S, stringify!($F).into()) }; {$($F: path),*} => { From 7c31f30f88bcd0221e628daa8b0d0530f5c46523 Mon Sep 17 00:00:00 2001 From: Tristan Murphy Date: Thu, 13 Jun 2024 10:54:22 -0400 Subject: [PATCH 2/2] cargo fmt --- examples/custom_activation.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/examples/custom_activation.rs b/examples/custom_activation.rs index f52882b..bc6aae2 100644 --- a/examples/custom_activation.rs +++ b/examples/custom_activation.rs @@ -27,11 +27,7 @@ fn fitness(g: &AgentDNA) -> f32 { let n = rng.gen::(); let n2 = rng.gen::(); - let expected = if (n + n2) / 2. >= 0.5 { - 0 - } else { - 1 - }; + let expected = if (n + n2) / 2. >= 0.5 { 0 } else { 1 }; let result = network.predict([n, n2]); network.flush_state(); @@ -73,15 +69,11 @@ fn main() { let mut sim = GeneticSim::new( #[cfg(not(feature = "rayon"))] Vec::gen_random(&mut rng, 100), - #[cfg(feature = "rayon")] Vec::gen_random(100), - fitness, - #[cfg(not(feature = "serde"))] division_pruning_nextgen, - #[cfg(feature = "serde")] serde_nextgen, ); @@ -89,4 +81,4 @@ fn main() { for _ in 0..200 { sim.next_generation(); } -} \ No newline at end of file +}