Skip to content

Commit

Permalink
Merge pull request #70 from ontime-re/65-find-a-correct-way-to-evalua…
Browse files Browse the repository at this point in the history
…te-multivariate-forecasting-eg-specific-metrics-error-normalization

65 find a correct way to evaluate multivariate forecasting eg specific metrics error normalization
  • Loading branch information
ben-jy authored Nov 6, 2024
2 parents ec971dd + c9e3cd4 commit 33fbd65
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 93 deletions.
195 changes: 107 additions & 88 deletions docs/user_guide/1_module/3-benchmarking/3.0-benchmarking.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/ontime/module/benchmarking/benchmark_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ def __init__(self, name: str, metric_function, reduction=None):
self.name = name
self.reduction = reduction

def compute(self, target: TimeSeries, pred: TimeSeries):
def compute(self, target: TimeSeries, pred: TimeSeries, **kwargs):
"""
Compute the metric on the target and predicted time series.
"""
return self.metric(target, pred, component_reduction=self.reduction)
return self.metric(target, pred, component_reduction=self.reduction, **kwargs)
16 changes: 15 additions & 1 deletion src/ontime/module/benchmarking/benchmark_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List
from typing import List, Union
from enum import Enum

from numpy import ndarray
from darts.metrics import mase

from ontime.core.time_series.time_series import TimeSeries
from ontime.module.benchmarking import BenchmarkMetric

Expand Down Expand Up @@ -61,3 +64,14 @@ def get_benchmark_mode(self) -> BenchmarkMode:
Return the benchmark mode of the model.
"""
pass

def _compute_metric(self, forecast: TimeSeries, label: TimeSeries, metric: BenchmarkMetric, **kwargs) -> Union[float, List[float], ndarray, List[ndarray]]:
"""
Helper method to compute metric on given forecast and label TimeSeries. This method also handles any specific
conditions required by specific metrics.
"""
if metric.metric == mase:
return metric.compute(label, forecast, insample=kwargs["input"])
else:
return metric.compute(label, forecast)

Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def evaluate(
for input, label in zip(dataset["input"], dataset["label"]):
forecast = self.predict(input, horizon)
for metric in metrics:
metrics_values[metric.name].append(metric.compute(forecast, label))
metrics_values[metric.name].append(self._compute_metric(forecast, label, metric, input=input))
return {metric: np.mean(values) for metric, values in metrics_values.items()}

def load_checkpoint(self, path: str) -> GlobalDartsBenchmarkModel:
Expand Down
2 changes: 1 addition & 1 deletion src/ontime/module/benchmarking/darts_models/local_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def evaluate(
for input, label in zip(dataset["input"], dataset["label"]):
forecast = self._fit_predict(input, horizon)
for metric in metrics:
metrics_values[metric.name].append(metric.compute(forecast, label))
metrics_values[metric.name].append(self._compute_metric(forecast, label, metric, input=input))
return {metric: np.mean(values) for metric, values in metrics_values.items()}

def _fit_predict(
Expand Down

0 comments on commit 33fbd65

Please sign in to comment.