Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add activation fn registry system #34

Merged
merged 14 commits into from
Apr 16, 2024
Merged
14 changes: 14 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -26,12 +26,14 @@ serde = ["dep:serde", "dep:serde-big-array"]


[dependencies]
bitflags = "2.5.0"
genetic-rs = { version = "0.5.1", features = ["derive"] }
lazy_static = "1.4.0"
rand = "0.8.5"
rayon = { version = "1.8.1", optional = true }
serde = { version = "1.0.197", features = ["derive"], optional = true }
serde-big-array = { version = "0.5.1", optional = true }

[dev-dependencies]
bincode = "1.3.3"
serde_json = "1.0.114"
serde_json = "1.0.114"
2 changes: 1 addition & 1 deletion src/runnable.rs
Original file line number Diff line number Diff line change
@@ -253,7 +253,7 @@ impl Neuron {

/// Applies the activation function to the neuron
pub fn activate(&mut self) {
self.state.value = (self.activation.func)(self.state.value);
self.state.value = self.activation.func.activate(self.state.value);
}
}

222 changes: 222 additions & 0 deletions src/topology/activation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};

use bitflags::bitflags;
use lazy_static::lazy_static;
use std::{
collections::HashMap,
fmt,
sync::{Arc, RwLock},
};

use crate::NeuronLocation;

/// Creates an [`ActivationFn`] object from a function
#[macro_export]
macro_rules! activation_fn {
($F: path) => {
ActivationFn::new(Arc::new($F), ActivationScope::default(), stringify!($F).into())
};

($F: path, $S: expr) => {
ActivationFn::new(Arc::new($F), $S, stringify!($F).into())
};

{$($F: path),*} => {
[$(activation_fn!($F)),*]
};

{$($F: path => $S: expr),*} => {
[$(activation_fn!($F, $S)),*]
}
}

lazy_static! {
/// A static activation registry for use in deserialization.
pub(crate) static ref ACTIVATION_REGISTRY: Arc<RwLock<ActivationRegistry>> = Arc::new(RwLock::new(ActivationRegistry::default()));
}

/// Register an activation function to the registry.
pub fn register_activation(act: ActivationFn) {
let mut reg = ACTIVATION_REGISTRY.write().unwrap();
reg.register(act);
}

/// Registers multiple activation functions to the registry at once.
pub fn batch_register_activation(acts: impl IntoIterator<Item = ActivationFn>) {
let mut reg = ACTIVATION_REGISTRY.write().unwrap();
reg.batch_register(acts);
}

/// A registry of the different possible activation functions.
pub struct ActivationRegistry {
/// The currently-registered activation functions.
pub fns: HashMap<String, ActivationFn>,
}

impl ActivationRegistry {
/// Registers an activation function.
pub fn register(&mut self, activation: ActivationFn) {
self.fns.insert(activation.name.clone(), activation);
}

/// Registers multiple activation functions at once.
pub fn batch_register(&mut self, activations: impl IntoIterator<Item = ActivationFn>) {
for act in activations {
self.register(act);
}
}

/// Gets a Vec of all the
pub fn activations(&self) -> Vec<ActivationFn> {
self.fns.values().cloned().collect()
}

/// Gets all activation functions that are valid for a scope.
pub fn activations_in_scope(&self, scope: ActivationScope) -> Vec<ActivationFn> {
let acts = self.activations();

acts.into_iter()
.filter(|a| !scope.contains(ActivationScope::NONE) && scope.contains(a.scope))
.collect()
}
}

impl Default for ActivationRegistry {
fn default() -> Self {
let mut s = Self {
fns: HashMap::new(),
};

s.batch_register(activation_fn! {
sigmoid => ActivationScope::HIDDEN | ActivationScope::OUTPUT,
relu => ActivationScope::HIDDEN | ActivationScope::OUTPUT,
linear_activation => ActivationScope::INPUT | ActivationScope::HIDDEN | ActivationScope::OUTPUT,
f32::tanh => ActivationScope::HIDDEN | ActivationScope::OUTPUT
});

s
}
}

