-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjupyter_kernel_proxy.py
executable file
·443 lines (383 loc) · 17 KB
/
jupyter_kernel_proxy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
#!/usr/bin/env python
# -*- encoding: utf8 -*-
#
# Copyright (c) 2022 ESET spol. s r.o.
# Author: Marc-Etienne M.Léveillé <[email protected]>
# See LICENSE file for redistribution.
import sys
import json
import hmac
import uuid
import hashlib
import datetime
import glob
import os
import six
from collections import namedtuple, OrderedDict
from operator import attrgetter
from jupyter_core.paths import jupyter_runtime_dir, jupyter_data_dir
import zmq
from tornado import ioloop
from zmq.eventloop import zmqstream
SocketInfo = namedtuple("SocketInfo", (
"name",
"server_type",
"client_type",
"signed",
))
KERNEL_SOCKETS = (
SocketInfo("hb", zmq.REP, zmq.REQ, False),
SocketInfo("iopub", zmq.PUB, zmq.SUB, True),
SocketInfo("control", zmq.ROUTER, zmq.DEALER, True),
SocketInfo("stdin", zmq.ROUTER, zmq.DEALER, True),
SocketInfo("shell", zmq.ROUTER, zmq.DEALER, True),
)
KERNEL_SOCKETS_NAMES = tuple(map(attrgetter("name"), KERNEL_SOCKETS))
SocketGroup = namedtuple("SocketGroup", KERNEL_SOCKETS_NAMES)
JupyterMessageTuple = namedtuple("JupyterMessageTuple", (
"identities", # type: list of byte strings
# "<IDS|MSG>" delimiter goes here
"signature", # type: bytes
"header", # \
"parent_header", # \ type: dict or bytes (JSON)
"metadata", # /
"content", # /
"buffers", # type: list of byte strings
))
class JupyterMessage(JupyterMessageTuple):
# As defined here:
# https://jupyter-client.readthedocs.io/en/stable/messaging.html#the-wire-protocol
DELIMITER = b"<IDS|MSG>"
@classmethod
def parse(cls, parts, verify_using=None):
i = parts.index(cls.DELIMITER)
if i < 0:
raise ValueError
identities = parts[:i]
signature = parts[i+1]
payloads = parts[i+2:i+6]
buffers = parts[i+6:]
raw_msg = cls._make([identities, signature] + payloads + [buffers])
if verify_using and not raw_msg.has_valid_signature(verify_using):
raise ValueError("Signature verification failed")
return raw_msg.parsed
@property
def _json_fields_slice(self):
return slice(2, 6)
@property
def json_fields(self):
return self[self._json_fields_slice]
@property
def json_field_names(self):
return self._fields[self._json_fields_slice]
@property
def parsed(self):
def ensure_parsed(field):
if isinstance(field, six.binary_type):
return json.loads(field)
else:
return field
parsed_fields = [ ensure_parsed(f) for f in self.json_fields ]
return self._replace(**dict(zip(self.json_field_names, parsed_fields)))
@property
def serialized(self):
def ensure_serialized(field):
if not isinstance(field, six.binary_type):
return six.ensure_binary(json.dumps(field))
else:
return field
serialized_fields = [ ensure_serialized(f) for f in self.json_fields ]
return self._replace(**dict(zip(self.json_field_names, serialized_fields)))
@property
def parts(self):
return self.identities + \
[ self.DELIMITER, self.signature ] + \
list(self.serialized.json_fields) + \
self.buffers
def _compute_signature(self, key):
h = hmac.HMAC(six.ensure_binary(key), digestmod=hashlib.sha256)
for f in list(self.serialized.json_fields) + self.buffers:
h.update(six.ensure_binary(f))
return six.ensure_binary(h.hexdigest())
def has_valid_signature(self, key):
return self.signature == self._compute_signature(key)
def sign_using(self, key):
return self._replace(signature=self._compute_signature(key))
class AbstractProxyKernel(object):
def __init__(self, config, role, zmq_context=zmq.Context.instance()):
if role not in ("server", "client"):
raise ValueError("role value must be 'server' or 'client'")
self.role = role
self.config = config.copy()
self.zmq_context = zmq_context
if role == "server":
self._create_sockets("bind")
elif role == "client":
self._create_sockets("connect")
def _url_for_port(self, port):
return "{:s}://{:s}:{:d}".format(
self.config.get("transport", "tcp"),
self.config.get("ip", "localhost"),
port
)
def _create_sockets(self, bind_or_connect):
if bind_or_connect not in ("bind", "connect"):
raise ValueError("bind_or_connect must be 'bind' or 'connect'")
ctx = self.zmq_context
zmq_type_key = self.role + "_type"
self.sockets = SocketGroup(*[ctx.socket(getattr(s, zmq_type_key)) for s in KERNEL_SOCKETS])
for i, sock in enumerate(KERNEL_SOCKETS):
sock_bind = getattr(self.sockets[i], bind_or_connect)
sock_bind(self._url_for_port(self.config.get(sock.name + "_port", 0)))
if getattr(sock, zmq_type_key) == zmq.SUB:
self.sockets[i].setsockopt(zmq.SUBSCRIBE, b'')
self.streams = SocketGroup(*map(zmqstream.ZMQStream, self.sockets))
def sign(self, message, key=None):
if key is None:
key = self.config.get("key")
h = hmac.HMAC(six.ensure_binary(key), digestmod=hashlib.sha256)
for m in message:
h.update(m)
return six.ensure_binary(h.hexdigest())
def make_multipart_message(self, msg_type, content={}, parent_header={}, metadata={}):
header = {
"date": datetime.datetime.now().isoformat(),
"msg_id": str(uuid.uuid4()),
"username": "kernel",
"session": getattr(self, "session_id", str(uuid.uuid4())),
"msg_type": msg_type,
"version": "5.0",
}
msg = JupyterMessage([], None, header, parent_header, metadata, content, [])
return msg.sign_using(self.config.get("key")).parts
class ProxyKernelClient(AbstractProxyKernel):
def __init__(self, config, role="client", zmq_context=zmq.Context.instance()):
super(ProxyKernelClient, self).__init__(config, role, zmq_context)
InterceptionFilter = namedtuple("InterceptionFilter", ("stream_type", "msg_type", "callback"))
class ProxyKernelServer(AbstractProxyKernel):
def __init__(self, config, role="server", zmq_context=zmq.Context.instance()):
super(ProxyKernelServer, self).__init__(config, role, zmq_context)
self.filters = []
self.session_id = None
self.proxy_target = None
def _proxy_to(self, other_stream, socktype=None, validate_using=None, resign_using=None):
# request
# Notebook -> ProxyServer -> ProxyClient -> Real kernel
# reply
# Notebook <- ProxyServer <- ProxyClient <- Real kernel
is_reply = other_stream in self.streams
if is_reply:
validate_using = validate_using or self.proxy_target.config.get("key")
resign_using = resign_using or self.config.get("key")
else:
validate_using = validate_using or self.config.get("key")
resign_using = resign_using or self.proxy_target.config.get("key")
def handler(data):
if socktype.signed:
msg = JupyterMessage.parse(data, validate_using)
if not self.session_id and is_reply:
# We catch the session ID here so that if we inject custom
# messages we can use `make_multipart_message` to get a one with
# the right ID
self.session_id = msg.header.get("session")
for stream_type, msg_type, callback in self.filters:
if stream_type == socktype and msg_type == msg.header.get("msg_type"):
new_data = callback(self, other_stream, data)
if new_data is None:
return
else:
data = new_data
if resign_using:
data = msg.sign_using(resign_using).parts
other_stream.send_multipart(data)
other_stream.flush()
return handler
def set_proxy_target(self, proxy_client):
if self.proxy_target is not None:
for stream in self.proxy_target.streams:
stream.stop_on_recv()
self.proxy_target = proxy_client
for i, socktype in enumerate(KERNEL_SOCKETS):
if socktype.server_type != zmq.PUB:
self.streams[i].on_recv(
self._proxy_to(proxy_client.streams[i], socktype=socktype)
)
if socktype.client_type != zmq.PUB:
proxy_client.streams[i].on_recv(
self._proxy_to(self.streams[i], socktype=socktype)
)
def intercept_message(self, stream_type=None, msg_type=None, callback=None):
if stream_type in KERNEL_SOCKETS_NAMES:
stream_type = KERNEL_SOCKETS[KERNEL_SOCKETS_NAMES.index(stream_type)]
if stream_type not in KERNEL_SOCKETS:
raise ValueError("stream_type should be one of " + ", ".join(KERNEL_SOCKETS_NAMES))
if not callable(callback):
raise ValueError("callback must be callable")
self.filters.append(InterceptionFilter(stream_type, msg_type, callback))
class KernelProxyManager(object):
def __init__(self, server):
if isinstance(server, ProxyKernelServer):
self.server = server
else:
self.server = ProxyKernelServer(server)
self.server.intercept_message("shell", "execute_request", self._catch_proxy_magic_command)
self.magic_command = "%proxy"
self._kernel_info_requests = []
self.server.intercept_message("shell", "kernel_info_request", self._on_kernel_info_request)
self.server.intercept_message("shell", "kernel_info_reply", self._on_kernel_info_reply)
self.connect_to_last()
def _catch_proxy_magic_command(self, server, target_stream, data):
msg = JupyterMessage.parse(data)
def send(text, stream="stdout"):
server.streams.iopub.send_multipart(server.make_multipart_message(
"stream", { "name": stream, "text": text }, parent_header=msg.header
))
if msg.content.get("code").startswith(self.magic_command):
server.streams.iopub.send_multipart(server.make_multipart_message(
"status", { "execution_state": "busy"}, parent_header=msg.header
))
argv = list(filter(lambda x: len(x) > 0, msg.content.get("code").rstrip().split(" ")))
def send_usage():
send("Usage: {:s} [ list | connect <file>]".format(self.magic_command))
if len(argv) < 2:
send_usage()
elif argv[1] == "list":
self.update_running_kernels()
send(self._formatted_kernel_list())
elif argv[1] == "connect":
if len(argv) > 2:
try:
self.connect_to(argv[2], request_kernel_info=True)
send("Connecting to " + self.connected_kernel_name)
except ValueError:
send("Unknown kernel " + argv[2], "stderr")
else:
send_usage()
else:
send("Unknown subcommand " + argv[1], "stderr")
send_usage()
server.streams.iopub.send_multipart(server.make_multipart_message(
"status", { "execution_state": "idle"}, parent_header=msg.header
))
server.streams.shell.send_multipart(msg.identities +
server.make_multipart_message(
"execute_reply", {"status": "ok", "execution_count": 0},
parent_header=msg.header
))
return None
else:
return data
def _formatted_kernel_list(self):
return "\n".join(
" {:s} {:s} ({:s})".format(
(filename == self.connected_kernel_name) and "*" or " ",
filename,
config.get("kernel_name") or "no name"
)
for filename, config in self.kernels.items()
)
def update_running_kernels(self):
"Update self.kernels with an ordored dict where keys are file name and "
"values are the configuration (file content) as dict"
files = glob.glob(os.path.join(jupyter_runtime_dir(), "kernel-*.json"))
self.kernels = OrderedDict()
for path in reversed(sorted(files, key=lambda f: os.stat(f).st_atime)):
try:
filename = os.path.basename(path)
with open(path, "r") as f:
config = json.load(f)
if config != self.server.config:
self.kernels[filename] = config
except:
# print something to stderr
pass
return self.kernels
def _on_kernel_info_request(self, server, target_stream, data):
msg = JupyterMessage.parse(data)
self._kernel_info_requests.append(msg.header.get("msg_id"))
ioloop.IOLoop.current().call_later(3, self._send_proxy_kernel_info, data)
return data
def _on_kernel_info_reply(self, server, target_stream, data):
msg = JupyterMessage.parse(data)
if msg.parent_header.get("msg_id") in self._kernel_info_requests:
self._kernel_info_requests.remove(msg.parent_header.get("msg_id"))
elif len(self._kernel_info_requests) > 0:
self._kernel_info_requests.pop(0)
return data
def _send_proxy_kernel_info(self, request):
parent = JupyterMessage.parse(request)
if not parent.header.get("msg_id") in self._kernel_info_requests:
return
msg = self.server.make_multipart_message("kernel_info_reply", {
"status": "ok",
"protocol_version": "5.3",
"implementation": "proxy",
"banner": "Jupyter kernel proxy. Not connected or connected to unresponsive kernel. Use %proxy to connect.",
"language_info": {
"name": "magic",
},
}, parent_header=parent.header)
self.server.streams.shell.send_multipart(parent.identities + msg)
self.server.streams.iopub.send_multipart(
self.server.make_multipart_message("stream",
{
"name": "stderr",
"text": "Target kernel did not reply. "
"Use `%proxy list` and `%proxy connect` to use to "
"another kernel.",
}
)
)
self.server.streams.iopub.send_multipart(self.server.make_multipart_message(
"status", { "execution_state": "idle"}, parent_header=parent.header
))
self._kernel_info_requests.remove(parent.header.get("msg_id"))
def connect_to_last(self):
self.update_running_kernels()
self.connect_to(next(iter(self.kernels.keys()), "<no kernel running>"))
def connect_to(self, kernel_file_name, request_kernel_info=False):
matching = next((n for n in self.kernels if kernel_file_name in n), None)
if matching is None:
raise ValueError("Unknown kernel " + kernel_file_name)
if self.kernels[matching] == self.server.config:
raise ValueError("Refusing loopback connection")
self.connected_kernel_name = matching
self.connected_kernel = ProxyKernelClient(self.kernels[matching])
self.server.set_proxy_target(self.connected_kernel)
if request_kernel_info:
req = self.connected_kernel.make_multipart_message("kernel_info_request")
self._on_kernel_info_request(self.server, self.connected_kernel.streams.shell, req)
self.connected_kernel.streams.shell.send_multipart(req)
def install():
user_kernels_dir = os.path.join(jupyter_data_dir(), "kernels")
if not os.path.exists(user_kernels_dir):
os.mkdir(user_kernels_dir, 0o700)
proxy_kernel_dir = os.path.join(user_kernels_dir, "proxy")
if not os.path.exists(proxy_kernel_dir):
os.mkdir(proxy_kernel_dir, 0o700)
with open(os.path.join(proxy_kernel_dir, "kernel.json"), "w") as f:
json.dump({
"argv": [
sys.executable,
"-m",
"jupyter_kernel_proxy",
"start",
"{connection_file}"
],
"display_name": "Existing session",
}, f)
def start(connection_file):
loop = ioloop.IOLoop.current()
with open(connection_file) as f:
notebook_config = json.load(f)
proxy_manager = KernelProxyManager(notebook_config)
loop.start()
def main():
if len(sys.argv) > 1 and sys.argv[1] == "install":
install()
elif len(sys.argv) > 2 and sys.argv[1] == "start":
start(sys.argv[2])
else:
print("Usage: {:s} [install | start <connection_file>]".format(sys.argv[0]))
if __name__ == "__main__":
main()