Skip to content

Commit

Permalink
Add example of using custom pipeline pools in Auto (#504)
Browse files Browse the repository at this point in the history
* Add example of using custom pipeline pools in Auto

* update changelog

* minor fix

---------

Co-authored-by: Egor Baturin <[email protected]>
  • Loading branch information
egoriyaa and Egor Baturin authored Nov 11, 2024
1 parent e65fa1b commit a1647bb
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 142 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
- Add `load_dataset` to public API ([#484](https://github.com/etna-team/etna/pull/484))
-
- Add example of using custom pipeline pools in `Auto` ([#504](https://github.com/etna-team/etna/pull/504))
-
-
-
Expand Down
7 changes: 5 additions & 2 deletions etna/auto/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from etna.auto.optuna import ConfigSampler
from etna.auto.optuna import Optuna
from etna.auto.pool import Pool
from etna.auto.pool import PoolGenerator
from etna.auto.runner import AbstractRunner
from etna.auto.runner import LocalRunner
from etna.auto.utils import config_hash
Expand Down Expand Up @@ -214,7 +215,7 @@ def __init__(
metric_aggregation: MetricAggregationStatistics = "mean",
backtest_params: Optional[dict] = None,
experiment_folder: Optional[str] = None,
pool: Union[Pool, List[BasePipeline]] = Pool.default,
pool: Union[Pool, PoolGenerator, List[BasePipeline]] = Pool.default,
runner: Optional[AbstractRunner] = None,
storage: Optional[BaseStorage] = None,
metrics: Optional[List[Metric]] = None,
Expand Down Expand Up @@ -265,9 +266,11 @@ def __init__(
self._pool_folder = f"{root_folder}pool"

@staticmethod
def _make_pool(pool: Union[Pool, List[BasePipeline]], horizon: int) -> List[BasePipeline]:
def _make_pool(pool: Union[Pool, PoolGenerator, List[BasePipeline]], horizon: int) -> List[BasePipeline]:
if isinstance(pool, Pool):
list_pool: List[BasePipeline] = list(pool.value.generate(horizon=horizon))
elif isinstance(pool, PoolGenerator):
list_pool = list(pool.generate(horizon=horizon))
else:
list_pool = list(pool)

Expand Down
407 changes: 273 additions & 134 deletions examples/205-automl.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ We have prepared a set of tutorials for an easy introduction:
- General AutoML
- How `Auto` works
- Example
- Using custom pipeline pool
- Summary

#### [Clustering](https://github.com/etna-team/etna/tree/master/examples/206-clustering.ipynb)
Expand Down Expand Up @@ -133,6 +134,7 @@ We have prepared a set of tutorials for an easy introduction:
- EmbeddingSegmentTransform
- EmbeddingWindowTransform
- Saving and loading models
- Loading external pretrained models

### Advanced

Expand Down
34 changes: 29 additions & 5 deletions tests/test_auto/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing_extensions import Literal

from etna.auto import Auto
from etna.auto.auto import PoolGenerator
from etna.auto.auto import _Callback
from etna.auto.auto import _Initializer
from etna.metrics import MAE
Expand All @@ -17,6 +18,29 @@
from etna.transforms import LagTransform


@pytest.fixture()
def pool_generator():
pool = [
{
"_target_": "etna.pipeline.Pipeline",
"horizon": "${__aux__.horizon}",
"model": {"_target_": "etna.models.MovingAverageModel", "window": "${mult:${horizon},1}"},
},
{
"_target_": "etna.pipeline.Pipeline",
"horizon": "${__aux__.horizon}",
"model": {"_target_": "etna.models.NaiveModel", "lag": 1},
},
]
pool_generator = PoolGenerator(pool)
return pool_generator


@pytest.fixture()
def pool_list():
return [Pipeline(MovingAverageModel(7), horizon=7), Pipeline(NaiveModel(1), horizon=7)]


def test_objective(
example_tsds,
target_metric=MAE(),
Expand Down Expand Up @@ -118,11 +142,9 @@ def test_init_optuna(
)


def test_fit_without_tuning(
example_tsds,
optuna_storage,
pool=(Pipeline(MovingAverageModel(5), horizon=7), Pipeline(NaiveModel(1), horizon=7)),
):
@pytest.mark.parametrize("pool", ["pool_list", "pool_generator"])
def test_fit_without_tuning_list(example_tsds, optuna_storage, pool, request):
pool = request.getfixturevalue(pool)
auto = Auto(
MAE(),
pool=pool,
Expand All @@ -136,6 +158,8 @@ def test_fit_without_tuning(
assert len(auto.summary()) == 2
assert len(auto.top_k(k=5)) == 2
assert len(auto.top_k(k=1)) == 1
if isinstance(pool, PoolGenerator):
pool = pool.generate(7)
assert auto.top_k(k=1)[0].to_dict() == pool[0].to_dict()


Expand Down

0 comments on commit a1647bb

Please sign in to comment.