-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8269d27
commit 9c0d741
Showing
9 changed files
with
589 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
name: Publish Python 🐍 Package 📦 | ||
|
||
on: | ||
release: | ||
types: | ||
- published | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: [3.7, 3.8, 3.9] | ||
|
||
steps: | ||
- name: Checkout repository | ||
uses: actions/checkout@v2 | ||
|
||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install pytest | ||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi | ||
- name: Run tests | ||
run: | | ||
pytest tests/test_batche.py | ||
publish: | ||
needs: test | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Checkout repository | ||
uses: actions/checkout@v2 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: "3.x" | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install setuptools wheel twine | ||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi | ||
- name: Build package | ||
run: python setup.py sdist bdist_wheel | ||
|
||
- name: Publish to PyPI | ||
uses: pypa/[email protected] | ||
with: | ||
user: __token__ | ||
password: ${{ secrets.PYPI_API_TOKEN }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
.ruff_cache/ | ||
.vscode/ | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
[![Python Package Test](https://github.com/gautierdag/batche/actions/workflows/publish.yml/badge.svg)](https://github.com/gautierdag/batche/actions/workflows/publish.yml) | ||
|
||
# batche | ||
Batch cache decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .batche import cache_batch_variable, batche_cache | ||
|
||
__all__ = ["cache_batch_variable", "batche_cache"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import inspect | ||
from functools import wraps | ||
from typing import Callable, List, Optional, TypeVar, Any | ||
from collections import OrderedDict | ||
|
||
from pydantic import ValidationError, parse_obj_as | ||
|
||
R = TypeVar("R") | ||
|
||
batche_cache: OrderedDict[Any, List[R]] = OrderedDict() | ||
|
||
|
||
def is_list_annotation(annotation: Any): | ||
if annotation is list or ( | ||
hasattr(annotation, "__origin__") and issubclass(list, annotation.__origin__) | ||
): | ||
return True | ||
return False | ||
|
||
|
||
def cache_batch_variable( | ||
batch_variable_name: Optional[str] = None, max_size: Optional[int] = None | ||
): | ||
def internal_cache_batch_decorator( | ||
func: Callable[..., List[R]] | ||
) -> Callable[..., List[R]]: | ||
# validate batch_func | ||
function_function_argspec = inspect.getfullargspec(func) | ||
if batch_variable_name is not None: | ||
assert ( | ||
batch_variable_name in function_function_argspec.args | ||
), f"{batch_variable_name} must be a valid argument of the batch function" | ||
|
||
batch_arg_annotations = function_function_argspec.annotations.get( | ||
batch_variable_name | ||
) | ||
if batch_arg_annotations: | ||
assert is_list_annotation( | ||
batch_arg_annotations | ||
), f"{batch_variable_name} annotation must be a list of hashable objects" | ||
|
||
return_annotation = function_function_argspec.annotations.get("return") | ||
if return_annotation: | ||
assert is_list_annotation( | ||
return_annotation | ||
), "return annotation must be a list" | ||
|
||
@wraps(func) | ||
def batch_function_wrapper(*args, **kwargs): | ||
args = list(args) | ||
in_args = False | ||
for arg_index, arg in enumerate(args): | ||
try: | ||
batch = parse_obj_as( | ||
func.__annotations__.get(batch_variable_name), arg | ||
) | ||
in_args = True | ||
break | ||
except ValidationError: | ||
continue | ||
else: | ||
batch = kwargs.get(batch_variable_name) | ||
assert ( | ||
batch is not None | ||
), f"{batch_variable_name} must be a valid argument of the batch function" | ||
batch = parse_obj_as( | ||
func.__annotations__.get(batch_variable_name), batch | ||
) | ||
|
||
# new_batch will contain only the items that are not in cache | ||
new_batch, new_indices = [], [] | ||
|
||
# initialize predictions with empty lists | ||
predictions = [[] for _ in range(len(batch))] | ||
|
||
for i, batch_item in enumerate(batch): | ||
if batch_item in batche_cache: | ||
predictions[i] = batche_cache[batch_item] | ||
else: | ||
new_batch.append(batch_item) | ||
new_indices.append(i) | ||
|
||
# if all items are in cache, return | ||
if len(new_batch) == 0: | ||
return predictions | ||
|
||
if in_args: | ||
args[arg_index] = new_batch | ||
else: | ||
kwargs[batch_variable_name] = new_batch | ||
|
||
# call the function with the new batch | ||
out = func(*args, **kwargs) | ||
assert len(out) == len( | ||
new_batch | ||
), "batch function must return a list of predictions of the same length as the batch" | ||
|
||
for i, prediction in zip(new_indices, out): | ||
predictions[i] = prediction | ||
batche_cache[batch[i]] = prediction | ||
|
||
return predictions | ||
|
||
return batch_function_wrapper | ||
|
||
return internal_cache_batch_decorator |
Oops, something went wrong.