From 9f5322a6b67068a47a4185eb0f931f27e7c81b4b Mon Sep 17 00:00:00 2001 From: Primoz Godec Date: Wed, 9 Feb 2022 09:48:25 +0100 Subject: [PATCH] server_embedder: modify callback to match others --- Orange/misc/server_embedder.py | 66 +++++++++++++++++------ Orange/misc/tests/test_server_embedder.py | 25 +++++++++ 2 files changed, 74 insertions(+), 17 deletions(-) diff --git a/Orange/misc/server_embedder.py b/Orange/misc/server_embedder.py index 52ff16b0eaf..7636a719e1f 100644 --- a/Orange/misc/server_embedder.py +++ b/Orange/misc/server_embedder.py @@ -3,13 +3,16 @@ import logging import random import uuid +import warnings from collections import namedtuple +from functools import partial from json import JSONDecodeError from os import getenv from typing import Any, Callable, List, Optional from AnyQt.QtCore import QSettings from httpx import AsyncClient, NetworkError, ReadTimeout, Response +from numpy import linspace from Orange.misc.utils.embedder_utils import ( EmbedderCache, @@ -17,6 +20,7 @@ EmbeddingConnectionError, get_proxies, ) +from Orange.util import dummy_callback log = logging.getLogger(__name__) TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats")) @@ -59,8 +63,7 @@ def __init__( self._model = model_name self.embedder_type = embedder_type - # attribute that offers support for cancelling the embedding - # if ran in another thread + # remove in 3.33 self._cancelled = False self.machine_id = None @@ -81,9 +84,11 @@ def __init__( self.content_type = None # need to be set in a class inheriting def embedd_data( - self, - data: List[Any], - processed_callback: Callable[[bool], None] = None, + self, + data: List[Any], + processed_callback: Optional[Callable] = None, + *, + callback: Callable = dummy_callback, ) -> List[Optional[List[float]]]: """ This function repeats calling embedding function until all items @@ -95,9 +100,12 @@ def embedd_data( data List with data that needs to be embedded. processed_callback + Deprecated: remove in 3.33 A function that is called after each item is embedded by either getting a successful response from the server, getting the result from cache or skipping the item. + callback + Callback for reporting the progress in share of embedded items Returns ------- @@ -119,7 +127,7 @@ def embedd_data( asyncio.set_event_loop(loop) try: embeddings = asyncio.get_event_loop().run_until_complete( - self.embedd_batch(data, processed_callback) + self.embedd_batch(data, processed_callback, callback=callback) ) finally: loop.close() @@ -127,7 +135,11 @@ def embedd_data( return embeddings async def embedd_batch( - self, data: List[Any], proc_callback: Callable[[bool], None] = None + self, + data: List[Any], + proc_callback: Optional[Callable] = None, + *, + callback: Callable = dummy_callback, ) -> List[Optional[List[float]]]: """ Function perform embedding of a batch of data items. @@ -136,10 +148,8 @@ async def embedd_batch( ---------- data A list of data that must be embedded. - proc_callback - A function that is called after each item is fully processed - by either getting a successful response from the server, - getting the result from cache or skipping the item. + callback + Callback for reporting the progress in share of embedded items Returns ------- @@ -151,6 +161,21 @@ async def embedd_batch( EmbeddingCancelledException: If cancelled attribute is set to True (default=False). """ + # in Orange 3.33 keep content of the if - remove if clause and complete else + if proc_callback is None: + progress_items = iter(linspace(0, 1, len(data))) + + def success_callback(): + """Callback called on every successful embedding""" + callback(next(progress_items)) + else: + warnings.warn( + "proc_callback is deprecated and will be removed in version 3.33, " + "use callback instead", + FutureWarning, + ) + success_callback = partial(proc_callback, True) + results = [None] * len(data) queue = asyncio.Queue() @@ -161,7 +186,7 @@ async def embedd_batch( async with AsyncClient( timeout=self.timeout, base_url=self.server_url, proxies=get_proxies() ) as client: - tasks = self._init_workers(client, queue, results, proc_callback) + tasks = self._init_workers(client, queue, results, success_callback) # wait for the queue to complete or one of workers to exit queue_complete = asyncio.create_task(queue.join()) @@ -203,6 +228,7 @@ async def _cancel_workers(tasks): await asyncio.gather(*tasks, return_exceptions=True) log.debug("All workers canceled") + # remove in 3.33 def __check_cancelled(self): if self._cancelled: raise EmbeddingCancelledException() @@ -230,11 +256,11 @@ async def _send_to_server( client: AsyncClient, queue: asyncio.Queue, results: List, - proc_callback: Callable[[bool], None] = None, + proc_callback: Callable, ): """ - Worker that embedds data. It is pulling items from the until the queue - is empty. It is canceled by embedd_batch all tasks are finished + Worker that embedds data. It is pulling items from the queue until + it is empty. It is runs until anything is in the queue, or it is canceled Parameters ---------- @@ -252,6 +278,7 @@ async def _send_to_server( getting the result from cache or skipping the item. """ while not queue.empty(): + # remove in 3.33 self.__check_cancelled() # get item from the queue @@ -283,8 +310,7 @@ async def _send_to_server( # store result if embedding is successful log.debug("Successfully embedded: %s", cache_key) results[i] = emb - if proc_callback: - proc_callback(emb is not None) + proc_callback() elif num_repeats < self.MAX_REPEATS: log.debug("Embedding unsuccessful - reading to queue: %s", cache_key) # if embedding not successful put the item to queue to be handled at @@ -379,5 +405,11 @@ def _parse_response(response: Response) -> Optional[List[float]]: def clear_cache(self): self._cache.clear_cache() + # remove in 3.33 def set_cancelled(self): + warnings.warn( + "set_cancelled is deprecated and will be removed in version 3.33, " + "the process can be canceled by raising Error in callback", + FutureWarning, + ) self._cancelled = True diff --git a/Orange/misc/tests/test_server_embedder.py b/Orange/misc/tests/test_server_embedder.py index fc001a2cdd1..c9a7d9a59cf 100644 --- a/Orange/misc/tests/test_server_embedder.py +++ b/Orange/misc/tests/test_server_embedder.py @@ -5,6 +5,7 @@ import numpy as np from httpx import ReadTimeout +import Orange from Orange.data import Domain, StringVariable, Table from Orange.misc.tests.example_embedder import ExampleServerEmbedder @@ -173,6 +174,30 @@ def test_retries(self, mock): self.embedder.embedd_data(self.test_data) self.assertEqual(len(self.test_data) * 3, mock.call_count) + @patch(_HTTPX_POST_METHOD, regular_dummy_sr) + def test_callback(self): + mock = MagicMock() + self.embedder.embedd_data(self.test_data, callback=mock) + + process_items = [call(x) for x in np.linspace(0, 1, len(self.test_data))] + mock.assert_has_calls(process_items) + + @patch(_HTTPX_POST_METHOD, regular_dummy_sr) + def test_deprecated(self): + """ + When this start to fail: + - remove process_callback parameter and marked places connected to this param + - remove set_canceled and marked places connected to this method + - this test + """ + self.assertGreaterEqual("3.33.0", Orange.__version__) + + mock = MagicMock() + self.embedder.embedd_data(self.test_data, processed_callback=mock) + + process_items = [call(True) for _ in range(len(self.test_data))] + mock.assert_has_calls(process_items) + if __name__ == "__main__": unittest.main()