Skip to content

Commit

Permalink
Also allow callable nonlinearities
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Nov 8, 2023
1 parent 5b51756 commit 2508919
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 18 deletions.
4 changes: 2 additions & 2 deletions neuralprocesses/architectures/agnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def construct_agnp(*args, nps=nps, num_heads=8, **kw_args):
`False`.
num_dec_layers (int, optional): Number of layers in the decoder. Defaults to 6.
width (int, optional): Widths of all intermediate MLPs. Defaults to 512.
nonlinearity (string, optional): Nonlinearity in the MLP layers. Must be one of
`"ReLU"`, and `"LeakyReLU"`. Defaults to `"ReLU"`.
nonlinearity (Callable or str, optional): Nonlinearity. Can also be specified
as a string: `"ReLU"` or `"LeakyReLU"`. Defaults to ReLUs.
likelihood (str, optional): Likelihood. Must be one of `"het"` or `"lowrank"`.
Defaults to `"lowrank"`.
num_basis_functions (int, optional): Number of basis functions for the
Expand Down
4 changes: 2 additions & 2 deletions neuralprocesses/architectures/gnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def construct_gnp(
`False`.
num_dec_layers (int, optional): Number of layers in the decoder. Defaults to 6.
width (int, optional): Widths of all intermediate MLPs. Defaults to 512.
nonlinearity (string, optional): Nonlinearity in the MLP layers. Must be one of
`"ReLU"`, and `"LeakyReLU"`. Defaults to `"ReLU"`.
nonlinearity (Callable or str, optional): Nonlinearity. Can also be specified
as a string: `"ReLU"` or `"LeakyReLU"`. Defaults to ReLUs.
likelihood (str, optional): Likelihood. Must be one of `"het"` or `"lowrank"`.
Defaults to `"lowrank"`.
num_basis_functions (int, optional): Number of basis functions for the
Expand Down
4 changes: 2 additions & 2 deletions neuralprocesses/coders/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class Attention:
dim_embedding (int): Dimensionality of the embedding.
num_heads (int): Number of heads.
num_enc_layers (int): Number of layers in the encoders.
nonlinearity (string, optional): Nonlinearity in the encoders. Must be
one of `"ReLU"`, and `"LeakyReLU"`. Defaults to `"ReLU"`.
nonlinearity (Callable or str, optional): Nonlinearity. Can also be specified
as a string: `"ReLU"` or `"LeakyReLU"`. Defaults to ReLUs.
dtype (dtype, optional): Data type.
Attributes:
Expand Down
35 changes: 23 additions & 12 deletions neuralprocesses/coders/nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from functools import partial
from typing import Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union

import lab as B
from plum import convert
Expand Down Expand Up @@ -29,6 +29,12 @@ def __init__(self, in_channels, out_channels, dtype):
self.net = self.nn.Linear(in_channels, out_channels, dtype=dtype)


_nonlinearity_name_map = {
"relu": "ReLU",
"leakyrelu": "LeakyReLU",
}


@register_module
class MLP:
"""MLP.
Expand All @@ -39,8 +45,8 @@ class MLP:
layers (tuple[int, ...], optional): Width of every hidden layer.
num_layers (int, optional): Number of hidden layers.
width (int, optional): Width of the hidden layers
nonlinearity (string, optional): Nonlinearity. Must be one of
`"ReLU"`, and `"LeakyReLU"`. Defaults to `"ReLU"`.
nonlinearity (Callable or str, optional): Nonlinearity. Can also be specified
as a string: `"ReLU"` or `"LeakyReLU"`. Defaults to ReLUs.
dtype (dtype, optional): Data type.
Attributes:
Expand All @@ -54,27 +60,32 @@ def __init__(
layers: Optional[Tuple[int, ...]] = None,
num_layers: Optional[int] = None,
width: Optional[int] = None,
nonlinearity="ReLU",
nonlinearity: Union[Callable, str] = "ReLU",
dtype=None,
):
# Check that one of the two specifications is given.
layers_given = layers is not None
num_layers_given = num_layers is not None and width is not None
if not (layers_given or num_layers_given):
raise ValueError(
f"Must specify either `layers` or `num_layers` and `width`."
"Must specify either `layers` or `num_layers` and `width`."
)
# Make sure that `layers` is a tuple of various widths.
if not layers_given and num_layers_given:
layers = (width,) * num_layers

# Default to ReLUs.
if nonlinearity is None or str(nonlinearity).lower() == "relu":
nonlinearity = self.nn.ReLU()
elif str(nonlinearity).lower() == "leakyrelu":
nonlinearity = self.nn.LeakyReLU()
else:
raise ValueError("""'nonlinearity' must be either `ReLU`, or `LeakyReLU`""")
# Resolve string-form `nonlinearity`.
if isinstance(nonlinearity, str):
try:
resolved_name = _nonlinearity_name_map[nonlinearity.lower()]
nonlinearity = getattr(self.nn, resolved_name)()
except KeyError:
raise ValueError(
f"Nonlinearity `{resolved_name}` invalid. "
f"Must be one of "
+ ", ".join(f"`{k}`" for k in _nonlinearity_name_map.keys())
+ "."
)

# Build layers.
if len(layers) == 0:
Expand Down

0 comments on commit 2508919

Please sign in to comment.