bitflags! {
/// Specifies where an activation function can occur
#[derive(Copy, Clone)]
pub struct ActivationScope: u8 {
/// Whether the activation can be applied to the input layer.
const INPUT = 0b001;

/// Whether the activation can be applied to the hidden layer.
const HIDDEN = 0b010;

/// Whether the activation can be applied to the output layer.
const OUTPUT = 0b100;

/// If this flag is true, it ignores all the rest and does not make the function naturally occur.
const NONE = 0b1000;
}
}

impl Default for ActivationScope {
fn default() -> Self {
Self::HIDDEN
}
}

impl From<&NeuronLocation> for ActivationScope {
fn from(value: &NeuronLocation) -> Self {
match value {
NeuronLocation::Input(_) => Self::INPUT,
NeuronLocation::Hidden(_) => Self::HIDDEN,
NeuronLocation::Output(_) => Self::OUTPUT,
}
}
}

/// A trait that represents an activation method.
pub trait Activation {
/// The activation function.
fn activate(&self, n: f32) -> f32;
}

impl<F: Fn(f32) -> f32> Activation for F {
fn activate(&self, n: f32) -> f32 {
(self)(n)
}
}

/// An activation function object that implements [`fmt::Debug`] and is [`Send`]
#[derive(Clone)]
pub struct ActivationFn {
/// The actual activation function.
pub func: Arc<dyn Activation + Send + Sync>,

/// The scope defining where the activation function can appear.
pub scope: ActivationScope,
pub(crate) name: String,
}

impl ActivationFn {
/// Creates a new ActivationFn object.
pub fn new(
func: Arc<dyn Activation + Send + Sync>,
scope: ActivationScope,
name: String,
) -> Self {
Self { func, name, scope }
}
}

impl fmt::Debug for ActivationFn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "{}", self.name)
}
}

impl PartialEq for ActivationFn {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}

#[cfg(feature = "serde")]
impl Serialize for ActivationFn {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.name)
}
}

#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for ActivationFn {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'a>,
{
let name = String::deserialize(deserializer)?;

let reg = ACTIVATION_REGISTRY.read().unwrap();

let f = reg.fns.get(&name);

if f.is_none() {
panic!("Activation function {name} not found");
}

Ok(f.unwrap().clone())
}
}

/// The sigmoid activation function.
pub fn sigmoid(n: f32) -> f32 {
1. / (1. + std::f32::consts::E.powf(-n))
}

/// The ReLU activation function.
pub fn relu(n: f32) -> f32 {
n.max(0.)
}

/// Activation function that does nothing.
pub fn linear_activation(n: f32) -> f32 {
n
}
196 changes: 23 additions & 173 deletions src/topology.rs → src/topology/mod.rs
Original file line number Diff line number Diff line change
@@ -1,106 +1,25 @@
/// Contains useful structs for serializing/deserializing a [`NeuronTopology`]
#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
#[cfg(feature = "serde")]
pub mod nnt_serde;

/// Contains structs and traits used for activation functions.
pub mod activation;

pub use activation::*;

use std::{
collections::HashSet,
fmt,
sync::{Arc, RwLock},
};

use genetic_rs::prelude::*;
use rand::prelude::*;

#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};

