Skip to content

Commit

Permalink
Merge pull request #35 from Fraunhofer-IIS/34-naming-of-variable-recu…
Browse files Browse the repository at this point in the history
…rrent_cell_type-in-ecnn

34 naming of variable recurrent cell type in ecnn
  • Loading branch information
bknico-iis authored Nov 13, 2024
2 parents 63d6dec + 3fa0001 commit c56f684
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 39 deletions.
12 changes: 6 additions & 6 deletions docs/source/tutorials/ECNN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@
"\n",
"Before initializing the model, we set some parameters.\n",
"The parameter ``n_state_neurons`` defines the number of the state neurons used in $s_t$.\n",
"As default we use ``'elman'`` as ``recurrent_cell_type``, which builds an Elman RNN with error correction. However, there is a simpler form of GRU variant 3 (Dey and Salem, 2017) implemented. By setting ``recurrent_cell_type`` to ``'gru_3_variant'``, this is used. The ``lstm`` and ``gru`` cell is also supported.\n",
"As default we use ``'elman'`` as ``cell_type``, which builds an Elman RNN with error correction. However, there is a simpler form of GRU variant 3 (Dey and Salem, 2017) implemented. By setting ``cell_type`` to ``'gru_3_variant'``, this is used. The ``lstm`` and ``gru`` cell is also supported.\n",
"The ECNN knows two ``approaches``, ``backward`` and ``forward``. If approach is set to backward, U at time t is used in the model to predict Y at time t. If approach is set to forward, U at time t is used only to predict Y at time t+1. This way, you can implement your belief in there being a delay in the impact of U on Y.\n",
"There is the possibility to set the initial state of the ECNN by giving a tensor of the right dimensions as ``init_state``. If none is provided, the hidden state is initialized randomly.\n",
"We can decide whether that initial state is fixed or it should be learned as a trainable parameter. This is done by setting ``learn_init_state``.\n",
Expand All @@ -386,13 +386,13 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "alien-patent",
"metadata": {},
"outputs": [],
"source": [
"n_state_neurons = 4\n",
"recurrent_cell_type = 'elman'\n",
"cell_type = 'elman'\n",
"approach = \"backward\"\n",
"init_state = torch.zeros(1, n_state_neurons)\n",
"learn_init_state = True\n",
Expand All @@ -403,7 +403,7 @@
" n_state_neurons=n_state_neurons,\n",
" past_horizon=past_horizon,\n",
" forecast_horizon=forecast_horizon,\n",
" recurrent_cell_type=recurrent_cell_type,\n",
" cell_type=cell_type,\n",
" approach=approach,\n",
" init_state=init_state,\n",
" learn_init_state=learn_init_state,\n",
Expand Down Expand Up @@ -680,7 +680,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"id": "honey-mozambique",
"metadata": {},
"outputs": [],
Expand All @@ -691,7 +691,7 @@
" n_state_neurons=n_state_neurons,\n",
" past_horizon=past_horizon,\n",
" forecast_horizon=forecast_horizon,\n",
" recurrent_cell_type='elman',\n",
" cell_type='elman',\n",
" approach=\"backward\",\n",
" init_state=init_state,\n",
" learn_init_state=True,\n",
Expand Down
34 changes: 17 additions & 17 deletions examples/case_study/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ def __init__(
n_state_neurons: int,
n_features_Y: int,
forecast_horizon: int,
recurrent_cell_type: str,
cell_type: str,
):
super(Benchmark_RNN, self).__init__()
self.multivariate = False

self.n_features_Y = n_features_Y
self.forecast_horizon = forecast_horizon
self.n_state_neurons = n_state_neurons
self.recurrent_cell_type = recurrent_cell_type
self.cell_type = cell_type

self.cell = self.get_recurrent_cell()
self.rnn = self.cell(input_size=n_features_U, hidden_size=n_state_neurons)
Expand All @@ -50,7 +50,7 @@ def forward(self, features_past: torch.Tensor) -> torch.Tensor:

def set_init_state(self) -> Union[nn.Parameter, Tuple[nn.Parameter, nn.Parameter]]:
dtype = torch.float64
if self.recurrent_cell_type == "lstm":
if self.cell_type == "lstm":
init_state = (
nn.Parameter(
torch.rand(1, self.n_state_neurons, dtype=dtype), requires_grad=True
Expand All @@ -66,22 +66,22 @@ def set_init_state(self) -> Union[nn.Parameter, Tuple[nn.Parameter, nn.Parameter
return init_state

def get_recurrent_cell(self) -> nn.Module:
if self.recurrent_cell_type == "elman":
if self.cell_type == "elman":
cell = nn.RNN
elif self.recurrent_cell_type == "gru":
elif self.cell_type == "gru":
cell = nn.GRU
elif self.recurrent_cell_type == "lstm":
elif self.cell_type == "lstm":
cell = nn.LSTM
else:
raise ValueError(
f"recurrent_cell_type {self.recurrent_cell_type} not available."
f"cell_type {self.cell_type} not available."
)
return cell

def repeat_init_state(
self, batchsize: int
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.recurrent_cell_type == "lstm":
if self.cell_type == "lstm":
return self.init_state[0].repeat(batchsize, 1).unsqueeze(
0
), self.init_state[1].repeat(batchsize, 1).unsqueeze(0)
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
n_state_neurons: int,
n_features_Y: int,
forecast_horizon: int,
recurrent_cell_type: str,
cell_type: str,
):
self.forecast_method = "direct"
self.output_size_linear_decoder = n_features_Y * forecast_horizon
Expand All @@ -126,7 +126,7 @@ def __init__(
n_state_neurons,
n_features_Y,
forecast_horizon,
recurrent_cell_type,
cell_type,
)

def forward(self, features_past: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -161,7 +161,7 @@ def __init__(
n_state_neurons: int,
n_features_Y: int,
forecast_horizon: int,
recurrent_cell_type: str,
cell_type: str,
):
self.forecast_method = "recursive"
self.output_size_linear_decoder = n_features_Y
Expand All @@ -170,7 +170,7 @@ def __init__(
n_state_neurons,
n_features_Y,
forecast_horizon,
recurrent_cell_type,
cell_type,
)

def forward(self, features_past: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -207,7 +207,7 @@ def __init__(
n_state_neurons: int,
n_features_Y: int,
forecast_horizon: int,
recurrent_cell_type: str,
cell_type: str,
):
self.forecast_method = "s2s"
self.output_size_linear_decoder = n_features_Y
Expand All @@ -216,7 +216,7 @@ def __init__(
n_state_neurons,
n_features_Y,
forecast_horizon,
recurrent_cell_type,
cell_type,
)
self.decoder = self.cell(input_size=n_features_U, hidden_size=n_state_neurons)

Expand Down Expand Up @@ -366,16 +366,16 @@ def init_models(

# Compare to further Recurrent Neural Networks
for forecast_module in [RNN_direct, RNN_recursive, RNN_S2S]:
for recurrent_cell_type in ["elman", "gru", "lstm"]:
for cell_type in ["elman", "gru", "lstm"]:
model = forecast_module(
n_features_U,
n_state_neurons,
n_features_Y,
forecast_horizon,
recurrent_cell_type,
cell_type,
)
ensemble = init_ensemble(model, n_models)
benchmark_models[f"{recurrent_cell_type}_{model.forecast_method}"] = (
benchmark_models[f"{cell_type}_{model.forecast_method}"] = (
ensemble
)

Expand Down
2 changes: 1 addition & 1 deletion examples/ecnn_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
n_state_neurons,
past_horizon,
forecast_horizon,
recurrent_cell_type="gru",
cell_type="gru",
approach="backward",
learn_init_state=True,
n_features_Y=n_features_Y,
Expand Down
14 changes: 7 additions & 7 deletions prosper_nn/models/ecnn/ecnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
n_state_neurons: int,
past_horizon: int,
forecast_horizon: int = 1,
recurrent_cell_type: str = "elman",
cell_type: str = "elman",
kwargs_recurrent_cell: dict = {},
approach: str = "backward",
learn_init_state: bool = True,
Expand All @@ -88,7 +88,7 @@ def __init__(
prediction.
forecast_horizon: int
The forecast horizon.
recurrent_cell_type: str
cell_type: str
Possible choices: elman, lstm, gru or gru_3_variant.
kwargs_recurrent_cell: dict
Parameters for the recurrent cell. Activation function can be set here.
Expand Down Expand Up @@ -124,24 +124,24 @@ def __init__(
self.approach = approach
self.future_U = future_U
self.learn_init_state = learn_init_state
self.recurrent_cell_type = recurrent_cell_type
self.cell_type = cell_type

self._check_variables()

if recurrent_cell_type == 'lstm':
if cell_type == 'lstm':
self.state = ([(torch.tensor, torch.tensor)] * (past_horizon + forecast_horizon + 1))
else:
self.state = [torch.tensor] * (past_horizon + forecast_horizon + 1)

self.ECNNCell = ecnn_cell.ECNNCell(
n_features_U, n_state_neurons, n_features_Y, recurrent_cell_type, kwargs_recurrent_cell
n_features_U, n_state_neurons, n_features_Y, cell_type, kwargs_recurrent_cell
)

self.init_state = self.set_init_state()


def set_init_state(self):
if self.recurrent_cell_type == 'lstm':
if self.cell_type == 'lstm':
init_state = (nn.Parameter(
torch.rand(1, self.n_state_neurons), requires_grad=self.learn_init_state
), nn.Parameter(
Expand Down Expand Up @@ -250,7 +250,7 @@ def forward(self, U: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
return torch.cat((past_error, forecast), dim=0)

def repeat_init_state(self, batchsize):
if self.recurrent_cell_type == 'lstm':
if self.cell_type == 'lstm':
return self.init_state[0].repeat(batchsize, 1), self.init_state[1].repeat(batchsize, 1)
else:
return self.init_state.repeat(batchsize, 1)
Expand Down
14 changes: 7 additions & 7 deletions prosper_nn/models/ecnn/ecnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
n_features_U: int,
n_state_neurons: int,
n_features_Y: int = 1,
recurrent_cell_type: str = "elman",
cell_type: str = "elman",
kwargs_recurrent_cell: dict = {},
):
"""
Expand All @@ -58,7 +58,7 @@ def __init__(
n_features_Y: int
The number of outputs, i.e. the number of elements of Y at each time
step. The default is 1.
recurrent_cell_type: str
cell_type: str
Select the cell for the state transition. The cells elman, lstm, gru
(all from pytorch) and gru_3_variant (from prosper_nn) are supported.
kwargs_recurrent_cell: dict
Expand All @@ -73,17 +73,17 @@ def __init__(

self.C = nn.Linear(n_state_neurons, n_features_Y, bias=False)

if recurrent_cell_type == "elman":
if cell_type == "elman":
self.recurrent_cell = nn.RNNCell
elif recurrent_cell_type == "lstm":
elif cell_type == "lstm":
self.recurrent_cell = nn.LSTMCell
elif recurrent_cell_type == "gru":
elif cell_type == "gru":
self.recurrent_cell = nn.GRUCell
elif recurrent_cell_type == "gru_3_variant":
elif cell_type == "gru_3_variant":
self.recurrent_cell = GRU_3_variant
else:
raise ValueError(
f"recurrent_cell_type: {recurrent_cell_type} is not known."
f"cell_type: {cell_type} is not known."
"Choose from elman, lstm, gru or gru_3_variant."
)
self.recurrent_cell = self.recurrent_cell(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setuptools.setup(
name="prosper_nn", # Replace with your own username
version="0.3.1",
version="0.3.2",
author="Nico Beck, Julia Schemm",
author_email="[email protected]",
description="Package contains, in PyTorch implemented, neural networks with problem specific pre-structuring architectures and utils that help building and understanding models.",
Expand Down

0 comments on commit c56f684

Please sign in to comment.