diff --git a/docs/user_guide/1_module/3-benchmarking/3.0-benchmarking.ipynb b/docs/user_guide/1_module/3-benchmarking/3.0-benchmarking.ipynb
index 19aa2d8..bebc5d9 100644
--- a/docs/user_guide/1_module/3-benchmarking/3.0-benchmarking.ipynb
+++ b/docs/user_guide/1_module/3-benchmarking/3.0-benchmarking.ipynb
@@ -11,7 +11,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 1,
"id": "27e27d12-d338-47d3-8fd3-b31ff80ac57c",
"metadata": {
"ExecuteTime": {
@@ -46,7 +46,7 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 2,
"id": "1177e73c",
"metadata": {},
"outputs": [],
@@ -76,7 +76,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": null,
"id": "71c4bc47-8479-46f1-a4c4-7ad5478020f0",
"metadata": {
"ExecuteTime": {
@@ -108,7 +108,7 @@
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 4,
"id": "3bd6e05b-d86a-484a-bd28-7559b640c20f",
"metadata": {
"ExecuteTime": {
@@ -118,12 +118,13 @@
},
"outputs": [],
"source": [
- "import darts\n",
+ "import darts.metrics\n",
"\n",
"metrics = [\n",
- " BenchmarkMetric(name=\"RMSE\", metric_function=darts.metrics.metrics.coefficient_of_variation),\n",
+ " BenchmarkMetric(name=\"COV\", metric_function=darts.metrics.metrics.coefficient_of_variation),\n",
" BenchmarkMetric(name=\"MAE\", metric_function=darts.metrics.metrics.mae),\n",
" BenchmarkMetric(name=\"sMAPE\", metric_function=darts.metrics.metrics.smape),\n",
+ " BenchmarkMetric(name=\"MASE\", metric_function=darts.metrics.metrics.mase)\n",
"]"
]
},
@@ -137,7 +138,7 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 5,
"id": "b99abb03-c3ee-4e19-87a7-c86e1b418a6b",
"metadata": {
"ExecuteTime": {
@@ -162,7 +163,7 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 6,
"id": "005d3bae-1fdb-4edd-91ba-0df918fae5fe",
"metadata": {
"ExecuteTime": {
@@ -195,7 +196,7 @@
},
"outputs": [],
"source": [
- "benchmark.run(verbose=True, debug=False)"
+ "benchmark.run(verbose=True, debug=True)"
]
},
{
@@ -219,7 +220,7 @@
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 19,
"id": "35b9b732-bd80-426f-a8bb-134f06bf1a3b",
"metadata": {
"ExecuteTime": {
@@ -243,8 +244,8 @@
"validation set size: 585\n",
"training time: 0\n",
"test set size: 732\n",
- "testing time: 0.4616200923919678\n",
- "metrics: {'RMSE': 26.42522074783427, 'MAE': 2.400040758215776, 'sMAPE': 24.148347520589876}\n",
+ "testing time: 0.5555312633514404\n",
+ "metrics: {'COV': 29.264706049554235, 'MAE': 2.400040758215776, 'sMAPE': 24.148347520589876, 'MASE': 1.259049717016714}\n",
"Dataset ETTh1:\n",
"couldn't complete training on ETTh1\n",
"\n",
@@ -259,8 +260,8 @@
"validation set size: 585\n",
"training time: 0\n",
"test set size: 732\n",
- "testing time: 0.4651451110839844\n",
- "metrics: {'RMSE': 26.42522074783427, 'MAE': 2.400040758215776, 'sMAPE': 24.148347520589876}\n",
+ "testing time: 0.5387301445007324\n",
+ "metrics: {'COV': 29.264706049554235, 'MAE': 2.400040758215776, 'sMAPE': 24.148347520589876, 'MASE': 1.259049717016714}\n",
"Dataset ETTh1:\n",
"nb features: 7\n",
"target column: ['HUFL', 'HULL', 'MUFL', 'MULL', 'LUFL', 'LULL', 'OT']\n",
@@ -268,8 +269,8 @@
"validation set size: 2788\n",
"training time: 0\n",
"test set size: 3485\n",
- "testing time: 36.36065649986267\n",
- "metrics: {'RMSE': 38.39537113248332, 'MAE': 3.018217521048022, 'sMAPE': 60.098702109350356}\n",
+ "testing time: 30.63499927520752\n",
+ "metrics: {'COV': 124.74141924436206, 'MAE': 3.018217521048022, 'sMAPE': 60.098702109350356, 'MASE': 2.7583032695689242}\n",
"\n",
"\n",
"Model Temporal Convolutional Network:\n",
@@ -280,19 +281,19 @@
"target column: ['Daily minimum temperatures']\n",
"training set size: 2335\n",
"validation set size: 585\n",
- "training time: 8.28007984161377\n",
+ "training time: 7.1029579639434814\n",
"test set size: 732\n",
- "testing time: 2.099846839904785\n",
- "metrics: {'RMSE': 30.594475602620527, 'MAE': 2.3070876834375156, 'sMAPE': 23.384630950870676}\n",
+ "testing time: 1.9020538330078125\n",
+ "metrics: {'COV': 27.426134138893044, 'MAE': 2.3070876834375156, 'sMAPE': 23.384630950870676, 'MASE': 1.1896365209888038}\n",
"Dataset ETTh1:\n",
"nb features: 7\n",
"target column: ['HUFL', 'HULL', 'MUFL', 'MULL', 'LUFL', 'LULL', 'OT']\n",
"training set size: 11147\n",
"validation set size: 2788\n",
- "training time: 35.88447618484497\n",
+ "training time: 34.660467863082886\n",
"test set size: 3485\n",
- "testing time: 10.182210922241211\n",
- "metrics: {'RMSE': 209.61898756609506, 'MAE': 2.959527545350041, 'sMAPE': 74.96327944852848}\n",
+ "testing time: 11.778612852096558\n",
+ "metrics: {'COV': 135.06732685865998, 'MAE': 2.959527545350041, 'sMAPE': 74.96327944852848, 'MASE': 3.811851652408341}\n",
"\n"
]
}
@@ -311,7 +312,7 @@
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": 21,
"id": "687a0b79-ff02-4bde-aff7-6d6ccd5e7d40",
"metadata": {
"ExecuteTime": {
@@ -386,7 +387,7 @@
"supports multivariate ✓ "
]
},
- "execution_count": 33,
+ "execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
@@ -398,7 +399,7 @@
},
{
"cell_type": "code",
- "execution_count": 34,
+ "execution_count": 22,
"id": "a113a1a8-6988-4440-9afc-c4e9b5e6f898",
"metadata": {},
"outputs": [
@@ -438,23 +439,23 @@
" \n",
"
\n",
" \n",
- " Daily temperature | \n",
+ " Daily temperature | \n",
" training time | \n",
" 0.000000 | \n",
" 0.000000 | \n",
- " 8.280080 | \n",
+ " 7.102958 | \n",
"
\n",
" \n",
" testing time | \n",
- " 0.461620 | \n",
- " 0.465145 | \n",
- " 2.099847 | \n",
+ " 0.555531 | \n",
+ " 0.538730 | \n",
+ " 1.902054 | \n",
"
\n",
" \n",
- " RMSE | \n",
- " 26.425221 | \n",
- " 26.425221 | \n",
- " 30.594476 | \n",
+ " COV | \n",
+ " 29.264706 | \n",
+ " 29.264706 | \n",
+ " 27.426134 | \n",
"
\n",
" \n",
" MAE | \n",
@@ -469,23 +470,29 @@
" 23.384631 | \n",
"
\n",
" \n",
- " ETTh1 | \n",
+ " MASE | \n",
+ " 1.259050 | \n",
+ " 1.259050 | \n",
+ " 1.189637 | \n",
+ "
\n",
+ " \n",
+ " ETTh1 | \n",
" training time | \n",
" NaN | \n",
" 0.000000 | \n",
- " 35.884476 | \n",
+ " 34.660468 | \n",
"
\n",
" \n",
" testing time | \n",
" NaN | \n",
- " 36.360656 | \n",
- " 10.182211 | \n",
+ " 30.634999 | \n",
+ " 11.778613 | \n",
"
\n",
" \n",
- " RMSE | \n",
+ " COV | \n",
" NaN | \n",
- " 38.395371 | \n",
- " 209.618988 | \n",
+ " 124.741419 | \n",
+ " 135.067327 | \n",
"
\n",
" \n",
" MAE | \n",
@@ -499,6 +506,12 @@
" 60.098702 | \n",
" 74.963279 | \n",
"
\n",
+ " \n",
+ " MASE | \n",
+ " NaN | \n",
+ " 2.758303 | \n",
+ " 3.811852 | \n",
+ "
\n",
" \n",
"\n",
""
@@ -507,44 +520,50 @@
" ExponentialSmoothingUnivariate \\\n",
"Dataset Metric \n",
"Daily temperature training time 0.000000 \n",
- " testing time 0.461620 \n",
- " RMSE 26.425221 \n",
+ " testing time 0.555531 \n",
+ " COV 29.264706 \n",
" MAE 2.400041 \n",
" sMAPE 24.148348 \n",
+ " MASE 1.259050 \n",
"ETTh1 training time NaN \n",
" testing time NaN \n",
- " RMSE NaN \n",
+ " COV NaN \n",
" MAE NaN \n",
" sMAPE NaN \n",
+ " MASE NaN \n",
"\n",
" ExponentialSmoothingMultivariate \\\n",
"Dataset Metric \n",
"Daily temperature training time 0.000000 \n",
- " testing time 0.465145 \n",
- " RMSE 26.425221 \n",
+ " testing time 0.538730 \n",
+ " COV 29.264706 \n",
" MAE 2.400041 \n",
" sMAPE 24.148348 \n",
+ " MASE 1.259050 \n",
"ETTh1 training time 0.000000 \n",
- " testing time 36.360656 \n",
- " RMSE 38.395371 \n",
+ " testing time 30.634999 \n",
+ " COV 124.741419 \n",
" MAE 3.018218 \n",
" sMAPE 60.098702 \n",
+ " MASE 2.758303 \n",
"\n",
" Temporal Convolutional Network \n",
"Dataset Metric \n",
- "Daily temperature training time 8.280080 \n",
- " testing time 2.099847 \n",
- " RMSE 30.594476 \n",
+ "Daily temperature training time 7.102958 \n",
+ " testing time 1.902054 \n",
+ " COV 27.426134 \n",
" MAE 2.307088 \n",
" sMAPE 23.384631 \n",
- "ETTh1 training time 35.884476 \n",
- " testing time 10.182211 \n",
- " RMSE 209.618988 \n",
+ " MASE 1.189637 \n",
+ "ETTh1 training time 34.660468 \n",
+ " testing time 11.778613 \n",
+ " COV 135.067327 \n",
" MAE 2.959528 \n",
- " sMAPE 74.963279 "
+ " sMAPE 74.963279 \n",
+ " MASE 3.811852 "
]
},
- "execution_count": 34,
+ "execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
@@ -565,7 +584,7 @@
},
{
"cell_type": "code",
- "execution_count": 35,
+ "execution_count": 11,
"id": "622c8342",
"metadata": {},
"outputs": [],
@@ -575,7 +594,7 @@
},
{
"cell_type": "code",
- "execution_count": 36,
+ "execution_count": 12,
"id": "6c87c037",
"metadata": {},
"outputs": [],
@@ -589,7 +608,7 @@
},
{
"cell_type": "code",
- "execution_count": 37,
+ "execution_count": 13,
"id": "9e46a81e",
"metadata": {},
"outputs": [],
@@ -601,7 +620,7 @@
},
{
"cell_type": "code",
- "execution_count": 38,
+ "execution_count": 14,
"id": "da7935d5",
"metadata": {},
"outputs": [
@@ -610,23 +629,23 @@
"text/html": [
"\n",
"\n",
- "\n",
+ "\n",
""
],
"text/plain": [
"alt.LayerChart(...)"
]
},
- "execution_count": 38,
+ "execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@@ -696,7 +715,7 @@
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": 15,
"id": "ffecb8fb",
"metadata": {},
"outputs": [],
@@ -706,7 +725,7 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": 16,
"id": "e6c1b476",
"metadata": {},
"outputs": [
@@ -715,23 +734,23 @@
"text/html": [
"\n",
"\n",
- "\n",
+ "\n",
""
],
"text/plain": [
"alt.LayerChart(...)"
]
},
- "execution_count": 40,
+ "execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -801,7 +820,7 @@
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": 17,
"id": "4bc66495",
"metadata": {},
"outputs": [],
@@ -813,7 +832,7 @@
},
{
"cell_type": "code",
- "execution_count": 42,
+ "execution_count": 18,
"id": "6312ee42",
"metadata": {},
"outputs": [
@@ -822,23 +841,23 @@
"text/html": [
"\n",
"\n",
- "\n",
+ "\n",
""
],
"text/plain": [
"alt.LayerChart(...)"
]
},
- "execution_count": 42,
+ "execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
diff --git a/src/ontime/module/benchmarking/benchmark_metric.py b/src/ontime/module/benchmarking/benchmark_metric.py
index d753d67..68f2209 100644
--- a/src/ontime/module/benchmarking/benchmark_metric.py
+++ b/src/ontime/module/benchmarking/benchmark_metric.py
@@ -7,8 +7,8 @@ def __init__(self, name: str, metric_function, reduction=None):
self.name = name
self.reduction = reduction
- def compute(self, target: TimeSeries, pred: TimeSeries):
+ def compute(self, target: TimeSeries, pred: TimeSeries, **kwargs):
"""
Compute the metric on the target and predicted time series.
"""
- return self.metric(target, pred, component_reduction=self.reduction)
+ return self.metric(target, pred, component_reduction=self.reduction, **kwargs)
diff --git a/src/ontime/module/benchmarking/benchmark_model.py b/src/ontime/module/benchmarking/benchmark_model.py
index 47a5abf..1f26637 100644
--- a/src/ontime/module/benchmarking/benchmark_model.py
+++ b/src/ontime/module/benchmarking/benchmark_model.py
@@ -1,8 +1,11 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from typing import List
+from typing import List, Union
from enum import Enum
+from numpy import ndarray
+from darts.metrics import mase
+
from ontime.core.time_series.time_series import TimeSeries
from ontime.module.benchmarking import BenchmarkMetric
@@ -61,3 +64,14 @@ def get_benchmark_mode(self) -> BenchmarkMode:
Return the benchmark mode of the model.
"""
pass
+
+ def _compute_metric(self, forecast: TimeSeries, label: TimeSeries, metric: BenchmarkMetric, **kwargs) -> Union[float, List[float], ndarray, List[ndarray]]:
+ """
+ Helper method to compute metric on given forecast and label TimeSeries. This method also handles any specific
+ conditions required by specific metrics.
+ """
+ if metric.metric == mase:
+ return metric.compute(label, forecast, insample=kwargs["input"])
+ else:
+ return metric.compute(label, forecast)
+
diff --git a/src/ontime/module/benchmarking/darts_models/global_model.py b/src/ontime/module/benchmarking/darts_models/global_model.py
index 9cbe77e..8b409f1 100644
--- a/src/ontime/module/benchmarking/darts_models/global_model.py
+++ b/src/ontime/module/benchmarking/darts_models/global_model.py
@@ -69,7 +69,7 @@ def evaluate(
for input, label in zip(dataset["input"], dataset["label"]):
forecast = self.predict(input, horizon)
for metric in metrics:
- metrics_values[metric.name].append(metric.compute(forecast, label))
+ metrics_values[metric.name].append(self._compute_metric(forecast, label, metric, input=input))
return {metric: np.mean(values) for metric, values in metrics_values.items()}
def load_checkpoint(self, path: str) -> GlobalDartsBenchmarkModel:
diff --git a/src/ontime/module/benchmarking/darts_models/local_model.py b/src/ontime/module/benchmarking/darts_models/local_model.py
index cafaebd..d5059f5 100644
--- a/src/ontime/module/benchmarking/darts_models/local_model.py
+++ b/src/ontime/module/benchmarking/darts_models/local_model.py
@@ -71,7 +71,7 @@ def evaluate(
for input, label in zip(dataset["input"], dataset["label"]):
forecast = self._fit_predict(input, horizon)
for metric in metrics:
- metrics_values[metric.name].append(metric.compute(forecast, label))
+ metrics_values[metric.name].append(self._compute_metric(forecast, label, metric, input=input))
return {metric: np.mean(values) for metric, values in metrics_values.items()}
def _fit_predict(