diff --git a/src/runnable.rs b/src/runnable.rs index fc9209f..9bf69ea 100644 --- a/src/runnable.rs +++ b/src/runnable.rs @@ -1,6 +1,8 @@ use crate::topology::*; use std::sync::{Arc, RwLock}; +#[cfg(feature = "rayon")] use rayon::prelude::*; + pub struct NeuralNetwork { input_layer: Vec>>, hidden_layers: Vec>>, @@ -8,7 +10,8 @@ pub struct NeuralNetwork { } impl NeuralNetwork { - pub fn predict(&mut self, inputs: Vec) -> Vec { + #[cfg(not(feature = "rayon"))] + pub fn predict(&self, inputs: Vec) -> Vec { if self.input_layer.len() != inputs.len() { panic!("Invalid input layer specified. Expected {}, got {}", self.input_layer.len(), inputs.len()); } @@ -23,6 +26,23 @@ impl NeuralNetwork { .collect() } + #[cfg(feature = "rayon")] + pub fn predict(&self, inputs: Vec) -> Vec { + if self.input_layer.len() != inputs.len() { + panic!("Invalid input layer specified. Expected {}, got {}", self.input_layer.len(), inputs.len()); + } + + inputs.par_iter().enumerate().for_each(|(i, v)| { + self.input_layer[i].write().unwrap().state.value = *v; + }); + + (0..self.output_layer.len()) + .into_par_iter() + .map(|i| NeuronLocation::Output(i)) + .map(|loc| self.process_neuron(loc)) + .collect() + } + pub fn get_neuron(&self, loc: NeuronLocation) -> Arc> { match loc { NeuronLocation::Input(i) => self.input_layer[i].clone(), @@ -31,7 +51,8 @@ impl NeuralNetwork { } } - pub fn flush_state(&mut self) { + #[cfg(not(feature = "rayon"))] + pub fn flush_state(&self) { for n in &self.input_layer { n.write().unwrap().flush_state(); } @@ -44,8 +65,18 @@ impl NeuralNetwork { n.write().unwrap().flush_state(); } } + + #[cfg(feature = "rayon")] + pub fn flush_state(&self) { + self.input_layer.par_iter().for_each(|n| n.write().unwrap().flush_state()); + + self.hidden_layers.par_iter().for_each(|n| n.write().unwrap().flush_state()); + + self.output_layer.par_iter().for_each(|n| n.write().unwrap().flush_state()); + } - pub fn process_neuron(&mut self, loc: NeuronLocation) -> f32 { + #[cfg(not(feature = "rayon"))] + pub fn process_neuron(&self, loc: NeuronLocation) -> f32 { let n = self.get_neuron(loc); { @@ -56,15 +87,43 @@ impl NeuralNetwork { } } - let mut n = n.write().unwrap(); + let mut n = n.try_write().unwrap(); for (l, w) in n.inputs.clone() { n.state.value += self.process_neuron(l) * w; } + n.write().unwrap().sigmoid(); + n.state.value } + #[cfg(feature = "rayon")] + pub fn process_neuron(&self, loc: NeuronLocation) -> f32 { + let n = self.get_neuron(loc); + + { + let nr = n.read().unwrap(); + + if nr.state.processed { + return nr.state.value; + } + } + + n.read().unwrap().inputs + .clone() + .into_par_iter() + .for_each(|(n2, w)| { + let processed = self.process_neuron(n2); // separate step so write lock doesnt block process_neuron on other threads + n.write().unwrap().state.value += processed * w + }); + + n.write().unwrap().sigmoid(); + + let nr = n.read().unwrap(); + nr.state.value + } + } impl From<&NeuralNetworkTopology> for NeuralNetwork {