/// Contains useful structs for serializing/deserializing a [`NeuronTopology`]
#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
#[cfg(feature = "serde")]
pub mod nnt_serde {
use super::*;
use serde::{Deserialize, Serialize};
use serde_big_array::BigArray;

/// A serializable wrapper for [`NeuronTopology`]. See [`NNTSerde::from`] for conversion.
#[derive(Serialize, Deserialize)]
pub struct NNTSerde<const I: usize, const O: usize> {
#[serde(with = "BigArray")]
pub(crate) input_layer: [NeuronTopology; I],

pub(crate) hidden_layers: Vec<NeuronTopology>,

#[serde(with = "BigArray")]
pub(crate) output_layer: [NeuronTopology; O],

pub(crate) mutation_rate: f32,
pub(crate) mutation_passes: usize,
}

impl<const I: usize, const O: usize> From<&NeuralNetworkTopology<I, O>> for NNTSerde<I, O> {
fn from(value: &NeuralNetworkTopology<I, O>) -> Self {
let input_layer = value
.input_layer
.iter()
.map(|n| n.read().unwrap().clone())
.collect::<Vec<_>>()
.try_into()
.unwrap();

let hidden_layers = value
.hidden_layers
.iter()
.map(|n| n.read().unwrap().clone())
.collect();

let output_layer = value
.output_layer
.iter()
.map(|n| n.read().unwrap().clone())
.collect::<Vec<_>>()
.try_into()
.unwrap();

Self {
input_layer,
hidden_layers,
output_layer,
mutation_rate: value.mutation_rate,
mutation_passes: value.mutation_passes,
}
}
}

#[cfg(test)]
#[test]
fn serde() {
let mut rng = rand::thread_rng();
let nnt = NeuralNetworkTopology::<10, 10>::new(0.1, 3, &mut rng);
let nnts = NNTSerde::from(&nnt);

let encoded = bincode::serialize(&nnts).unwrap();

if let Some(_) = option_env!("TEST_CREATEFILE") {
std::fs::write("serde-test.nn", &encoded).unwrap();
}

let decoded: NNTSerde<10, 10> = bincode::deserialize(&encoded).unwrap();
let nnt2: NeuralNetworkTopology<10, 10> = decoded.into();

dbg!(nnt, nnt2);
}
}

/// Creates an [`ActivationFn`] object from a function
#[macro_export]
macro_rules! activation_fn {
($F: path) => {
ActivationFn {
func: Arc::new($F),
name: String::from(stringify!($F)),
}
};
use serde::{Deserialize, Serialize};

{$($F: path),*} => {
[$(activation_fn!($F)),*]
};
}
use crate::activation_fn;

/// A stateless neural network topology.
/// This is the struct you want to use in your agent's inheritance.
@@ -116,7 +35,7 @@ pub struct NeuralNetworkTopology<const I: usize, const O: usize> {
/// The output layer of the neural netowrk. Uses a fixed length of `O`.
pub output_layer: [Arc<RwLock<NeuronTopology>>; O],

/// The mutation rate used in [`NeuralNetworkTopology::mutate`].
/// The mutation rate used in [`NeuralNetworkTopology::mutate`] after crossover/division.
pub mutation_rate: f32,

/// The number of mutation passes (and thus, maximum number of possible mutations that can occur for each entity in the generation).
@@ -371,7 +290,7 @@ impl<const I: usize, const O: usize> RandomlyMutable for NeuralNetworkTopology<I

let loc3 = NeuronLocation::Hidden(self.hidden_layers.len());

let n3 = NeuronTopology::new(vec![loc], rng);
let n3 = NeuronTopology::new(vec![loc], ActivationScope::HIDDEN, rng);

self.hidden_layers.push(Arc::new(RwLock::new(n3)));

@@ -425,11 +344,8 @@ impl<const I: usize, const O: usize> RandomlyMutable for NeuralNetworkTopology<I

if rng.gen::<f32>() <= rate && !self.hidden_layers.is_empty() {
// mutate activation function
let activations = activation_fn! {
sigmoid,
relu,
f32::tanh
};
let reg = ACTIVATION_REGISTRY.read().unwrap();
let activations = reg.activations_in_scope(ActivationScope::HIDDEN);

let (mut n, mut loc) = self.rand_neuron(rng);

@@ -608,73 +524,6 @@ fn input_exists<const I: usize>(
}
}

/// An activation function object that implements [`fmt::Debug`] and is [`Send`]
#[derive(Clone)]
pub struct ActivationFn {
/// The actual activation function.
pub func: Arc<dyn Fn(f32) -> f32 + Send + Sync + 'static>,
name: String,
}

impl fmt::Debug for ActivationFn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "{}", self.name)
}
}

