From cdff37bc84e78982d1372c694bb2f182f7271a47 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Wed, 21 Aug 2024 10:26:09 +0545 Subject: [PATCH 01/27] chore: Add request_evicted_status to streaming loop to cancel requests --- src/litserve/server.py | 61 ++++++++++++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 4c06048b..07f52b25 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -217,7 +217,13 @@ def run_batched_loop( response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) -def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]): +def run_streaming_loop( + lit_api: LitAPI, + lit_spec: LitSpec, + request_queue: Queue, + response_queues: List[Queue], + request_evicted_status: Dict[str, bool], +): while True: try: response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) @@ -256,6 +262,8 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, y_gen, ) for y_enc in y_enc_gen: + if request_evicted_status.get(uid): + break y_enc = lit_api.format_encoded_response(y_enc) response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING))) @@ -338,6 +346,7 @@ def inference_worker( worker_id: int, request_queue: Queue, response_queues: List[Queue], + request_evicted_status: Dict[str, bool], max_batch_size: int, batch_timeout: float, stream: bool, @@ -357,7 +366,7 @@ def inference_worker( if max_batch_size > 1: run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout) else: - run_streaming_loop(lit_api, lit_spec, request_queue, response_queues) + run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, request_evicted_status) return if max_batch_size > 1: @@ -397,7 +406,7 @@ async def response_queue_to_buffer( await asyncio.sleep(0.0001) continue q, event = buffer[uid] - q.append(payload) + q.append((uid, payload)) event.set() else: @@ -499,6 +508,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int): manager = mp.Manager() self.workers_setup_status = manager.dict() self.request_queue = manager.Queue() + self.request_evicted_status = manager.dict() self.response_queues = [] for _ in range(num_uvicorn_servers): @@ -532,6 +542,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int): worker_id, self.request_queue, self.response_queues, + self.request_evicted_status, self.max_batch_size, self.batch_timeout, self.stream, @@ -570,25 +581,34 @@ def device_identifiers(self, accelerator, device): async def data_streamer(self, q: deque, data_available: asyncio.Event, send_status: bool = False): while True: - await data_available.wait() - while len(q) > 0: - data, status = q.popleft() - if status == LitAPIStatus.FINISH_STREAMING: - return - - if status == LitAPIStatus.ERROR: - logger.error( - "Error occurred while streaming outputs from the inference worker. " - "Please check the above traceback." - ) + try: + await data_available.wait() + while len(q) > 0: + uid, (data, status) = q.popleft() + if status == LitAPIStatus.FINISH_STREAMING: + return + + if status == LitAPIStatus.ERROR: + logger.error( + "Error occurred while streaming outputs from the inference worker. " + "Please check the above traceback." + ) + if send_status: + yield data, status + return if send_status: yield data, status - return - if send_status: - yield data, status - else: - yield data - data_available.clear() + else: + yield data + data_available.clear() + except asyncio.CancelledError: + self.request_evicted_status[uid] = True + logger.error("Request evicted for the uid=%s", uid) + break + except Exception as e: + # Handle other exceptions that might occur + logger.error(f"Exception occurred during streaming: {e}") + break def setup_server(self): workers_ready = False @@ -635,6 +655,7 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks) async def stream_predict(request: self.request_type, background_tasks: BackgroundTasks) -> self.response_type: response_queue_id = self.app.response_queue_id + print("response_queue_id=", response_queue_id) uid = uuid.uuid4() event = asyncio.Event() q = deque() From e5565a88e28c6eebb191928808e67b5508fde233 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 22 Aug 2024 00:22:43 +0545 Subject: [PATCH 02/27] fix failing test --- tests/test_lit_server.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 1e5b4f45..c7314930 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -66,10 +66,10 @@ def test_device_identifiers(lifespan_mock, simple_litapi): @patch("litserve.server.run_batched_loop") @patch("litserve.server.run_single_loop") def test_inference_worker(mock_single_loop, mock_batched_loop): - inference_worker(*[MagicMock()] * 6, max_batch_size=2, batch_timeout=0, stream=False) + inference_worker(*[MagicMock()] * 7, max_batch_size=2, batch_timeout=0, stream=False) mock_batched_loop.assert_called_once() - inference_worker(*[MagicMock()] * 6, max_batch_size=1, batch_timeout=0, stream=False) + inference_worker(*[MagicMock()] * 7, max_batch_size=1, batch_timeout=0, stream=False) mock_single_loop.assert_called_once() @@ -175,11 +175,12 @@ def fake_encode(output): fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x) requests_queue = Queue() + request_evicted_status = {} requests_queue.put((0, "UUID-1234", time.monotonic(), {"prompt": "Hello"})) response_queues = [FakeStreamResponseQueue(num_streamed_outputs)] with pytest.raises(StopIteration, match="exit loop"): - run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues) + run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues, request_evicted_status) fake_stream_api.predict.assert_called_once_with("Hello") fake_stream_api.encode_response.assert_called_once() From 36429a5e77599cd402bbc58dfb095dd1d77c5719 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 22 Aug 2024 01:11:42 +0545 Subject: [PATCH 03/27] fixed: cannot access local variable 'uid' --- src/litserve/server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 07f52b25..f5b4af64 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -580,6 +580,7 @@ def device_identifiers(self, accelerator, device): return [f"{accelerator}:{device}"] async def data_streamer(self, q: deque, data_available: asyncio.Event, send_status: bool = False): + uid = None while True: try: await data_available.wait() @@ -602,8 +603,9 @@ async def data_streamer(self, q: deque, data_available: asyncio.Event, send_stat yield data data_available.clear() except asyncio.CancelledError: - self.request_evicted_status[uid] = True - logger.error("Request evicted for the uid=%s", uid) + if uid is not None: + self.request_evicted_status[uid] = True + logger.error("Request evicted for the uid=%s", uid) break except Exception as e: # Handle other exceptions that might occur From 9c087441f51ef25b4bdb94a7cb6d87303ae8ff10 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 22 Aug 2024 10:10:09 +0545 Subject: [PATCH 04/27] feat: adds test for `test_stream_client_disconnection` --- tests/test_lit_server.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index c7314930..b94b43f6 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -15,6 +15,7 @@ import inspect import pickle import re +import logging from asgi_lifespan import LifespanManager from litserve import LitAPI from fastapi import Request, Response, HTTPException @@ -120,6 +121,24 @@ async def test_stream(simple_stream_api): ), "Server returns input prompt and generated output which didn't match." +@pytest.mark.asyncio +async def test_stream_client_disconnection(simple_stream_api, caplog): + server = LitServer(simple_stream_api, stream=True, timeout=10) + + with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG): + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?" * 2}, timeout=10)) + await asyncio.sleep(0.4) + + # Simulate client disconnection by canceling the request + task.cancel() + + # Allow some time for the server to handle the cancellation + await asyncio.sleep(0.5) + + assert "Request evicted for the uid=" in caplog.text, "Server should log client disconnection" + + @pytest.mark.asyncio() async def test_batched_stream_server(simple_batched_stream_api): server = LitServer(simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30) From f5522fa32913ecee84192dba8634efce368104d3 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 22 Aug 2024 10:10:45 +0545 Subject: [PATCH 05/27] ref: format imports using ruff --- tests/test_lit_server.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index b94b43f6..401cc01a 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -13,33 +13,33 @@ # limitations under the License. import asyncio import inspect +import logging import pickle import re -import logging -from asgi_lifespan import LifespanManager -from litserve import LitAPI -from fastapi import Request, Response, HTTPException import time -import torch -import torch.nn as nn from queue import Queue -from httpx import AsyncClient -from litserve.utils import wrap_litserve_start +from unittest.mock import MagicMock, patch -from unittest.mock import patch, MagicMock import pytest +import torch +import torch.nn as nn +from asgi_lifespan import LifespanManager +from fastapi import HTTPException, Request, Response +from fastapi.testclient import TestClient +from httpx import AsyncClient +import litserve as ls +from litserve import LitAPI from litserve.connector import _Connector from litserve.server import ( + LitAPIStatus, + LitServer, inference_worker, + run_batched_streaming_loop, run_single_loop, run_streaming_loop, - LitAPIStatus, - run_batched_streaming_loop, ) -from litserve.server import LitServer -import litserve as ls -from fastapi.testclient import TestClient +from litserve.utils import wrap_litserve_start def test_index(sync_testclient): From 2f46532c1ec6d2e58a36eba6bce21088cb737727 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 22 Aug 2024 10:11:48 +0545 Subject: [PATCH 06/27] fix lint warning for `@pytest.mark.asyncio` --- tests/test_lit_server.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 401cc01a..caf3719f 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -74,7 +74,7 @@ def test_inference_worker(mock_single_loop, mock_batched_loop): mock_single_loop.assert_called_once() -@pytest.fixture() +@pytest.fixture def loop_args(): requests_queue = Queue() requests_queue.put((0, "uuid-123", time.monotonic(), 1)) # response_queue_id, uid, timestamp, x_enc @@ -100,7 +100,7 @@ def test_single_loop(loop_args): run_single_loop(lit_api_mock, None, requests_queue, response_queues) -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_stream(simple_stream_api): server = LitServer(simple_stream_api, stream=True, timeout=10) expected_output1 = "prompt=Hello generated_output=LitServe is streaming output".lower().replace(" ", "") @@ -139,7 +139,7 @@ async def test_stream_client_disconnection(simple_stream_api, caplog): assert "Request evicted for the uid=" in caplog.text, "Server should log client disconnection" -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_batched_stream_server(simple_batched_stream_api): server = LitServer(simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30) expected_output1 = "Hello LitServe is streaming output".lower().replace(" ", "") @@ -418,7 +418,7 @@ def encode_response(self, output, context): return {"output": input} -@pytest.mark.asyncio() +@pytest.mark.asyncio @patch("litserve.server.load_and_raise") async def test_inject_context(mocked_load_and_raise): def dummy_load_and_raise(resp): From 7aacee6ed2e821935a3a53186001fa571bf4b725 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 22 Aug 2024 10:13:39 +0545 Subject: [PATCH 07/27] adds a todo in the test for reminder --- tests/test_lit_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index caf3719f..7242f12f 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -135,9 +135,10 @@ async def test_stream_client_disconnection(simple_stream_api, caplog): # Allow some time for the server to handle the cancellation await asyncio.sleep(0.5) - assert "Request evicted for the uid=" in caplog.text, "Server should log client disconnection" + # TODO: also chec if the task actually stopped in the server + @pytest.mark.asyncio async def test_batched_stream_server(simple_batched_stream_api): From 4327e49899dc06993925c2e74d7ef1d3233d73b3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Aug 2024 04:28:53 +0000 Subject: [PATCH 08/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_lit_server.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 7242f12f..b1de2a2d 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -74,7 +74,7 @@ def test_inference_worker(mock_single_loop, mock_batched_loop): mock_single_loop.assert_called_once() -@pytest.fixture +@pytest.fixture() def loop_args(): requests_queue = Queue() requests_queue.put((0, "uuid-123", time.monotonic(), 1)) # response_queue_id, uid, timestamp, x_enc @@ -100,7 +100,7 @@ def test_single_loop(loop_args): run_single_loop(lit_api_mock, None, requests_queue, response_queues) -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_stream(simple_stream_api): server = LitServer(simple_stream_api, stream=True, timeout=10) expected_output1 = "prompt=Hello generated_output=LitServe is streaming output".lower().replace(" ", "") @@ -121,7 +121,7 @@ async def test_stream(simple_stream_api): ), "Server returns input prompt and generated output which didn't match." -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_stream_client_disconnection(simple_stream_api, caplog): server = LitServer(simple_stream_api, stream=True, timeout=10) @@ -140,7 +140,7 @@ async def test_stream_client_disconnection(simple_stream_api, caplog): # TODO: also chec if the task actually stopped in the server -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_batched_stream_server(simple_batched_stream_api): server = LitServer(simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30) expected_output1 = "Hello LitServe is streaming output".lower().replace(" ", "") @@ -419,7 +419,7 @@ def encode_response(self, output, context): return {"output": input} -@pytest.mark.asyncio +@pytest.mark.asyncio() @patch("litserve.server.load_and_raise") async def test_inject_context(mocked_load_and_raise): def dummy_load_and_raise(resp): From 6eeb90ddaa3643edb1a0318e1ef640403b82725a Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 22 Aug 2024 10:16:24 +0545 Subject: [PATCH 09/27] adds cleanup for the dict to prevent leakage --- src/litserve/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/litserve/server.py b/src/litserve/server.py index f5b4af64..c5cc4794 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -263,6 +263,7 @@ def run_streaming_loop( ) for y_enc in y_enc_gen: if request_evicted_status.get(uid): + request_evicted_status.pop(uid) break y_enc = lit_api.format_encoded_response(y_enc) response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) From dc041d2a1cddad394cc376129bc72564477bb9b0 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 22 Aug 2024 10:18:59 +0545 Subject: [PATCH 10/27] chore: fix typo in test_lit_server.py --- tests/test_lit_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index b1de2a2d..79e9060d 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -137,7 +137,7 @@ async def test_stream_client_disconnection(simple_stream_api, caplog): await asyncio.sleep(0.5) assert "Request evicted for the uid=" in caplog.text, "Server should log client disconnection" - # TODO: also chec if the task actually stopped in the server + # TODO: also check if the task actually stopped in the server @pytest.mark.asyncio() From 18419f18989cff4ac6121b23fd30d948f38450f6 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 22 Aug 2024 10:26:14 +0545 Subject: [PATCH 11/27] updates the sleep time --- tests/test_lit_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 79e9060d..583efdc8 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -128,13 +128,13 @@ async def test_stream_client_disconnection(simple_stream_api, caplog): with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG): async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?" * 2}, timeout=10)) - await asyncio.sleep(0.4) + await asyncio.sleep(0.5) # Simulate client disconnection by canceling the request task.cancel() # Allow some time for the server to handle the cancellation - await asyncio.sleep(0.5) + await asyncio.sleep(1) assert "Request evicted for the uid=" in caplog.text, "Server should log client disconnection" # TODO: also check if the task actually stopped in the server From f6763e55bbee1839a71ff2c45a4946a1fd99edcc Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 22 Aug 2024 10:55:45 +0545 Subject: [PATCH 12/27] updated some time --- tests/test_lit_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 583efdc8..26428ba0 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -127,8 +127,8 @@ async def test_stream_client_disconnection(simple_stream_api, caplog): with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG): async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?" * 2}, timeout=10)) - await asyncio.sleep(0.5) + task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?" * 10}, timeout=10)) + await asyncio.sleep(1) # Simulate client disconnection by canceling the request task.cancel() From 6dc64542cbbc0fd29920ec90f8e14b1e5b6f9270 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 22 Aug 2024 10:57:04 +0545 Subject: [PATCH 13/27] updated prompt len --- tests/test_lit_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 26428ba0..a43efac6 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -127,7 +127,7 @@ async def test_stream_client_disconnection(simple_stream_api, caplog): with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG): async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?" * 10}, timeout=10)) + task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?" * 20}, timeout=10)) await asyncio.sleep(1) # Simulate client disconnection by canceling the request From e7b305931d0a615cfeded2ed11e5bb404ebc42ef Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 22 Aug 2024 11:17:06 +0545 Subject: [PATCH 14/27] chore: Remove print statement in stream_predict method --- src/litserve/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index c5cc4794..14d7208f 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -658,7 +658,6 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks) async def stream_predict(request: self.request_type, background_tasks: BackgroundTasks) -> self.response_type: response_queue_id = self.app.response_queue_id - print("response_queue_id=", response_queue_id) uid = uuid.uuid4() event = asyncio.Event() q = deque() From 34453e91fec1d006be7791235eb7300a8f8b387c Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 23 Aug 2024 10:12:23 +0545 Subject: [PATCH 15/27] chore: Add delayed prediction support in LitAPI subclasses --- tests/conftest.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index f7d1c84d..d92cea54 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,6 +37,12 @@ def encode_response(self, output) -> Response: return {"output": output} +class SimpleDelayedLitAPI(SimpleLitAPI): + def predict(self, x): + time.sleep(0.5) + return self.model(x) + + class SimpleStreamAPI(LitAPI): def setup(self, device) -> None: self.sentence = "LitServe is streaming output" @@ -55,6 +61,14 @@ def encode_response(self, output: Generator) -> Generator: yield out.lower() +class SimpleDelayedStreamAPI(SimpleStreamAPI): + def encode_response(self, output: Generator) -> Generator: + delay = 0.2 + for out in output: + time.sleep(delay) + yield out.lower() + + class SimpleBatchedStreamAPI(LitAPI): def setup(self, device) -> None: self.sentence = "LitServe is streaming output" @@ -88,11 +102,21 @@ def simple_litapi(): return SimpleLitAPI() +@pytest.fixture() +def simple_delayed_litapi(): + return SimpleDelayedLitAPI() + + @pytest.fixture() def simple_stream_api(): return SimpleStreamAPI() +@pytest.fixture() +def simple_delayed_stream_api(): + return SimpleDelayedStreamAPI() + + @pytest.fixture() def simple_batched_stream_api(): return SimpleBatchedStreamAPI() From 0069b986322db9968cf536aa61f22137a687664c Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 23 Aug 2024 10:12:44 +0545 Subject: [PATCH 16/27] updated stream test and added test for nonstream case --- tests/test_lit_server.py | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index a43efac6..d8532c44 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -95,9 +95,9 @@ def test_single_loop(loop_args): lit_api_mock, requests_queue = loop_args lit_api_mock.unbatch.side_effect = None response_queues = [FakeResponseQueue()] - + request_evicted_status = {} with pytest.raises(StopIteration, match="exit loop"): - run_single_loop(lit_api_mock, None, requests_queue, response_queues) + run_single_loop(lit_api_mock, None, requests_queue, response_queues, request_evicted_status) @pytest.mark.asyncio() @@ -122,23 +122,42 @@ async def test_stream(simple_stream_api): @pytest.mark.asyncio() -async def test_stream_client_disconnection(simple_stream_api, caplog): - server = LitServer(simple_stream_api, stream=True, timeout=10) +async def test_client_disconnection(simple_delayed_litapi, caplog): + server = LitServer(simple_delayed_litapi, timeout=10) with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG): async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?" * 20}, timeout=10)) + task = asyncio.create_task(ac.post("/predict", json={"input": 1.0}, timeout=10)) + await asyncio.sleep(0.2) + task.cancel() await asyncio.sleep(1) + assert "Client disconnected for the request uid" in caplog.text + # TODO: also check if the task actually stopped in the server - # Simulate client disconnection by canceling the request - task.cancel() + caplog.clear() + task = asyncio.create_task(ac.post("/predict", json={"input": 1.0}, timeout=10)) + await task + assert "Client disconnected for the request uid" not in caplog.text - # Allow some time for the server to handle the cancellation - await asyncio.sleep(1) - assert "Request evicted for the uid=" in caplog.text, "Server should log client disconnection" +@pytest.mark.asyncio() +async def test_stream_client_disconnection(simple_delayed_stream_api, caplog): + server = LitServer(simple_delayed_stream_api, stream=True, timeout=10) + + with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG): + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?" * 5}, timeout=10)) + await asyncio.sleep(1) + task.cancel() # simulate client disconnection + await asyncio.sleep(1) # wait for the task to stop + assert "Request evicted for the uid=" in caplog.text # TODO: also check if the task actually stopped in the server + caplog.clear() + task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?"}, timeout=10)) + await task + assert "Request evicted for the uid=" not in caplog.text + @pytest.mark.asyncio() async def test_batched_stream_server(simple_batched_stream_api): From f3d6bd2770a5b8aa9bdbce5f0fc3b9cb96dbf102 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 23 Aug 2024 10:13:17 +0545 Subject: [PATCH 17/27] added logic to handle the client disconnection in predict --- src/litserve/server.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index d2357c9b..c521e8ae 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -110,7 +110,13 @@ def collate_requests( return payloads, timed_out_uids -def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]): +def run_single_loop( + lit_api: LitAPI, + lit_spec: LitSpec, + request_queue: Queue, + response_queues: List[Queue], + request_evicted_status: Dict[str, bool], +): while True: try: response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) @@ -146,6 +152,8 @@ def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, re lit_api.encode_response, y, ) + # TODO: Cancel the task if the client disconnects + response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) except Exception as e: logger.exception( @@ -378,6 +386,7 @@ def inference_worker( lit_spec, request_queue, response_queues, + request_evicted_status, ) @@ -648,8 +657,22 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks) self.request_queue.put_nowait((response_queue_id, uid, time.monotonic(), payload)) - await event.wait() - response, status = self.response_buffer.pop(uid) + async def wait_for_response(): + await event.wait() + return self.response_buffer.pop(uid) + + task = asyncio.create_task(wait_for_response()) + response, status = None, None + try: + while not task.done(): + await asyncio.sleep(0.1) + if await request.is_disconnected(): + task.cancel() + break + response, status = await task + except asyncio.CancelledError: + logger.error("Client disconnected for the request uid=%s", uid) + self.request_evicted_status[uid] = True if status == LitAPIStatus.ERROR: load_and_raise(response) From 60291656a4dae4598568b5e2c922177877363dbb Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 23 Aug 2024 10:19:25 +0545 Subject: [PATCH 18/27] update sleep duration --- src/litserve/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index c521e8ae..31b2161f 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -665,7 +665,7 @@ async def wait_for_response(): response, status = None, None try: while not task.done(): - await asyncio.sleep(0.1) + await asyncio.sleep(0.5) if await request.is_disconnected(): task.cancel() break From 6e95b305a38b84877a69f7a0ff3218647ccafabf Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 23 Aug 2024 10:26:57 +0545 Subject: [PATCH 19/27] Update sleep duration --- tests/test_lit_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index d8532c44..0f97e67f 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -147,7 +147,7 @@ async def test_stream_client_disconnection(simple_delayed_stream_api, caplog): with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG): async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?" * 5}, timeout=10)) - await asyncio.sleep(1) + await asyncio.sleep(2) task.cancel() # simulate client disconnection await asyncio.sleep(1) # wait for the task to stop assert "Request evicted for the uid=" in caplog.text From f6f3e4cbf7a2a8523fed41df1c3491a7149e61ba Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 23 Aug 2024 10:31:18 +0545 Subject: [PATCH 20/27] update sleep time --- src/litserve/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 31b2161f..c521e8ae 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -665,7 +665,7 @@ async def wait_for_response(): response, status = None, None try: while not task.done(): - await asyncio.sleep(0.5) + await asyncio.sleep(0.1) if await request.is_disconnected(): task.cancel() break From 9d47245f6b6f9f322ddbc0639a0c2b0755cc5259 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 23 Aug 2024 10:47:31 +0545 Subject: [PATCH 21/27] removed sleep --- src/litserve/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index c521e8ae..2ce54850 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -665,7 +665,6 @@ async def wait_for_response(): response, status = None, None try: while not task.done(): - await asyncio.sleep(0.1) if await request.is_disconnected(): task.cancel() break From 86ca3cecb2f2b66d93975be5a9652fb5edd62323 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 23 Aug 2024 11:09:51 +0545 Subject: [PATCH 22/27] check if `is_disconnected` exists --- src/litserve/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 2ce54850..72b5409e 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -665,7 +665,7 @@ async def wait_for_response(): response, status = None, None try: while not task.done(): - if await request.is_disconnected(): + if hasattr(request, "is_disconnected") and await request.is_disconnected(): task.cancel() break response, status = await task From 154cc6cf17439369fd0e27baac37e7582aea0c2d Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 23 Aug 2024 11:16:06 +0545 Subject: [PATCH 23/27] adds sleep --- src/litserve/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/litserve/server.py b/src/litserve/server.py index 72b5409e..4c818141 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -665,6 +665,7 @@ async def wait_for_response(): response, status = None, None try: while not task.done(): + asyncio.sleep(0.01) if hasattr(request, "is_disconnected") and await request.is_disconnected(): task.cancel() break From 39986bfdd09369d1d6830e4c9970e1c5e4e0fb00 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 23 Aug 2024 12:50:00 +0545 Subject: [PATCH 24/27] chore: Update sleep duration --- src/litserve/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 4c818141..2ace0c25 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -665,7 +665,7 @@ async def wait_for_response(): response, status = None, None try: while not task.done(): - asyncio.sleep(0.01) + asyncio.sleep(0.2) if hasattr(request, "is_disconnected") and await request.is_disconnected(): task.cancel() break From 2c7633ab7788b396f60456d9a8ad4ac5e9425dfb Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 23 Aug 2024 13:25:05 +0545 Subject: [PATCH 25/27] chore: Update sleep duration in LitServer --- src/litserve/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 2ace0c25..ea664284 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -665,7 +665,7 @@ async def wait_for_response(): response, status = None, None try: while not task.done(): - asyncio.sleep(0.2) + await asyncio.sleep(1) if hasattr(request, "is_disconnected") and await request.is_disconnected(): task.cancel() break From ccaeee95a4e3a7fbb95bc0fa741b57c13a1f491e Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 23 Aug 2024 13:57:11 +0545 Subject: [PATCH 26/27] tried another approach to check & handle disconnection --- src/litserve/server.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index ea664284..c067e0eb 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -661,18 +661,25 @@ async def wait_for_response(): await event.wait() return self.response_buffer.pop(uid) - task = asyncio.create_task(wait_for_response()) - response, status = None, None - try: - while not task.done(): - await asyncio.sleep(1) + async def check_disconnection(): + while True: if hasattr(request, "is_disconnected") and await request.is_disconnected(): - task.cancel() - break - response, status = await task - except asyncio.CancelledError: - logger.error("Client disconnected for the request uid=%s", uid) + return True + await asyncio.sleep(1) # Check every second + + response_task = asyncio.create_task(wait_for_response()) + disconnection_task = asyncio.create_task(check_disconnection()) + + done, pending = await asyncio.wait([response_task, disconnection_task], return_when=asyncio.FIRST_COMPLETED) + + if response_task in done: + response, status = await response_task + disconnection_task.cancel() + else: + response_task.cancel() + logger.error(f"Client disconnected for the request uid={uid}") self.request_evicted_status[uid] = True + raise HTTPException(status_code=499, detail="Client closed request") if status == LitAPIStatus.ERROR: load_and_raise(response) From f0b19af88e557f0ede1457d698b0f9937edea581 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 23 Aug 2024 14:22:13 +0545 Subject: [PATCH 27/27] wrap in try catch --- src/litserve/server.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index c067e0eb..798d5636 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -670,15 +670,23 @@ async def check_disconnection(): response_task = asyncio.create_task(wait_for_response()) disconnection_task = asyncio.create_task(check_disconnection()) - done, pending = await asyncio.wait([response_task, disconnection_task], return_when=asyncio.FIRST_COMPLETED) - - if response_task in done: - response, status = await response_task - disconnection_task.cancel() - else: + try: + # Use asyncio.wait to handle both response and disconnection checks + done, pending = await asyncio.wait( + [response_task, disconnection_task], return_when=asyncio.FIRST_COMPLETED + ) + if response_task in done: + response, status = await response_task + disconnection_task.cancel() + else: + response_task.cancel() + logger.error(f"Client disconnected for the request uid={uid}") + self.request_evicted_status[uid] = True + raise HTTPException(status_code=499, detail="Client closed request") + except asyncio.CancelledError: response_task.cancel() + disconnection_task.cancel() logger.error(f"Client disconnected for the request uid={uid}") - self.request_evicted_status[uid] = True raise HTTPException(status_code=499, detail="Client closed request") if status == LitAPIStatus.ERROR: