Skip to content

Commit

Permalink
Server embedder: use queue, handle unsuccessful requests at the end
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Feb 9, 2022
1 parent c38b96f commit f7cd683
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 80 deletions.
196 changes: 116 additions & 80 deletions Orange/misc/server_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,30 @@
import logging
import random
import uuid
from collections import namedtuple
from json import JSONDecodeError
from os import getenv
from typing import Any, Callable, List, Optional

from AnyQt.QtCore import QSettings
from httpx import AsyncClient, NetworkError, ReadTimeout, Response

from Orange.misc.utils.embedder_utils import (EmbedderCache,
EmbeddingCancelledException,
EmbeddingConnectionError,
get_proxies)
from Orange.misc.utils.embedder_utils import (
EmbedderCache,
EmbeddingCancelledException,
EmbeddingConnectionError,
get_proxies,
)

log = logging.getLogger(__name__)
TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats"))


class ServerEmbedderCommunicator:
"""
This class needs to be inherited by the class which re-implements
_encode_data_instance and defines self.content_type. For sending a table
with data items use embedd_table function. This one is called with the
complete Orange data Table. Then _encode_data_instance needs to extract
data to be embedded from the RowInstance. For images, it takes the image
path from the table, load image, and transform it into bytes.
_encode_data_instance and defines self.content_type. For sending a list
with data items use embedd_table function.
Attributes
----------
Expand Down Expand Up @@ -69,14 +70,14 @@ def __init__(
) or str(uuid.getnode())
except TypeError:
self.machine_id = str(uuid.getnode())
self.session_id = str(random.randint(1, 1e10))
self.session_id = str(random.randint(1, int(1e10)))

self._cache = EmbedderCache(model_name)

# default embedding timeouts are too small we need to increase them
self.timeout = 180
self.num_parallel_requests = 0
self.max_parallel = max_parallel_requests
self.max_parallel_requests = max_parallel_requests

self.content_type = None # need to be set in a class inheriting

def embedd_data(
Expand Down Expand Up @@ -111,8 +112,7 @@ def embedd_data(
EmbeddingCancelledException:
If cancelled attribute is set to True (default=False).
"""
# if there is less items than 10 connection error should be raised
# earlier
# if there is less items than 10 connection error should be raised earlier
self.max_errors = min(len(data) * self.MAX_REPEATS, 10)

