Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add first unit tests for hcnn and sensitivity analysis #31

Merged
merged 3 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions prosper_nn/utils/sensitivity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,109 @@ def analyse_temporal_sensitivity(

return torch.stack(total_heat)

def plot_analyse_temporal_sensitivity(
sensis: torch.Tensor,
target_var: List[str],
features: List[str],
n_future_steps: int,
path: Optional[str] = None,
title: Optional[Union[dict, str]] = None,
xticks: Optional[Union[dict, str]] = None,
yticks: Optional[Union[dict, str]] = None,
xlabel: Optional[Union[dict, str]] = None,
ylabel: Optional[Union[dict, str]] = None,
figsize: List[float] = [12.4, 5.8],
) -> None:
"""
Plots a sensitivity analysis and creates a table with monotonie and total heat on the right side
for each task variable.
"""
# Calculate total heat and monotony
total_heat = torch.sum(torch.abs(sensis), dim=2)
total_heat = (total_heat * 100).round() / 100
monotonie = torch.sum(sensis, dim=2) / total_heat
monotonie = (monotonie * 100).round() / 100

plt.rcParams["figure.figsize"] = figsize
### Temporal Sensitivity Heatmap ###
# plot a sensitivity matrix for every feature/target variable to be investigated
for i, node in enumerate(target_var):
# Set description
if not title:
title = "Influence of auxiliary variables on {}"
if not xlabel:
xlabel = "Weeks into future"
if not ylabel:
ylabel = "Auxiliary variables"
if not xticks:
xticks = {
"ticks": range(1, n_future_steps + 1),
"labels": [
str(i) if i % 2 == 1 else None for i in range(1, n_future_steps + 1)
],
"horizontalalignment": "right",
}
if not yticks:
yticks = {
"ticks": range(len(features)),
"labels": [feature.replace("_", " ") for feature in features],
"rotation": 0,
"va": "top",
"size": "large",
}

sns.heatmap(sensis[i],
center=0,
cmap='coolwarm',
robust=True,
cbar_kws={'location':'right', 'pad': 0.22},
)
plt.ylabel(ylabel)
plt.xlabel(xlabel)
plt.xticks(**xticks)
plt.yticks(**yticks),
plt.title(title.format(node.replace("_", " ")), pad=25)

# Fade out row name if total heat is not that strong
for j, ticklabel in enumerate(plt.gca().get_yticklabels()):
if j >= len(target_var):
alpha = float(0.5 + (total_heat[i][j] / torch.max(total_heat)) / 2)
ticklabel.set_color(color=[0, 0, 0, alpha])
else:
ticklabel.set_color(color="C0")
plt.tight_layout()

### Table with total heat and monotonie ###
table_values = torch.stack((total_heat[i], monotonie[i])).T

# Colour of cells
cell_colours = [
["#E1E3E3" for _ in range(table_values.shape[1])]
for _ in range(table_values.shape[0])
]
cell_colours[torch.argmax(table_values, dim=0)[0]][0] = "#179C7D"
cell_colours[torch.argmax(torch.abs(table_values), dim=0)[1]][1] = "#179C7D"

# Plot table
plt.table(
table_values.numpy(),
loc='right',
colLabels=['Absolute', 'Monotony'],
colWidths=[0.2,0.2],
bbox=[1, 0, 0.3, 1.042], #[1, 0, 0.4, 1.042],
cellColours=cell_colours,
edges='BRT',
)
plt.subplots_adjust(left=0.05, right=1.0) # creates space for table

# Save and close
if path:
plt.savefig(
path + "sensi_analysis_{}.png".format(node), bbox_inches="tight"
)
else:
plt.show()
plt.close()

# %% Sensitivity for feed-forward models and other not-recurrent models

Expand Down
Empty file added tests/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions tests/models/hcnn/test_hcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
from prosper_nn.models.hcnn import HCNN
import pytest


class TestHcnn:
@pytest.mark.parametrize(
"n_state_neurons, n_features_Y, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward",
[
(10, 4, 5, 2, 1, 0, 1, True, True),
(1, 1, 1, 1, 1, 0, 1, True, True),
(10, 4, 5, 2, 1, 0.5, 1, True, True),
(10, 4, 5, 2, 1, 0, 0.5, False, True),
(10, 4, 5, 2, 1, 0, 0.5, False, False),
(10, 4, 5, 2, 1, 0, 0.5, True, False),
(10, 4, 5, 2, 1, 0, 1, True, True),
],
)
def test_forward(self, n_state_neurons, n_features_Y, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward):
hcnn = HCNN(
n_state_neurons=n_state_neurons,
n_features_Y=n_features_Y,
past_horizon=past_horizon,
forecast_horizon=forecast_horizon,
sparsity=sparsity,
teacher_forcing=teacher_forcing,
backward_full_Y=backward_full_Y,
ptf_in_backward=ptf_in_backward,
)
observation = torch.zeros(past_horizon, batchsize, n_features_Y)
output_ = hcnn(observation)

assert output_.shape == torch.Size((past_horizon + forecast_horizon, batchsize, n_features_Y))
assert isinstance(output_, torch.Tensor)
assert not (output_.isnan()).any()

@pytest.mark.parametrize(
"n_state_neurons, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward",
[
(5, 50, 5, 1, 0, 1, True, True),

],
)
def test_train(self, n_state_neurons, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward):
n_features_Y = 1
n_epochs = 10000

hcnn = HCNN(
n_state_neurons=n_state_neurons,
n_features_Y=n_features_Y,
past_horizon=past_horizon,
forecast_horizon=forecast_horizon,
sparsity=sparsity,
teacher_forcing=teacher_forcing,
backward_full_Y=backward_full_Y,
ptf_in_backward=ptf_in_backward,
)
observation = torch.zeros(past_horizon, batchsize, n_features_Y)
observation = torch.sin(torch.linspace(0.5, 10 * torch.pi, past_horizon + forecast_horizon))
observation = observation.unsqueeze(1).unsqueeze(1)

optimizer = torch.optim.Adam(hcnn.parameters(), lr=0.001)
target = torch.zeros_like(observation[:past_horizon])
loss_fct = torch.nn.MSELoss()

start_weight = hcnn.HCNNCell.A.weight.clone()

for epoch in range(n_epochs):
output_ = hcnn(observation[:past_horizon])
loss = loss_fct(output_[:past_horizon], target)
loss.backward()
assert hcnn.HCNNCell.A.weight.grad is not None
optimizer.step()
if epoch == 1:
start_loss = loss.detach()
assert (hcnn.HCNNCell.A.weight != start_weight).all()
hcnn.zero_grad()

forecast = hcnn(observation[:past_horizon])[past_horizon:]
assert loss < start_loss
assert torch.isclose(observation[past_horizon:], forecast, atol=1).all()

@pytest.mark.parametrize("teacher_forcing, decrease_teacher_forcing, result", [(1, 0, 1), (1, 0.2, 0.8), (0, 0.1, 0)],)
def test_adjust_teacher_forcing(self, teacher_forcing, decrease_teacher_forcing, result):
hcnn = HCNN(
n_state_neurons=10,
n_features_Y=2,
past_horizon=10,
forecast_horizon=5,
teacher_forcing=teacher_forcing,
decrease_teacher_forcing=decrease_teacher_forcing)
hcnn.adjust_teacher_forcing()
assert hcnn.HCNNCell.teacher_forcing == result
assert hcnn.teacher_forcing == result
116 changes: 116 additions & 0 deletions tests/models/hcnn/test_hcnn_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import torch
from prosper_nn.models.hcnn.hcnn_cell import HCNNCell, PartialTeacherForcing


class TestPartialTeacherForcing:
ptf = PartialTeacherForcing(p=0.5)

def test_evaluation(self):
self.ptf.eval()
input = torch.randn((20, 1, 100))

output = self.ptf(input)
# fill dropped nodes
output = torch.where(output == 0, input, output)
assert (output == input).all()

def test_train(self):
self.ptf.train()
input = torch.randn((20, 1, 100))

output = self.ptf(input)
# fill dropped nodes
output = torch.where(output == 0, input, output)
assert (output == input).all()


class TestHcnnCell:
n_state_neurons = 10
n_features_Y = 5
batchsize = 7

hcnn_cell = HCNNCell(
n_state_neurons=n_state_neurons,
n_features_Y=n_features_Y,
)
hcnn_cell.A.weight = torch.nn.Parameter(torch.ones_like(hcnn_cell.A.weight))
state = 0.5 * torch.ones((batchsize, n_state_neurons))
expectation = state[..., :n_features_Y]
observation = torch.ones(batchsize, n_features_Y)

def test_get_teacher_forcing_full_Y(self):
self.hcnn_cell.ptf_dropout.p = 0
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_full_Y(
self.observation, self.expectation
)
self.checks_get_teacher_forcing(output_, teacher_forcing_)

### with partial teacher forcing
self.hcnn_cell.ptf_dropout.p = 0.5
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_full_Y(
self.observation, self.expectation
)

# fill dropped nodes
teacher_forcing_[..., : self.n_features_Y] = torch.where(
teacher_forcing_[..., : self.n_features_Y] == 0,
-0.5,
teacher_forcing_[..., : self.n_features_Y],
)

self.checks_get_teacher_forcing(output_, teacher_forcing_)

def test_get_teacher_forcing_partial_Y(self):
self.hcnn_cell.ptf_dropout.p = 0
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_partial_Y(
self.observation, self.expectation
)
self.checks_get_teacher_forcing(output_, teacher_forcing_)

### with partial teacher forcing
self.hcnn_cell.ptf_dropout.p = 0.5
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_partial_Y(
self.observation, self.expectation
)
# fill dropped nodes
teacher_forcing_[..., : self.n_features_Y] = torch.where(
teacher_forcing_[..., : self.n_features_Y] == 0,
-0.5,
teacher_forcing_[..., : self.n_features_Y],
)
output_ = torch.where(output_ == 0, -0.5, output_)
self.checks_get_teacher_forcing(output_, teacher_forcing_)

def checks_get_teacher_forcing(self, output_, teacher_forcing_):
assert (output_ == -0.5 * torch.ones(self.batchsize, self.n_features_Y)).all()
assert (teacher_forcing_[..., : self.n_features_Y] == -self.expectation).all()
assert (teacher_forcing_[..., self.n_features_Y :] == 0).all()
assert (
(self.expectation - teacher_forcing_[..., : self.n_features_Y])
== self.observation
).all()

def test_forward(self):
state_, output_ = self.hcnn_cell.forward(self.state)
self.checks_forward(state_, output_)

state_, output_ = self.hcnn_cell.forward(self.state, self.observation)
self.checks_forward(state_, output_)

def test_forward_past_horizon(self):
state_, output_ = self.hcnn_cell.forward_past_horizon(
self.state, self.observation, self.expectation
)
self.checks_forward(state_, output_)

def test_forward_forecast_horizon(self):
state_, output_ = self.hcnn_cell.forward_forecast_horizon(
self.state, self.expectation
)
self.checks_forward(state_, output_)

def checks_forward(self, state_, output_):
assert state_.shape == torch.Size((self.batchsize, self.n_state_neurons))
assert output_.shape == torch.Size((self.batchsize, self.n_features_Y))
assert not (state_.isnan()).any()
assert not (output_.isnan()).any()
Loading
Loading