impl PartialEq for ActivationFn {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}

#[cfg(feature = "serde")]
impl Serialize for ActivationFn {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.name)
}
}

#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for ActivationFn {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'a>,
{
let name = String::deserialize(deserializer)?;
let activations = activation_fn! {
sigmoid,
relu,
f32::tanh,
linear_activation
};

for a in activations {
if a.name == name {
return Ok(a);
}
}

// eventually will make an activation fn registry of sorts.
panic!("Custom activation functions currently not supported.") // TODO return error instead of raw panic
}
}

/// The sigmoid activation function.
pub fn sigmoid(n: f32) -> f32 {
1. / (1. + std::f32::consts::E.powf(-n))
}

/// The ReLU activation function.
pub fn relu(n: f32) -> f32 {
n.max(0.)
}

/// Activation function that does nothing.
pub fn linear_activation(n: f32) -> f32 {
n
}

/// A stateless version of [`Neuron`][crate::Neuron].
#[derive(PartialEq, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -691,12 +540,13 @@ pub struct NeuronTopology {

impl NeuronTopology {
/// Creates a new neuron with the given input locations.
pub fn new(inputs: Vec<NeuronLocation>, rng: &mut impl Rng) -> Self {
let activations = activation_fn! {
sigmoid,
relu,
f32::tanh
};
pub fn new(
inputs: Vec<NeuronLocation>,
current_scope: ActivationScope,
rng: &mut impl Rng,
) -> Self {
let reg = ACTIVATION_REGISTRY.read().unwrap();
let activations = reg.activations_in_scope(current_scope);

Self::new_with_activations(inputs, activations, rng)
}
71 changes: 71 additions & 0 deletions src/topology/nnt_serde.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use super::*;
use serde::{Deserialize, Serialize};
use serde_big_array::BigArray;

/// A serializable wrapper for [`NeuronTopology`]. See [`NNTSerde::from`] for conversion.
#[derive(Serialize, Deserialize)]
pub struct NNTSerde<const I: usize, const O: usize> {
#[serde(with = "BigArray")]
pub(crate) input_layer: [NeuronTopology; I],

pub(crate) hidden_layers: Vec<NeuronTopology>,

#[serde(with = "BigArray")]
pub(crate) output_layer: [NeuronTopology; O],

pub(crate) mutation_rate: f32,
pub(crate) mutation_passes: usize,
}

impl<const I: usize, const O: usize> From<&NeuralNetworkTopology<I, O>> for NNTSerde<I, O> {
fn from(value: &NeuralNetworkTopology<I, O>) -> Self {
let input_layer = value
.input_layer
.iter()
.map(|n| n.read().unwrap().clone())
.collect::<Vec<_>>()
.try_into()
.unwrap();

let hidden_layers = value
.hidden_layers
.iter()
.map(|n| n.read().unwrap().clone())
.collect();

let output_layer = value
.output_layer
.iter()
.map(|n| n.read().unwrap().clone())
.collect::<Vec<_>>()
.try_into()
.unwrap();

Self {
input_layer,
hidden_layers,
output_layer,
mutation_rate: value.mutation_rate,
mutation_passes: value.mutation_passes,
}
}
}

#[cfg(test)]
#[test]
fn serde() {
let mut rng = rand::thread_rng();
let nnt = NeuralNetworkTopology::<10, 10>::new(0.1, 3, &mut rng);
let nnts = NNTSerde::from(&nnt);

let encoded = bincode::serialize(&nnts).unwrap();

if let Some(_) = option_env!("TEST_CREATEFILE") {
std::fs::write("serde-test.nn", &encoded).unwrap();
}

let decoded: NNTSerde<10, 10> = bincode::deserialize(&encoded).unwrap();
let nnt2: NeuralNetworkTopology<10, 10> = decoded.into();

dbg!(nnt, nnt2);
}