diff --git a/orangecontrib/text/vectorization/sbert.py b/orangecontrib/text/vectorization/sbert.py index 7019a2217..652c78a6f 100644 --- a/orangecontrib/text/vectorization/sbert.py +++ b/orangecontrib/text/vectorization/sbert.py @@ -1,8 +1,10 @@ +import asyncio import json import base64 import warnings import zlib import sys +from threading import Thread from typing import Any, List, Optional, Callable, Tuple import numpy as np @@ -124,11 +126,44 @@ def clear_cache(self): self._server_communicator.clear_cache() +class RunThread(Thread): + def __init__(self, func, *args, **kwargs): + self.func, self.args, self.kwargs = func, args, kwargs + self.result = None + super().__init__() + + def run(self): + self.result = asyncio.run(self.func(*self.args, **self.kwargs)) + + class _ServerCommunicator(ServerEmbedderCommunicator): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.content_type = "application/json" + def embedd_data( + self, + data: List[Any], + callback: Callable = dummy_callback, + ) -> List[Optional[List[float]]]: + # if there is less items than 10 connection error should be raised earlier + self.max_errors = min(len(data) * self.MAX_REPEATS, 10) + + # in case of ontology widget it happens that QSelectorEventLoop event + # loop already exists (non-running, by QT). asyncio.run fails in that case + # use separate thread in case of existing event loop + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop: + thread = RunThread(self.embedd_batch, data, callback=callback) + thread.start() + thread.join() + return thread.result + else: + return asyncio.run(self.embedd_batch(data, callback=callback)) + async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]: data = base64.b64encode( zlib.compress(data_instance.encode("utf-8", "replace"), level=-1)