Skip to content

Commit

Permalink
implement rayon feature on runnable
Browse files Browse the repository at this point in the history
  • Loading branch information
HyperCodec committed Feb 8, 2024
1 parent 36e244b commit b28886a
Showing 1 changed file with 63 additions and 4 deletions.
67 changes: 63 additions & 4 deletions src/runnable.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use crate::topology::*;
use std::sync::{Arc, RwLock};

#[cfg(feature = "rayon")] use rayon::prelude::*;

pub struct NeuralNetwork {
input_layer: Vec<Arc<RwLock<Neuron>>>,
hidden_layers: Vec<Arc<RwLock<Neuron>>>,
output_layer: Vec<Arc<RwLock<Neuron>>>,
}

impl NeuralNetwork {
pub fn predict(&mut self, inputs: Vec<f32>) -> Vec<f32> {
#[cfg(not(feature = "rayon"))]
pub fn predict(&self, inputs: Vec<f32>) -> Vec<f32> {
if self.input_layer.len() != inputs.len() {
panic!("Invalid input layer specified. Expected {}, got {}", self.input_layer.len(), inputs.len());
}
Expand All @@ -23,6 +26,23 @@ impl NeuralNetwork {
.collect()
}

#[cfg(feature = "rayon")]
pub fn predict(&self, inputs: Vec<f32>) -> Vec<f32> {
if self.input_layer.len() != inputs.len() {
panic!("Invalid input layer specified. Expected {}, got {}", self.input_layer.len(), inputs.len());
}

inputs.par_iter().enumerate().for_each(|(i, v)| {
self.input_layer[i].write().unwrap().state.value = *v;
});

(0..self.output_layer.len())
.into_par_iter()
.map(|i| NeuronLocation::Output(i))
.map(|loc| self.process_neuron(loc))
.collect()
}

pub fn get_neuron(&self, loc: NeuronLocation) -> Arc<RwLock<Neuron>> {
match loc {
NeuronLocation::Input(i) => self.input_layer[i].clone(),
Expand All @@ -31,7 +51,8 @@ impl NeuralNetwork {
}
}

pub fn flush_state(&mut self) {
#[cfg(not(feature = "rayon"))]
pub fn flush_state(&self) {
for n in &self.input_layer {
n.write().unwrap().flush_state();
}
Expand All @@ -44,8 +65,18 @@ impl NeuralNetwork {
n.write().unwrap().flush_state();
}
}

#[cfg(feature = "rayon")]
pub fn flush_state(&self) {
self.input_layer.par_iter().for_each(|n| n.write().unwrap().flush_state());

self.hidden_layers.par_iter().for_each(|n| n.write().unwrap().flush_state());

self.output_layer.par_iter().for_each(|n| n.write().unwrap().flush_state());
}

pub fn process_neuron(&mut self, loc: NeuronLocation) -> f32 {
#[cfg(not(feature = "rayon"))]
pub fn process_neuron(&self, loc: NeuronLocation) -> f32 {
let n = self.get_neuron(loc);

{
Expand All @@ -56,15 +87,43 @@ impl NeuralNetwork {
}
}

let mut n = n.write().unwrap();
let mut n = n.try_write().unwrap();

for (l, w) in n.inputs.clone() {
n.state.value += self.process_neuron(l) * w;
}

n.write().unwrap().sigmoid();

n.state.value
}

#[cfg(feature = "rayon")]
pub fn process_neuron(&self, loc: NeuronLocation) -> f32 {
let n = self.get_neuron(loc);

{
let nr = n.read().unwrap();

if nr.state.processed {
return nr.state.value;
}
}

n.read().unwrap().inputs
.clone()
.into_par_iter()
.for_each(|(n2, w)| {
let processed = self.process_neuron(n2); // separate step so write lock doesnt block process_neuron on other threads
n.write().unwrap().state.value += processed * w
});

n.write().unwrap().sigmoid();

let nr = n.read().unwrap();
nr.state.value
}

}

impl From<&NeuralNetworkTopology> for NeuralNetwork {
Expand Down

0 comments on commit b28886a

Please sign in to comment.