Skip to content

Commit

Permalink
fix segment encoder tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor Baturin committed Nov 7, 2024
1 parent 5e864cb commit ce5c3ed
Showing 1 changed file with 9 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,36 @@
import pandas as pd
import pytest

from etna.datasets import TSDataset
from etna.transforms import SegmentEncoderTransform
from tests.test_transforms.utils import assert_transformation_equals_loaded_original
from tests.utils import select_segments_subset


def test_segment_encoder_transform(simple_ts):
def test_segment_encoder_transform(mean_segment_encoder_ts):
transform = SegmentEncoderTransform()
transformed_df = transform.fit_transform(simple_ts).to_pandas()
transformed_df = transform.fit_transform(mean_segment_encoder_ts).to_pandas()
assert (
len(transformed_df.loc[:, pd.IndexSlice[:, "segment_code"]].columns) == 2
), "Number of columns not the same as segments"
assert len(simple_ts.to_pandas()) == len(transformed_df), "Row missing"
assert len(mean_segment_encoder_ts.to_pandas()) == len(transformed_df), "Row missing"
codes = set()
for segment in simple_ts.segments:
for segment in mean_segment_encoder_ts.segments:
column = transformed_df.loc[:, pd.IndexSlice[segment, "segment_code"]]
assert column.dtype == "category", "Column type is not category"
assert np.all(column == column.iloc[0]), "Values are not the same for the whole column"
codes.add(column.iloc[0])
assert codes == {0, 1}, "Codes are not 0 and 1"


def test_subset_segments(simple_ts):
train_ts = simple_ts
test_df = simple_ts.loc[:, pd.IndexSlice["Omsk", :]]
test_ts = TSDataset(df=test_df, freq=simple_ts.freq)
transform = SegmentEncoderTransform()

transform.fit(train_ts)
transformed_test_df = transform.transform(test_ts).to_pandas()

segments = sorted(transformed_test_df.columns.get_level_values("segment").unique())
features = sorted(transformed_test_df.columns.get_level_values("feature").unique())
assert segments == ["Omsk"]
assert features == ["exog", "segment_code", "target"]
values = transformed_test_df.loc[:, pd.IndexSlice[:, "segment_code"]]
assert np.all(values == values.iloc[0])


def test_not_fitted_error(simple_ts):
def test_not_fitted_error(mean_segment_encoder_ts):
encoder = SegmentEncoderTransform()
with pytest.raises(ValueError, match="The transform isn't fitted"):
encoder.transform(simple_ts)
encoder.transform(mean_segment_encoder_ts)


def test_new_segments_error(simple_ts):
train_ts = select_segments_subset(ts=simple_ts, segments=["Moscow"])
test_ts = select_segments_subset(ts=simple_ts, segments=["Omsk"])
def test_new_segments_error(mean_segment_encoder_ts):
train_ts = select_segments_subset(ts=mean_segment_encoder_ts, segments=["segment_0"])
test_ts = select_segments_subset(ts=mean_segment_encoder_ts, segments=["segment_1"])
transform = SegmentEncoderTransform()

transform.fit(train_ts)
Expand Down

0 comments on commit ce5c3ed

Please sign in to comment.