Skip to content

Commit

Permalink
server_embedder: modify callback to match others
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Feb 9, 2022
1 parent f7cd683 commit 77db719
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 13 deletions.
57 changes: 44 additions & 13 deletions Orange/misc/server_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@
import logging
import random
import uuid
import warnings
from collections import namedtuple
from json import JSONDecodeError
from os import getenv
from typing import Any, Callable, List, Optional

from AnyQt.QtCore import QSettings
from httpx import AsyncClient, NetworkError, ReadTimeout, Response
from numpy import linspace

from Orange.misc.utils.embedder_utils import (
EmbedderCache,
EmbeddingCancelledException,
EmbeddingConnectionError,
get_proxies,
)
from Orange.util import dummy_callback

log = logging.getLogger(__name__)
TaskItem = namedtuple("TaskItem", ("id", "item", "no_repeats"))
Expand Down Expand Up @@ -59,8 +62,7 @@ def __init__(
self._model = model_name
self.embedder_type = embedder_type

# attribute that offers support for cancelling the embedding
# if ran in another thread
# remove in 3.33
self._cancelled = False

self.machine_id = None
Expand All @@ -81,9 +83,10 @@ def __init__(
self.content_type = None # need to be set in a class inheriting

def embedd_data(
self,
data: List[Any],
processed_callback: Callable[[bool], None] = None,
self,
data: List[Any],
processed_callback: Optional[Callable] = None,
callback: Callable = dummy_callback,
) -> List[Optional[List[float]]]:
"""
This function repeats calling embedding function until all items
Expand All @@ -95,9 +98,12 @@ def embedd_data(
data
List with data that needs to be embedded.
processed_callback
Deprecated: remove in 3.33
A function that is called after each item is embedded
by either getting a successful response from the server,
getting the result from cache or skipping the item.
callback
Callback for reporting the progress in share of embedded items
Returns
-------
Expand All @@ -119,15 +125,18 @@ def embedd_data(
asyncio.set_event_loop(loop)
try:
embeddings = asyncio.get_event_loop().run_until_complete(
self.embedd_batch(data, processed_callback)
self.embedd_batch(data, processed_callback, callback)
)
finally:
loop.close()

return embeddings

async def embedd_batch(
self, data: List[Any], proc_callback: Callable[[bool], None] = None
self,
data: List[Any],
processed_calback: Optional[Callable] = None,
callback: Callable = dummy_callback,
) -> List[Optional[List[float]]]:
"""
Function perform embedding of a batch of data items.
Expand All @@ -136,10 +145,8 @@ async def embedd_batch(
----------
data
A list of data that must be embedded.
proc_callback
A function that is called after each item is fully processed
by either getting a successful response from the server,
getting the result from cache or skipping the item.
callback
Callback for reporting the progress in share of embedded items
Returns
-------
Expand All @@ -151,6 +158,22 @@ async def embedd_batch(
EmbeddingCancelledException:
If cancelled attribute is set to True (default=False).
"""
# in Orange 3.33 keep content of the if - remove if clause and complete else
if processed_calback is None:
progress_items = iter(linspace(0, 1, len(data)))

def success_callback():
"""Callback called on every successful embedding"""
callback(next(progress_items))

else:
warnings.warn(
"process_callback is deprecated and will be removed in version 3.33, "
"use callback instead",
FutureWarning,
)
success_callback = processed_calback

results = [None] * len(data)
queue = asyncio.Queue()

Expand All @@ -161,7 +184,7 @@ async def embedd_batch(
async with AsyncClient(
timeout=self.timeout, base_url=self.server_url, proxies=get_proxies()
) as client:
tasks = self._init_workers(client, queue, results, proc_callback)
tasks = self._init_workers(client, queue, results, success_callback)

# wait for the queue to complete or one of workers to exit
queue_complete = asyncio.create_task(queue.join())
Expand Down Expand Up @@ -203,6 +226,7 @@ async def _cancel_workers(tasks):
await asyncio.gather(*tasks, return_exceptions=True)
log.debug("All workers canceled")

# remove in 3.33
def __check_cancelled(self):
if self._cancelled:
raise EmbeddingCancelledException()
Expand Down Expand Up @@ -252,6 +276,7 @@ async def _send_to_server(
getting the result from cache or skipping the item.
"""
while not queue.empty():
# remove in 3.33
self.__check_cancelled()

# get item from the queue
Expand Down Expand Up @@ -284,7 +309,7 @@ async def _send_to_server(
log.debug("Successfully embedded: %s", cache_key)
results[i] = emb
if proc_callback:
proc_callback(emb is not None)
proc_callback()
elif num_repeats < self.MAX_REPEATS:
log.debug("Embedding unsuccessful - reading to queue: %s", cache_key)
# if embedding not successful put the item to queue to be handled at
Expand Down Expand Up @@ -379,5 +404,11 @@ def _parse_response(response: Response) -> Optional[List[float]]:
def clear_cache(self):
self._cache.clear_cache()

# remove in 3.33
def set_cancelled(self):
warnings.warn(
"set_cancelled is deprecated and will be removed in version 3.33, "
"the process can be canceled by raising Error in callback",
FutureWarning,
)
self._cancelled = True
18 changes: 18 additions & 0 deletions Orange/misc/tests/test_server_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from httpx import ReadTimeout

import Orange
from Orange.data import Domain, StringVariable, Table
from Orange.misc.tests.example_embedder import ExampleServerEmbedder

Expand Down Expand Up @@ -173,6 +174,23 @@ def test_retries(self, mock):
self.embedder.embedd_data(self.test_data)
self.assertEqual(len(self.test_data) * 3, mock.call_count)

@patch(_HTTPX_POST_METHOD, regular_dummy_sr)
def test_callback(self):
mock = MagicMock()
self.embedder.embedd_data(self.test_data, callback=mock)

process_items = [call(x) for x in np.linspace(0, 1, len(self.test_data))]
mock.assert_has_calls(process_items)

def test_deprecated(self):
"""
When this start to fail:
- remove process_callback parameter and marked places connected to this param
- remove set_canceled and marked places connected to this method
- this test
"""
self.assertGreaterEqual("3.33.0", Orange.__version__)


if __name__ == "__main__":
unittest.main()

0 comments on commit 77db719

Please sign in to comment.