Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Server embedder: use queue, handle unsuccessful requests at the end #5835

Merged
merged 2 commits into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 159 additions & 92 deletions Orange/misc/server_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,34 @@
import logging
import random
import uuid
import warnings
from collections import namedtuple
from functools import partial
from json import JSONDecodeError
from os import getenv
from typing import Any, Callable, List, Optional

from AnyQt.QtCore import QSettings
from httpx import AsyncClient, NetworkError, ReadTimeout, Response
from numpy import linspace

from Orange.misc.utils.embedder_utils import (EmbedderCache,
EmbeddingCancelledException,
EmbeddingConnectionError,
get_proxies)
from Orange.misc.utils.embedder_utils import (
EmbedderCache,
EmbeddingCancelledException,
EmbeddingConnectionError,
get_proxies,
)
from Orange.util import dummy_callback

log = logging.getLogger(__name__)
TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats"))


class ServerEmbedderCommunicator:
"""
This class needs to be inherited by the class which re-implements
_encode_data_instance and defines self.content_type. For sending a table
with data items use embedd_table function. This one is called with the
complete Orange data Table. Then _encode_data_instance needs to extract
data to be embedded from the RowInstance. For images, it takes the image
path from the table, load image, and transform it into bytes.
_encode_data_instance and defines self.content_type. For sending a list
with data items use embedd_table function.

Attributes
----------
Expand Down Expand Up @@ -58,8 +63,7 @@ def __init__(
self._model = model_name
self.embedder_type = embedder_type

# attribute that offers support for cancelling the embedding
# if ran in another thread
# remove in 3.33
self._cancelled = False

self.machine_id = None
Expand All @@ -69,20 +73,22 @@ def __init__(
) or str(uuid.getnode())
except TypeError:
self.machine_id = str(uuid.getnode())
self.session_id = str(random.randint(1, 1e10))
self.session_id = str(random.randint(1, int(1e10)))

self._cache = EmbedderCache(model_name)

# default embedding timeouts are too small we need to increase them
self.timeout = 180
self.num_parallel_requests = 0
self.max_parallel = max_parallel_requests
self.max_parallel_requests = max_parallel_requests

self.content_type = None # need to be set in a class inheriting

def embedd_data(
self,
data: List[Any],
processed_callback: Callable[[bool], None] = None,
self,
data: List[Any],
processed_callback: Optional[Callable] = None,
*,
callback: Callable = dummy_callback,
) -> List[Optional[List[float]]]:
"""
This function repeats calling embedding function until all items
Expand All @@ -94,9 +100,12 @@ def embedd_data(
data
List with data that needs to be embedded.
processed_callback
Deprecated: remove in 3.33
A function that is called after each item is embedded
by either getting a successful response from the server,
getting the result from cache or skipping the item.
callback
Callback for reporting the progress in share of embedded items

Returns
-------
Expand All @@ -111,25 +120,26 @@ def embedd_data(
EmbeddingCancelledException:
If cancelled attribute is set to True (default=False).
"""
# if there is less items than 10 connection error should be raised
# earlier
# if there is less items than 10 connection error should be raised earlier
self.max_errors = min(len(data) * self.MAX_REPEATS, 10)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
embeddings = asyncio.get_event_loop().run_until_complete(
self.embedd_batch(data, processed_callback)
self.embedd_batch(data, processed_callback, callback=callback)
)
except Exception:
finally:
loop.close()
raise

loop.close()
return embeddings

async def embedd_batch(
self, data: List[Any], proc_callback: Callable[[bool], None] = None
self,
data: List[Any],
proc_callback: Optional[Callable] = None,
*,
callback: Callable = dummy_callback,
markotoplak marked this conversation as resolved.
Show resolved Hide resolved
) -> List[Optional[List[float]]]:
"""
Function perform embedding of a batch of data items.
Expand All @@ -138,10 +148,8 @@ async def embedd_batch(
----------
data
A list of data that must be embedded.
proc_callback
A function that is called after each item is fully processed
by either getting a successful response from the server,
getting the result from cache or skipping the item.
callback
Callback for reporting the progress in share of embedded items

Returns
-------
Expand All @@ -153,32 +161,79 @@ async def embedd_batch(
EmbeddingCancelledException:
If cancelled attribute is set to True (default=False).
"""
requests = []
# in Orange 3.33 keep content of the if - remove if clause and complete else
if proc_callback is None:
progress_items = iter(linspace(0, 1, len(data)))

def success_callback():
"""Callback called on every successful embedding"""
callback(next(progress_items))
else:
warnings.warn(
"proc_callback is deprecated and will be removed in version 3.33, "
"use callback instead",
FutureWarning,
)
success_callback = partial(proc_callback, True)

results = [None] * len(data)
queue = asyncio.Queue()

# fill the queue with items to embedd
for i, item in enumerate(data):
queue.put_nowait(TaskItem(id=i, item=item, no_repeats=0))

async with AsyncClient(
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
) as client:
for p in data:
if self._cancelled:
raise EmbeddingCancelledException()
requests.append(self._send_to_server(p, client, proc_callback))
tasks = self._init_workers(client, queue, results, success_callback)

embeddings = await asyncio.gather(*requests)
self._cache.persist_cache()
assert self.num_parallel_requests == 0
# wait for the queue to complete or one of workers to exit
queue_complete = asyncio.create_task(queue.join())
await asyncio.wait(
markotoplak marked this conversation as resolved.
Show resolved Hide resolved
[queue_complete, *tasks], return_when=asyncio.FIRST_COMPLETED
)

return embeddings
# Cancel worker tasks when done
queue_complete.cancel()
await self._cancel_workers(tasks)

async def __wait_until_released(self) -> None:
while self.num_parallel_requests >= self.max_parallel:
await asyncio.sleep(0.1)
self._cache.persist_cache()
return results

def _init_workers(self, client, queue, results, callback):
"""Init required number of workers"""
t = [
asyncio.create_task(self._send_to_server(client, queue, results, callback))
for _ in range(self.max_parallel_requests)
]
log.debug("Created %d workers", self.max_parallel_requests)
return t

@staticmethod
async def _cancel_workers(tasks):
"""Cancel worker at the end"""
log.debug("Canceling workers")
try:
# try to catch any potential exceptions
await asyncio.gather(*tasks)
except Exception as ex:
# raise exceptions gathered from an failed worker
raise ex
finally:
# cancel all tasks in both cases
for task in tasks:
task.cancel()
# Wait until all worker tasks are cancelled.
await asyncio.gather(*tasks, return_exceptions=True)
log.debug("All workers canceled")

# remove in 3.33
def __check_cancelled(self):
if self._cancelled:
raise EmbeddingCancelledException()

async def _encode_data_instance(
self, data_instance: Any
) -> Optional[bytes]:
async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
"""
The reimplementation of this function must implement the procedure
to encode the data item in a string format that will be sent to the
Expand All @@ -197,63 +252,73 @@ async def _encode_data_instance(
raise NotImplementedError

async def _send_to_server(
self,
data_instance: Any,
client: AsyncClient,
proc_callback: Callable[[bool], None] = None,
) -> Optional[List[float]]:
self,
client: AsyncClient,
queue: asyncio.Queue,
results: List,
proc_callback: Callable,
):
"""
Function get an data instance. It extract data from it and send them to
server and retrieve responses.
Worker that embedds data. It is pulling items from the queue until
it is empty. It is runs until anything is in the queue, or it is canceled

Parameters
----------
data_instance
Single row of the input table.
client
HTTPX client that communicates with the server
queue
The queue with items of type TaskItem to be embedded
results
The list to append results in. The list has length equal to numbers
of all items to embedd. The result need to be inserted at the index
defined in queue items.
proc_callback
A function that is called after each item is fully processed
by either getting a successful response from the server,
getting the result from cache or skipping the item.

Returns
-------
Embedding. For items that are not successfully embedded returns None.
"""
await self.__wait_until_released()
self.__check_cancelled()

self.num_parallel_requests += 1
# load bytes
data_bytes = await self._encode_data_instance(data_instance)
if data_bytes is None:
self.num_parallel_requests -= 1
return None

# if data in cache return it
cache_key = self._cache.md5_hash(data_bytes)
emb = self._cache.get_cached_result_or_none(cache_key)

if emb is None:
# in case that embedding not sucessfull resend it to the server
# maximally for MAX_REPEATS time
for i in range(1, self.MAX_REPEATS + 1):
self.__check_cancelled()
while not queue.empty():
# remove in 3.33
self.__check_cancelled()

# get item from the queue
i, data_instance, num_repeats = await queue.get()

# load bytes
data_bytes = await self._encode_data_instance(data_instance)
if data_bytes is None:
continue

# retrieve embedded item from the local cache
cache_key = self._cache.md5_hash(data_bytes)
log.debug("Embedding %s", cache_key)
emb = self._cache.get_cached_result_or_none(cache_key)

if emb is None:
# send the item to the server for embedding if not in the local cache
log.debug("Sending to the server: %s", cache_key)
url = (
f"/{self.embedder_type}/{self._model}?"
f"machine={self.machine_id}"
f"&session={self.session_id}&retry={i}"
f"/{self.embedder_type}/{self._model}?machine={self.machine_id}"
f"&session={self.session_id}&retry={num_repeats+1}"
)
emb = await self._send_request(client, data_bytes, url)
if emb is not None:
self._cache.add(cache_key, emb)
break # repeat only when embedding None
if proc_callback:
proc_callback(emb is not None)

self.num_parallel_requests -= 1
return emb
if emb is not None:
# store result if embedding is successful
log.debug("Successfully embedded: %s", cache_key)
results[i] = emb
proc_callback()
elif num_repeats+1 < self.MAX_REPEATS:
log.debug("Embedding unsuccessful - reading to queue: %s", cache_key)
# if embedding not successful put the item to queue to be handled at
# the end - the item is put to the end since it is possible that server
# still process the request and the result will be in the cache later
# repeating the request immediately may result in another fail when
# processing takes longer
queue.put_nowait(TaskItem(i, data_instance, no_repeats=num_repeats+1))
queue.task_done()
markotoplak marked this conversation as resolved.
Show resolved Hide resolved

async def _send_request(
self, client: AsyncClient, data: bytes, url: str
Expand Down Expand Up @@ -284,27 +349,23 @@ async def _send_request(
response = await client.post(url, headers=headers, data=data)
except ReadTimeout as ex:
log.debug("Read timeout", exc_info=True)
# it happens when server do not respond in 60 seconds, in
# this case we return None and items will be resend later
# it happens when server do not respond in time defined by timeout
# return None and items will be resend later

# if it happens more than in ten consecutive cases it means
# sth is wrong with embedder we stop embedding
self.count_read_errors += 1

if self.count_read_errors >= self.max_errors:
self.num_parallel_requests = 0 # for safety reasons
raise EmbeddingConnectionError from ex
return None
except (OSError, NetworkError) as ex:
log.debug("Network error", exc_info=True)
# it happens when no connection and items cannot be sent to the
# server
# we count number of consecutive errors
# it happens when no connection and items cannot be sent to server

# if more than 10 consecutive errors it means there is no
# connection so we stop embedding with EmbeddingConnectionError
self.count_connection_errors += 1
if self.count_connection_errors >= self.max_errors:
self.num_parallel_requests = 0 # for safety reasons
raise EmbeddingConnectionError from ex
return None
except Exception:
Expand Down Expand Up @@ -343,5 +404,11 @@ def _parse_response(response: Response) -> Optional[List[float]]:
def clear_cache(self):
self._cache.clear_cache()

# remove in 3.33
def set_cancelled(self):
warnings.warn(
"set_cancelled is deprecated and will be removed in version 3.33, "
"the process can be canceled by raising Error in callback",
FutureWarning,
)
self._cancelled = True
Loading