Skip to content

Commit

Permalink
add batch_register fn
Browse files Browse the repository at this point in the history
  • Loading branch information
HyperCodec committed Apr 15, 2024
1 parent b85b9f6 commit 1464805
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions src/topology/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@ lazy_static! {
pub(crate) static ref ACTIVATION_REGISTRY: Arc<RwLock<ActivationRegistry>> = Arc::new(RwLock::new(ActivationRegistry::default()));
}

/// Register an activation function to the registry
/// Register an activation function to the registry.
pub fn register_activation(act: ActivationFn) {
let mut reg = ACTIVATION_REGISTRY.write().unwrap();
reg.register(act);
}

/// Registers multiple activation functions to the registry at once.
pub fn batch_register_activation(acts: impl IntoIterator<Item = ActivationFn>) {
let mut reg = ACTIVATION_REGISTRY.write().unwrap();
reg.batch_register(acts);
}

/// A registry of the different possible activation functions.
pub struct ActivationRegistry {
/// The currently-registered activation functions.
Expand All @@ -42,6 +48,13 @@ impl ActivationRegistry {
self.fns.insert(activation.name.clone(), activation);
}

/// Registers multiple activation functions at once.
pub fn batch_register(&mut self, activations: impl IntoIterator<Item = ActivationFn>) {
for act in activations {
self.register(act);
}
}

/// Gets a Vec of all the
pub fn activations(&self) -> Vec<ActivationFn> {
self.fns.values()
Expand All @@ -55,14 +68,12 @@ impl Default for ActivationRegistry {
fn default() -> Self {
let mut s = Self { fns: HashMap::new() };

activation_fn! {
s.batch_register(activation_fn! {
sigmoid,
relu,
linear_activation,
f32::tanh
}
.into_iter()
.for_each(|f| s.register(f));
});

s
}
Expand Down

0 comments on commit 1464805

Please sign in to comment.