diff --git a/Orange/misc/server_embedder.py b/Orange/misc/server_embedder.py index 52ff16b0eaf..cd05e41a10a 100644 --- a/Orange/misc/server_embedder.py +++ b/Orange/misc/server_embedder.py @@ -3,6 +3,7 @@ import logging import random import uuid +import warnings from collections import namedtuple from json import JSONDecodeError from os import getenv @@ -10,6 +11,7 @@ from AnyQt.QtCore import QSettings from httpx import AsyncClient, NetworkError, ReadTimeout, Response +from numpy import linspace from Orange.misc.utils.embedder_utils import ( EmbedderCache, @@ -17,6 +19,7 @@ EmbeddingConnectionError, get_proxies, ) +from Orange.util import dummy_callback log = logging.getLogger(__name__) TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats")) @@ -59,8 +62,7 @@ def __init__( self._model = model_name self.embedder_type = embedder_type - # attribute that offers support for cancelling the embedding - # if ran in another thread + # remove in 3.33 self._cancelled = False self.machine_id = None @@ -81,9 +83,10 @@ def __init__( self.content_type = None # need to be set in a class inheriting def embedd_data( - self, - data: List[Any], - processed_callback: Callable[[bool], None] = None, + self, + data: List[Any], + processed_callback: Optional[Callable] = None, + callback: Callable = dummy_callback, ) -> List[Optional[List[float]]]: """ This function repeats calling embedding function until all items @@ -95,9 +98,12 @@ def embedd_data( data List with data that needs to be embedded. processed_callback + Deprecated: remove in 3.33 A function that is called after each item is embedded by either getting a successful response from the server, getting the result from cache or skipping the item. + callback + Callback for reporting the progress in share of embedded items Returns ------- @@ -119,7 +125,7 @@ def embedd_data( asyncio.set_event_loop(loop) try: embeddings = asyncio.get_event_loop().run_until_complete( - self.embedd_batch(data, processed_callback) + self.embedd_batch(data, processed_callback, callback) ) finally: loop.close() @@ -127,7 +133,10 @@ def embedd_data( return embeddings async def embedd_batch( - self, data: List[Any], proc_callback: Callable[[bool], None] = None + self, + data: List[Any], + processed_calback: Optional[Callable] = None, + callback: Callable = dummy_callback, ) -> List[Optional[List[float]]]: """ Function perform embedding of a batch of data items. @@ -136,10 +145,8 @@ async def embedd_batch( ---------- data A list of data that must be embedded. - proc_callback - A function that is called after each item is fully processed - by either getting a successful response from the server, - getting the result from cache or skipping the item. + callback + Callback for reporting the progress in share of embedded items Returns ------- @@ -151,6 +158,22 @@ async def embedd_batch( EmbeddingCancelledException: If cancelled attribute is set to True (default=False). """ + # in Orange 3.33 keep content of the if - remove if clause and complete else + if processed_calback is None: + progress_items = iter(linspace(0, 1, len(data))) + + def success_callback(): + """Callback called on every successful embedding""" + callback(next(progress_items)) + + else: + warnings.warn( + "process_callback is deprecated and will be removed in version 3.33, " + "use callback instead", + FutureWarning, + ) + success_callback = processed_calback + results = [None] * len(data) queue = asyncio.Queue() @@ -161,7 +184,7 @@ async def embedd_batch( async with AsyncClient( timeout=self.timeout, base_url=self.server_url, proxies=get_proxies() ) as client: - tasks = self._init_workers(client, queue, results, proc_callback) + tasks = self._init_workers(client, queue, results, success_callback) # wait for the queue to complete or one of workers to exit queue_complete = asyncio.create_task(queue.join()) @@ -203,6 +226,7 @@ async def _cancel_workers(tasks): await asyncio.gather(*tasks, return_exceptions=True) log.debug("All workers canceled") + # remove in 3.33 def __check_cancelled(self): if self._cancelled: raise EmbeddingCancelledException() @@ -252,6 +276,7 @@ async def _send_to_server( getting the result from cache or skipping the item. """ while not queue.empty(): + # remove in 3.33 self.__check_cancelled() # get item from the queue @@ -284,7 +309,7 @@ async def _send_to_server( log.debug("Successfully embedded: %s", cache_key) results[i] = emb if proc_callback: - proc_callback(emb is not None) + proc_callback() elif num_repeats < self.MAX_REPEATS: log.debug("Embedding unsuccessful - reading to queue: %s", cache_key) # if embedding not successful put the item to queue to be handled at @@ -379,5 +404,11 @@ def _parse_response(response: Response) -> Optional[List[float]]: def clear_cache(self): self._cache.clear_cache() + # remove in 3.33 def set_cancelled(self): + warnings.warn( + "set_cancelled is deprecated and will be removed in version 3.33, " + "the process can be canceled by raising Error in callback", + FutureWarning, + ) self._cancelled = True diff --git a/Orange/misc/tests/test_server_embedder.py b/Orange/misc/tests/test_server_embedder.py index fc001a2cdd1..40442aa3d13 100644 --- a/Orange/misc/tests/test_server_embedder.py +++ b/Orange/misc/tests/test_server_embedder.py @@ -5,6 +5,7 @@ import numpy as np from httpx import ReadTimeout +import Orange from Orange.data import Domain, StringVariable, Table from Orange.misc.tests.example_embedder import ExampleServerEmbedder @@ -173,6 +174,23 @@ def test_retries(self, mock): self.embedder.embedd_data(self.test_data) self.assertEqual(len(self.test_data) * 3, mock.call_count) + @patch(_HTTPX_POST_METHOD, regular_dummy_sr) + def test_callback(self): + mock = MagicMock() + self.embedder.embedd_data(self.test_data, callback=mock) + + process_items = [call(x) for x in np.linspace(0, 1, len(self.test_data))] + mock.assert_has_calls(process_items) + + def test_deprecated(self): + """ + When this start to fail: + - remove process_callback parameter and marked places connected to this param + - remove set_canceled and marked places connected to this method + - this test + """ + self.assertGreaterEqual("3.33.0", Orange.__version__) + if __name__ == "__main__": unittest.main()