From 1b6debc706df36d204ff2892d4bec2ed9f29017e Mon Sep 17 00:00:00 2001 From: Julia Schemm Date: Fri, 16 Aug 2024 13:59:02 +0200 Subject: [PATCH 1/3] add first unit tests for hcnn and sensitivity analysis --- tests/__init__.py | 0 tests/models/hcnn/test_hcnn.py | 94 ++++++++++++++++++ tests/models/hcnn/test_hcnn_cell.py | 116 +++++++++++++++++++++++ tests/utils/test_sensitivity_analysis.py | 83 ++++++++++++++++ 4 files changed, 293 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/models/hcnn/test_hcnn.py create mode 100644 tests/models/hcnn/test_hcnn_cell.py create mode 100644 tests/utils/test_sensitivity_analysis.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/hcnn/test_hcnn.py b/tests/models/hcnn/test_hcnn.py new file mode 100644 index 0000000..6a7b444 --- /dev/null +++ b/tests/models/hcnn/test_hcnn.py @@ -0,0 +1,94 @@ +import torch +from prosper_nn.models.hcnn import HCNN +import pytest + + +class TestHcnn: + @pytest.mark.parametrize( + "n_state_neurons, n_features_Y, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward", + [ + (10, 4, 5, 2, 1, 0, 1, True, True), + (1, 1, 1, 1, 1, 0, 1, True, True), + (10, 4, 5, 2, 1, 0.5, 1, True, True), + (10, 4, 5, 2, 1, 0, 0.5, False, True), + (10, 4, 5, 2, 1, 0, 0.5, False, False), + (10, 4, 5, 2, 1, 0, 0.5, True, False), + (10, 4, 5, 2, 1, 0, 1, True, True), + ], + ) + def test_forward(self, n_state_neurons, n_features_Y, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward): + hcnn = HCNN( + n_state_neurons=n_state_neurons, + n_features_Y=n_features_Y, + past_horizon=past_horizon, + forecast_horizon=forecast_horizon, + sparsity=sparsity, + teacher_forcing=teacher_forcing, + backward_full_Y=backward_full_Y, + ptf_in_backward=ptf_in_backward, + ) + observation = torch.zeros(past_horizon, batchsize, n_features_Y) + output_ = hcnn(observation) + + assert output_.shape == torch.Size((past_horizon + forecast_horizon, batchsize, n_features_Y)) + assert isinstance(output_, torch.Tensor) + assert not (output_.isnan()).any() + + @pytest.mark.parametrize( + "n_state_neurons, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward", + [ + (5, 50, 5, 1, 0, 1, True, True), + + ], + ) + def test_train(self, n_state_neurons, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward): + n_features_Y = 1 + n_epochs = 10000 + + hcnn = HCNN( + n_state_neurons=n_state_neurons, + n_features_Y=n_features_Y, + past_horizon=past_horizon, + forecast_horizon=forecast_horizon, + sparsity=sparsity, + teacher_forcing=teacher_forcing, + backward_full_Y=backward_full_Y, + ptf_in_backward=ptf_in_backward, + ) + observation = torch.zeros(past_horizon, batchsize, n_features_Y) + observation = torch.sin(torch.linspace(0.5, 10 * torch.pi, past_horizon + forecast_horizon)) + observation = observation.unsqueeze(1).unsqueeze(1) + + optimizer = torch.optim.Adam(hcnn.parameters(), lr=0.001) + target = torch.zeros_like(observation[:past_horizon]) + loss_fct = torch.nn.MSELoss() + + start_weight = hcnn.HCNNCell.A.weight.clone() + + for epoch in range(n_epochs): + output_ = hcnn(observation[:past_horizon]) + loss = loss_fct(output_[:past_horizon], target) + loss.backward() + assert hcnn.HCNNCell.A.weight.grad is not None + optimizer.step() + if epoch == 1: + start_loss = loss.detach() + assert (hcnn.HCNNCell.A.weight != start_weight).all() + hcnn.zero_grad() + + forecast = hcnn(observation[:past_horizon])[past_horizon:] + assert loss < start_loss + assert torch.isclose(observation[past_horizon:], forecast, atol=1).all() + + @pytest.mark.parametrize("teacher_forcing, decrease_teacher_forcing, result", [(1, 0, 1), (1, 0.2, 0.8), (0, 0.1, 0)],) + def test_adjust_teacher_forcing(self, teacher_forcing, decrease_teacher_forcing, result): + hcnn = HCNN( + n_state_neurons=10, + n_features_Y=2, + past_horizon=10, + forecast_horizon=5, + teacher_forcing=teacher_forcing, + decrease_teacher_forcing=decrease_teacher_forcing) + hcnn.adjust_teacher_forcing() + assert hcnn.HCNNCell.teacher_forcing == result + assert hcnn.teacher_forcing == result diff --git a/tests/models/hcnn/test_hcnn_cell.py b/tests/models/hcnn/test_hcnn_cell.py new file mode 100644 index 0000000..55f6b83 --- /dev/null +++ b/tests/models/hcnn/test_hcnn_cell.py @@ -0,0 +1,116 @@ +import torch +from prosper_nn.models.hcnn.hcnn_cell import HCNNCell, PartialTeacherForcing + + +class TestPartialTeacherForcing: + ptf = PartialTeacherForcing(p=0.5) + + def test_evaluation(self): + self.ptf.eval() + input = torch.randn((20, 1, 100)) + + output = self.ptf(input) + # fill dropped nodes + output = torch.where(output == 0, input, output) + assert (output == input).all() + + def test_train(self): + self.ptf.train() + input = torch.randn((20, 1, 100)) + + output = self.ptf(input) + # fill dropped nodes + output = torch.where(output == 0, input, output) + assert (output == input).all() + + +class TestHcnnCell: + n_state_neurons = 10 + n_features_Y = 5 + batchsize = 7 + + hcnn_cell = HCNNCell( + n_state_neurons=n_state_neurons, + n_features_Y=n_features_Y, + ) + hcnn_cell.A.weight = torch.nn.Parameter(torch.ones_like(hcnn_cell.A.weight)) + state = 0.5 * torch.ones((batchsize, n_state_neurons)) + expectation = state[..., :n_features_Y] + observation = torch.ones(batchsize, n_features_Y) + + def test_get_teacher_forcing_full_Y(self): + self.hcnn_cell.ptf_dropout.p = 0 + output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_full_Y( + self.observation, self.expectation + ) + self.checks_get_teacher_forcing(output_, teacher_forcing_) + + ### with partial teacher forcing + self.hcnn_cell.ptf_dropout.p = 0.5 + output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_full_Y( + self.observation, self.expectation + ) + + # fill dropped nodes + teacher_forcing_[..., : self.n_features_Y] = torch.where( + teacher_forcing_[..., : self.n_features_Y] == 0, + -0.5, + teacher_forcing_[..., : self.n_features_Y], + ) + + self.checks_get_teacher_forcing(output_, teacher_forcing_) + + def test_get_teacher_forcing_partial_Y(self): + self.hcnn_cell.ptf_dropout.p = 0 + output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_partial_Y( + self.observation, self.expectation + ) + self.checks_get_teacher_forcing(output_, teacher_forcing_) + + ### with partial teacher forcing + self.hcnn_cell.ptf_dropout.p = 0.5 + output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_partial_Y( + self.observation, self.expectation + ) + # fill dropped nodes + teacher_forcing_[..., : self.n_features_Y] = torch.where( + teacher_forcing_[..., : self.n_features_Y] == 0, + -0.5, + teacher_forcing_[..., : self.n_features_Y], + ) + output_ = torch.where(output_ == 0, -0.5, output_) + self.checks_get_teacher_forcing(output_, teacher_forcing_) + + def checks_get_teacher_forcing(self, output_, teacher_forcing_): + assert (output_ == -0.5 * torch.ones(self.batchsize, self.n_features_Y)).all() + assert (teacher_forcing_[..., : self.n_features_Y] == -self.expectation).all() + assert (teacher_forcing_[..., self.n_features_Y :] == 0).all() + assert ( + (self.expectation - teacher_forcing_[..., : self.n_features_Y]) + == self.observation + ).all() + + def test_forward(self): + state_, output_ = self.hcnn_cell.forward(self.state) + self.checks_forward(state_, output_) + + state_, output_ = self.hcnn_cell.forward(self.state, self.observation) + self.checks_forward(state_, output_) + + def test_forward_past_horizon(self): + state_, output_ = self.hcnn_cell.forward_past_horizon( + self.state, self.observation, self.expectation + ) + self.checks_forward(state_, output_) + + def test_forward_forecast_horizon(self): + state_, output_ = self.hcnn_cell.forward_forecast_horizon( + self.state, self.expectation + ) + self.checks_forward(state_, output_) + + def checks_forward(self, state_, output_): + assert state_.shape == torch.Size((self.batchsize, self.n_state_neurons)) + assert output_.shape == torch.Size((self.batchsize, self.n_features_Y)) + assert not (state_.isnan()).any() + assert not (output_.isnan()).any() diff --git a/tests/utils/test_sensitivity_analysis.py b/tests/utils/test_sensitivity_analysis.py new file mode 100644 index 0000000..fc5561b --- /dev/null +++ b/tests/utils/test_sensitivity_analysis.py @@ -0,0 +1,83 @@ +import torch +from prosper_nn.utils import sensitivity_analysis +from prosper_nn.models.hcnn import HCNN + + +def test_sensitivity_analysis(): + in_features = 10 + out_features = 5 + batchsize = 3 + n_batches = 4 + model = torch.nn.Linear(in_features=in_features, out_features=out_features) + data = torch.randn(n_batches, batchsize, in_features) + sensi = sensitivity_analysis.sensitivity_analysis( + model, data=data, output_neuron=(slice(0, batchsize), 0), batchsize=batchsize + ) + assert isinstance(sensi, torch.Tensor) + assert sensi.shape == torch.Size((batchsize * n_batches, in_features)) + + +def test_calculate_sensitivity_analysis(): + in_features = 10 + out_features = 5 + batchsize = 3 + n_batches = 4 + model = torch.nn.Linear(in_features=in_features, out_features=out_features) + data = torch.randn(n_batches, batchsize, in_features) + sensi = sensitivity_analysis.calculate_sensitivity_analysis( + model, data, output_neuron=(slice(0, batchsize), 0), batchsize=batchsize + ) + sensi = sensi.reshape((sensi.shape[0], -1)) + assert isinstance(sensi, torch.Tensor) + assert sensi.shape == torch.Size((batchsize * n_batches, in_features)) + + +def test_plot_sensitivity_curve(): + in_features = 5 + samples = 10 + sensi = torch.randn(samples, in_features) + sensitivity_analysis.plot_sensitivity_curve(sensi, output_neuron=1) + + +def test_analyse_temporal_sensitivity(): + n_features_Y = 3 + n_state_neurons = 5 + batchsize = 2 + past_horizon = 4 + forecast_horizon = 3 + task_nodes = [0, 1] + + model = HCNN( + n_features_Y=n_features_Y, + n_state_neurons=n_state_neurons, + past_horizon=past_horizon, + forecast_horizon=forecast_horizon, + ) + data = torch.randn(past_horizon, batchsize, n_features_Y) + + sensi = sensitivity_analysis.analyse_temporal_sensitivity( + model, + data=data, + task_nodes=task_nodes, + n_future_steps=forecast_horizon, + past_horizon=past_horizon, + n_features=input_size, + ) + assert isinstance(sensi, torch.Tensor) + assert sensi.shape == torch.Size((len(task_nodes), forecast_horizon, n_features_Y)) + + +def test_plot_analyse_temporal_sensitivity(): + n_features_Y = 3 + n_target_vars = 2 + forecast_horizon = 3 + target_var = [f"target_var_{i}" for i in range(n_target_vars)] + features = [f"feat_{i}" for i in range(n_features_Y)] + + sensis = torch.randn(len(target_var), forecast_horizon, n_features_Y) + sensitivity_analysis.plot_analyse_temporal_sensitivity( + sensis, + target_var, + features, + n_future_steps=forecast_horizon, + ) From 75330ea79f003e5d0d31d3f2a635914b6b7da2f0 Mon Sep 17 00:00:00 2001 From: Julia Schemm Date: Fri, 16 Aug 2024 14:12:41 +0200 Subject: [PATCH 2/3] fix parameter --- tests/utils/test_sensitivity_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_sensitivity_analysis.py b/tests/utils/test_sensitivity_analysis.py index fc5561b..356935c 100644 --- a/tests/utils/test_sensitivity_analysis.py +++ b/tests/utils/test_sensitivity_analysis.py @@ -61,7 +61,7 @@ def test_analyse_temporal_sensitivity(): task_nodes=task_nodes, n_future_steps=forecast_horizon, past_horizon=past_horizon, - n_features=input_size, + n_features=n_features_Y, ) assert isinstance(sensi, torch.Tensor) assert sensi.shape == torch.Size((len(task_nodes), forecast_horizon, n_features_Y)) From 84c6115cbbf04fa72e9bc3c9e41d1de311c0185d Mon Sep 17 00:00:00 2001 From: Julia Schemm Date: Fri, 16 Aug 2024 14:18:59 +0200 Subject: [PATCH 3/3] add missing plot function --- prosper_nn/utils/sensitivity_analysis.py | 103 +++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/prosper_nn/utils/sensitivity_analysis.py b/prosper_nn/utils/sensitivity_analysis.py index b3e71b4..9e9211f 100644 --- a/prosper_nn/utils/sensitivity_analysis.py +++ b/prosper_nn/utils/sensitivity_analysis.py @@ -254,6 +254,109 @@ def analyse_temporal_sensitivity( return torch.stack(total_heat) +def plot_analyse_temporal_sensitivity( + sensis: torch.Tensor, + target_var: List[str], + features: List[str], + n_future_steps: int, + path: Optional[str] = None, + title: Optional[Union[dict, str]] = None, + xticks: Optional[Union[dict, str]] = None, + yticks: Optional[Union[dict, str]] = None, + xlabel: Optional[Union[dict, str]] = None, + ylabel: Optional[Union[dict, str]] = None, + figsize: List[float] = [12.4, 5.8], +) -> None: + """ + Plots a sensitivity analysis and creates a table with monotonie and total heat on the right side + for each task variable. + """ + # Calculate total heat and monotony + total_heat = torch.sum(torch.abs(sensis), dim=2) + total_heat = (total_heat * 100).round() / 100 + monotonie = torch.sum(sensis, dim=2) / total_heat + monotonie = (monotonie * 100).round() / 100 + + plt.rcParams["figure.figsize"] = figsize + ### Temporal Sensitivity Heatmap ### + # plot a sensitivity matrix for every feature/target variable to be investigated + for i, node in enumerate(target_var): + # Set description + if not title: + title = "Influence of auxiliary variables on {}" + if not xlabel: + xlabel = "Weeks into future" + if not ylabel: + ylabel = "Auxiliary variables" + if not xticks: + xticks = { + "ticks": range(1, n_future_steps + 1), + "labels": [ + str(i) if i % 2 == 1 else None for i in range(1, n_future_steps + 1) + ], + "horizontalalignment": "right", + } + if not yticks: + yticks = { + "ticks": range(len(features)), + "labels": [feature.replace("_", " ") for feature in features], + "rotation": 0, + "va": "top", + "size": "large", + } + + sns.heatmap(sensis[i], + center=0, + cmap='coolwarm', + robust=True, + cbar_kws={'location':'right', 'pad': 0.22}, + ) + plt.ylabel(ylabel) + plt.xlabel(xlabel) + plt.xticks(**xticks) + plt.yticks(**yticks), + plt.title(title.format(node.replace("_", " ")), pad=25) + + # Fade out row name if total heat is not that strong + for j, ticklabel in enumerate(plt.gca().get_yticklabels()): + if j >= len(target_var): + alpha = float(0.5 + (total_heat[i][j] / torch.max(total_heat)) / 2) + ticklabel.set_color(color=[0, 0, 0, alpha]) + else: + ticklabel.set_color(color="C0") + plt.tight_layout() + + ### Table with total heat and monotonie ### + table_values = torch.stack((total_heat[i], monotonie[i])).T + + # Colour of cells + cell_colours = [ + ["#E1E3E3" for _ in range(table_values.shape[1])] + for _ in range(table_values.shape[0]) + ] + cell_colours[torch.argmax(table_values, dim=0)[0]][0] = "#179C7D" + cell_colours[torch.argmax(torch.abs(table_values), dim=0)[1]][1] = "#179C7D" + + # Plot table + plt.table( + table_values.numpy(), + loc='right', + colLabels=['Absolute', 'Monotony'], + colWidths=[0.2,0.2], + bbox=[1, 0, 0.3, 1.042], #[1, 0, 0.4, 1.042], + cellColours=cell_colours, + edges='BRT', + ) + plt.subplots_adjust(left=0.05, right=1.0) # creates space for table + + # Save and close + if path: + plt.savefig( + path + "sensi_analysis_{}.png".format(node), bbox_inches="tight" + ) + else: + plt.show() + plt.close() # %% Sensitivity for feed-forward models and other not-recurrent models