loop = asyncio.new_event_loop()
Expand All @@ -121,11 +121,9 @@ def embedd_data(
embeddings = asyncio.get_event_loop().run_until_complete(
self.embedd_batch(data, processed_callback)
)
except Exception:
finally:
loop.close()
raise

loop.close()
return embeddings

async def embedd_batch(
Expand Down Expand Up @@ -153,32 +151,63 @@ async def embedd_batch(
EmbeddingCancelledException:
If cancelled attribute is set to True (default=False).
"""
requests = []
results = [None] * len(data)
queue = asyncio.Queue()

# fill the queue with items to embedd
for i, item in enumerate(data):
queue.put_nowait(TaskItem(id=i, item=item, no_repeats=0))

async with AsyncClient(
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
) as client:
for p in data:
if self._cancelled:
raise EmbeddingCancelledException()
requests.append(self._send_to_server(p, client, proc_callback))
tasks = self._init_workers(client, queue, results, proc_callback)

# wait for the queue to complete or one of workers to exit
queue_complete = asyncio.create_task(queue.join())
await asyncio.wait(
[queue_complete, *tasks], return_when=asyncio.FIRST_COMPLETED
)

# Cancel worker tasks when done
queue_complete.cancel()
await self._cancel_workers(tasks)

embeddings = await asyncio.gather(*requests)
self._cache.persist_cache()
assert self.num_parallel_requests == 0
return results

return embeddings
def _init_workers(self, client, queue, results, callback):
"""Init required number of workers"""
t = [
asyncio.create_task(self._send_to_server(client, queue, results, callback))
for _ in range(self.max_parallel_requests)
]
log.debug("Created %d workers", self.max_parallel_requests)
return t

async def __wait_until_released(self) -> None:
while self.num_parallel_requests >= self.max_parallel:
await asyncio.sleep(0.1)
@staticmethod
async def _cancel_workers(tasks):
"""Cancel worker at the end"""
log.debug("Canceling workers")
try:
# try to catch any potential exceptions
await asyncio.gather(*tasks)
except Exception as ex:
# raise exceptions gathered from an failed worker
raise ex
finally:
# cancel all tasks in both cases
for task in tasks:
task.cancel()
# Wait until all worker tasks are cancelled.
await asyncio.gather(*tasks, return_exceptions=True)
log.debug("All workers canceled")

def __check_cancelled(self):
if self._cancelled:
raise EmbeddingCancelledException()

async def _encode_data_instance(
self, data_instance: Any
) -> Optional[bytes]:
async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
"""
The reimplementation of this function must implement the procedure
to encode the data item in a string format that will be sent to the
Expand All @@ -197,63 +226,74 @@ async def _encode_data_instance(
raise NotImplementedError

async def _send_to_server(
self,
data_instance: Any,
client: AsyncClient,
proc_callback: Callable[[bool], None] = None,
) -> Optional[List[float]]:
self,
client: AsyncClient,
queue: asyncio.Queue,
results: List,
proc_callback: Callable[[bool], None] = None,
):
"""
Function get an data instance. It extract data from it and send them to
server and retrieve responses.
Worker that embedds data. It is pulling items from the until the queue
is empty. It is canceled by embedd_batch all tasks are finished
Parameters
----------
data_instance
Single row of the input table.
client
HTTPX client that communicates with the server
queue
The queue with items of type TaskItem to be embedded
results
The list to append results in. The list has length equal to numbers
of all items to embedd. The result need to be inserted at the index
defined in queue items.
proc_callback
A function that is called after each item is fully processed
by either getting a successful response from the server,
getting the result from cache or skipping the item.
Returns
-------
Embedding. For items that are not successfully embedded returns None.
"""
await self.__wait_until_released()
self.__check_cancelled()

self.num_parallel_requests += 1
# load bytes
data_bytes = await self._encode_data_instance(data_instance)
if data_bytes is None:
self.num_parallel_requests -= 1
return None

# if data in cache return it
cache_key = self._cache.md5_hash(data_bytes)
emb = self._cache.get_cached_result_or_none(cache_key)

if emb is None:
# in case that embedding not sucessfull resend it to the server
# maximally for MAX_REPEATS time
for i in range(1, self.MAX_REPEATS + 1):
self.__check_cancelled()
while not queue.empty():
self.__check_cancelled()

# get item from the queue
i, data_instance, num_repeats = await queue.get()
num_repeats += 1

# load bytes
data_bytes = await self._encode_data_instance(data_instance)
if data_bytes is None:
continue

# retrieve embedded item from the local cache
cache_key = self._cache.md5_hash(data_bytes)
log.debug("Embedding %s", cache_key)
emb = self._cache.get_cached_result_or_none(cache_key)

if emb is None:
# send the item to the server for embedding if not in the local cache
log.debug("Sending to the server: %s", cache_key)
url = (
f"/{self.embedder_type}/{self._model}?"
f"machine={self.machine_id}"
f"&session={self.session_id}&retry={i}"
f"/{self.embedder_type}/{self._model}?machine={self.machine_id}"
f"&session={self.session_id}&retry={num_repeats}"
)
emb = await self._send_request(client, data_bytes, url)
if emb is not None:
self._cache.add(cache_key, emb)
break # repeat only when embedding None
if proc_callback:
proc_callback(emb is not None)

self.num_parallel_requests -= 1
return emb
if emb is not None:
# store result if embedding is successful
log.debug("Successfully embedded: %s", cache_key)
results[i] = emb
if proc_callback:
proc_callback(emb is not None)
elif num_repeats < self.MAX_REPEATS:
log.debug("Embedding unsuccessful - reading to queue: %s", cache_key)
# if embedding not successful put the item to queue to be handled at
# the end - the item is put to the end since it is possible that server
# still process the request and the result will be in the cache later
# repeating the request immediately may result in another fail when
# processing takes longer
queue.put_nowait(TaskItem(i, data_instance, no_repeats=num_repeats))
queue.task_done()

async def _send_request(
self, client: AsyncClient, data: bytes, url: str
Expand Down Expand Up @@ -284,27 +324,23 @@ async def _send_request(
response = await client.post(url, headers=headers, data=data)
except ReadTimeout as ex:
log.debug("Read timeout", exc_info=True)
# it happens when server do not respond in 60 seconds, in
# this case we return None and items will be resend later
# it happens when server do not respond in time defined by timeout
# return None and items will be resend later

# if it happens more than in ten consecutive cases it means
# sth is wrong with embedder we stop embedding
self.count_read_errors += 1

if self.count_read_errors >= self.max_errors:
self.num_parallel_requests = 0 # for safety reasons
raise EmbeddingConnectionError from ex
return None
except (OSError, NetworkError) as ex:
log.debug("Network error", exc_info=True)
# it happens when no connection and items cannot be sent to the
# server
# we count number of consecutive errors
# it happens when no connection and items cannot be sent to server

# if more than 10 consecutive errors it means there is no
# connection so we stop embedding with EmbeddingConnectionError
self.count_connection_errors += 1
if self.count_connection_errors >= self.max_errors:
self.num_parallel_requests = 0 # for safety reasons
raise EmbeddingConnectionError from ex
return None
except Exception:
Expand Down
9 changes: 9 additions & 0 deletions Orange/misc/tests/test_server_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,12 @@ def test_encode_data_instance(self):
mocked_fun.assert_has_calls(
[call(item) for item in self.test_data], any_order=True
)

@patch(_HTTPX_POST_METHOD, return_value=DummyResponse(b''), new_callable=AsyncMock)
def test_retries(self, mock):
self.embedder.embedd_data(self.test_data)
self.assertEqual(len(self.test_data) * 3, mock.call_count)


if __name__ == "__main__":
unittest.main()

0 comments on commit f7cd683

Please sign in to comment.