-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_s18_aggregation.py
159 lines (123 loc) · 5.63 KB
/
test_s18_aggregation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import datetime
import typing
import time
import pytest
from pyflink.common import WatermarkStrategy
from pyflink.common.watermark_strategy import TimestampAssigner
from pyflink.datastream import StreamExecutionEnvironment
from models import UserStatistics
from helpers import build_flight, build_user_statistics
from s18_aggregation import define_workflow
@pytest.fixture(scope="module")
def env():
env = StreamExecutionEnvironment.get_execution_environment()
yield env
@pytest.fixture(scope="module")
def default_watermark_strategy():
class DefaultTimestampAssigner(TimestampAssigner):
def extract_timestamp(self, value, record_timestamp):
return int(time.time_ns() / 1000000)
return WatermarkStrategy.for_monotonous_timestamps().with_timestamp_assigner(
DefaultTimestampAssigner()
)
def test_user_statistics_should_create_statistics_using_flight_data():
flight = build_flight()
stats = UserStatistics.from_flight(flight)
expected_duration = int(
(
datetime.datetime.fromisoformat(flight.arrival_time)
- datetime.datetime.fromisoformat(flight.departure_time)
).seconds
/ 60
)
assert flight.email_address == stats.email_address
assert expected_duration == stats.total_flight_duration
assert 1 == stats.number_of_flights
def test_user_statistics_should_merge_two_user_statistics():
stats1 = build_user_statistics()
stats2 = build_user_statistics(email_address=stats1.email_address)
merged = UserStatistics.merge(stats1, stats2)
assert stats1.email_address == merged.email_address
assert (
stats1.total_flight_duration + stats2.total_flight_duration
) == merged.total_flight_duration
assert 2 == merged.number_of_flights
def test_user_statistics_should_fail_for_different_email_address():
stats1 = build_user_statistics()
stats2 = build_user_statistics(email_address="[email protected]")
assert stats1.email_address != stats2.email_address
with pytest.raises(AssertionError):
UserStatistics.merge(stats1, stats2)
def test_define_workflow_should_convert_flight_data_to_user_statistics(
env, default_watermark_strategy
):
flight_data = build_flight()
flight_stream = env.from_collection(
collection=[flight_data.to_row()]
).assign_timestamps_and_watermarks(default_watermark_strategy)
elements: typing.List[UserStatistics] = list(
define_workflow(flight_stream).execute_and_collect()
)
expected = UserStatistics.from_flight(flight_data)
assert expected.email_address == next(iter(elements)).email_address
assert expected.total_flight_duration == next(iter(elements)).total_flight_duration
assert expected.number_of_flights == next(iter(elements)).number_of_flights
def test_define_workflow_should_group_statistics_by_email_address(env, default_watermark_strategy):
flight_data_1 = build_flight()
flight_data_2 = build_flight()
flight_data_3 = build_flight()
flight_data_3.email_address = flight_data_1.email_address
flight_stream = env.from_collection(
collection=[flight_data_1.to_row(), flight_data_2.to_row(), flight_data_3.to_row()]
).assign_timestamps_and_watermarks(default_watermark_strategy)
elements: typing.List[UserStatistics] = list(
define_workflow(flight_stream).execute_and_collect()
)
expected_1 = UserStatistics.merge(
UserStatistics.from_flight(flight_data_1), UserStatistics.from_flight(flight_data_3)
)
expected_2 = UserStatistics.from_flight(flight_data_2)
assert len(elements) == 2
for e in elements:
if e.email_address == flight_data_1.email_address:
assert e.total_flight_duration == expected_1.total_flight_duration
assert e.number_of_flights == expected_1.number_of_flights
else:
assert e.total_flight_duration == expected_2.total_flight_duration
assert e.number_of_flights == expected_2.number_of_flights
def test_define_workflow_should_window_statistics_by_minute(env):
flight_data_1 = build_flight()
flight_data_2 = build_flight()
flight_data_2.email_address = flight_data_1.email_address
flight_data_3 = build_flight()
flight_data_3.email_address = flight_data_1.email_address
flight_data_3.departure_airport_code = "LATE"
class CustomTimestampAssigner(TimestampAssigner):
def extract_timestamp(self, value, record_timestamp):
if value.departure_airport_code == "LATE":
# higher than 27300 makes a separate window
# shouldn't it be values lower than 60000???
return int(time.time_ns() / 1000000) + 60000
else:
return int(time.time_ns() / 1000000)
custom_watermark_strategy = (
WatermarkStrategy.for_monotonous_timestamps().with_timestamp_assigner(
CustomTimestampAssigner()
)
)
flight_stream = env.from_collection(
collection=[flight_data_1.to_row(), flight_data_2.to_row(), flight_data_3.to_row()]
).assign_timestamps_and_watermarks(custom_watermark_strategy)
elements: typing.List[UserStatistics] = list(
define_workflow(flight_stream).execute_and_collect()
)
expected_1 = UserStatistics.merge(
UserStatistics.from_flight(flight_data_1), UserStatistics.from_flight(flight_data_2)
)
expected_2 = UserStatistics.from_flight(flight_data_3)
assert len(elements) == 2
for e in elements:
if e.number_of_flights > 1:
assert e.total_flight_duration == expected_1.total_flight_duration
else:
assert e.total_flight_duration == expected_2.total_flight_duration