Skip to content

Commit

Permalink
Unlock experimental tests (#464)
Browse files Browse the repository at this point in the history
  • Loading branch information
d-a-bunin authored Sep 3, 2024
1 parent 9a03374 commit 7dbbf50
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
- name: PyTest with sharding
run: |
poetry run pytest tests -v --shard-id=${{ matrix.shard-id }} --num-shards=3 --cov=etna --ignore=tests/test_experimental --cov-report=xml --durations=10
poetry run pytest tests -v --shard-id=${{ matrix.shard-id }} --num-shards=3 --cov=etna --cov-report=xml --durations=10
poetry run pytest etna -v --doctest-modules --ignore=etna/libs --durations=10
- name: Upload coverage
Expand Down
9 changes: 8 additions & 1 deletion etna/pipeline/autoregressive_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _forecast(self, ts: TSDataset, return_components: bool) -> TSDataset:
freq=ts.freq,
df_exog=ts.df_exog,
known_future=ts.known_future,
hierarchical_structure=ts.hierarchical_structure,
)
with warnings.catch_warnings():
warnings.filterwarnings(
Expand Down Expand Up @@ -171,7 +172,13 @@ def _forecast(self, ts: TSDataset, return_components: bool) -> TSDataset:
prediction_df = prediction_df.combine_first(current_ts_future.to_pandas()[prediction_df.columns])

# construct dataset and add all features
prediction_ts = TSDataset(df=prediction_df, freq=ts.freq, df_exog=ts.df_exog, known_future=ts.known_future)
prediction_ts = TSDataset(
df=prediction_df,
freq=ts.freq,
df_exog=ts.df_exog,
known_future=ts.known_future,
hierarchical_structure=ts.hierarchical_structure,
)
prediction_ts.transform(self.transforms)
prediction_ts.inverse_transform(self.transforms)

Expand Down
14 changes: 3 additions & 11 deletions tests/test_experimental/test_prediction_intervals/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,24 +208,16 @@ def test_params_to_tune(pipeline, expected_params_to_tune):
horizon=1,
reconciliator=BottomUpReconciliator(target_level="market", source_level="product"),
),
DirectEnsemble(pipelines=[get_naive_pipeline(horizon=1), get_naive_pipeline_with_transforms(horizon=2)]),
VotingEnsemble(pipelines=[get_naive_pipeline(horizon=1), get_naive_pipeline_with_transforms(horizon=1)]),
StackingEnsemble(pipelines=[get_naive_pipeline(horizon=1), get_naive_pipeline_with_transforms(horizon=1)]),
),
)
def test_valid_params_sampling(product_level_constant_hierarchical_ts, pipeline):
intervals_pipeline = DummyPredictionIntervals(pipeline=pipeline)
assert_sampling_is_valid(intervals_pipeline=intervals_pipeline, ts=product_level_constant_hierarchical_ts)


@pytest.mark.parametrize(
"pipeline",
(VotingEnsemble(pipelines=[get_naive_pipeline(horizon=1), get_naive_pipeline_with_transforms(horizon=1)]),),
)
def test_default_params_to_tune_error(pipeline):
intervals_pipeline = DummyPredictionIntervals(pipeline=pipeline)

with pytest.raises(NotImplementedError, match=f"{pipeline.__class__.__name__} doesn't support"):
_ = intervals_pipeline.params_to_tune()


@pytest.mark.parametrize("load_ts", (True, False))
@pytest.mark.parametrize(
"pipeline",
Expand Down

0 comments on commit 7dbbf50

Please sign in to comment.