Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Chronos ⚡ models #204

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

## 🚀 News

- **27 June 2024**: 🚀 [Released datasets](https://huggingface.co/datasets/autogluon/chronos_datasets) used in the paper and an [evaluation script](./scripts/README.md#evaluating-chronos-models) to compute the WQL and MASE scores reported in the paper.
- **25 Nov 2024**: 🚀 Chronos⚡️ (read: Chronos-Bolt) models released [on HuggingFace](https://huggingface.co/collections/amazon/chronos-models-65f1791d630a8d57cb718444). Chronos⚡️ models are more accurate (5% lower error) and 100x faster than the original Chronos models!
- **27 Jun 2024**: 🚀 [Released datasets](https://huggingface.co/datasets/autogluon/chronos_datasets) used in the paper and an [evaluation script](./scripts/README.md#evaluating-chronos-models) to compute the WQL and MASE scores reported in the paper.
- **17 May 2024**: 🐛 Fixed an off-by-one error in bin indices in the `output_transform`. This simple fix significantly improves the overall performance of Chronos. We will update the results in the next revision on ArXiv.
- **10 May 2024**: 🚀 We added the code for pretraining and fine-tuning Chronos models. You can find it in [this folder](./scripts/training). We also added [a script](./scripts/kernel-synth.py) for generating synthetic time series data from Gaussian processes (KernelSynth; see Section 4.2 in the paper for details). Check out the [usage examples](./scripts/).
- **19 Apr 2024**: 🚀 Chronos is now supported on [AutoGluon-TimeSeries](https://auto.gluon.ai/stable/tutorials/timeseries/index.html), the powerful AutoML package for time series forecasting which enables model ensembles, cloud deployments, and much more. Get started with the [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html).
Expand Down Expand Up @@ -52,62 +53,72 @@ The models in this repository are based on the [T5 architecture](https://arxiv.o
| [**chronos-t5-small**](https://huggingface.co/amazon/chronos-t5-small) | 46M | [t5-efficient-small](https://huggingface.co/google/t5-efficient-small) |
| [**chronos-t5-base**](https://huggingface.co/amazon/chronos-t5-base) | 200M | [t5-efficient-base](https://huggingface.co/google/t5-efficient-base) |
| [**chronos-t5-large**](https://huggingface.co/amazon/chronos-t5-large) | 710M | [t5-efficient-large](https://huggingface.co/google/t5-efficient-large) |
| [**chronos-bolt-tiny**](https://huggingface.co/amazon/chronos-bolt-tiny) | 9M | [t5-efficient-tiny](https://huggingface.co/google/t5-efficient-tiny) |
| [**chronos-bolt-mini**](https://huggingface.co/amazon/chronos-bolt-mini) | 21M | [t5-efficient-mini](https://huggingface.co/google/t5-efficient-mini) |
| [**chronos-bolt-small**](https://huggingface.co/amazon/chronos-bolt-small) | 48M | [t5-efficient-small](https://huggingface.co/google/t5-efficient-small) |
| [**chronos-bolt-base**](https://huggingface.co/amazon/chronos-bolt-base) | 205M | [t5-efficient-base](https://huggingface.co/google/t5-efficient-base) |

</div>

### Zero-Shot Results

The following figure showcases the remarkable **zero-shot** performance of Chronos models on 27 datasets against local models, task-specific models and other pretrained models. For details on the evaluation setup and other results, please refer to [the paper](https://arxiv.org/abs/2403.07815).
The following figure showcases the remarkable **zero-shot** performance of Chronos and Chronos⚡️ models on 27 datasets against local models, task-specific models and other pretrained models. For details on the evaluation setup and other results, please refer to [the paper](https://arxiv.org/abs/2403.07815).

<p align="center">
<img src="figures/zero_shot-agg_scaled_score.png" width="80%">
<br />
<span>
Fig. 2: Performance of different models on Benchmark II, comprising 27 datasets <b>not seen</b> by Chronos models during training. This benchmark provides insights into the zero-shot performance of Chronos models against local statistical models, which fit parameters individually for each time series, task-specific models <i>trained on each task</i>, and pretrained models trained on a large corpus of time series. Pretrained Models (Other) indicates that some (or all) of the datasets in Benchmark II may have been in the training corpus of these models. The probabilistic (WQL) and point (MASE) forecasting metrics were normalized using the scores of the Seasonal Naive baseline and aggregated through a geometric mean to obtain the Agg. Relative WQL and MASE, respectively.
Fig. 2: Performance of different models on Benchmark II, comprising 27 datasets <b>not seen</b> by Chronos and Chronos⚡️ models during training. This benchmark provides insights into the zero-shot performance of Chronos and Chronos⚡️ models against local statistical models, which fit parameters individually for each time series, task-specific models <i>trained on each task</i>, and pretrained models trained on a large corpus of time series. Pretrained Models (Other) indicates that some (or all) of the datasets in Benchmark II may have been in the training corpus of these models. The probabilistic (WQL) and point (MASE) forecasting metrics were normalized using the scores of the Seasonal Naive baseline and aggregated through a geometric mean to obtain the Agg. Relative WQL and MASE, respectively.
</span>
</p>

## 📈 Usage

To perform inference with Chronos models, install this package by running:
To perform inference with Chronos or Chronos⚡️ models, install this package by running:

```
pip install git+https://github.com/amazon-science/chronos-forecasting.git
```
> [!TIP]
> The recommended way of using Chronos for production use cases is through [AutoGluon](https://auto.gluon.ai), which features ensembling with other statistical and machine learning models for time series forecasting as well as seamless deployments on AWS with SageMaker 🧠. Check out the AutoGluon Chronos [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html).
> This repository is intended for research purposes and provides a minimal interface to Chronos models. The recommended way of using Chronos for production use cases is through [AutoGluon](https://auto.gluon.ai), which features effortless fine-tuning, ensembling with other statistical and machine learning models for time series forecasting as well as seamless deployments on AWS with SageMaker 🧠. Check out the AutoGluon Chronos [tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html).

### Forecasting

A minimal example showing how to perform forecasting using Chronos models:
A minimal example showing how to perform forecasting using Chronos and Chronos⚡️ models:

```python
import pandas as pd # requires: pip install pandas
import torch
from chronos import ChronosPipeline
from chronos import BaseChronosPipeline

pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
pipeline = BaseChronosPipeline.from_pretrained(
"amazon/chronos-t5-small", # use "amazon/chronos-bolt-small" for the corresponding Chronos⚡️ model
device_map="cuda", # use "cpu" for CPU inference and "mps" for Apple Silicon
torch_dtype=torch.bfloat16,
)

df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")
df = pd.read_csv(
"https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
)

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
# forecast shape: [num_series, num_samples, prediction_length]
# The original Chronos models generate forecast samples, so forecast has shape
# [num_series, num_samples, prediction_length].
# Chronos⚡️ models generate quantile forecasts, so forecast has shape
# [num_series, num_quantiles, prediction_length].
forecast = pipeline.predict(
context=torch.tensor(df["#Passengers"]),
prediction_length=12,
num_samples=20,
context=torch.tensor(df["#Passengers"]), prediction_length=12
)
```

More options for `pipeline.predict` can be found with:

```python
print(ChronosPipeline.predict.__doc__)
from chronos import ChronosPipeline, ChronosBoltPipeline

print(ChronosPipeline.predict.__doc__) # for Chronos models
print(ChronosBoltPipeline.predict.__doc__) # for Chronos⚡️ models
```

We can now visualize the forecast:
Expand Down
Binary file modified figures/zero_shot-agg_scaled_score.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
[project]
name = "chronos"
version = "1.2.1"
version = "1.3.0"
requires-python = ">=3.8"
license = { file = "LICENSE" }
dependencies = [
"torch~=2.0", # package was tested on 2.2
"transformers~=4.30",
"accelerate",
"torch>=2.0,<2.6", # package was tested on 2.2
"transformers>=4.30,<4.48",
"accelerate>=0.32,<1",
]

[project.optional-dependencies]
test = ["pytest~=8.0", "numpy~=1.21"]
typecheck = ["mypy~=1.9"]
training = ["gluonts[pro]", "numpy", "tensorboard", "typer", "typer-config", "joblib", "scikit-learn"]
evaluation = ["gluonts[pro]", "datasets", "numpy", "typer"]
training = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer", "typer-config", "joblib", "scikit-learn", "tensorboard"]
evaluation = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer"]
abdulfatir marked this conversation as resolved.
Show resolved Hide resolved

[tool.mypy]
ignore_missing_imports = true
60 changes: 60 additions & 0 deletions scripts/evaluation/agg-relative-score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pandas as pd
import typer
from scipy.stats import gmean
from pathlib import Path

app = typer.Typer(pretty_exceptions_enable=False)
DEFAULT_RESULTS_DIR = Path(__file__).parent / "results"


def agg_relative_score(model_csv: Path, baseline_csv: Path):
model_df = pd.read_csv(model_csv).set_index("dataset")
baseline_df = pd.read_csv(baseline_csv).set_index("dataset")
relative_score = model_df.drop("model", axis="columns") / baseline_df.drop(
"model", axis="columns"
)
return relative_score.agg(gmean)


@app.command()
def main(
model_name: str,
baseline_name: str = "seasonal-naive",
results_dir: Path = DEFAULT_RESULTS_DIR,
):
"""
Compute the aggregated relative score as reported in the Chronos paper.
Results will be saved to {results_dir}/{model_name}-agg-rel-scores.csv

Parameters
----------
model_name : str
Name of the model used in the CSV files. The in-domain and zero-shot CSVs
are expected to be named {model_name}-in-domain.csv and {model_name}-zero-shot.csv.
results_dir : Path, optional, default = results/
Directory where results CSVs generated by evaluate.py are stored
"""

in_domain_agg_score_df = agg_relative_score(
results_dir / f"{model_name}-in-domain.csv",
results_dir / f"{baseline_name}-in-domain.csv",
)
in_domain_agg_score_df.name = "value"
in_domain_agg_score_df.index.name = "metric"

zero_shot_agg_score_df = agg_relative_score(
results_dir / f"{model_name}-zero-shot.csv",
results_dir / f"{baseline_name}-zero-shot.csv",
)
zero_shot_agg_score_df.name = "value"
zero_shot_agg_score_df.index.name = "metric"

agg_score_df = pd.concat(
{"in-domain": in_domain_agg_score_df, "zero-shot": zero_shot_agg_score_df},
names=["benchmark"],
)
agg_score_df.to_csv(f"{results_dir}/{model_name}-agg-rel-scores.csv")


if __name__ == "__main__":
app()
106 changes: 82 additions & 24 deletions scripts/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.itertools import batcher
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import SampleForecast
from gluonts.model.forecast import QuantileForecast, SampleForecast
from tqdm.auto import tqdm

from chronos import ChronosPipeline
from chronos import (
BaseChronosPipeline,
ChronosBoltPipeline,
ChronosPipeline,
abdulfatir marked this conversation as resolved.
Show resolved Hide resolved
ForecastType,
)

app = typer.Typer(pretty_exceptions_enable=False)

Expand Down Expand Up @@ -228,37 +233,45 @@ def load_and_split_dataset(backtest_config: dict):
return test_data


def generate_sample_forecasts(
def generate_forecasts(
test_data_input: Iterable,
pipeline: ChronosPipeline,
pipeline: BaseChronosPipeline,
prediction_length: int,
batch_size: int,
num_samples: int,
**predict_kwargs,
):
# Generate forecast samples
forecast_samples = []
# Generate forecasts
forecast_outputs = []
for batch in tqdm(batcher(test_data_input, batch_size=batch_size)):
context = [torch.tensor(entry["target"]) for entry in batch]
forecast_samples.append(
forecast_outputs.append(
pipeline.predict(
context,
prediction_length=prediction_length,
num_samples=num_samples,
**predict_kwargs,
).numpy()
)
forecast_samples = np.concatenate(forecast_samples)
forecast_outputs = np.concatenate(forecast_outputs)

# Convert forecast samples into gluonts SampleForecast objects
sample_forecasts = []
for item, ts in zip(forecast_samples, test_data_input):
# Convert forecast samples into gluonts Forecast objects
forecasts = []
for item, ts in zip(forecast_outputs, test_data_input):
forecast_start_date = ts["start"] + len(ts["target"])
sample_forecasts.append(
SampleForecast(samples=item, start_date=forecast_start_date)
)

return sample_forecasts
if pipeline.forecast_type == ForecastType.SAMPLES:
forecasts.append(
SampleForecast(samples=item, start_date=forecast_start_date)
)
elif pipeline.forecast_type == ForecastType.QUANTILES:
forecasts.append(
QuantileForecast(
forecast_arrays=item,
forecast_keys=list(map(str, pipeline.quantiles)),
start_date=forecast_start_date,
)
)

return forecasts


@app.command()
Expand All @@ -274,17 +287,65 @@ def main(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
):
"""Evaluate Chronos models.

Parameters
----------
config_path : Path
Path to the evaluation config. See ./configs/.
metrics_path : Path
Path to the CSV file where metrics will be saved.
chronos_model_id : str, optional, default = "amazon/chronos-t5-small"
HuggingFace ID of the Chronos model or local path
Available models on HuggingFace:
Chronos:
- amazon/chronos-t5-tiny
- amazon/chronos-t5-mini
- amazon/chronos-t5-small
- amazon/chronos-t5-base
- amazon/chronos-t5-large
Chronos-Bolt:
- amazon/chronos-bolt-tiny
- amazon/chronos-bolt-mini
- amazon/chronos-bolt-small
- amazon/chronos-bolt-base
device : str, optional, default = "cuda"
Device on which inference will be performed
torch_dtype : str, optional
Model's dtype, by default "bfloat16"
batch_size : int, optional, default = 32
Batch size for inference. For Chronos-Bolt models, significantly larger
batch sizes can be used
num_samples : int, optional, default = 20
Number of samples to draw when using the original Chronos models
temperature : Optional[float], optional, default = 1.0
Softmax temperature to used for the original Chronos models
top_k : Optional[int], optional, default = 50
Top-K sampling, by default None
top_p : Optional[float], optional, default = 1.0
Top-p sampling, by default None
"""
if isinstance(torch_dtype, str):
torch_dtype = getattr(torch, torch_dtype)
assert isinstance(torch_dtype, torch.dtype)

# Load Chronos
pipeline = ChronosPipeline.from_pretrained(
pipeline = BaseChronosPipeline.from_pretrained(
chronos_model_id,
device_map=device,
torch_dtype=torch_dtype,
)

if isinstance(pipeline, ChronosPipeline):
predict_kwargs = dict(
num_samples=num_samples,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
elif isinstance(pipeline, ChronosBoltPipeline):
predict_kwargs = {}

# Load backtest configs
with open(config_path) as fp:
backtest_configs = yaml.safe_load(fp)
Expand All @@ -301,21 +362,18 @@ def main(
f"Generating forecasts for {dataset_name} "
f"({len(test_data.input)} time series)"
)
sample_forecasts = generate_sample_forecasts(
forecasts = generate_forecasts(
test_data.input,
pipeline=pipeline,
prediction_length=prediction_length,
batch_size=batch_size,
num_samples=num_samples,
temperature=temperature,
top_k=top_k,
top_p=top_p,
**predict_kwargs,
)

logger.info(f"Evaluating forecasts for {dataset_name}")
metrics = (
evaluate_forecasts(
sample_forecasts,
forecasts,
test_data=test_data,
metrics=[
MASE(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
benchmark,metric,value
in-domain,MASE,0.6800133628315155
in-domain,WQL,0.5339263811489279
zero-shot,MASE,0.7914551113353537
zero-shot,WQL,0.6241424984163773
16 changes: 16 additions & 0 deletions scripts/evaluation/results/chronos-bolt-base-in-domain.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
dataset,model,MASE,WQL
electricity_15min,autogluon/chronos-bolt-base,0.41069374835605243,0.0703533790998506
m4_daily,autogluon/chronos-bolt-base,3.205192517121196,0.02110308498174413
m4_hourly,autogluon/chronos-bolt-base,0.8350129849014075,0.025353803894164
m4_monthly,autogluon/chronos-bolt-base,0.9491758928362231,0.09382496106659234
m4_weekly,autogluon/chronos-bolt-base,2.0847827409162742,0.03816605075768161
monash_electricity_hourly,autogluon/chronos-bolt-base,1.254966217685461,0.09442192616975713
monash_electricity_weekly,autogluon/chronos-bolt-base,1.8391546050108039,0.06410971963960499
monash_kdd_cup_2018,autogluon/chronos-bolt-base,0.6405985809360102,0.2509172188706336
monash_london_smart_meters,autogluon/chronos-bolt-base,0.701398572604996,0.3218915088923906
monash_pedestrian_counts,autogluon/chronos-bolt-base,0.2646412642278343,0.18789459806066328
monash_rideshare,autogluon/chronos-bolt-base,0.7695376426829713,0.11637119433040358
monash_temperature_rain,autogluon/chronos-bolt-base,0.8983612698773724,0.6050555216496304
taxi_30min,autogluon/chronos-bolt-base,0.7688908266765317,0.2363178601205094
uber_tlc_daily,autogluon/chronos-bolt-base,0.8231767493519677,0.0926036406916842
uber_tlc_hourly,autogluon/chronos-bolt-base,0.6632193728217927,0.14987786887626975
Loading