diff --git a/CONTRIBUTING b/CONTRIBUTING
new file mode 100644
index 0000000..68c1a8c
--- /dev/null
+++ b/CONTRIBUTING
@@ -0,0 +1,6 @@
+Thanks for contributing to this project.
+
+To get started, check out the [issues page](https://github.com/inflectrix/neat). You can either find a feature/fix from there or start a new issue, then begin implementing it in your own fork of this repo.
+
+Once you are done making the changes you'd like the make, start a pull request to the [dev](https://github.com/inflectrix/neat/tree/dev) branch. State your changes and request a review. After all branch rules have been satisfied, someone with management permissions on this repository will merge it.
+
diff --git a/Cargo.lock b/Cargo.lock
index 9980c49..e77a241 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -92,6 +92,12 @@ dependencies = [
"wasi",
]
+[[package]]
+name = "itoa"
+version = "1.0.10"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
+
[[package]]
name = "libc"
version = "0.2.153"
@@ -108,6 +114,7 @@ dependencies = [
"rayon",
"serde",
"serde-big-array",
+ "serde_json",
]
[[package]]
@@ -190,6 +197,12 @@ version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3a8614ee435691de62bcffcf4a66d91b3594bf1428a5722e79103249a095690"
+[[package]]
+name = "ryu"
+version = "1.0.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1"
+
[[package]]
name = "serde"
version = "1.0.197"
@@ -219,6 +232,17 @@ dependencies = [
"syn",
]
+[[package]]
+name = "serde_json"
+version = "1.0.114"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0"
+dependencies = [
+ "itoa",
+ "ryu",
+ "serde",
+]
+
[[package]]
name = "syn"
version = "2.0.51"
diff --git a/Cargo.toml b/Cargo.toml
index fa5f9bc..285c041 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -33,4 +33,5 @@ serde = { version = "1.0.197", features = ["derive"], optional = true }
serde-big-array = { version = "0.5.1", optional = true }
[dev-dependencies]
-bincode = "1.3.3"
\ No newline at end of file
+bincode = "1.3.3"
+serde_json = "1.0.114"
\ No newline at end of file
diff --git a/README.md b/README.md
index 2502fae..ad775e2 100644
--- a/README.md
+++ b/README.md
@@ -3,12 +3,14 @@
[](https://crates.io/crates/neat)
[](https://docs.rs/neat)
-Implementation of the NEAT algorithm using `genetic-rs`
+Implementation of the NEAT algorithm using `genetic-rs`.
### Features
- rayon - Uses parallelization on the `NeuralNetwork` struct and adds the `rayon` feature to the `genetic-rs` re-export.
- serde - Adds the NNTSerde struct and allows for serialization of `NeuralNetworkTopology`
-- crossover - Implements the `CrossoverReproduction` trait on `NeuralNetworkTopology` and adds the `crossover` feature to the `genetic-rs re-export.
+- crossover - Implements the `CrossoverReproduction` trait on `NeuralNetworkTopology` and adds the `crossover` feature to the `genetic-rs` re-export.
+
+*Do you like this repo and want to support it? If so, leave a ⭐*
### How To Use
When working with this crate, you'll want to use the `NeuralNetworkTopology` struct in your agent's DNA and
@@ -21,36 +23,55 @@ use neat::*;
#[derive(Clone, RandomlyMutable, DivisionReproduction)]
struct MyAgentDNA {
network: NeuralNetworkTopology<1, 2>,
- other_stuff: Foo,
}
impl GenerateRandom for MyAgentDNA {
fn gen_random(rng: &mut impl rand::Rng) -> Self {
Self {
network: NeuralNetworkTopology::new(0.01, 3, rng),
- other_stuff: Foo::gen_random(rng),
}
}
}
struct MyAgent {
network: NeuralNetwork<1, 2>,
- some_other_state: Bar,
+ // ... other state
}
impl From<&MyAgentDNA> for MyAgent {
fn from(value: &MyAgentDNA) -> Self {
Self {
network: NeuralNetwork::from(&value.network),
- some_other_state: Bar::default(),
}
}
}
fn fitness(dna: &MyAgentDNA) -> f32 {
+ // agent will simply try to predict whether a number is greater than 0.5
let mut agent = MyAgent::from(dna);
+ let mut rng = rand::thread_rng();
+ let mut fitness = 0;
+
+ // use repeated tests to avoid situational bias and some local maximums, overall providing more accurate score
+ for _ in 0..10 {
+ let n = rng.gen::();
+ let above = n > 0.5;
+
+ let res = agent.network.predict([n]);
+ let resi = res.iter().max_index();
+
+ if resi == 0 ^ above {
+ // agent did not guess correctly, punish slightly (too much will hinder exploration)
+ fitness -= 0.5;
+
+ continue;
+ }
+
+ // agent guessed correctly, they become more fit.
+ fitness += 3.;
+ }
- // ... use agent.network.predict() and agent.network.flush() throughout multiple iterations
+ fitness
}
fn main() {
@@ -62,7 +83,18 @@ fn main() {
division_pruning_nextgen,
);
- // ... simulate generations, etc.
+ // simulate 100 generations
+ for _ in 0..100 {
+ sim.next_generation();
+ }
+
+ // display fitness results
+ let fits: Vec<_> = sim.entities
+ .iter()
+ .map(fitness)
+ .collect();
+
+ dbg!(&fits, fits.iter().max());
}
```
diff --git a/examples/basic.rs b/examples/basic.rs
index 32e99e6..bcd5d6d 100644
--- a/examples/basic.rs
+++ b/examples/basic.rs
@@ -143,14 +143,17 @@ fn main() {
sim.next_generation();
}
- let fits: Vec<_> = sim.genomes.iter().map(fitness).collect();
+ let mut fits: Vec<_> = sim.genomes.iter().map(|e| (e, fitness(e))).collect();
- let maxfit = fits
- .iter()
- .max_by(|a, b| a.partial_cmp(b).unwrap())
- .unwrap();
+ fits.sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap());
- dbg!(&fits, maxfit);
+ dbg!(&fits);
+
+ if cfg!(feature = "serde") {
+ let intermediate = NNTSerde::from(&fits[0].0.network);
+ let serialized = serde_json::to_string(&intermediate).unwrap();
+ println!("{}", serialized);
+ }
}
#[cfg(all(feature = "crossover", feature = "rayon"))]
@@ -161,12 +164,15 @@ fn main() {
sim.next_generation();
}
- let fits: Vec<_> = sim.genomes.iter().map(fitness).collect();
+ let mut fits: Vec<_> = sim.genomes.iter().map(|e| (e, fitness(e))).collect();
- let maxfit = fits
- .iter()
- .max_by(|a, b| a.partial_cmp(b).unwrap())
- .unwrap();
+ fits.sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap());
- dbg!(&fits, maxfit);
+ dbg!(&fits);
+
+ if cfg!(feature = "serde") {
+ let intermediate = NNTSerde::from(&fits[0].0.network);
+ let serialized = serde_json::to_string(&intermediate).unwrap();
+ println!("serialized: {}", serialized);
+ }
}