From 9f5322a6b67068a47a4185eb0f931f27e7c81b4b Mon Sep 17 00:00:00 2001
From: Primoz Godec
Date: Wed, 9 Feb 2022 09:48:25 +0100
Subject: [PATCH] server_embedder: modify callback to match others
---
Orange/misc/server_embedder.py | 66 +++++++++++++++++------
Orange/misc/tests/test_server_embedder.py | 25 +++++++++
2 files changed, 74 insertions(+), 17 deletions(-)
diff --git a/Orange/misc/server_embedder.py b/Orange/misc/server_embedder.py
index 52ff16b0eaf..7636a719e1f 100644
--- a/Orange/misc/server_embedder.py
+++ b/Orange/misc/server_embedder.py
@@ -3,13 +3,16 @@
import logging
import random
import uuid
+import warnings
from collections import namedtuple
+from functools import partial
from json import JSONDecodeError
from os import getenv
from typing import Any, Callable, List, Optional
from AnyQt.QtCore import QSettings
from httpx import AsyncClient, NetworkError, ReadTimeout, Response
+from numpy import linspace
from Orange.misc.utils.embedder_utils import (
EmbedderCache,
@@ -17,6 +20,7 @@
EmbeddingConnectionError,
get_proxies,
)
+from Orange.util import dummy_callback
log = logging.getLogger(__name__)
TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats"))
@@ -59,8 +63,7 @@ def __init__(
self._model = model_name
self.embedder_type = embedder_type
- # attribute that offers support for cancelling the embedding
- # if ran in another thread
+ # remove in 3.33
self._cancelled = False
self.machine_id = None
@@ -81,9 +84,11 @@ def __init__(
self.content_type = None # need to be set in a class inheriting
def embedd_data(
- self,
- data: List[Any],
- processed_callback: Callable[[bool], None] = None,
+ self,
+ data: List[Any],
+ processed_callback: Optional[Callable] = None,
+ *,
+ callback: Callable = dummy_callback,
) -> List[Optional[List[float]]]:
"""
This function repeats calling embedding function until all items
@@ -95,9 +100,12 @@ def embedd_data(
data
List with data that needs to be embedded.
processed_callback
+ Deprecated: remove in 3.33
A function that is called after each item is embedded
by either getting a successful response from the server,
getting the result from cache or skipping the item.
+ callback
+ Callback for reporting the progress in share of embedded items
Returns
-------
@@ -119,7 +127,7 @@ def embedd_data(
asyncio.set_event_loop(loop)
try:
embeddings = asyncio.get_event_loop().run_until_complete(
- self.embedd_batch(data, processed_callback)
+ self.embedd_batch(data, processed_callback, callback=callback)
)
finally:
loop.close()
@@ -127,7 +135,11 @@ def embedd_data(
return embeddings
async def embedd_batch(
- self, data: List[Any], proc_callback: Callable[[bool], None] = None
+ self,
+ data: List[Any],
+ proc_callback: Optional[Callable] = None,
+ *,
+ callback: Callable = dummy_callback,
) -> List[Optional[List[float]]]:
"""
Function perform embedding of a batch of data items.
@@ -136,10 +148,8 @@ async def embedd_batch(
----------
data
A list of data that must be embedded.
- proc_callback
- A function that is called after each item is fully processed
- by either getting a successful response from the server,
- getting the result from cache or skipping the item.
+ callback
+ Callback for reporting the progress in share of embedded items
Returns
-------
@@ -151,6 +161,21 @@ async def embedd_batch(
EmbeddingCancelledException:
If cancelled attribute is set to True (default=False).
"""
+ # in Orange 3.33 keep content of the if - remove if clause and complete else
+ if proc_callback is None:
+ progress_items = iter(linspace(0, 1, len(data)))
+
+ def success_callback():
+ """Callback called on every successful embedding"""
+ callback(next(progress_items))
+ else:
+ warnings.warn(
+ "proc_callback is deprecated and will be removed in version 3.33, "
+ "use callback instead",
+ FutureWarning,
+ )
+ success_callback = partial(proc_callback, True)
+
results = [None] * len(data)
queue = asyncio.Queue()
@@ -161,7 +186,7 @@ async def embedd_batch(
async with AsyncClient(
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
) as client:
- tasks = self._init_workers(client, queue, results, proc_callback)
+ tasks = self._init_workers(client, queue, results, success_callback)
# wait for the queue to complete or one of workers to exit
queue_complete = asyncio.create_task(queue.join())
@@ -203,6 +228,7 @@ async def _cancel_workers(tasks):
await asyncio.gather(*tasks, return_exceptions=True)
log.debug("All workers canceled")
+ # remove in 3.33
def __check_cancelled(self):
if self._cancelled:
raise EmbeddingCancelledException()
@@ -230,11 +256,11 @@ async def _send_to_server(
client: AsyncClient,
queue: asyncio.Queue,
results: List,
- proc_callback: Callable[[bool], None] = None,
+ proc_callback: Callable,
):
"""
- Worker that embedds data. It is pulling items from the until the queue
- is empty. It is canceled by embedd_batch all tasks are finished
+ Worker that embedds data. It is pulling items from the queue until
+ it is empty. It is runs until anything is in the queue, or it is canceled
Parameters
----------
@@ -252,6 +278,7 @@ async def _send_to_server(
getting the result from cache or skipping the item.
"""
while not queue.empty():
+ # remove in 3.33
self.__check_cancelled()
# get item from the queue
@@ -283,8 +310,7 @@ async def _send_to_server(
# store result if embedding is successful
log.debug("Successfully embedded: %s", cache_key)
results[i] = emb
- if proc_callback:
- proc_callback(emb is not None)
+ proc_callback()
elif num_repeats < self.MAX_REPEATS:
log.debug("Embedding unsuccessful - reading to queue: %s", cache_key)
# if embedding not successful put the item to queue to be handled at
@@ -379,5 +405,11 @@ def _parse_response(response: Response) -> Optional[List[float]]:
def clear_cache(self):
self._cache.clear_cache()
+ # remove in 3.33
def set_cancelled(self):
+ warnings.warn(
+ "set_cancelled is deprecated and will be removed in version 3.33, "
+ "the process can be canceled by raising Error in callback",
+ FutureWarning,
+ )
self._cancelled = True
diff --git a/Orange/misc/tests/test_server_embedder.py b/Orange/misc/tests/test_server_embedder.py
index fc001a2cdd1..c9a7d9a59cf 100644
--- a/Orange/misc/tests/test_server_embedder.py
+++ b/Orange/misc/tests/test_server_embedder.py
@@ -5,6 +5,7 @@
import numpy as np
from httpx import ReadTimeout
+import Orange
from Orange.data import Domain, StringVariable, Table
from Orange.misc.tests.example_embedder import ExampleServerEmbedder
@@ -173,6 +174,30 @@ def test_retries(self, mock):
self.embedder.embedd_data(self.test_data)
self.assertEqual(len(self.test_data) * 3, mock.call_count)
+ @patch(_HTTPX_POST_METHOD, regular_dummy_sr)
+ def test_callback(self):
+ mock = MagicMock()
+ self.embedder.embedd_data(self.test_data, callback=mock)
+
+ process_items = [call(x) for x in np.linspace(0, 1, len(self.test_data))]
+ mock.assert_has_calls(process_items)
+
+ @patch(_HTTPX_POST_METHOD, regular_dummy_sr)
+ def test_deprecated(self):
+ """
+ When this start to fail:
+ - remove process_callback parameter and marked places connected to this param
+ - remove set_canceled and marked places connected to this method
+ - this test
+ """
+ self.assertGreaterEqual("3.33.0", Orange.__version__)
+
+ mock = MagicMock()
+ self.embedder.embedd_data(self.test_data, processed_callback=mock)
+
+ process_items = [call(True) for _ in range(len(self.test_data))]
+ mock.assert_has_calls(process_items)
+
if __name__ == "__main__":
unittest.main()