diff --git a/Orange/misc/server_embedder.py b/Orange/misc/server_embedder.py index bdcf945952d..52ff16b0eaf 100644 --- a/Orange/misc/server_embedder.py +++ b/Orange/misc/server_embedder.py @@ -3,6 +3,7 @@ import logging import random import uuid +from collections import namedtuple from json import JSONDecodeError from os import getenv from typing import Any, Callable, List, Optional @@ -10,22 +11,22 @@ from AnyQt.QtCore import QSettings from httpx import AsyncClient, NetworkError, ReadTimeout, Response -from Orange.misc.utils.embedder_utils import (EmbedderCache, - EmbeddingCancelledException, - EmbeddingConnectionError, - get_proxies) +from Orange.misc.utils.embedder_utils import ( + EmbedderCache, + EmbeddingCancelledException, + EmbeddingConnectionError, + get_proxies, +) log = logging.getLogger(__name__) +TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats")) class ServerEmbedderCommunicator: """ This class needs to be inherited by the class which re-implements - _encode_data_instance and defines self.content_type. For sending a table - with data items use embedd_table function. This one is called with the - complete Orange data Table. Then _encode_data_instance needs to extract - data to be embedded from the RowInstance. For images, it takes the image - path from the table, load image, and transform it into bytes. + _encode_data_instance and defines self.content_type. For sending a list + with data items use embedd_table function. Attributes ---------- @@ -69,14 +70,14 @@ def __init__( ) or str(uuid.getnode()) except TypeError: self.machine_id = str(uuid.getnode()) - self.session_id = str(random.randint(1, 1e10)) + self.session_id = str(random.randint(1, int(1e10))) self._cache = EmbedderCache(model_name) # default embedding timeouts are too small we need to increase them self.timeout = 180 - self.num_parallel_requests = 0 - self.max_parallel = max_parallel_requests + self.max_parallel_requests = max_parallel_requests + self.content_type = None # need to be set in a class inheriting def embedd_data( @@ -111,8 +112,7 @@ def embedd_data( EmbeddingCancelledException: If cancelled attribute is set to True (default=False). """ - # if there is less items than 10 connection error should be raised - # earlier + # if there is less items than 10 connection error should be raised earlier self.max_errors = min(len(data) * self.MAX_REPEATS, 10) loop = asyncio.new_event_loop() @@ -121,11 +121,9 @@ def embedd_data( embeddings = asyncio.get_event_loop().run_until_complete( self.embedd_batch(data, processed_callback) ) - except Exception: + finally: loop.close() - raise - loop.close() return embeddings async def embedd_batch( @@ -153,32 +151,63 @@ async def embedd_batch( EmbeddingCancelledException: If cancelled attribute is set to True (default=False). """ - requests = [] + results = [None] * len(data) + queue = asyncio.Queue() + + # fill the queue with items to embedd + for i, item in enumerate(data): + queue.put_nowait(TaskItem(id=i, item=item, no_repeats=0)) + async with AsyncClient( - timeout=self.timeout, base_url=self.server_url, proxies=get_proxies() + timeout=self.timeout, base_url=self.server_url, proxies=get_proxies() ) as client: - for p in data: - if self._cancelled: - raise EmbeddingCancelledException() - requests.append(self._send_to_server(p, client, proc_callback)) + tasks = self._init_workers(client, queue, results, proc_callback) + + # wait for the queue to complete or one of workers to exit + queue_complete = asyncio.create_task(queue.join()) + await asyncio.wait( + [queue_complete, *tasks], return_when=asyncio.FIRST_COMPLETED + ) + + # Cancel worker tasks when done + queue_complete.cancel() + await self._cancel_workers(tasks) - embeddings = await asyncio.gather(*requests) self._cache.persist_cache() - assert self.num_parallel_requests == 0 + return results - return embeddings + def _init_workers(self, client, queue, results, callback): + """Init required number of workers""" + t = [ + asyncio.create_task(self._send_to_server(client, queue, results, callback)) + for _ in range(self.max_parallel_requests) + ] + log.debug("Created %d workers", self.max_parallel_requests) + return t - async def __wait_until_released(self) -> None: - while self.num_parallel_requests >= self.max_parallel: - await asyncio.sleep(0.1) + @staticmethod + async def _cancel_workers(tasks): + """Cancel worker at the end""" + log.debug("Canceling workers") + try: + # try to catch any potential exceptions + await asyncio.gather(*tasks) + except Exception as ex: + # raise exceptions gathered from an failed worker + raise ex + finally: + # cancel all tasks in both cases + for task in tasks: + task.cancel() + # Wait until all worker tasks are cancelled. + await asyncio.gather(*tasks, return_exceptions=True) + log.debug("All workers canceled") def __check_cancelled(self): if self._cancelled: raise EmbeddingCancelledException() - async def _encode_data_instance( - self, data_instance: Any - ) -> Optional[bytes]: + async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]: """ The reimplementation of this function must implement the procedure to encode the data item in a string format that will be sent to the @@ -197,63 +226,74 @@ async def _encode_data_instance( raise NotImplementedError async def _send_to_server( - self, - data_instance: Any, - client: AsyncClient, - proc_callback: Callable[[bool], None] = None, - ) -> Optional[List[float]]: + self, + client: AsyncClient, + queue: asyncio.Queue, + results: List, + proc_callback: Callable[[bool], None] = None, + ): """ - Function get an data instance. It extract data from it and send them to - server and retrieve responses. + Worker that embedds data. It is pulling items from the until the queue + is empty. It is canceled by embedd_batch all tasks are finished Parameters ---------- - data_instance - Single row of the input table. client HTTPX client that communicates with the server + queue + The queue with items of type TaskItem to be embedded + results + The list to append results in. The list has length equal to numbers + of all items to embedd. The result need to be inserted at the index + defined in queue items. proc_callback A function that is called after each item is fully processed by either getting a successful response from the server, getting the result from cache or skipping the item. - - Returns - ------- - Embedding. For items that are not successfully embedded returns None. """ - await self.__wait_until_released() - self.__check_cancelled() - - self.num_parallel_requests += 1 - # load bytes - data_bytes = await self._encode_data_instance(data_instance) - if data_bytes is None: - self.num_parallel_requests -= 1 - return None - - # if data in cache return it - cache_key = self._cache.md5_hash(data_bytes) - emb = self._cache.get_cached_result_or_none(cache_key) - - if emb is None: - # in case that embedding not sucessfull resend it to the server - # maximally for MAX_REPEATS time - for i in range(1, self.MAX_REPEATS + 1): - self.__check_cancelled() + while not queue.empty(): + self.__check_cancelled() + + # get item from the queue + i, data_instance, num_repeats = await queue.get() + num_repeats += 1 + + # load bytes + data_bytes = await self._encode_data_instance(data_instance) + if data_bytes is None: + continue + + # retrieve embedded item from the local cache + cache_key = self._cache.md5_hash(data_bytes) + log.debug("Embedding %s", cache_key) + emb = self._cache.get_cached_result_or_none(cache_key) + + if emb is None: + # send the item to the server for embedding if not in the local cache + log.debug("Sending to the server: %s", cache_key) url = ( - f"/{self.embedder_type}/{self._model}?" - f"machine={self.machine_id}" - f"&session={self.session_id}&retry={i}" + f"/{self.embedder_type}/{self._model}?machine={self.machine_id}" + f"&session={self.session_id}&retry={num_repeats}" ) emb = await self._send_request(client, data_bytes, url) if emb is not None: self._cache.add(cache_key, emb) - break # repeat only when embedding None - if proc_callback: - proc_callback(emb is not None) - self.num_parallel_requests -= 1 - return emb + if emb is not None: + # store result if embedding is successful + log.debug("Successfully embedded: %s", cache_key) + results[i] = emb + if proc_callback: + proc_callback(emb is not None) + elif num_repeats < self.MAX_REPEATS: + log.debug("Embedding unsuccessful - reading to queue: %s", cache_key) + # if embedding not successful put the item to queue to be handled at + # the end - the item is put to the end since it is possible that server + # still process the request and the result will be in the cache later + # repeating the request immediately may result in another fail when + # processing takes longer + queue.put_nowait(TaskItem(i, data_instance, no_repeats=num_repeats)) + queue.task_done() async def _send_request( self, client: AsyncClient, data: bytes, url: str @@ -284,27 +324,23 @@ async def _send_request( response = await client.post(url, headers=headers, data=data) except ReadTimeout as ex: log.debug("Read timeout", exc_info=True) - # it happens when server do not respond in 60 seconds, in - # this case we return None and items will be resend later + # it happens when server do not respond in time defined by timeout + # return None and items will be resend later # if it happens more than in ten consecutive cases it means # sth is wrong with embedder we stop embedding self.count_read_errors += 1 - if self.count_read_errors >= self.max_errors: - self.num_parallel_requests = 0 # for safety reasons raise EmbeddingConnectionError from ex return None except (OSError, NetworkError) as ex: log.debug("Network error", exc_info=True) - # it happens when no connection and items cannot be sent to the - # server - # we count number of consecutive errors + # it happens when no connection and items cannot be sent to server + # if more than 10 consecutive errors it means there is no # connection so we stop embedding with EmbeddingConnectionError self.count_connection_errors += 1 if self.count_connection_errors >= self.max_errors: - self.num_parallel_requests = 0 # for safety reasons raise EmbeddingConnectionError from ex return None except Exception: diff --git a/Orange/misc/tests/test_server_embedder.py b/Orange/misc/tests/test_server_embedder.py index f5f9007c33e..fc001a2cdd1 100644 --- a/Orange/misc/tests/test_server_embedder.py +++ b/Orange/misc/tests/test_server_embedder.py @@ -167,3 +167,12 @@ def test_encode_data_instance(self): mocked_fun.assert_has_calls( [call(item) for item in self.test_data], any_order=True ) + + @patch(_HTTPX_POST_METHOD, return_value=DummyResponse(b''), new_callable=AsyncMock) + def test_retries(self, mock): + self.embedder.embedd_data(self.test_data) + self.assertEqual(len(self.test_data) * 3, mock.call_count) + + +if __name__ == "__main__": + unittest.main()