Skip to content

Commit

Permalink
utilize activations registry in neural network topology
Browse files Browse the repository at this point in the history
  • Loading branch information
HyperCodec committed Apr 16, 2024
1 parent 2178dc3 commit 50a7947
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
12 changes: 12 additions & 0 deletions src/topology/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use std::{collections::HashMap, fmt, sync::{Arc, RwLock}};
use lazy_static::lazy_static;
use bitflags::bitflags;

use crate::NeuronLocation;

/// Creates an [`ActivationFn`] object from a function
#[macro_export]
macro_rules! activation_fn {
Expand Down Expand Up @@ -119,6 +121,16 @@ impl Default for ActivationScope {
}
}

impl From<&NeuronLocation> for ActivationScope {
fn from(value: &NeuronLocation) -> Self {
match value {
NeuronLocation::Input(_) => Self::INPUT,
NeuronLocation::Hidden(_) => Self::HIDDEN,
NeuronLocation::Output(_) => Self::OUTPUT,
}
}
}

/// A trait that represents an activation method.
pub trait Activation {
/// The activation function.
Expand Down
20 changes: 7 additions & 13 deletions src/topology/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ impl<const I: usize, const O: usize> RandomlyMutable for NeuralNetworkTopology<I

let loc3 = NeuronLocation::Hidden(self.hidden_layers.len());

let n3 = NeuronTopology::new(vec![loc], rng);
let n3 = NeuronTopology::new(vec![loc], ActivationScope::HIDDEN, rng);

self.hidden_layers.push(Arc::new(RwLock::new(n3)));

Expand Down Expand Up @@ -344,11 +344,8 @@ impl<const I: usize, const O: usize> RandomlyMutable for NeuralNetworkTopology<I

if rng.gen::<f32>() <= rate && !self.hidden_layers.is_empty() {
// mutate activation function
let activations = activation_fn! {
sigmoid,
relu,
f32::tanh
};
let reg = ACTIVATION_REGISTRY.read().unwrap();
let activations = reg.activations_in_scope(ActivationScope::HIDDEN);

let (mut n, mut loc) = self.rand_neuron(rng);

Expand Down Expand Up @@ -543,12 +540,9 @@ pub struct NeuronTopology {

impl NeuronTopology {
/// Creates a new neuron with the given input locations.
pub fn new(inputs: Vec<NeuronLocation>, rng: &mut impl Rng) -> Self {
let activations = activation_fn! {
sigmoid,
relu,
f32::tanh
};
pub fn new(inputs: Vec<NeuronLocation>, current_scope: ActivationScope, rng: &mut impl Rng) -> Self {
let reg = ACTIVATION_REGISTRY.read().unwrap();
let activations = reg.activations_in_scope(current_scope);

Self::new_with_activations(inputs, activations, rng)
}
Expand Down Expand Up @@ -625,4 +619,4 @@ impl NeuronLocation {
Self::Output(i) => *i,
}
}
}
}

0 comments on commit 50a7947

Please sign in to comment.