Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gautierdag committed Apr 27, 2023
1 parent 8269d27 commit 9c0d741
Show file tree
Hide file tree
Showing 9 changed files with 589 additions and 25 deletions.
59 changes: 59 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
name: Publish Python 🐍 Package 📦

on:
release:
types:
- published

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]

steps:
- name: Checkout repository
uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Run tests
run: |
pytest tests/test_batche.py
publish:
needs: test
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.x"

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Build package
run: python setup.py sdist bdist_wheel

- name: Publish to PyPI
uses: pypa/[email protected]
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
.ruff_cache/
.vscode/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,3 @@ repos:
- id: check-yaml
- id: end-of-file-fixer
- id: mixed-line-ending
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: "v0.0.263"
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
[![Python Package Test](https://github.com/gautierdag/batche/actions/workflows/publish.yml/badge.svg)](https://github.com/gautierdag/batche/actions/workflows/publish.yml)

# batche
Batch cache decorator
3 changes: 3 additions & 0 deletions batche/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .batche import cache_batch_variable, batche_cache

__all__ = ["cache_batch_variable", "batche_cache"]
106 changes: 106 additions & 0 deletions batche/batche.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import inspect
from functools import wraps
from typing import Callable, List, Optional, TypeVar, Any
from collections import OrderedDict

from pydantic import ValidationError, parse_obj_as

R = TypeVar("R")

batche_cache: OrderedDict[Any, List[R]] = OrderedDict()


def is_list_annotation(annotation: Any):
if annotation is list or (
hasattr(annotation, "__origin__") and issubclass(list, annotation.__origin__)
):
return True
return False


def cache_batch_variable(
batch_variable_name: Optional[str] = None, max_size: Optional[int] = None
):
def internal_cache_batch_decorator(
func: Callable[..., List[R]]
) -> Callable[..., List[R]]:
# validate batch_func
function_function_argspec = inspect.getfullargspec(func)
if batch_variable_name is not None:
assert (
batch_variable_name in function_function_argspec.args
), f"{batch_variable_name} must be a valid argument of the batch function"

batch_arg_annotations = function_function_argspec.annotations.get(
batch_variable_name
)
if batch_arg_annotations:
assert is_list_annotation(
batch_arg_annotations
), f"{batch_variable_name} annotation must be a list of hashable objects"

return_annotation = function_function_argspec.annotations.get("return")
if return_annotation:
assert is_list_annotation(
return_annotation
), "return annotation must be a list"

@wraps(func)
def batch_function_wrapper(*args, **kwargs):
args = list(args)
in_args = False
for arg_index, arg in enumerate(args):
try:
batch = parse_obj_as(
func.__annotations__.get(batch_variable_name), arg
)
in_args = True
break
except ValidationError:
continue
else:
batch = kwargs.get(batch_variable_name)
assert (
batch is not None
), f"{batch_variable_name} must be a valid argument of the batch function"
batch = parse_obj_as(
func.__annotations__.get(batch_variable_name), batch
)

# new_batch will contain only the items that are not in cache
new_batch, new_indices = [], []

# initialize predictions with empty lists
predictions = [[] for _ in range(len(batch))]

for i, batch_item in enumerate(batch):
if batch_item in batche_cache:
predictions[i] = batche_cache[batch_item]
else:
new_batch.append(batch_item)
new_indices.append(i)

# if all items are in cache, return
if len(new_batch) == 0:
return predictions

if in_args:
args[arg_index] = new_batch
else:
kwargs[batch_variable_name] = new_batch

# call the function with the new batch
out = func(*args, **kwargs)
assert len(out) == len(
new_batch
), "batch function must return a list of predictions of the same length as the batch"

for i, prediction in zip(new_indices, out):
predictions[i] = prediction
batche_cache[batch[i]] = prediction

return predictions

return batch_function_wrapper

return internal_cache_batch_decorator
Loading

0 comments on commit 9c0d741

Please sign in to comment.