Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example of using custom pipeline pools in Auto #504

Merged
merged 4 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
- Add `load_dataset` to public API ([#484](https://github.com/etna-team/etna/pull/484))
-
- Add example of using custom pipeline pools in `Auto` ([#504](https://github.com/etna-team/etna/pull/504))
-
-
-
Expand Down
7 changes: 5 additions & 2 deletions etna/auto/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from etna.auto.optuna import ConfigSampler
from etna.auto.optuna import Optuna
from etna.auto.pool import Pool
from etna.auto.pool import PoolGenerator
from etna.auto.runner import AbstractRunner
from etna.auto.runner import LocalRunner
from etna.auto.utils import config_hash
Expand Down Expand Up @@ -214,7 +215,7 @@
metric_aggregation: MetricAggregationStatistics = "mean",
backtest_params: Optional[dict] = None,
experiment_folder: Optional[str] = None,
pool: Union[Pool, List[BasePipeline]] = Pool.default,
pool: Union[Pool, PoolGenerator, List[BasePipeline]] = Pool.default,
runner: Optional[AbstractRunner] = None,
storage: Optional[BaseStorage] = None,
metrics: Optional[List[Metric]] = None,
Expand Down Expand Up @@ -265,9 +266,11 @@
self._pool_folder = f"{root_folder}pool"

@staticmethod
def _make_pool(pool: Union[Pool, List[BasePipeline]], horizon: int) -> List[BasePipeline]:
def _make_pool(pool: Union[Pool, PoolGenerator, List[BasePipeline]], horizon: int) -> List[BasePipeline]:
if isinstance(pool, Pool):
list_pool: List[BasePipeline] = list(pool.value.generate(horizon=horizon))
elif isinstance(pool, PoolGenerator):
list_pool = list(pool.generate(horizon=horizon))

Check warning on line 273 in etna/auto/auto.py

View check run for this annotation

Codecov / codecov/patch

etna/auto/auto.py#L273

Added line #L273 was not covered by tests
else:
list_pool = list(pool)

Expand Down
407 changes: 273 additions & 134 deletions examples/205-automl.ipynb

Large diffs are not rendered by default.

34 changes: 29 additions & 5 deletions tests/test_auto/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing_extensions import Literal

from etna.auto import Auto
from etna.auto.auto import PoolGenerator
from etna.auto.auto import _Callback
from etna.auto.auto import _Initializer
from etna.metrics import MAE
Expand All @@ -17,6 +18,29 @@
from etna.transforms import LagTransform


@pytest.fixture()
def pool_generator():
pool = [
{
"_target_": "etna.pipeline.Pipeline",
"horizon": "${__aux__.horizon}",
"model": {"_target_": "etna.models.MovingAverageModel", "window": "${mult:${horizon},1}"},
},
{
"_target_": "etna.pipeline.Pipeline",
"horizon": "${__aux__.horizon}",
"model": {"_target_": "etna.models.NaiveModel", "lag": 1},
},
]
pool_generator = PoolGenerator(pool)
return pool_generator


@pytest.fixture()
def pool_list():
return [Pipeline(MovingAverageModel(7), horizon=7), Pipeline(NaiveModel(1), horizon=7)]


def test_objective(
example_tsds,
target_metric=MAE(),
Expand Down Expand Up @@ -118,11 +142,9 @@ def test_init_optuna(
)


def test_fit_without_tuning(
example_tsds,
optuna_storage,
pool=(Pipeline(MovingAverageModel(5), horizon=7), Pipeline(NaiveModel(1), horizon=7)),
):
@pytest.mark.parametrize("pool", ["pool_list", "pool_generator"])
def test_fit_without_tuning_list(example_tsds, optuna_storage, pool, request):
pool = request.getfixturevalue(pool)
auto = Auto(
MAE(),
pool=pool,
Expand All @@ -136,6 +158,8 @@ def test_fit_without_tuning(
assert len(auto.summary()) == 2
assert len(auto.top_k(k=5)) == 2
assert len(auto.top_k(k=1)) == 1
if isinstance(pool, PoolGenerator):
pool = pool.generate(7)
assert auto.top_k(k=1)[0].to_dict() == pool[0].to_dict()


Expand Down
Loading