Skip to content

Commit

Permalink
add async interface
Browse files Browse the repository at this point in the history
  • Loading branch information
davschul committed Feb 16, 2022
1 parent 1c97c6d commit 7df8090
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 2 deletions.
35 changes: 35 additions & 0 deletions pylsp_jsonrpc/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import uuid
import sys
import asyncio

from concurrent import futures
from .exceptions import (JsonRpcException, JsonRpcRequestCancelled,
Expand All @@ -12,6 +13,7 @@
log = logging.getLogger(__name__)
JSONRPC_VERSION = '2.0'
CANCEL_METHOD = '$/cancelRequest'
EXIT_METHOD = 'exit'


class Endpoint:
Expand All @@ -35,9 +37,24 @@ def __init__(self, dispatcher, consumer, id_generator=lambda: str(uuid.uuid4()),
self._client_request_futures = {}
self._server_request_futures = {}
self._executor_service = futures.ThreadPoolExecutor(max_workers=max_workers)
self._cancelledRequests = set()
self._messageQueue = None
self._consume_task = None

def init_async(self):
self._messageQueue = asyncio.Queue()
self._consume_task = asyncio.create_task(self.consume_task())

async def consume_task(self):
while True:
message = await self._messageQueue.get()
await asyncio.to_thread(self.consume, message)
self._messageQueue.task_done()

def shutdown(self):
self._executor_service.shutdown()
if self._consume_task is not None:
self._consume_task.cancel()

def notify(self, method, params=None):
"""Send a JSON RPC notification to the client.
Expand Down Expand Up @@ -94,6 +111,21 @@ def callback(future):
future.set_exception(JsonRpcRequestCancelled())
return callback

async def consume_async(self, message):
"""Consume a JSON RPC message from the client and put it into a queue.
Args:
message (dict): The JSON RPC message sent by the client
"""
if message['method'] == CANCEL_METHOD:
self._cancelledRequests.add(message.get('params')['id'])

# The exit message needs to be handled directly since the stream cannot be closed asynchronously
if message['method'] == EXIT_METHOD:
self.consume(message)
else:
await self._messageQueue.put(message)

def consume(self, message):
"""Consume a JSON RPC message from the client.
Expand Down Expand Up @@ -182,6 +214,9 @@ def _handle_request(self, msg_id, method, params):
except KeyError as e:
raise JsonRpcMethodNotFound.of(method) from e

if msg_id in self._cancelledRequests:
raise JsonRpcRequestCancelled()

handler_result = handler(params)

if callable(handler_result):
Expand Down
25 changes: 25 additions & 0 deletions pylsp_jsonrpc/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
import threading
import asyncio

try:
import ujson as json
Expand Down Expand Up @@ -65,6 +66,30 @@ def _read_message(self):
# Grab the body
return self._rfile.read(content_length)

async def listen_async(self, message_consumer):
"""Blocking call to listen for messages on the rfile.
Args:
message_consumer (fn): function that is passed each message as it is read off the socket.
"""

while not self._rfile.closed:
try:
request_str = await asyncio.to_thread(self._read_message)
except ValueError:
if self._rfile.closed:
return
log.exception("Failed to read from rfile")

if request_str is None:
break

try:
await message_consumer(json.loads(request_str.decode('utf-8')))
except ValueError:
log.exception("Failed to parse JSON message %s", request_str)
continue

@staticmethod
def _content_length(line):
"""Extract the content length from an input line."""
Expand Down
56 changes: 55 additions & 1 deletion test/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pylsp_jsonrpc.endpoint import Endpoint

MSG_ID = 'id'

EXIT_METHOD = 'exit'

@pytest.fixture()
def dispatcher():
Expand Down Expand Up @@ -319,6 +319,60 @@ def test_consume_request_cancel_unknown(endpoint):
})


@pytest.mark.asyncio
async def test_consume_async_request_cancel(endpoint, dispatcher, consumer):
def async_handler():
time.sleep(1)
handler = mock.Mock(return_value=async_handler)
dispatcher['methodName'] = handler

endpoint.init_async()

await endpoint.consume_async({
'jsonrpc': '2.0',
'method': 'methodName',
'params': {'key': 'value'}
})
await endpoint.consume_async({
'jsonrpc': '2.0',
'id': MSG_ID,
'method': 'methodName',
'params': {'key': 'value'}
})
await endpoint.consume_async({
'jsonrpc': '2.0',
'method': '$/cancelRequest',
'params': {'id': MSG_ID}
})

await endpoint._messageQueue.join()

consumer.assert_called_once_with({
'jsonrpc': '2.0',
'id': MSG_ID,
'error': exceptions.JsonRpcRequestCancelled().to_dict()
})

endpoint.shutdown()


@pytest.mark.asyncio
async def test_consume_async_exit(endpoint, dispatcher, consumer):
# verify that exit is still called synchronously
handler = mock.Mock()
dispatcher[EXIT_METHOD] = handler

endpoint.init_async()

await endpoint.consume_async({
'jsonrpc': '2.0',
'method': EXIT_METHOD
})

handler.assert_called_once_with(None)

endpoint.shutdown()

def assert_consumer_error(consumer_mock, exception):
"""Assert that the consumer mock has had once call with the given error message and code.
Expand Down
53 changes: 52 additions & 1 deletion test/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,53 @@ def test_reader_bad_json(rfile, reader):
consumer.assert_not_called()


@pytest.mark.asyncio
async def test_reader_async(rfile, reader):
rfile.write(
b'Content-Length: 49\r\n'
b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n'
b'\r\n'
b'{"id": "hello", "method": "method", "params": {}}'
)
rfile.seek(0)

consumer = mock.AsyncMock()
await reader.listen_async(consumer)

consumer.assert_called_once_with({
'id': 'hello',
'method': 'method',
'params': {}
})


@pytest.mark.asyncio
async def test_reader_bad_message_async(rfile, reader):
rfile.write(b'Hello world')
rfile.seek(0)

# Ensure the listener doesn't throw
consumer = mock.AsyncMock()
await reader.listen_async(consumer)
consumer.assert_not_called()


@pytest.mark.asyncio
async def test_reader_bad_json_async(rfile, reader):
rfile.write(
b'Content-Length: 8\r\n'
b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n'
b'\r\n'
b'{hello}}'
)
rfile.seek(0)

# Ensure the listener doesn't throw
consumer = mock.AsyncMock()
await reader.listen_async(consumer)
consumer.assert_not_called()


def test_writer(wfile, writer):
writer.write({
'id': 'hello',
Expand Down Expand Up @@ -124,5 +171,9 @@ def test_writer_bad_message(wfile, writer):
b'Content-Length: 10\r\n'
b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n'
b'\r\n'
b'1546322461'
b'1546322461',
b'Content-Length: 10\r\n'
b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n'
b'\r\n'
b'1546300861'
]

0 comments on commit 7df8090

Please sign in to comment.