Skip to content

Commit

Permalink
cargo fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
HyperCodec committed Feb 20, 2024
1 parent 8dd5b68 commit 7a8396a
Showing 1 changed file with 36 additions and 18 deletions.
54 changes: 36 additions & 18 deletions src/topology.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::{
collections::HashSet, fmt, sync::{Arc, RwLock}
collections::HashSet,
fmt,
sync::{Arc, RwLock},
};

use genetic_rs::prelude::*;
Expand Down Expand Up @@ -95,17 +97,26 @@ impl<const I: usize, const O: usize> NeuralNetworkTopology<I, O> {
/// Creates a new connection between the neurons.
/// If the connection is cyclic, it does not add a connection and returns false.
/// Otherwise, it returns true.
pub fn add_connection(&mut self, from: NeuronLocation, to: NeuronLocation, weight: f32) -> bool {
pub fn add_connection(
&mut self,
from: NeuronLocation,
to: NeuronLocation,
weight: f32,
) -> bool {
if self.is_connection_cyclic(from, to) {
return false;
}

// Add the connection since it is not cyclic
self.get_neuron(to).write().unwrap().inputs.push((from, weight));

self.get_neuron(to)
.write()
.unwrap()
.inputs
.push((from, weight));

true
}

fn is_connection_cyclic(&self, from: NeuronLocation, to: NeuronLocation) -> bool {
if to.is_input() || from.is_output() {
return true;
Expand All @@ -114,24 +125,29 @@ impl<const I: usize, const O: usize> NeuralNetworkTopology<I, O> {
let mut visited = HashSet::new();
self.dfs(from, to, &mut visited)
}

// TODO rayon implementation
fn dfs(&self, current: NeuronLocation, target: NeuronLocation, visited: &mut HashSet<NeuronLocation>) -> bool {
fn dfs(
&self,
current: NeuronLocation,
target: NeuronLocation,
visited: &mut HashSet<NeuronLocation>,
) -> bool {
if current == target {
return true;
}

visited.insert(current);

let n = self.get_neuron(current);
let nr = n.read().unwrap();

for &(input, _) in &nr.inputs {
if !visited.contains(&input) && self.dfs(input, target, visited) {
return true;
}
}

visited.remove(&current);
false
}
Expand Down Expand Up @@ -168,14 +184,15 @@ impl<const I: usize, const O: usize> NeuralNetworkTopology<I, O> {
if !loc.is_hidden() {
panic!("Invalid neuron deletion");
}

let index = loc.unwrap();
let neuron = Arc::into_inner(self.hidden_layers.remove(index)).unwrap();

for n in &self.hidden_layers {
let mut nw = n.write().unwrap();

nw.inputs = nw.inputs
nw.inputs = nw
.inputs
.iter()
.filter_map(|&(input_loc, w)| {
if !input_loc.is_hidden() {
Expand All @@ -194,10 +211,11 @@ impl<const I: usize, const O: usize> NeuralNetworkTopology<I, O> {
})
.collect();
}

for n2 in &self.output_layer {
let mut nw = n2.write().unwrap();
nw.inputs = nw.inputs
nw.inputs = nw
.inputs
.iter()
.filter_map(|&(input_loc, w)| {
if !input_loc.is_hidden() {
Expand All @@ -216,7 +234,7 @@ impl<const I: usize, const O: usize> NeuralNetworkTopology<I, O> {
})
.collect();
}

neuron.into_inner().unwrap()
}
}
Expand Down Expand Up @@ -298,7 +316,7 @@ impl<const I: usize, const O: usize> RandomlyMutable for NeuralNetworkTopology<I
while !loc.is_hidden() {
(_, loc) = self.rand_neuron(rng);
}

// delete the neuron
self.delete_neuron(loc);
}
Expand Down

0 comments on commit 7a8396a

Please sign in to comment.