From 273b9faa5e4e6cbda2d0b877f56de984fe1c99d4 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Tue, 16 Apr 2024 15:23:07 +0000 Subject: [PATCH 01/24] rename CONTRIBUTING --- CONTRIBUTING => CONTRIBUTING.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename CONTRIBUTING => CONTRIBUTING.md (100%) diff --git a/CONTRIBUTING b/CONTRIBUTING.md similarity index 100% rename from CONTRIBUTING rename to CONTRIBUTING.md From 3f50d080e9d248506da5feac73949d9d656210ce Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Tue, 16 Apr 2024 15:23:57 +0000 Subject: [PATCH 02/24] add flush_state to readme example --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index ad775e2..1a97864 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,8 @@ fn fitness(dna: &MyAgentDNA) -> f32 { let above = n > 0.5; let res = agent.network.predict([n]); + agent.network.flush_state(); + let resi = res.iter().max_index(); if resi == 0 ^ above { From a94198a9bc2776a451bca27211d0c72291df5842 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Tue, 16 Apr 2024 16:14:42 +0000 Subject: [PATCH 03/24] create basic log test --- src/lib.rs | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index ee9f769..ac569b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,3 +23,67 @@ pub use topology::*; #[cfg(feature = "serde")] pub use nnt_serde::*; + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + + #[derive(RandomlyMutable, DivisionReproduction, Clone)] + struct AgentDNA { + network: NeuralNetworkTopology<2, 1>, + } + + impl Prunable for AgentDNA {} + + impl GenerateRandom for AgentDNA { + fn gen_random(rng: &mut impl Rng) -> Self { + Self { + network: NeuralNetworkTopology::new(0.01, 3, rng), + } + } + } + + #[test] + fn basic_test() { + let fitness = |g: &AgentDNA| { + let network = NeuralNetwork::from(&g.network); + let mut fitness = 0.; + let mut rng = rand::thread_rng(); + + for _ in 0..100 { + let n = rng.gen::() * 10000.; + let base = rng.gen::() * 10.; + let expected = n.log(base); + + let [answer] = network.predict([n, base]); + network.flush_state(); + + fitness += 5. / (answer - expected).abs(); + } + + fitness + }; + + let mut rng = rand::thread_rng(); + + let mut sim = GeneticSim::new( + Vec::gen_random(&mut rng, 100), + fitness, + division_pruning_nextgen, + ); + + for _ in 0..100 { + sim.next_generation(); + } + + let mut fits: Vec<_> = sim.genomes + .iter() + .map(fitness) + .collect(); + + fits.sort_by(|a, b| a.partial_cmp(&b).unwrap()); + + dbg!(fits); + } +} \ No newline at end of file From 339b90b3c7970a9b88b162db34551d190d947324 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Tue, 16 Apr 2024 18:05:13 +0000 Subject: [PATCH 04/24] create plotters example --- Cargo.lock | 744 ++++++++++++++++++++++++++++++++++++++++++++++- Cargo.toml | 1 + examples/plot.rs | 135 +++++++++ 3 files changed, 879 insertions(+), 1 deletion(-) create mode 100644 examples/plot.rs diff --git a/Cargo.lock b/Cargo.lock index be4d7b8..5c98cd1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,33 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "autocfg" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" + [[package]] name = "bincode" version = "1.3.3" @@ -11,18 +38,135 @@ dependencies = [ "serde", ] +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitflags" version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + +[[package]] +name = "bytemuck" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cc" +version = "1.0.94" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7" + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets", +] + +[[package]] +name = "color_quant" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" + +[[package]] +name = "const-cstr" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed3d0b5ff30645a68f35ece8cea4556ca14ef8a1651455f789a099a0513532a6" + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "core-graphics" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2581bbab3b8ffc6fcbd550bf46c355135d16e9ff2a6ea032ad6b9bf1d7efe4fb" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "core-graphics-types", + "foreign-types", + "libc", +] + +[[package]] +name = "core-graphics-types" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + +[[package]] +name = "core-text" +version = "19.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d74ada66e07c1cefa18f8abfba765b486f250de2e4a999e5727fc0dd4b4a25" +dependencies = [ + "core-foundation", + "core-graphics", + "foreign-types", + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossbeam-deque" version = "0.8.5" @@ -48,12 +192,140 @@ version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +[[package]] +name = "dirs-next" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" +dependencies = [ + "cfg-if", + "dirs-sys-next", +] + +[[package]] +name = "dirs-sys-next" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + +[[package]] +name = "dlib" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "330c60081dcc4c72131f8eb70510f1ac07223e5d4163db481a04a0befcffa412" +dependencies = [ + "libloading", +] + +[[package]] +name = "dwrote" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439a1c2ba5611ad3ed731280541d36d2e9c4ac5e7fb818a27b604bdc5a6aa65b" +dependencies = [ + "lazy_static", + "libc", + "winapi", + "wio", +] + [[package]] name = "either" version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "fdeflate" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645" +dependencies = [ + "simd-adler32", +] + +[[package]] +name = "flate2" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "float-ord" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bad48618fdb549078c333a7a8528acb57af271d0433bdecd523eb620628364e" + +[[package]] +name = "font-kit" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21fe28504d371085fae9ac7a3450f0b289ab71e07c8e57baa3fb68b9e57d6ce5" +dependencies = [ + "bitflags 1.3.2", + "byteorder", + "core-foundation", + "core-graphics", + "core-text", + "dirs-next", + "dwrote", + "float-ord", + "freetype", + "lazy_static", + "libc", + "log", + "pathfinder_geometry", + "pathfinder_simd", + "walkdir", + "winapi", + "yeslogic-fontconfig-sys", +] + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "freetype" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efc8599a3078adf8edeb86c71e9f8fa7d88af5ca31e806a867756081f90f5d83" +dependencies = [ + "freetype-sys", + "libc", +] + +[[package]] +name = "freetype-sys" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66ee28c39a43d89fbed8b4798fb4ba56722cfd2b5af81f9326c27614ba88ecd5" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "genetic-rs" version = "0.5.1" @@ -98,12 +370,74 @@ dependencies = [ "wasi", ] +[[package]] +name = "gif" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80792593675e051cf94a4b111980da2ba60d4a83e43e0048c5693baab3977045" +dependencies = [ + "color_quant", + "weezl", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "image" +version = "0.24.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5690139d2f55868e080017335e4b94cb7414274c74f1669c84fb5feba2c9f69d" +dependencies = [ + "bytemuck", + "byteorder", + "color_quant", + "jpeg-decoder", + "num-traits", + "png", +] + [[package]] name = "itoa" version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "jpeg-decoder" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" + +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -116,14 +450,51 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libloading" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" +dependencies = [ + "cfg-if", + "windows-targets", +] + +[[package]] +name = "libredox" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +dependencies = [ + "bitflags 2.5.0", + "libc", +] + +[[package]] +name = "log" +version = "0.4.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" + +[[package]] +name = "miniz_oxide" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" +dependencies = [ + "adler", + "simd-adler32", +] + [[package]] name = "neat" version = "0.5.1" dependencies = [ "bincode", - "bitflags", + "bitflags 2.5.0", "genetic-rs", "lazy_static", + "plotters", "rand", "rayon", "serde", @@ -131,6 +502,105 @@ dependencies = [ "serde_json", ] +[[package]] +name = "num-traits" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "pathfinder_geometry" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b7e7b4ea703700ce73ebf128e1450eb69c3a8329199ffbfb9b2a0418e5ad3" +dependencies = [ + "log", + "pathfinder_simd", +] + +[[package]] +name = "pathfinder_simd" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebf45976c56919841273f2a0fc684c28437e2f304e264557d9c72be5d5a718be" +dependencies = [ + "rustc_version", +] + +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + +[[package]] +name = "plotters" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +dependencies = [ + "chrono", + "font-kit", + "image", + "lazy_static", + "num-traits", + "pathfinder_geometry", + "plotters-backend", + "plotters-bitmap", + "plotters-svg", + "ttf-parser", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" + +[[package]] +name = "plotters-bitmap" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cebbe1f70205299abc69e8b295035bb52a6a70ee35474ad10011f0a4efb8543" +dependencies = [ + "gif", + "image", + "plotters-backend", +] + +[[package]] +name = "plotters-svg" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "png" +version = "0.17.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1" +dependencies = [ + "bitflags 1.3.2", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -205,18 +675,53 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_users" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + [[package]] name = "replace_with" version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3a8614ee435691de62bcffcf4a66d91b3594bf1428a5722e79103249a095690" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "ryu" version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "semver" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" + [[package]] name = "serde" version = "1.0.197" @@ -257,6 +762,12 @@ dependencies = [ "serde", ] +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + [[package]] name = "syn" version = "2.0.51" @@ -268,14 +779,245 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "thiserror" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "ttf-parser" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "375812fa44dab6df41c195cd2f7fecb488f6c09fbaafb62807488cefab642bff" + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + +[[package]] +name = "web-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "weezl" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" + +[[package]] +name = "wio" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d129932f4644ac2396cb456385cbf9e63b5b30c6e8dc4820bdca4eb082037a5" +dependencies = [ + "winapi", +] + +[[package]] +name = "yeslogic-fontconfig-sys" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2bbd69036d397ebbff671b1b8e4d918610c181c5a16073b96f984a38d08c386" +dependencies = [ + "const-cstr", + "dlib", + "once_cell", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 91247ad..f2f976e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,3 +37,4 @@ serde-big-array = { version = "0.5.1", optional = true } [dev-dependencies] bincode = "1.3.3" serde_json = "1.0.114" +plotters = "0.3.5" \ No newline at end of file diff --git a/examples/plot.rs b/examples/plot.rs new file mode 100644 index 0000000..48db937 --- /dev/null +++ b/examples/plot.rs @@ -0,0 +1,135 @@ +use std::{error::Error, sync::{Arc, Mutex}}; + +use neat::*; +use rand::prelude::*; +use plotters::prelude::*; + +#[derive(RandomlyMutable, DivisionReproduction, Clone)] +struct AgentDNA { + network: NeuralNetworkTopology<2, 1>, +} + +impl Prunable for AgentDNA {} + +impl GenerateRandom for AgentDNA { + fn gen_random(rng: &mut impl Rng) -> Self { + Self { + network: NeuralNetworkTopology::new(0.01, 3, rng), + } + } +} + +fn fitness(g: &AgentDNA) -> f32 { + let network = NeuralNetwork::from(&g.network); + let mut fitness = 0.; + let mut rng = rand::thread_rng(); + + for _ in 0..100 { + let n = rng.gen::() * 10000.; + let base = rng.gen::() * 10.; + let expected = n.log(base); + + let [answer] = network.predict([n, base]); + network.flush_state(); + + fitness += 5. / (answer - expected).abs(); + } + + fitness +} + +struct PlottingNG { + performance_stats: Arc>>, +} + +impl NextgenFn for PlottingNG { + fn next_gen(&self, fitness: Vec<(AgentDNA, f32)>) -> Vec { + let l = fitness.len(); + + let high = fitness[0].1; + + let median = fitness[l / 2].1; + + let low = fitness[l-1].1; + + let mut ps = self.performance_stats.lock().unwrap(); + ps.push(PerformanceStats { high, median, low }); + + division_pruning_nextgen(fitness) + } +} + +struct PerformanceStats { + high: f32, + median: f32, + low: f32, +} + +const OUTPUT_FILE_NAME: &'static str = "fitness-plot.png"; +const GENS: usize = 100; +fn main() -> Result<(), Box> { + let mut rng = rand::thread_rng(); + + let performance_stats = Arc::new(Mutex::new(Vec::with_capacity(GENS))); + let ng = PlottingNG { performance_stats: performance_stats.clone() }; + + let mut sim = GeneticSim::new( + Vec::gen_random(&mut rng, 100), + fitness, + ng, + ); + + println!("Training..."); + + for _ in 0..GENS { + sim.next_generation(); + } + + println!("Training complete, collecting data and building chart..."); + + let root = BitMapBackend::new(OUTPUT_FILE_NAME, (640, 480)).into_drawing_area(); + root.fill(&WHITE)?; + + let mut chart = ChartBuilder::on(&root) + .caption("agent fitness over gens", ("sans-serif", 50).into_font()) + .margin(5) + .x_label_area_size(30) + .y_label_area_size(30) + .build_cartesian_2d(0usize..100, 0f32..200.0)?; + + chart.configure_mesh().draw()?; + + let data: Vec<_> = Arc::into_inner(performance_stats).unwrap().into_inner().unwrap() + .into_iter() + .enumerate() + .collect(); + let highs = data + .iter() + .map(|(i, PerformanceStats { high, .. })| (*i, *high)); + + let medians = data + .iter() + .map(|(i, PerformanceStats { median, .. })| (*i, *median)); + + let lows = data + .iter() + .map(|(i, PerformanceStats { low, .. })| (*i, *low)); + + chart + .draw_series(LineSeries::new(highs, &GREEN))? + .label("high"); + + chart + .draw_series(LineSeries::new(medians, &YELLOW))? + .label("median"); + + chart + .draw_series(LineSeries::new(lows, &RED))? + .label("low"); + + root.present()?; + + println!("Complete"); + + Ok(()) +} \ No newline at end of file From 5cddae7b31b34c8a0f5d37398558f833f3e22560 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Tue, 16 Apr 2024 18:25:49 +0000 Subject: [PATCH 05/24] small changes --- examples/plot.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/plot.rs b/examples/plot.rs index 48db937..59e3a24 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -91,7 +91,7 @@ fn main() -> Result<(), Box> { root.fill(&WHITE)?; let mut chart = ChartBuilder::on(&root) - .caption("agent fitness over gens", ("sans-serif", 50).into_font()) + .caption("agent fitness values per generation", ("sans-serif", 50).into_font()) .margin(5) .x_label_area_size(30) .y_label_area_size(30) @@ -103,6 +103,7 @@ fn main() -> Result<(), Box> { .into_iter() .enumerate() .collect(); + let highs = data .iter() .map(|(i, PerformanceStats { high, .. })| (*i, *high)); From 728cbdeca4e6009399b1167bdd25a535da491d06 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Tue, 16 Apr 2024 18:29:56 +0000 Subject: [PATCH 06/24] more configuration --- examples/plot.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/plot.rs b/examples/plot.rs index 59e3a24..605ed69 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -103,7 +103,7 @@ fn main() -> Result<(), Box> { .into_iter() .enumerate() .collect(); - + let highs = data .iter() .map(|(i, PerformanceStats { high, .. })| (*i, *high)); @@ -128,6 +128,12 @@ fn main() -> Result<(), Box> { .draw_series(LineSeries::new(lows, &RED))? .label("low"); + chart + .configure_series_labels() + .background_style(&WHITE.mix(0.8)) + .border_style(&BLACK) + .draw()?; + root.present()?; println!("Complete"); From 91c3f9f1463096178cf011acc8e8aadd730f7f46 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 17 Apr 2024 11:31:35 +0000 Subject: [PATCH 07/24] make plotting ng more generic --- examples/plot.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/plot.rs b/examples/plot.rs index 605ed69..6cd555e 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -38,11 +38,12 @@ fn fitness(g: &AgentDNA) -> f32 { fitness } -struct PlottingNG { +struct PlottingNG> { performance_stats: Arc>>, + actual_ng: F, } -impl NextgenFn for PlottingNG { +impl> NextgenFn for PlottingNG { fn next_gen(&self, fitness: Vec<(AgentDNA, f32)>) -> Vec { let l = fitness.len(); @@ -55,7 +56,7 @@ impl NextgenFn for PlottingNG { let mut ps = self.performance_stats.lock().unwrap(); ps.push(PerformanceStats { high, median, low }); - division_pruning_nextgen(fitness) + self.actual_ng.next_gen(fitness) } } @@ -71,7 +72,7 @@ fn main() -> Result<(), Box> { let mut rng = rand::thread_rng(); let performance_stats = Arc::new(Mutex::new(Vec::with_capacity(GENS))); - let ng = PlottingNG { performance_stats: performance_stats.clone() }; + let ng = PlottingNG { performance_stats: performance_stats.clone(), actual_ng: division_pruning_nextgen }; let mut sim = GeneticSim::new( Vec::gen_random(&mut rng, 100), From f6d0df0493d2ec8b8cdc7ce8c978154470c449f5 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Thu, 18 Apr 2024 11:54:27 +0000 Subject: [PATCH 08/24] fix test rayon feature --- src/lib.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index ac569b0..0de19a1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,10 +65,16 @@ mod tests { fitness }; + #[cfg(not(feature = "rayon"))] let mut rng = rand::thread_rng(); let mut sim = GeneticSim::new( + #[cfg(not(feature = "rayon"))] Vec::gen_random(&mut rng, 100), + + #[cfg(feature = "rayon")] + Vec::gen_random(100), + fitness, division_pruning_nextgen, ); From cc88ebfc8497ec4583a4fbd43d5ad7e53ef8d9ba Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Thu, 18 Apr 2024 11:55:23 +0000 Subject: [PATCH 09/24] cargo fmt --- src/lib.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 0de19a1..0dd0b8c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,7 +30,7 @@ mod tests { use rand::prelude::*; #[derive(RandomlyMutable, DivisionReproduction, Clone)] - struct AgentDNA { + struct AgentDNA { network: NeuralNetworkTopology<2, 1>, } @@ -71,10 +71,8 @@ mod tests { let mut sim = GeneticSim::new( #[cfg(not(feature = "rayon"))] Vec::gen_random(&mut rng, 100), - #[cfg(feature = "rayon")] Vec::gen_random(100), - fitness, division_pruning_nextgen, ); @@ -83,13 +81,10 @@ mod tests { sim.next_generation(); } - let mut fits: Vec<_> = sim.genomes - .iter() - .map(fitness) - .collect(); + let mut fits: Vec<_> = sim.genomes.iter().map(fitness).collect(); fits.sort_by(|a, b| a.partial_cmp(&b).unwrap()); dbg!(fits); } -} \ No newline at end of file +} From b95084dd4d615c3a685eebca83df4c16e7133910 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:53:59 +0000 Subject: [PATCH 10/24] create custom activations example --- examples/custom_activation.rs | 92 +++++++++++++++++++++++++++++++++++ src/topology/activation.rs | 4 +- 2 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 examples/custom_activation.rs diff --git a/examples/custom_activation.rs b/examples/custom_activation.rs new file mode 100644 index 0000000..f52882b --- /dev/null +++ b/examples/custom_activation.rs @@ -0,0 +1,92 @@ +//! An example implementation of a custom activation function. + +use neat::*; +use rand::prelude::*; + +#[derive(DivisionReproduction, RandomlyMutable, Clone)] +struct AgentDNA { + network: NeuralNetworkTopology<2, 2>, +} + +impl Prunable for AgentDNA {} + +impl GenerateRandom for AgentDNA { + fn gen_random(rng: &mut impl Rng) -> Self { + Self { + network: NeuralNetworkTopology::new(0.01, 3, rng), + } + } +} + +fn fitness(g: &AgentDNA) -> f32 { + let network: NeuralNetwork<2, 2> = NeuralNetwork::from(&g.network); + let mut fitness = 0.; + let mut rng = rand::thread_rng(); + + for _ in 0..50 { + let n = rng.gen::(); + let n2 = rng.gen::(); + + let expected = if (n + n2) / 2. >= 0.5 { + 0 + } else { + 1 + }; + + let result = network.predict([n, n2]); + network.flush_state(); + + // partial_cmp chance of returning None in this smh + let result = result.iter().max_index(); + + if result == expected { + fitness += 1.; + } else { + fitness -= 1.; + } + } + + fitness +} + +#[cfg(feature = "serde")] +fn serde_nextgen(rewards: Vec<(AgentDNA, f32)>) -> Vec { + let max = rewards + .iter() + .max_by(|(_, ra), (_, rb)| ra.total_cmp(rb)) + .unwrap(); + + let ser = NNTSerde::from(&max.0.network); + let data = serde_json::to_string_pretty(&ser).unwrap(); + std::fs::write("best-agent.json", data).expect("Failed to write to file"); + + division_pruning_nextgen(rewards) +} + +fn main() { + let log_activation = activation_fn!(f32::log10); + register_activation(log_activation); + + #[cfg(not(feature = "rayon"))] + let mut rng = rand::thread_rng(); + + let mut sim = GeneticSim::new( + #[cfg(not(feature = "rayon"))] + Vec::gen_random(&mut rng, 100), + + #[cfg(feature = "rayon")] + Vec::gen_random(100), + + fitness, + + #[cfg(not(feature = "serde"))] + division_pruning_nextgen, + + #[cfg(feature = "serde")] + serde_nextgen, + ); + + for _ in 0..200 { + sim.next_generation(); + } +} \ No newline at end of file diff --git a/src/topology/activation.rs b/src/topology/activation.rs index a711851..5bf9540 100644 --- a/src/topology/activation.rs +++ b/src/topology/activation.rs @@ -15,11 +15,11 @@ use crate::NeuronLocation; #[macro_export] macro_rules! activation_fn { ($F: path) => { - ActivationFn::new(Arc::new($F), ActivationScope::default(), stringify!($F).into()) + ActivationFn::new(std::sync::Arc::new($F), ActivationScope::default(), stringify!($F).into()) }; ($F: path, $S: expr) => { - ActivationFn::new(Arc::new($F), $S, stringify!($F).into()) + ActivationFn::new(std::sync::Arc::new($F), $S, stringify!($F).into()) }; {$($F: path),*} => { From 35868795738bb443c31c09371dc4d76ee56fa6e5 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:01:37 +0000 Subject: [PATCH 11/24] fix opposite high and low --- examples/plot.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/plot.rs b/examples/plot.rs index 6cd555e..4fa4c51 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -47,11 +47,11 @@ impl> NextgenFn for PlottingNG { fn next_gen(&self, fitness: Vec<(AgentDNA, f32)>) -> Vec { let l = fitness.len(); - let high = fitness[0].1; + let high = fitness[l-1].1; let median = fitness[l / 2].1; - let low = fitness[l-1].1; + let low = fitness[0].1; let mut ps = self.performance_stats.lock().unwrap(); ps.push(PerformanceStats { high, median, low }); From 27e972af6f6fcaf885cbc96de7dc052405840897 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+HyperCodec@users.noreply.github.com> Date: Mon, 6 May 2024 10:19:37 -0400 Subject: [PATCH 12/24] Update Cargo.toml --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 91247ad..96767ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,9 +3,9 @@ name = "neat" description = "Crate for working with NEAT in rust" version = "0.5.1" edition = "2021" -authors = ["Inflectrix"] -repository = "https://github.com/inflectrix/neat" -homepage = "https://github.com/inflectrix/neat" +authors = ["HyperCodec"] +repository = "https://github.com/HyperCodec/neat" +homepage = "https://github.com/HyperCodec/neat" readme = "README.md" keywords = ["genetic", "machine-learning", "ai", "algorithm", "evolution"] categories = ["algorithms", "science", "simulation"] From 4b8cef0f7a2d226a6bab62d9e8fde1996978a018 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 11:42:36 +0000 Subject: [PATCH 13/24] use svgbackend (now it hangs for some reason) --- .gitignore | 3 ++- examples/plot.rs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 1b71596..a6e0cb6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target/ -/.vscode/ \ No newline at end of file +/.vscode/ +best-agent.json \ No newline at end of file diff --git a/examples/plot.rs b/examples/plot.rs index 4fa4c51..2be99c9 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -88,7 +88,7 @@ fn main() -> Result<(), Box> { println!("Training complete, collecting data and building chart..."); - let root = BitMapBackend::new(OUTPUT_FILE_NAME, (640, 480)).into_drawing_area(); + let root = SVGBackend::new(OUTPUT_FILE_NAME, (640, 480)).into_drawing_area(); root.fill(&WHITE)?; let mut chart = ChartBuilder::on(&root) From 6a7090ace3522817ed9979c21efa8fdb42d53112 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 13:36:47 +0000 Subject: [PATCH 14/24] fix arc::into_inner failure --- .gitignore | 3 ++- examples/plot.rs | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index a6e0cb6..b2d8069 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target/ /.vscode/ -best-agent.json \ No newline at end of file +best-agent.json +fitness-plot.svg \ No newline at end of file diff --git a/examples/plot.rs b/examples/plot.rs index 2be99c9..967b3d0 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -66,7 +66,7 @@ struct PerformanceStats { low: f32, } -const OUTPUT_FILE_NAME: &'static str = "fitness-plot.png"; +const OUTPUT_FILE_NAME: &'static str = "fitness-plot.svg"; const GENS: usize = 100; fn main() -> Result<(), Box> { let mut rng = rand::thread_rng(); @@ -86,6 +86,9 @@ fn main() -> Result<(), Box> { sim.next_generation(); } + // prevent `Arc::into_inner` from failing + drop(sim); + println!("Training complete, collecting data and building chart..."); let root = SVGBackend::new(OUTPUT_FILE_NAME, (640, 480)).into_drawing_area(); From 945ea4a7b1a350d75b9260f2d903ddb1566fc848 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 13:43:33 +0000 Subject: [PATCH 15/24] fix data retrieval --- examples/plot.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/plot.rs b/examples/plot.rs index 967b3d0..33c032a 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -44,7 +44,10 @@ struct PlottingNG> { } impl> NextgenFn for PlottingNG { - fn next_gen(&self, fitness: Vec<(AgentDNA, f32)>) -> Vec { + fn next_gen(&self, mut fitness: Vec<(AgentDNA, f32)>) -> Vec { + // it's a bit slower because of sorting twice but I don't want to rewrite the nextgen. + fitness.sort_by(|(_, fa), (_, fb)| fa.partial_cmp(fb).unwrap()); + let l = fitness.len(); let high = fitness[l-1].1; From 0717843bfd615421f20b352f165b8a758822bacc Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 13:48:58 +0000 Subject: [PATCH 16/24] make compatible with other features --- examples/plot.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/plot.rs b/examples/plot.rs index 33c032a..ab0585a 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -71,14 +71,21 @@ struct PerformanceStats { const OUTPUT_FILE_NAME: &'static str = "fitness-plot.svg"; const GENS: usize = 100; + fn main() -> Result<(), Box> { + #[cfg(not(feature = "rayon"))] let mut rng = rand::thread_rng(); let performance_stats = Arc::new(Mutex::new(Vec::with_capacity(GENS))); let ng = PlottingNG { performance_stats: performance_stats.clone(), actual_ng: division_pruning_nextgen }; let mut sim = GeneticSim::new( + #[cfg(not(feature = "rayon"))] Vec::gen_random(&mut rng, 100), + + #[cfg(feature = "rayon")] + Vec::gen_random(100), + fitness, ng, ); From 6a98fb0d928f8922e6c535f3a74312a89a277d43 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 13:49:53 +0000 Subject: [PATCH 17/24] cargo fmt --- examples/plot.rs | 38 +++++++++++++++++++++++--------------- src/lib.rs | 9 +++------ 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/examples/plot.rs b/examples/plot.rs index ab0585a..2b6a851 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -1,11 +1,14 @@ -use std::{error::Error, sync::{Arc, Mutex}}; +use std::{ + error::Error, + sync::{Arc, Mutex}, +}; use neat::*; -use rand::prelude::*; use plotters::prelude::*; +use rand::prelude::*; #[derive(RandomlyMutable, DivisionReproduction, Clone)] -struct AgentDNA { +struct AgentDNA { network: NeuralNetworkTopology<2, 1>, } @@ -50,7 +53,7 @@ impl> NextgenFn for PlottingNG { let l = fitness.len(); - let high = fitness[l-1].1; + let high = fitness[l - 1].1; let median = fitness[l / 2].1; @@ -77,21 +80,22 @@ fn main() -> Result<(), Box> { let mut rng = rand::thread_rng(); let performance_stats = Arc::new(Mutex::new(Vec::with_capacity(GENS))); - let ng = PlottingNG { performance_stats: performance_stats.clone(), actual_ng: division_pruning_nextgen }; + let ng = PlottingNG { + performance_stats: performance_stats.clone(), + actual_ng: division_pruning_nextgen, + }; let mut sim = GeneticSim::new( #[cfg(not(feature = "rayon"))] Vec::gen_random(&mut rng, 100), - #[cfg(feature = "rayon")] Vec::gen_random(100), - fitness, ng, ); println!("Training..."); - + for _ in 0..GENS { sim.next_generation(); } @@ -105,7 +109,10 @@ fn main() -> Result<(), Box> { root.fill(&WHITE)?; let mut chart = ChartBuilder::on(&root) - .caption("agent fitness values per generation", ("sans-serif", 50).into_font()) + .caption( + "agent fitness values per generation", + ("sans-serif", 50).into_font(), + ) .margin(5) .x_label_area_size(30) .y_label_area_size(30) @@ -113,7 +120,10 @@ fn main() -> Result<(), Box> { chart.configure_mesh().draw()?; - let data: Vec<_> = Arc::into_inner(performance_stats).unwrap().into_inner().unwrap() + let data: Vec<_> = Arc::into_inner(performance_stats) + .unwrap() + .into_inner() + .unwrap() .into_iter() .enumerate() .collect(); @@ -138,9 +148,7 @@ fn main() -> Result<(), Box> { .draw_series(LineSeries::new(medians, &YELLOW))? .label("median"); - chart - .draw_series(LineSeries::new(lows, &RED))? - .label("low"); + chart.draw_series(LineSeries::new(lows, &RED))?.label("low"); chart .configure_series_labels() @@ -151,6 +159,6 @@ fn main() -> Result<(), Box> { root.present()?; println!("Complete"); - + Ok(()) -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index ac569b0..98429d6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,7 +30,7 @@ mod tests { use rand::prelude::*; #[derive(RandomlyMutable, DivisionReproduction, Clone)] - struct AgentDNA { + struct AgentDNA { network: NeuralNetworkTopology<2, 1>, } @@ -77,13 +77,10 @@ mod tests { sim.next_generation(); } - let mut fits: Vec<_> = sim.genomes - .iter() - .map(fitness) - .collect(); + let mut fits: Vec<_> = sim.genomes.iter().map(fitness).collect(); fits.sort_by(|a, b| a.partial_cmp(&b).unwrap()); dbg!(fits); } -} \ No newline at end of file +} From 44b7fdbc37992f766d322f22c14e7133eab9a481 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 14:24:25 +0000 Subject: [PATCH 18/24] create progress bar for plotting example --- Cargo.lock | 69 ++++++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 4 ++- examples/plot.rs | 14 ++++++++-- 3 files changed, 84 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5c98cd1..700cf2f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,6 +100,19 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys", +] + [[package]] name = "const-cstr" version = "0.3.0" @@ -240,6 +253,12 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "fdeflate" version = "0.3.4" @@ -417,6 +436,28 @@ dependencies = [ "png", ] +[[package]] +name = "indicatif" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "unicode-width", +] + +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", +] + [[package]] name = "itoa" version = "1.0.10" @@ -493,6 +534,7 @@ dependencies = [ "bincode", "bitflags 2.5.0", "genetic-rs", + "indicatif", "lazy_static", "plotters", "rand", @@ -511,6 +553,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "once_cell" version = "1.19.0" @@ -601,6 +649,12 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -811,6 +865,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-width" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" + [[package]] name = "walkdir" version = "2.5.0" @@ -937,6 +997,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-targets" version = "0.52.5" diff --git a/Cargo.toml b/Cargo.toml index 8ccd8ca..8305fe4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ serde = ["dep:serde", "dep:serde-big-array"] [dependencies] bitflags = "2.5.0" genetic-rs = { version = "0.5.1", features = ["derive"] } + lazy_static = "1.4.0" rand = "0.8.5" rayon = { version = "1.8.1", optional = true } @@ -37,4 +38,5 @@ serde-big-array = { version = "0.5.1", optional = true } [dev-dependencies] bincode = "1.3.3" serde_json = "1.0.114" -plotters = "0.3.5" \ No newline at end of file +plotters = "0.3.5" +indicatif = "0.17.8" diff --git a/examples/plot.rs b/examples/plot.rs index 2b6a851..af48b01 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -6,6 +6,7 @@ use std::{ use neat::*; use plotters::prelude::*; use rand::prelude::*; +use indicatif::{ProgressBar, ProgressStyle}; #[derive(RandomlyMutable, DivisionReproduction, Clone)] struct AgentDNA { @@ -73,7 +74,7 @@ struct PerformanceStats { } const OUTPUT_FILE_NAME: &'static str = "fitness-plot.svg"; -const GENS: usize = 100; +const GENS: usize = 1000; fn main() -> Result<(), Box> { #[cfg(not(feature = "rayon"))] @@ -94,12 +95,21 @@ fn main() -> Result<(), Box> { ng, ); + let pb = ProgressBar::new(GENS as u64) + .with_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}") + .unwrap()) + .with_message("gen"); + println!("Training..."); for _ in 0..GENS { sim.next_generation(); + + pb.inc(1); } + pb.finish(); + // prevent `Arc::into_inner` from failing drop(sim); @@ -116,7 +126,7 @@ fn main() -> Result<(), Box> { .margin(5) .x_label_area_size(30) .y_label_area_size(30) - .build_cartesian_2d(0usize..100, 0f32..200.0)?; + .build_cartesian_2d(0usize..GENS, 0f32..1000.0)?; chart.configure_mesh().draw()?; From 6d17ec6bf1682f549ec8533aaa01f97ca0956dfe Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 14:27:12 +0000 Subject: [PATCH 19/24] create progress bar for basic example --- examples/basic.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/basic.rs b/examples/basic.rs index 9ad0419..2aa640b 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -2,6 +2,7 @@ use neat::*; use rand::prelude::*; +use indicatif::{ProgressBar, ProgressStyle}; #[derive(PartialEq, Clone, Debug, DivisionReproduction, RandomlyMutable)] #[cfg_attr(feature = "crossover", derive(CrossoverReproduction))] @@ -103,10 +104,19 @@ fn main() { crossover_pruning_nextgen, ); - for _ in 0..100 { + const GENS: u64 = 1000; + let pb = ProgressBar::new(GENS) + .with_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}") + .unwrap()) + .with_message("gen"); + + for _ in 0..GENS { sim.next_generation(); + pb.inc(1); } + pb.finish(); + #[cfg(not(feature = "serde"))] let mut fits: Vec<_> = sim.genomes.iter().map(fitness).collect(); From d3a9c409f51c11e7069b414e85069a1d7d7d7a76 Mon Sep 17 00:00:00 2001 From: Tristan Murphy <72839119+inflectrix@users.noreply.github.com> Date: Wed, 15 May 2024 14:28:23 +0000 Subject: [PATCH 20/24] cargo fmt --- examples/basic.rs | 10 +++++++--- examples/plot.rs | 10 +++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/examples/basic.rs b/examples/basic.rs index 2aa640b..9bbb346 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,8 +1,8 @@ //! A basic example of NEAT with this crate. Enable the `crossover` feature for it to use crossover reproduction +use indicatif::{ProgressBar, ProgressStyle}; use neat::*; use rand::prelude::*; -use indicatif::{ProgressBar, ProgressStyle}; #[derive(PartialEq, Clone, Debug, DivisionReproduction, RandomlyMutable)] #[cfg_attr(feature = "crossover", derive(CrossoverReproduction))] @@ -106,8 +106,12 @@ fn main() { const GENS: u64 = 1000; let pb = ProgressBar::new(GENS) - .with_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}") - .unwrap()) + .with_style( + ProgressStyle::with_template( + "[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}", + ) + .unwrap(), + ) .with_message("gen"); for _ in 0..GENS { diff --git a/examples/plot.rs b/examples/plot.rs index af48b01..34fb391 100644 --- a/examples/plot.rs +++ b/examples/plot.rs @@ -3,10 +3,10 @@ use std::{ sync::{Arc, Mutex}, }; +use indicatif::{ProgressBar, ProgressStyle}; use neat::*; use plotters::prelude::*; use rand::prelude::*; -use indicatif::{ProgressBar, ProgressStyle}; #[derive(RandomlyMutable, DivisionReproduction, Clone)] struct AgentDNA { @@ -96,8 +96,12 @@ fn main() -> Result<(), Box> { ); let pb = ProgressBar::new(GENS as u64) - .with_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}") - .unwrap()) + .with_style( + ProgressStyle::with_template( + "[{elapsed_precise}] {bar:40.cyan/blue} | {msg} {pos}/{len}", + ) + .unwrap(), + ) .with_message("gen"); println!("Training..."); From e45908cacd6126997f9d26a4a1dffa99c9f25244 Mon Sep 17 00:00:00 2001 From: Tristan Murphy Date: Fri, 31 May 2024 21:40:06 -0400 Subject: [PATCH 21/24] add logic to prevent duplicated input neurons --- src/topology/mod.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/topology/mod.rs b/src/topology/mod.rs index 02ad296..dd246f2 100644 --- a/src/topology/mod.rs +++ b/src/topology/mod.rs @@ -121,6 +121,18 @@ impl NeuralNetworkTopology { return true; } + // check to make sure it isn't duplicate + { + let n = self.get_neuron(to); + let n2 = n.read().unwrap(); + + for (loc, _) in &n2.inputs { + if from == *loc { + return false; + } + } + } + let mut visited = HashSet::new(); self.dfs(from, to, &mut visited) } From 7c31f30f88bcd0221e628daa8b0d0530f5c46523 Mon Sep 17 00:00:00 2001 From: Tristan Murphy Date: Thu, 13 Jun 2024 10:54:22 -0400 Subject: [PATCH 22/24] cargo fmt --- examples/custom_activation.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/examples/custom_activation.rs b/examples/custom_activation.rs index f52882b..bc6aae2 100644 --- a/examples/custom_activation.rs +++ b/examples/custom_activation.rs @@ -27,11 +27,7 @@ fn fitness(g: &AgentDNA) -> f32 { let n = rng.gen::(); let n2 = rng.gen::(); - let expected = if (n + n2) / 2. >= 0.5 { - 0 - } else { - 1 - }; + let expected = if (n + n2) / 2. >= 0.5 { 0 } else { 1 }; let result = network.predict([n, n2]); network.flush_state(); @@ -73,15 +69,11 @@ fn main() { let mut sim = GeneticSim::new( #[cfg(not(feature = "rayon"))] Vec::gen_random(&mut rng, 100), - #[cfg(feature = "rayon")] Vec::gen_random(100), - fitness, - #[cfg(not(feature = "serde"))] division_pruning_nextgen, - #[cfg(feature = "serde")] serde_nextgen, ); @@ -89,4 +81,4 @@ fn main() { for _ in 0..200 { sim.next_generation(); } -} \ No newline at end of file +} From a32bfff0375835cc9f49ad2d4e42513a8cb55a5c Mon Sep 17 00:00:00 2001 From: Tristan Murphy Date: Thu, 13 Jun 2024 11:09:52 -0400 Subject: [PATCH 23/24] change activation function to one that doesn't return NaN --- examples/custom_activation.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/custom_activation.rs b/examples/custom_activation.rs index bc6aae2..7b37c02 100644 --- a/examples/custom_activation.rs +++ b/examples/custom_activation.rs @@ -60,8 +60,8 @@ fn serde_nextgen(rewards: Vec<(AgentDNA, f32)>) -> Vec { } fn main() { - let log_activation = activation_fn!(f32::log10); - register_activation(log_activation); + let sin_activation = activation_fn!(f32::sin); + register_activation(sin_activation); #[cfg(not(feature = "rayon"))] let mut rng = rand::thread_rng(); From 331feb2bd3d3eeea0b3be11ea9437cce1ca763b5 Mon Sep 17 00:00:00 2001 From: Tristan Murphy Date: Mon, 8 Jul 2024 10:24:07 -0400 Subject: [PATCH 24/24] begin implementing a new cache system --- src/runnable.rs | 161 +++++++++--------------------------------------- 1 file changed, 28 insertions(+), 133 deletions(-) diff --git a/src/runnable.rs b/src/runnable.rs index 5b28f54..96a020f 100644 --- a/src/runnable.rs +++ b/src/runnable.rs @@ -1,7 +1,7 @@ use crate::topology::*; #[cfg(not(feature = "rayon"))] -use std::{cell::RefCell, rc::Rc}; +use std::{cell::RefCell, collections::HashMap, rc::Rc}; #[cfg(feature = "rayon")] use rayon::prelude::*; @@ -11,35 +11,23 @@ use std::sync::{Arc, RwLock}; /// A runnable, stated Neural Network generated from a [NeuralNetworkTopology]. Use [`NeuralNetwork::from`] to go from stateles to runnable. /// Because this has state, you need to run [`NeuralNetwork::flush_state`] between [`NeuralNetwork::predict`] calls. #[derive(Debug)] -#[cfg(not(feature = "rayon"))] -pub struct NeuralNetwork { - input_layer: [Rc>; I], - hidden_layers: Vec>>, - output_layer: [Rc>; O], -} - -/// Parallelized version of the [`NeuralNetwork`] struct. -#[derive(Debug)] -#[cfg(feature = "rayon")] -pub struct NeuralNetwork { - input_layer: [Arc>; I], - hidden_layers: Vec>>, - output_layer: [Arc>; O], +pub struct NeuralNetwork<'a, const I: usize, const O: usize> { + topology: &'a NeuralNetworkTopology, } -impl NeuralNetwork { +impl NeuralNetwork<'_, I, O> { /// Predicts an output for the given inputs. #[cfg(not(feature = "rayon"))] pub fn predict(&self, inputs: [f32; I]) -> [f32; O] { + let mut state_cache = HashMap::new(); + for (i, v) in inputs.iter().enumerate() { - let mut nw = self.input_layer[i].borrow_mut(); - nw.state.value = *v; - nw.state.processed = true; + state_cache.insert(NeuronLocation::Input(i), *v); } (0..O) .map(NeuronLocation::Output) - .map(|loc| self.process_neuron(loc)) + .map(|loc| self.process_neuron(loc, &mut state_cache)) .collect::>() .try_into() .unwrap() @@ -65,26 +53,22 @@ impl NeuralNetwork { } #[cfg(not(feature = "rayon"))] - fn process_neuron(&self, loc: NeuronLocation) -> f32 { - let n = self.get_neuron(loc); - - { - let nr = n.borrow(); - - if nr.state.processed { - return nr.state.value; - } + fn process_neuron(&self, loc: NeuronLocation, cache: &mut HashMap) -> f32 { + if let Some(v) = cache.get(&loc) { + return *v; } - let mut n = n.borrow_mut(); - - for (l, w) in n.inputs.clone() { - n.state.value += self.process_neuron(l) * w; + let n = self.get_neuron(loc).unwrap(); + let mut v = 0.; + for (l, w) in &n.inputs { + v += self.process_neuron(*l, &mut cache) * w; } - n.activate(); + v = n.activate(v); + + cache.insert(loc, v); - n.state.value + v } #[cfg(feature = "rayon")] @@ -118,11 +102,11 @@ impl NeuralNetwork { } #[cfg(not(feature = "rayon"))] - fn get_neuron(&self, loc: NeuronLocation) -> Rc> { + fn get_neuron<'a>(&self, loc: NeuronLocation) -> Option<&'a NeuronTopology> { match loc { - NeuronLocation::Input(i) => self.input_layer[i].clone(), - NeuronLocation::Hidden(i) => self.hidden_layers[i].clone(), - NeuronLocation::Output(i) => self.output_layer[i].clone(), + NeuronLocation::Input(i) => self.topology.input_layer.get(i), + NeuronLocation::Hidden(i) => self.topology.hidden_layers.get(i), + NeuronLocation::Output(i) => self.topology.output_layer.get(i), } } @@ -135,22 +119,6 @@ impl NeuralNetwork { } } - /// Flushes the network's state after a [prediction][NeuralNetwork::predict]. - #[cfg(not(feature = "rayon"))] - pub fn flush_state(&self) { - for n in &self.input_layer { - n.borrow_mut().flush_state(); - } - - for n in &self.hidden_layers { - n.borrow_mut().flush_state(); - } - - for n in &self.output_layer { - n.borrow_mut().flush_state(); - } - } - /// Flushes the neural network's state. #[cfg(feature = "rayon")] pub fn flush_state(&self) { @@ -168,36 +136,12 @@ impl NeuralNetwork { } } -impl From<&NeuralNetworkTopology> for NeuralNetwork { +impl<'a, const I: usize, const O: usize> From<&'a NeuralNetworkTopology> + for NeuralNetwork<'a, I, O> +{ #[cfg(not(feature = "rayon"))] - fn from(value: &NeuralNetworkTopology) -> Self { - let input_layer = value - .input_layer - .iter() - .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone())))) - .collect::>() - .try_into() - .unwrap(); - - let hidden_layers = value - .hidden_layers - .iter() - .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone())))) - .collect(); - - let output_layer = value - .output_layer - .iter() - .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone())))) - .collect::>() - .try_into() - .unwrap(); - - Self { - input_layer, - hidden_layers, - output_layer, - } + fn from(topology: &'a NeuralNetworkTopology) -> Self { + Self { topology } } #[cfg(feature = "rayon")] @@ -232,55 +176,6 @@ impl From<&NeuralNetworkTopology> for Neur } } -/// A state-filled neuron. -#[derive(Clone, Debug)] -pub struct Neuron { - inputs: Vec<(NeuronLocation, f32)>, - bias: f32, - - /// The current state of the neuron. - pub state: NeuronState, - - /// The neuron's activation function - pub activation: ActivationFn, -} - -impl Neuron { - /// Flushes a neuron's state. Called by [`NeuralNetwork::flush_state`] - pub fn flush_state(&mut self) { - self.state.value = self.bias; - } - - /// Applies the activation function to the neuron - pub fn activate(&mut self) { - self.state.value = self.activation.func.activate(self.state.value); - } -} - -impl From<&NeuronTopology> for Neuron { - fn from(value: &NeuronTopology) -> Self { - Self { - inputs: value.inputs.clone(), - bias: value.bias, - state: NeuronState { - value: value.bias, - ..Default::default() - }, - activation: value.activation.clone(), - } - } -} - -/// A state used in [`Neuron`]s for cache. -#[derive(Clone, Debug, Default)] -pub struct NeuronState { - /// The current value of the neuron. Initialized to a neuron's bias when flushed. - pub value: f32, - - /// Whether or not [`value`][NeuronState::value] has finished processing. - pub processed: bool, -} - /// A blanket trait for iterators meant to help with interpreting the output of a [`NeuralNetwork`] #[cfg(feature = "max-index")] pub trait MaxIndex {