-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy paths20_manage_state.py
140 lines (119 loc) · 4.94 KB
/
s20_manage_state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import logging
from typing import Iterable
from pyflink.common import WatermarkStrategy
from pyflink.datastream import DataStream
from pyflink.datastream import StreamExecutionEnvironment, RuntimeExecutionMode
from pyflink.datastream.window import TumblingEventTimeWindows, Time
from pyflink.datastream.state import ValueStateDescriptor
from pyflink.datastream.functions import ProcessWindowFunction, RuntimeContext
from pyflink.datastream.connectors.kafka import (
KafkaSource,
KafkaOffsetsInitializer,
KafkaSink,
KafkaRecordSerializationSchema,
DeliveryGuarantee,
)
from pyflink.datastream.formats.json import JsonRowSerializationSchema, JsonRowDeserializationSchema
from models import FlightData, UserStatistics
RUNTIME_ENV = os.getenv("RUNTIME_ENV", "local")
BOOTSTRAP_SERVERS = os.getenv("BOOTSTRAP_SERVERS", "localhost:29092")
class ProcessUserStatisticsFunction(ProcessWindowFunction):
def __init__(self):
self.state_descriptor = None
def open(self, context: RuntimeContext):
self.state_descriptor = ValueStateDescriptor(
"User Statistics", UserStatistics.get_value_type_info()
)
def process(
self, key: str, context: "ProcessWindowFunction.Context", elements: Iterable[UserStatistics]
) -> Iterable:
state = context.global_state().get_state(self.state_descriptor)
accumulated_stats = state.value()
for new_stats in elements:
# can't use python object, convert to row for state value
if accumulated_stats is None:
accumulated_stats = new_stats.to_row()
else:
accumulated_stats = UserStatistics.merge(
UserStatistics.from_row(accumulated_stats), new_stats
).to_row()
state.update(accumulated_stats)
# return back to python class
yield UserStatistics.from_row(accumulated_stats)
def define_workflow(flight_data_stream: DataStream):
return (
flight_data_stream.map(FlightData.to_user_statistics_data)
.key_by(lambda s: s.email_address)
.window(TumblingEventTimeWindows.of(Time.minutes(1)))
.reduce(UserStatistics.merge, window_function=ProcessUserStatisticsFunction())
)
if __name__ == "__main__":
"""
## local execution
## it takes too long to launch in a local cluster, better to submit it to cluster
python src/s20_manage_state.py
## cluster execution
docker exec jobmanager /opt/flink/bin/flink run \
--python /tmp/src/s20_manage_state.py \
--pyFiles file:///tmp/src/models.py,file:///tmp/src/utils.py \
-d
"""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s:%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logging.info(f"RUNTIME_ENV - {RUNTIME_ENV}, BOOTSTRAP_SERVERS - {BOOTSTRAP_SERVERS}")
env = StreamExecutionEnvironment.get_execution_environment()
env.set_runtime_mode(RuntimeExecutionMode.STREAMING)
env.set_parallelism(1)
if RUNTIME_ENV != "docker":
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
jar_files = ["flink-sql-connector-kafka-1.17.1.jar"]
jar_paths = tuple(
[f"file:///{os.path.join(CURRENT_DIR, 'jars', name)}" for name in jar_files]
)
logging.info(f"adding local jars - {', '.join(jar_files)}")
env.add_jars(*jar_paths)
flight_data_source = (
KafkaSource.builder()
.set_bootstrap_servers(BOOTSTRAP_SERVERS)
.set_topics("flightdata")
.set_group_id("group.flightdata")
.set_starting_offsets(KafkaOffsetsInitializer.latest())
.set_value_only_deserializer(
JsonRowDeserializationSchema.builder()
.type_info(FlightData.get_value_type_info())
.build()
)
.build()
)
flight_data_stream = env.from_source(
flight_data_source, WatermarkStrategy.for_monotonous_timestamps(), "flight_data_source"
)
stats_sink = (
KafkaSink.builder()
.set_bootstrap_servers(BOOTSTRAP_SERVERS)
.set_record_serializer(
KafkaRecordSerializationSchema.builder()
.set_topic("userstatistics")
.set_key_serialization_schema(
JsonRowSerializationSchema.builder()
.with_type_info(UserStatistics.get_key_type_info())
.build()
)
.set_value_serialization_schema(
JsonRowSerializationSchema.builder()
.with_type_info(UserStatistics.get_value_type_info())
.build()
)
.build()
)
.set_delivery_guarantee(DeliveryGuarantee.AT_LEAST_ONCE)
.build()
)
define_workflow(flight_data_stream).map(
lambda d: d.to_row(), output_type=UserStatistics.get_value_type_info()
).sink_to(stats_sink).name("userstatistics_sink").uid("userstatistics_sink")
env.execute("user_statistics")