Skip to content

Commit

Permalink
Merge pull request #5835 from PrimozGodec/embedders-change-order
Browse files Browse the repository at this point in the history
[ENH] Server embedder: use queue, handle unsuccessful requests at the end
  • Loading branch information
markotoplak authored Feb 25, 2022
2 parents 05fc0df + f3a9dde commit b8928c2
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 92 deletions.
251 changes: 159 additions & 92 deletions Orange/misc/server_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,34 @@
import logging
import random
import uuid
import warnings
from collections import namedtuple
from functools import partial
from json import JSONDecodeError
from os import getenv
from typing import Any, Callable, List, Optional

from AnyQt.QtCore import QSettings
from httpx import AsyncClient, NetworkError, ReadTimeout, Response
from numpy import linspace

from Orange.misc.utils.embedder_utils import (EmbedderCache,
EmbeddingCancelledException,
EmbeddingConnectionError,
get_proxies)
from Orange.misc.utils.embedder_utils import (
EmbedderCache,
EmbeddingCancelledException,
EmbeddingConnectionError,
get_proxies,
)
from Orange.util import dummy_callback

log = logging.getLogger(__name__)
TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats"))


class ServerEmbedderCommunicator:
"""
This class needs to be inherited by the class which re-implements
_encode_data_instance and defines self.content_type. For sending a table
with data items use embedd_table function. This one is called with the
complete Orange data Table. Then _encode_data_instance needs to extract
data to be embedded from the RowInstance. For images, it takes the image
path from the table, load image, and transform it into bytes.
_encode_data_instance and defines self.content_type. For sending a list
with data items use embedd_table function.
Attributes
----------
Expand Down Expand Up @@ -58,8 +63,7 @@ def __init__(
self._model = model_name
self.embedder_type = embedder_type

# attribute that offers support for cancelling the embedding
# if ran in another thread
# remove in 3.33
self._cancelled = False

self.machine_id = None
Expand All @@ -69,20 +73,22 @@ def __init__(
) or str(uuid.getnode())
except TypeError:
self.machine_id = str(uuid.getnode())
self.session_id = str(random.randint(1, 1e10))
self.session_id = str(random.randint(1, int(1e10)))

self._cache = EmbedderCache(model_name)

# default embedding timeouts are too small we need to increase them
self.timeout = 180
self.num_parallel_requests = 0
self.max_parallel = max_parallel_requests
self.max_parallel_requests = max_parallel_requests

self.content_type = None # need to be set in a class inheriting

def embedd_data(
self,
data: List[Any],
processed_callback: Callable[[bool], None] = None,
self,
data: List[Any],
processed_callback: Optional[Callable] = None,
*,
callback: Callable = dummy_callback,
) -> List[Optional[List[float]]]:
"""
This function repeats calling embedding function until all items
Expand All @@ -94,9 +100,12 @@ def embedd_data(
data
List with data that needs to be embedded.
processed_callback
Deprecated: remove in 3.33
A function that is called after each item is embedded
by either getting a successful response from the server,
getting the result from cache or skipping the item.
callback
Callback for reporting the progress in share of embedded items
Returns
-------
Expand All @@ -111,25 +120,26 @@ def embedd_data(
EmbeddingCancelledException:
If cancelled attribute is set to True (default=False).
"""
# if there is less items than 10 connection error should be raised
# earlier
# if there is less items than 10 connection error should be raised earlier
self.max_errors = min(len(data) * self.MAX_REPEATS, 10)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
embeddings = asyncio.get_event_loop().run_until_complete(
self.embedd_batch(data, processed_callback)
self.embedd_batch(data, processed_callback, callback=callback)
)
except Exception:
finally:
loop.close()
raise

loop.close()
return embeddings

async def embedd_batch(
self, data: List[Any], proc_callback: Callable[[bool], None] = None
self,
data: List[Any],
proc_callback: Optional[Callable] = None,
*,
callback: Callable = dummy_callback,
) -> List[Optional[List[float]]]:
"""
Function perform embedding of a batch of data items.
Expand All @@ -138,10 +148,8 @@ async def embedd_batch(
----------
data
A list of data that must be embedded.
proc_callback
A function that is called after each item is fully processed
by either getting a successful response from the server,
getting the result from cache or skipping the item.
callback
Callback for reporting the progress in share of embedded items
Returns
-------
Expand All @@ -153,32 +161,79 @@ async def embedd_batch(
EmbeddingCancelledException:
If cancelled attribute is set to True (default=False).
"""
requests = []
# in Orange 3.33 keep content of the if - remove if clause and complete else
if proc_callback is None:
progress_items = iter(linspace(0, 1, len(data)))

def success_callback():
"""Callback called on every successful embedding"""
callback(next(progress_items))
else:
warnings.warn(
"proc_callback is deprecated and will be removed in version 3.33, "
"use callback instead",
FutureWarning,
)
success_callback = partial(proc_callback, True)

results = [None] * len(data)
queue = asyncio.Queue()

# fill the queue with items to embedd
for i, item in enumerate(data):
queue.put_nowait(TaskItem(id=i, item=item, no_repeats=0))

async with AsyncClient(
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
) as client:
for p in data:
if self._cancelled:
raise EmbeddingCancelledException()
requests.append(self._send_to_server(p, client, proc_callback))
tasks = self._init_workers(client, queue, results, success_callback)

embeddings = await asyncio.gather(*requests)
self._cache.persist_cache()
assert self.num_parallel_requests == 0
# wait for the queue to complete or one of workers to exit
queue_complete = asyncio.create_task(queue.join())
await asyncio.wait(
[queue_complete, *tasks], return_when=asyncio.FIRST_COMPLETED
)

return embeddings
# Cancel worker tasks when done
queue_complete.cancel()
await self._cancel_workers(tasks)

async def __wait_until_released(self) -> None:
while self.num_parallel_requests >= self.max_parallel:
await asyncio.sleep(0.1)
self._cache.persist_cache()
return results

def _init_workers(self, client, queue, results, callback):
"""Init required number of workers"""
t = [
asyncio.create_task(self._send_to_server(client, queue, results, callback))
for _ in range(self.max_parallel_requests)
]
log.debug("Created %d workers", self.max_parallel_requests)
return t

@staticmethod
async def _cancel_workers(tasks):
"""Cancel worker at the end"""
log.debug("Canceling workers")
try:
# try to catch any potential exceptions
await asyncio.gather(*tasks)
except Exception as ex:
# raise exceptions gathered from an failed worker
raise ex
finally:
# cancel all tasks in both cases
for task in tasks:
task.cancel()
# Wait until all worker tasks are cancelled.
await asyncio.gather(*tasks, return_exceptions=True)
log.debug("All workers canceled")

# remove in 3.33
def __check_cancelled(self):
if self._cancelled:
raise EmbeddingCancelledException()

async def _encode_data_instance(
self, data_instance: Any
) -> Optional[bytes]:
async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
"""
The reimplementation of this function must implement the procedure
to encode the data item in a string format that will be sent to the
Expand All @@ -197,63 +252,73 @@ async def _encode_data_instance(
raise NotImplementedError

async def _send_to_server(
self,
data_instance: Any,
client: AsyncClient,
proc_callback: Callable[[bool], None] = None,
) -> Optional[List[float]]:
self,
client: AsyncClient,
queue: asyncio.Queue,
results: List,
proc_callback: Callable,
):
"""
Function get an data instance. It extract data from it and send them to
server and retrieve responses.
Worker that embedds data. It is pulling items from the queue until
it is empty. It is runs until anything is in the queue, or it is canceled
Parameters
----------
data_instance
Single row of the input table.
client
HTTPX client that communicates with the server
queue
The queue with items of type TaskItem to be embedded
results
The list to append results in. The list has length equal to numbers
of all items to embedd. The result need to be inserted at the index
defined in queue items.
proc_callback
A function that is called after each item is fully processed
by either getting a successful response from the server,
getting the result from cache or skipping the item.
Returns
-------
Embedding. For items that are not successfully embedded returns None.
"""
await self.__wait_until_released()
self.__check_cancelled()

self.num_parallel_requests += 1
# load bytes
data_bytes = await self._encode_data_instance(data_instance)
if data_bytes is None:
self.num_parallel_requests -= 1
return None

# if data in cache return it
cache_key = self._cache.md5_hash(data_bytes)
emb = self._cache.get_cached_result_or_none(cache_key)

if emb is None:
# in case that embedding not sucessfull resend it to the server
# maximally for MAX_REPEATS time
for i in range(1, self.MAX_REPEATS + 1):
self.__check_cancelled()
while not queue.empty():
# remove in 3.33
self.__check_cancelled()

# get item from the queue
i, data_instance, num_repeats = await queue.get()

# load bytes
data_bytes = await self._encode_data_instance(data_instance)
if data_bytes is None:
continue

# retrieve embedded item from the local cache
cache_key = self._cache.md5_hash(data_bytes)
log.debug("Embedding %s", cache_key)
emb = self._cache.get_cached_result_or_none(cache_key)

if emb is None:
# send the item to the server for embedding if not in the local cache
log.debug("Sending to the server: %s", cache_key)
url = (
f"/{self.embedder_type}/{self._model}?"
f"machine={self.machine_id}"
f"&session={self.session_id}&retry={i}"
f"/{self.embedder_type}/{self._model}?machine={self.machine_id}"
f"&session={self.session_id}&retry={num_repeats+1}"
)
emb = await self._send_request(client, data_bytes, url)
if emb is not None:
self._cache.add(cache_key, emb)
break # repeat only when embedding None
if proc_callback:
proc_callback(emb is not None)

self.num_parallel_requests -= 1
return emb
if emb is not None:
# store result if embedding is successful
log.debug("Successfully embedded: %s", cache_key)
results[i] = emb
proc_callback()
elif num_repeats+1 < self.MAX_REPEATS:
log.debug("Embedding unsuccessful - reading to queue: %s", cache_key)
# if embedding not successful put the item to queue to be handled at
# the end - the item is put to the end since it is possible that server
# still process the request and the result will be in the cache later
# repeating the request immediately may result in another fail when
# processing takes longer
queue.put_nowait(TaskItem(i, data_instance, no_repeats=num_repeats+1))
queue.task_done()

async def _send_request(
self, client: AsyncClient, data: bytes, url: str
Expand Down Expand Up @@ -284,27 +349,23 @@ async def _send_request(
response = await client.post(url, headers=headers, data=data)
except ReadTimeout as ex:
log.debug("Read timeout", exc_info=True)
# it happens when server do not respond in 60 seconds, in
# this case we return None and items will be resend later
# it happens when server do not respond in time defined by timeout
# return None and items will be resend later

# if it happens more than in ten consecutive cases it means
# sth is wrong with embedder we stop embedding
self.count_read_errors += 1

if self.count_read_errors >= self.max_errors:
self.num_parallel_requests = 0 # for safety reasons
raise EmbeddingConnectionError from ex
return None
except (OSError, NetworkError) as ex:
log.debug("Network error", exc_info=True)
# it happens when no connection and items cannot be sent to the
# server
# we count number of consecutive errors
# it happens when no connection and items cannot be sent to server

# if more than 10 consecutive errors it means there is no
# connection so we stop embedding with EmbeddingConnectionError
self.count_connection_errors += 1
if self.count_connection_errors >= self.max_errors:
self.num_parallel_requests = 0 # for safety reasons
raise EmbeddingConnectionError from ex
return None
except Exception:
Expand Down Expand Up @@ -343,5 +404,11 @@ def _parse_response(response: Response) -> Optional[List[float]]:
def clear_cache(self):
self._cache.clear_cache()

# remove in 3.33
def set_cancelled(self):
warnings.warn(
"set_cancelled is deprecated and will be removed in version 3.33, "
"the process can be canceled by raising Error in callback",
FutureWarning,
)
self._cancelled = True
Loading

0 comments on commit b8928c2

Please sign in to comment.