From f7cd6839c86efe3a3278eeb7fecd90154bc31fe2 Mon Sep 17 00:00:00 2001
From: Primoz Godec
Date: Tue, 8 Feb 2022 15:49:34 +0100
Subject: [PATCH] Server embedder: use queue, handle unsuccessful requests at
the end
---
Orange/misc/server_embedder.py | 196 +++++++++++++---------
Orange/misc/tests/test_server_embedder.py | 9 +
2 files changed, 125 insertions(+), 80 deletions(-)
diff --git a/Orange/misc/server_embedder.py b/Orange/misc/server_embedder.py
index bdcf945952d..52ff16b0eaf 100644
--- a/Orange/misc/server_embedder.py
+++ b/Orange/misc/server_embedder.py
@@ -3,6 +3,7 @@
import logging
import random
import uuid
+from collections import namedtuple
from json import JSONDecodeError
from os import getenv
from typing import Any, Callable, List, Optional
@@ -10,22 +11,22 @@
from AnyQt.QtCore import QSettings
from httpx import AsyncClient, NetworkError, ReadTimeout, Response
-from Orange.misc.utils.embedder_utils import (EmbedderCache,
- EmbeddingCancelledException,
- EmbeddingConnectionError,
- get_proxies)
+from Orange.misc.utils.embedder_utils import (
+ EmbedderCache,
+ EmbeddingCancelledException,
+ EmbeddingConnectionError,
+ get_proxies,
+)
log = logging.getLogger(__name__)
+TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats"))
class ServerEmbedderCommunicator:
"""
This class needs to be inherited by the class which re-implements
- _encode_data_instance and defines self.content_type. For sending a table
- with data items use embedd_table function. This one is called with the
- complete Orange data Table. Then _encode_data_instance needs to extract
- data to be embedded from the RowInstance. For images, it takes the image
- path from the table, load image, and transform it into bytes.
+ _encode_data_instance and defines self.content_type. For sending a list
+ with data items use embedd_table function.
Attributes
----------
@@ -69,14 +70,14 @@ def __init__(
) or str(uuid.getnode())
except TypeError:
self.machine_id = str(uuid.getnode())
- self.session_id = str(random.randint(1, 1e10))
+ self.session_id = str(random.randint(1, int(1e10)))
self._cache = EmbedderCache(model_name)
# default embedding timeouts are too small we need to increase them
self.timeout = 180
- self.num_parallel_requests = 0
- self.max_parallel = max_parallel_requests
+ self.max_parallel_requests = max_parallel_requests
+
self.content_type = None # need to be set in a class inheriting
def embedd_data(
@@ -111,8 +112,7 @@ def embedd_data(
EmbeddingCancelledException:
If cancelled attribute is set to True (default=False).
"""
- # if there is less items than 10 connection error should be raised
- # earlier
+ # if there is less items than 10 connection error should be raised earlier
self.max_errors = min(len(data) * self.MAX_REPEATS, 10)
loop = asyncio.new_event_loop()
@@ -121,11 +121,9 @@ def embedd_data(
embeddings = asyncio.get_event_loop().run_until_complete(
self.embedd_batch(data, processed_callback)
)
- except Exception:
+ finally:
loop.close()
- raise
- loop.close()
return embeddings
async def embedd_batch(
@@ -153,32 +151,63 @@ async def embedd_batch(
EmbeddingCancelledException:
If cancelled attribute is set to True (default=False).
"""
- requests = []
+ results = [None] * len(data)
+ queue = asyncio.Queue()
+
+ # fill the queue with items to embedd
+ for i, item in enumerate(data):
+ queue.put_nowait(TaskItem(id=i, item=item, no_repeats=0))
+
async with AsyncClient(
- timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
+ timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
) as client:
- for p in data:
- if self._cancelled:
- raise EmbeddingCancelledException()
- requests.append(self._send_to_server(p, client, proc_callback))
+ tasks = self._init_workers(client, queue, results, proc_callback)
+
+ # wait for the queue to complete or one of workers to exit
+ queue_complete = asyncio.create_task(queue.join())
+ await asyncio.wait(
+ [queue_complete, *tasks], return_when=asyncio.FIRST_COMPLETED
+ )
+
+ # Cancel worker tasks when done
+ queue_complete.cancel()
+ await self._cancel_workers(tasks)
- embeddings = await asyncio.gather(*requests)
self._cache.persist_cache()
- assert self.num_parallel_requests == 0
+ return results
- return embeddings
+ def _init_workers(self, client, queue, results, callback):
+ """Init required number of workers"""
+ t = [
+ asyncio.create_task(self._send_to_server(client, queue, results, callback))
+ for _ in range(self.max_parallel_requests)
+ ]
+ log.debug("Created %d workers", self.max_parallel_requests)
+ return t
- async def __wait_until_released(self) -> None:
- while self.num_parallel_requests >= self.max_parallel:
- await asyncio.sleep(0.1)
+ @staticmethod
+ async def _cancel_workers(tasks):
+ """Cancel worker at the end"""
+ log.debug("Canceling workers")
+ try:
+ # try to catch any potential exceptions
+ await asyncio.gather(*tasks)
+ except Exception as ex:
+ # raise exceptions gathered from an failed worker
+ raise ex
+ finally:
+ # cancel all tasks in both cases
+ for task in tasks:
+ task.cancel()
+ # Wait until all worker tasks are cancelled.
+ await asyncio.gather(*tasks, return_exceptions=True)
+ log.debug("All workers canceled")
def __check_cancelled(self):
if self._cancelled:
raise EmbeddingCancelledException()
- async def _encode_data_instance(
- self, data_instance: Any
- ) -> Optional[bytes]:
+ async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
"""
The reimplementation of this function must implement the procedure
to encode the data item in a string format that will be sent to the
@@ -197,63 +226,74 @@ async def _encode_data_instance(
raise NotImplementedError
async def _send_to_server(
- self,
- data_instance: Any,
- client: AsyncClient,
- proc_callback: Callable[[bool], None] = None,
- ) -> Optional[List[float]]:
+ self,
+ client: AsyncClient,
+ queue: asyncio.Queue,
+ results: List,
+ proc_callback: Callable[[bool], None] = None,
+ ):
"""
- Function get an data instance. It extract data from it and send them to
- server and retrieve responses.
+ Worker that embedds data. It is pulling items from the until the queue
+ is empty. It is canceled by embedd_batch all tasks are finished
Parameters
----------
- data_instance
- Single row of the input table.
client
HTTPX client that communicates with the server
+ queue
+ The queue with items of type TaskItem to be embedded
+ results
+ The list to append results in. The list has length equal to numbers
+ of all items to embedd. The result need to be inserted at the index
+ defined in queue items.
proc_callback
A function that is called after each item is fully processed
by either getting a successful response from the server,
getting the result from cache or skipping the item.
-
- Returns
- -------
- Embedding. For items that are not successfully embedded returns None.
"""
- await self.__wait_until_released()
- self.__check_cancelled()
-
- self.num_parallel_requests += 1
- # load bytes
- data_bytes = await self._encode_data_instance(data_instance)
- if data_bytes is None:
- self.num_parallel_requests -= 1
- return None
-
- # if data in cache return it
- cache_key = self._cache.md5_hash(data_bytes)
- emb = self._cache.get_cached_result_or_none(cache_key)
-
- if emb is None:
- # in case that embedding not sucessfull resend it to the server
- # maximally for MAX_REPEATS time
- for i in range(1, self.MAX_REPEATS + 1):
- self.__check_cancelled()
+ while not queue.empty():
+ self.__check_cancelled()
+
+ # get item from the queue
+ i, data_instance, num_repeats = await queue.get()
+ num_repeats += 1
+
+ # load bytes
+ data_bytes = await self._encode_data_instance(data_instance)
+ if data_bytes is None:
+ continue
+
+ # retrieve embedded item from the local cache
+ cache_key = self._cache.md5_hash(data_bytes)
+ log.debug("Embedding %s", cache_key)
+ emb = self._cache.get_cached_result_or_none(cache_key)
+
+ if emb is None:
+ # send the item to the server for embedding if not in the local cache
+ log.debug("Sending to the server: %s", cache_key)
url = (
- f"/{self.embedder_type}/{self._model}?"
- f"machine={self.machine_id}"
- f"&session={self.session_id}&retry={i}"
+ f"/{self.embedder_type}/{self._model}?machine={self.machine_id}"
+ f"&session={self.session_id}&retry={num_repeats}"
)
emb = await self._send_request(client, data_bytes, url)
if emb is not None:
self._cache.add(cache_key, emb)
- break # repeat only when embedding None
- if proc_callback:
- proc_callback(emb is not None)
- self.num_parallel_requests -= 1
- return emb
+ if emb is not None:
+ # store result if embedding is successful
+ log.debug("Successfully embedded: %s", cache_key)
+ results[i] = emb
+ if proc_callback:
+ proc_callback(emb is not None)
+ elif num_repeats < self.MAX_REPEATS:
+ log.debug("Embedding unsuccessful - reading to queue: %s", cache_key)
+ # if embedding not successful put the item to queue to be handled at
+ # the end - the item is put to the end since it is possible that server
+ # still process the request and the result will be in the cache later
+ # repeating the request immediately may result in another fail when
+ # processing takes longer
+ queue.put_nowait(TaskItem(i, data_instance, no_repeats=num_repeats))
+ queue.task_done()
async def _send_request(
self, client: AsyncClient, data: bytes, url: str
@@ -284,27 +324,23 @@ async def _send_request(
response = await client.post(url, headers=headers, data=data)
except ReadTimeout as ex:
log.debug("Read timeout", exc_info=True)
- # it happens when server do not respond in 60 seconds, in
- # this case we return None and items will be resend later
+ # it happens when server do not respond in time defined by timeout
+ # return None and items will be resend later
# if it happens more than in ten consecutive cases it means
# sth is wrong with embedder we stop embedding
self.count_read_errors += 1
-
if self.count_read_errors >= self.max_errors:
- self.num_parallel_requests = 0 # for safety reasons
raise EmbeddingConnectionError from ex
return None
except (OSError, NetworkError) as ex:
log.debug("Network error", exc_info=True)
- # it happens when no connection and items cannot be sent to the
- # server
- # we count number of consecutive errors
+ # it happens when no connection and items cannot be sent to server
+
# if more than 10 consecutive errors it means there is no
# connection so we stop embedding with EmbeddingConnectionError
self.count_connection_errors += 1
if self.count_connection_errors >= self.max_errors:
- self.num_parallel_requests = 0 # for safety reasons
raise EmbeddingConnectionError from ex
return None
except Exception:
diff --git a/Orange/misc/tests/test_server_embedder.py b/Orange/misc/tests/test_server_embedder.py
index f5f9007c33e..fc001a2cdd1 100644
--- a/Orange/misc/tests/test_server_embedder.py
+++ b/Orange/misc/tests/test_server_embedder.py
@@ -167,3 +167,12 @@ def test_encode_data_instance(self):
mocked_fun.assert_has_calls(
[call(item) for item in self.test_data], any_order=True
)
+
+ @patch(_HTTPX_POST_METHOD, return_value=DummyResponse(b''), new_callable=AsyncMock)
+ def test_retries(self, mock):
+ self.embedder.embedd_data(self.test_data)
+ self.assertEqual(len(self.test_data) * 3, mock.call_count)
+
+
+if __name__ == "__main__":
+ unittest.main()