diff --git a/poetry.lock b/poetry.lock index 49bd79e..be5137b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,41 @@ # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +[[package]] +name = "annotated-types" +version = "0.6.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.6.0-py3-none-any.whl", hash = "sha256:0641064de18ba7a25dee8f96403ebc39113d0cb953a01429249d5c7564666a43"}, + {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} + +[[package]] +name = "anyio" +version = "4.3.0" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.8" +files = [ + {file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"}, + {file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"}, +] + +[package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +sniffio = ">=1.1" +typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} + +[package.extras] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.23)"] + [[package]] name = "attrs" version = "23.2.0" @@ -129,6 +165,20 @@ files = [ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] +[[package]] +name = "click" +version = "8.1.7" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, + {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + [[package]] name = "colorama" version = "0.4.6" @@ -186,6 +236,39 @@ files = [ {file = "docstring_parser-0.15.tar.gz", hash = "sha256:48ddc093e8b1865899956fcc03b03e66bb7240c310fac5af81814580c55bf682"}, ] +[[package]] +name = "exceptiongroup" +version = "1.2.0" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"}, + {file = "exceptiongroup-1.2.0.tar.gz", hash = "sha256:91f5c769735f051a4290d52edd0858999b57e5876e9f85937691bd4c9fa3ed68"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "fastapi" +version = "0.110.1" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fastapi-0.110.1-py3-none-any.whl", hash = "sha256:5df913203c482f820d31f48e635e022f8cbfe7350e4830ef05a3163925b1addc"}, + {file = "fastapi-0.110.1.tar.gz", hash = "sha256:6feac43ec359dfe4f45b2c18ec8c94edb8dc2dfc461d417d9e626590c071baad"}, +] + +[package.dependencies] +pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" +starlette = ">=0.37.2,<0.38.0" +typing-extensions = ">=4.8.0" + +[package.extras] +all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] + [[package]] name = "filelock" version = "3.13.1" @@ -248,6 +331,17 @@ smb = ["smbprotocol"] ssh = ["paramiko"] tqdm = ["tqdm"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + [[package]] name = "huggingface-hub" version = "0.21.4" @@ -325,6 +419,17 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link perf = ["ipython"] testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -525,6 +630,21 @@ files = [ {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, ] +[[package]] +name = "pluggy" +version = "1.4.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, + {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + [[package]] name = "protobuf" version = "5.26.0" @@ -545,6 +665,116 @@ files = [ {file = "protobuf-5.26.0.tar.gz", hash = "sha256:82f5870d74c99addfe4152777bdf8168244b9cf0ac65f8eccf045ddfa9d80d9b"}, ] +[[package]] +name = "pydantic" +version = "2.6.4" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic-2.6.4-py3-none-any.whl", hash = "sha256:cc46fce86607580867bdc3361ad462bab9c222ef042d3da86f2fb333e1d916c5"}, + {file = "pydantic-2.6.4.tar.gz", hash = "sha256:b1704e0847db01817624a6b86766967f552dd9dbf3afba4004409f908dcc84e6"}, +] + +[package.dependencies] +annotated-types = ">=0.4.0" +pydantic-core = "2.16.3" +typing-extensions = ">=4.6.1" + +[package.extras] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.16.3" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic_core-2.16.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:75b81e678d1c1ede0785c7f46690621e4c6e63ccd9192af1f0bd9d504bbb6bf4"}, + {file = "pydantic_core-2.16.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9c865a7ee6f93783bd5d781af5a4c43dadc37053a5b42f7d18dc019f8c9d2bd1"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:162e498303d2b1c036b957a1278fa0899d02b2842f1ff901b6395104c5554a45"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2f583bd01bbfbff4eaee0868e6fc607efdfcc2b03c1c766b06a707abbc856187"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b926dd38db1519ed3043a4de50214e0d600d404099c3392f098a7f9d75029ff8"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:716b542728d4c742353448765aa7cdaa519a7b82f9564130e2b3f6766018c9ec"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc4ad7f7ee1a13d9cb49d8198cd7d7e3aa93e425f371a68235f784e99741561f"}, + {file = "pydantic_core-2.16.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bd87f48924f360e5d1c5f770d6155ce0e7d83f7b4e10c2f9ec001c73cf475c99"}, + {file = "pydantic_core-2.16.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0df446663464884297c793874573549229f9eca73b59360878f382a0fc085979"}, + {file = "pydantic_core-2.16.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4df8a199d9f6afc5ae9a65f8f95ee52cae389a8c6b20163762bde0426275b7db"}, + {file = "pydantic_core-2.16.3-cp310-none-win32.whl", hash = "sha256:456855f57b413f077dff513a5a28ed838dbbb15082ba00f80750377eed23d132"}, + {file = "pydantic_core-2.16.3-cp310-none-win_amd64.whl", hash = "sha256:732da3243e1b8d3eab8c6ae23ae6a58548849d2e4a4e03a1924c8ddf71a387cb"}, + {file = "pydantic_core-2.16.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:519ae0312616026bf4cedc0fe459e982734f3ca82ee8c7246c19b650b60a5ee4"}, + {file = "pydantic_core-2.16.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b3992a322a5617ded0a9f23fd06dbc1e4bd7cf39bc4ccf344b10f80af58beacd"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d62da299c6ecb04df729e4b5c52dc0d53f4f8430b4492b93aa8de1f541c4aac"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2acca2be4bb2f2147ada8cac612f8a98fc09f41c89f87add7256ad27332c2fda"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1b662180108c55dfbf1280d865b2d116633d436cfc0bba82323554873967b340"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e7c6ed0dc9d8e65f24f5824291550139fe6f37fac03788d4580da0d33bc00c97"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6b1bb0827f56654b4437955555dc3aeeebeddc47c2d7ed575477f082622c49e"}, + {file = "pydantic_core-2.16.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e56f8186d6210ac7ece503193ec84104da7ceb98f68ce18c07282fcc2452e76f"}, + {file = "pydantic_core-2.16.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:936e5db01dd49476fa8f4383c259b8b1303d5dd5fb34c97de194560698cc2c5e"}, + {file = "pydantic_core-2.16.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:33809aebac276089b78db106ee692bdc9044710e26f24a9a2eaa35a0f9fa70ba"}, + {file = "pydantic_core-2.16.3-cp311-none-win32.whl", hash = "sha256:ded1c35f15c9dea16ead9bffcde9bb5c7c031bff076355dc58dcb1cb436c4721"}, + {file = "pydantic_core-2.16.3-cp311-none-win_amd64.whl", hash = "sha256:d89ca19cdd0dd5f31606a9329e309d4fcbb3df860960acec32630297d61820df"}, + {file = "pydantic_core-2.16.3-cp311-none-win_arm64.whl", hash = "sha256:6162f8d2dc27ba21027f261e4fa26f8bcb3cf9784b7f9499466a311ac284b5b9"}, + {file = "pydantic_core-2.16.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:0f56ae86b60ea987ae8bcd6654a887238fd53d1384f9b222ac457070b7ac4cff"}, + {file = "pydantic_core-2.16.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9bd22a2a639e26171068f8ebb5400ce2c1bc7d17959f60a3b753ae13c632975"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4204e773b4b408062960e65468d5346bdfe139247ee5f1ca2a378983e11388a2"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f651dd19363c632f4abe3480a7c87a9773be27cfe1341aef06e8759599454120"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aaf09e615a0bf98d406657e0008e4a8701b11481840be7d31755dc9f97c44053"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8e47755d8152c1ab5b55928ab422a76e2e7b22b5ed8e90a7d584268dd49e9c6b"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:500960cb3a0543a724a81ba859da816e8cf01b0e6aaeedf2c3775d12ee49cade"}, + {file = "pydantic_core-2.16.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cf6204fe865da605285c34cf1172879d0314ff267b1c35ff59de7154f35fdc2e"}, + {file = "pydantic_core-2.16.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d33dd21f572545649f90c38c227cc8631268ba25c460b5569abebdd0ec5974ca"}, + {file = "pydantic_core-2.16.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:49d5d58abd4b83fb8ce763be7794d09b2f50f10aa65c0f0c1696c677edeb7cbf"}, + {file = "pydantic_core-2.16.3-cp312-none-win32.whl", hash = "sha256:f53aace168a2a10582e570b7736cc5bef12cae9cf21775e3eafac597e8551fbe"}, + {file = "pydantic_core-2.16.3-cp312-none-win_amd64.whl", hash = "sha256:0d32576b1de5a30d9a97f300cc6a3f4694c428d956adbc7e6e2f9cad279e45ed"}, + {file = "pydantic_core-2.16.3-cp312-none-win_arm64.whl", hash = "sha256:ec08be75bb268473677edb83ba71e7e74b43c008e4a7b1907c6d57e940bf34b6"}, + {file = "pydantic_core-2.16.3-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:b1f6f5938d63c6139860f044e2538baeee6f0b251a1816e7adb6cbce106a1f01"}, + {file = "pydantic_core-2.16.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2a1ef6a36fdbf71538142ed604ad19b82f67b05749512e47f247a6ddd06afdc7"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:704d35ecc7e9c31d48926150afada60401c55efa3b46cd1ded5a01bdffaf1d48"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d937653a696465677ed583124b94a4b2d79f5e30b2c46115a68e482c6a591c8a"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9803edf8e29bd825f43481f19c37f50d2b01899448273b3a7758441b512acf8"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:72282ad4892a9fb2da25defeac8c2e84352c108705c972db82ab121d15f14e6d"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f752826b5b8361193df55afcdf8ca6a57d0232653494ba473630a83ba50d8c9"}, + {file = "pydantic_core-2.16.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4384a8f68ddb31a0b0c3deae88765f5868a1b9148939c3f4121233314ad5532c"}, + {file = "pydantic_core-2.16.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a4b2bf78342c40b3dc830880106f54328928ff03e357935ad26c7128bbd66ce8"}, + {file = "pydantic_core-2.16.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:13dcc4802961b5f843a9385fc821a0b0135e8c07fc3d9949fd49627c1a5e6ae5"}, + {file = "pydantic_core-2.16.3-cp38-none-win32.whl", hash = "sha256:e3e70c94a0c3841e6aa831edab1619ad5c511199be94d0c11ba75fe06efe107a"}, + {file = "pydantic_core-2.16.3-cp38-none-win_amd64.whl", hash = "sha256:ecdf6bf5f578615f2e985a5e1f6572e23aa632c4bd1dc67f8f406d445ac115ed"}, + {file = "pydantic_core-2.16.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:bda1ee3e08252b8d41fa5537413ffdddd58fa73107171a126d3b9ff001b9b820"}, + {file = "pydantic_core-2.16.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:21b888c973e4f26b7a96491c0965a8a312e13be108022ee510248fe379a5fa23"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be0ec334369316fa73448cc8c982c01e5d2a81c95969d58b8f6e272884df0074"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b5b6079cc452a7c53dd378c6f881ac528246b3ac9aae0f8eef98498a75657805"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ee8d5f878dccb6d499ba4d30d757111847b6849ae07acdd1205fffa1fc1253c"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7233d65d9d651242a68801159763d09e9ec96e8a158dbf118dc090cd77a104c9"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6119dc90483a5cb50a1306adb8d52c66e447da88ea44f323e0ae1a5fcb14256"}, + {file = "pydantic_core-2.16.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:578114bc803a4c1ff9946d977c221e4376620a46cf78da267d946397dc9514a8"}, + {file = "pydantic_core-2.16.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d8f99b147ff3fcf6b3cc60cb0c39ea443884d5559a30b1481e92495f2310ff2b"}, + {file = "pydantic_core-2.16.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4ac6b4ce1e7283d715c4b729d8f9dab9627586dafce81d9eaa009dd7f25dd972"}, + {file = "pydantic_core-2.16.3-cp39-none-win32.whl", hash = "sha256:e7774b570e61cb998490c5235740d475413a1f6de823169b4cf94e2fe9e9f6b2"}, + {file = "pydantic_core-2.16.3-cp39-none-win_amd64.whl", hash = "sha256:9091632a25b8b87b9a605ec0e61f241c456e9248bfdcf7abdf344fdb169c81cf"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:36fa178aacbc277bc6b62a2c3da95226520da4f4e9e206fdf076484363895d2c"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:dcca5d2bf65c6fb591fff92da03f94cd4f315972f97c21975398bd4bd046854a"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a72fb9963cba4cd5793854fd12f4cfee731e86df140f59ff52a49b3552db241"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b60cc1a081f80a2105a59385b92d82278b15d80ebb3adb200542ae165cd7d183"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cbcc558401de90a746d02ef330c528f2e668c83350f045833543cd57ecead1ad"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:fee427241c2d9fb7192b658190f9f5fd6dfe41e02f3c1489d2ec1e6a5ab1e04a"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f4cb85f693044e0f71f394ff76c98ddc1bc0953e48c061725e540396d5c8a2e1"}, + {file = "pydantic_core-2.16.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:b29eeb887aa931c2fcef5aa515d9d176d25006794610c264ddc114c053bf96fe"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a425479ee40ff021f8216c9d07a6a3b54b31c8267c6e17aa88b70d7ebd0e5e5b"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:5c5cbc703168d1b7a838668998308018a2718c2130595e8e190220238addc96f"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99b6add4c0b39a513d323d3b93bc173dac663c27b99860dd5bf491b240d26137"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75f76ee558751746d6a38f89d60b6228fa174e5172d143886af0f85aa306fd89"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:00ee1c97b5364b84cb0bd82e9bbf645d5e2871fb8c58059d158412fee2d33d8a"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:287073c66748f624be4cef893ef9174e3eb88fe0b8a78dc22e88eca4bc357ca6"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ed25e1835c00a332cb10c683cd39da96a719ab1dfc08427d476bce41b92531fc"}, + {file = "pydantic_core-2.16.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:86b3d0033580bd6bbe07590152007275bd7af95f98eaa5bd36f3da219dcd93da"}, + {file = "pydantic_core-2.16.3.tar.gz", hash = "sha256:1cac689f80a3abab2d3c0048b29eea5751114054f032a941a32de4c852c59cad"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + [[package]] name = "pygments" version = "2.17.2" @@ -571,6 +801,28 @@ files = [ {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"}, ] +[[package]] +name = "pytest" +version = "8.1.1" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"}, + {file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=1.4,<2.0" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + [[package]] name = "pyyaml" version = "6.0.1" @@ -660,6 +912,35 @@ typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9 [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "sniffio" +version = "1.3.1" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + +[[package]] +name = "starlette" +version = "0.37.2" +description = "The little ASGI library that shines." +optional = false +python-versions = ">=3.8" +files = [ + {file = "starlette-0.37.2-py3-none-any.whl", hash = "sha256:6fe59f29268538e5d0d182f2791a479a0c64638e6935d1c6989e63fb2699c6ee"}, + {file = "starlette-0.37.2.tar.gz", hash = "sha256:9af890290133b79fc3db55474ade20f6220a364a0402e0b556e7cd5e1e093823"}, +] + +[package.dependencies] +anyio = ">=3.4.0,<5" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] + [[package]] name = "sympy" version = "1.12" @@ -801,6 +1082,17 @@ dev = ["tokenizers[testing]"] docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"] testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + [[package]] name = "tqdm" version = "4.66.1" @@ -849,6 +1141,25 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "uvicorn" +version = "0.29.0" +description = "The lightning-fast ASGI server." +optional = false +python-versions = ">=3.8" +files = [ + {file = "uvicorn-0.29.0-py3-none-any.whl", hash = "sha256:2c2aac7ff4f4365c206fd773a39bf4ebd1047c238f8b8268ad996829323473de"}, + {file = "uvicorn-0.29.0.tar.gz", hash = "sha256:6a69214c0b6a087462412670b3ef21224fa48cae0e452b5883e8e8bdfdd11dd0"}, +] + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" +typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} + +[package.extras] +standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] + [[package]] name = "zipp" version = "3.18.1" @@ -867,4 +1178,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8, <4.0" -content-hash = "368fd40a50adf012f1e464302b2f2be230c8c321a3206c43e5c7d6190ed43bd2" +content-hash = "c24ac6e34ae58ff7106536bcc9be360c596f48401718e2b32ab1abbdb1f0c89c" diff --git a/pyproject.toml b/pyproject.toml index 7b81e8b..a8c5c81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "swiftrank" -version = "1.2.0" +version = "1.3.0" description = "Compact, ultra-fast SoTA reranker enhancing retrieval pipelines and terminal applications." authors = ["Harsh Verma "] license = "Apache Software License (Apache 2.0)" @@ -17,10 +17,17 @@ tqdm = "4.66.1" cyclopts = "2.1.2" pyyaml = "6.0.1" orjson = "3.9.10" +pydantic = "2.6.4" +fastapi = "0.110.1" +uvicorn = "0.29.0" [tool.poetry.scripts] -swiftrank = "swiftrank.cli:app.meta" -srank = "swiftrank.cli:app.meta" +swiftrank = "swiftrank.interface.cli:app.meta" +srank = "swiftrank.interface.cli:app.meta" + +[tool.poetry.group.dev.dependencies] +pytest = "8.1.1" +requests = "2.31.0" [build-system] requires = ["poetry-core"] diff --git a/readme.md b/readme.md index d71b116..67e59ee 100644 --- a/readme.md +++ b/readme.md @@ -40,6 +40,9 @@ ⌨️ **Terminal Integration**: - Pipe your output into `swiftrank` cli tool and get reranked output +🌐 **API Integration**: +- Deploy `swiftrank` as an API service for seamless integration into your workflow. + --- ### 🚀 Installation @@ -57,6 +60,7 @@ Rerank contexts provided on stdin. ╭─ Commands ─────────────────────────────────────────────────────╮ │ process STDIN processor. [ json | jsonl | yaml ] │ +│ serve Startup a swiftrank server │ │ --help,-h Display this message and exit. │ │ --version Display application version. │ ╰────────────────────────────────────────────────────────────────╯ @@ -180,6 +184,27 @@ STDIN processor. [ json | jsonl | yaml ] Monogatari Series: Second Season ``` +#### Startup a FastAPI server instance + +``` +Usage: swiftrank serve [OPTIONS] + +Startup a swiftrank server + +╭─ Parameters ──────────────────────────────╮ +│ --host Host name [default: 0.0.0.0] │ +│ --port Port number. [default: 12345] │ +╰───────────────────────────────────────────╯ +``` + +```sh +swiftrank serve +``` +``` +[GET] /models - List Models +[POST] /rerank - Rerank Endpoint +``` + ### Library Usage 🤗 - Build a `ReRankPipeline` instance @@ -311,4 +336,4 @@ url = {https://github.com/PrithivirajDamodaran/FlashRank}, version = {1.0.0}, year = {2023} } -``` +``` \ No newline at end of file diff --git a/swiftrank/interface/__init__.py b/swiftrank/interface/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/swiftrank/interface/api.py b/swiftrank/interface/api.py new file mode 100644 index 0000000..6f8bcd2 --- /dev/null +++ b/swiftrank/interface/api.py @@ -0,0 +1,112 @@ +from typing import Any, Optional + +from fastapi import FastAPI, status +from fastapi.responses import ORJSONResponse +from fastapi.exceptions import HTTPException +from pydantic import BaseModel, Field + +from .utils import ObjectCollection, api_object_parser +from ..settings import MODEL_MAP +from ..ranker import ReRankPipeline + +server = FastAPI() +pipeline_map: dict[str, ReRankPipeline] = {} + +def get_pipeline(__id: str): + if pipeline_map.get(__id) is None: + pipeline_map[__id] = ReRankPipeline.from_model_id(__id) + return pipeline_map[__id] + + +class SchemaContext(BaseModel): + pre: Optional[str] = Field(None, description="schema for pre-processing input.") + ctx: Optional[str] = Field(None, description="schema for extracting context.") + post: Optional[str] = Field(None, description="schema for extracting field after reranking.") + +class RerankContext(BaseModel): + model: str = Field("ms-marco-TinyBERT-L-2-v2", description="model to use for reranking.") + contexts: ObjectCollection = Field(..., description="contexts to rerank.") + query: str = Field(..., description="query for reranking evaluation.") + threshold: Optional[float] = Field(None, ge=0.0, le=1.0, description="filter contexts using threshold.") + map_score: bool = Field(False, description="map relevance score with context") + schema_: Optional[SchemaContext] = Field(default=None, alias='schema') + + +@server.get('/models', response_class=ORJSONResponse) +def list_models(): + return list(MODEL_MAP.keys()) + +@server.post('/rerank') +def rerank_endpoint(ctx: RerankContext): + if not ctx.contexts: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="contexts field cannot be an empty array or object" + ) + + if ctx.model not in MODEL_MAP: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"{ctx.model!r} model is not available" + ) + + schema = ctx.schema_ or SchemaContext() + if schema.pre is not None: + contexts = api_object_parser(ctx.contexts, schema=schema.pre) + if isinstance(contexts, list) and not contexts: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Empty array after pre-processing" + ) + no_list_err = "Pre-processing must result into an array of objects" + + else: + contexts = ctx.contexts + no_list_err = "Expected an array of string or object. 'pre' schema might help" + + if not isinstance(contexts, list): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=no_list_err + ) + + ctx_schema = schema.ctx or '.' + post_schema = schema.post or '.' + pipeline = get_pipeline(ctx.model) + try: + if ctx.map_score is False: + reranked = pipeline.invoke( + query=ctx.query, + contexts=contexts, + threshold=ctx.threshold, + key=lambda x: api_object_parser(x, ctx_schema) + ) + + return [api_object_parser(context, post_schema) for context in reranked] + else: + reranked_tup = pipeline.invoke_with_score( + query=ctx.query, + contexts=contexts, + threshold=ctx.threshold, + key=lambda x: api_object_parser(x, ctx_schema) + ) + + return [ + {'score': score, 'context': api_object_parser(context, post_schema)} + for (score, context) in reranked_tup + ] + except TypeError: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail='Context processing must result into string' + ) + +def _serve(host: str, port: int): + import uvicorn + try: + uvicorn.run(server, host=host, port=port) + except KeyboardInterrupt: + exit(0) + +if __name__ == "__main__": + _serve(host='0.0.0.0', port=12345) \ No newline at end of file diff --git a/swiftrank/cli/__init__.py b/swiftrank/interface/cli.py similarity index 61% rename from swiftrank/cli/__init__.py rename to swiftrank/interface/cli.py index 5ba915d..6c11eab 100644 --- a/swiftrank/cli/__init__.py +++ b/swiftrank/interface/cli.py @@ -1,6 +1,6 @@ from typing import Annotated -from cyclopts import App, Parameter +from cyclopts import App, Parameter, validators try: from signal import signal, SIGPIPE, SIG_DFL @@ -23,22 +23,25 @@ def build_processing_parameters( help="schema for extracting field after reranking.", show_default=False )] = None ): - from .utils import object_parser, print_and_exit + from .utils import cli_object_parser, print_and_exit def preprocessor(_input: str): if _input.startswith(('{', '[')): from orjson import loads, JSONDecodeError try: - return object_parser(loads(_input), pre) + return cli_object_parser(loads(_input), pre) except JSONDecodeError: - from io import StringIO - with StringIO(_input) as handler: - return list(map(loads, handler)) + try: + from io import StringIO + with StringIO(_input) as handler: + return list(map(loads, handler)) + except (JSONDecodeError, Exception): + print_and_exit("Input data format not valid.", code=1) except Exception: - print_and_exit("Malformed JSON object not parseable.", code=1) + print_and_exit("Input data format not valid.", code=1) import yaml try: - return object_parser(yaml.safe_load(_input), pre) + return cli_object_parser(yaml.safe_load(_input), pre) except yaml.MarkedYAMLError: return list(yaml.safe_load_all(_input)) except yaml.YAMLError: @@ -53,11 +56,11 @@ def __entry__( query: Annotated[str, Parameter( name=("-q", "--query"), help="query for reranking evaluation.")], threshold: Annotated[float, Parameter( - name=("-t", "--threshold"), help="filter contexts using threshold.")] = None, + name=("-t", "--threshold"), help="filter contexts using threshold.", validator=validators.Number(gte=0.0, lte=1.0))] = None, first: Annotated[bool, Parameter( name=("-f", "--first"), help="get most relevant context.", negative="", show_default=False)] = False, ): - from .utils import read_stdin, object_parser, print_and_exit + from .utils import read_stdin, cli_object_parser, print_and_exit processing_params: dict = {} if tokens: @@ -66,18 +69,17 @@ def __entry__( if not _input: return contexts = processing_params['preprocessor'](_input) - else: contexts = read_stdin(readlines=True) ctx_schema = processing_params.get('ctx_schema', '.') + post_schema = processing_params.get('post_schema') or ctx_schema + if not isinstance(contexts, list): - print_and_exit(object_parser(contexts, ctx_schema)) + print_and_exit(cli_object_parser(contexts, post_schema)) - if not all(contexts): - print_and_exit("No contexts found on stdin", code=1) - if len(contexts) == 1: - print_and_exit(contexts[0]) + if not contexts: + print_and_exit("Nothing to rerank!", code=1) from .. import settings from ..ranker import ReRankPipeline @@ -88,18 +90,32 @@ def __entry__( query=query, contexts=contexts, threshold=threshold, - key=lambda x: object_parser(x, ctx_schema) + key=lambda x: cli_object_parser(x, ctx_schema) ) + + if reranked and first: + print_and_exit( + cli_object_parser(reranked[0], post_schema) + ) + + for context in reranked: + print(cli_object_parser(context, post_schema)) + except TypeError: print_and_exit( - 'Context processing must result into string. Hint: `--ctx` flag might help.', code=1 + 'Context processing must result into string.', code=1 ) - post_schema = processing_params.get('post_schema') or ctx_schema - if reranked and first: - print_and_exit( - object_parser(reranked[0], post_schema) - ) - - for context in reranked: - print(object_parser(context, post_schema)) \ No newline at end of file +@app.meta.command(name="serve", help="Startup a swiftrank server") +def serve( + *, + host: Annotated[str, Parameter( + name=('--host'), help="Host name")] = '0.0.0.0', + port: Annotated[int, Parameter( + name=('--port',), help="Port number.")] = 12345 +): + from .api import _serve + _serve(host=host, port=port) + +if __name__ == "__main__": + app.meta() \ No newline at end of file diff --git a/swiftrank/cli/utils.py b/swiftrank/interface/utils.py similarity index 72% rename from swiftrank/cli/utils.py rename to swiftrank/interface/utils.py index bd5ffb4..179a6e9 100644 --- a/swiftrank/cli/utils.py +++ b/swiftrank/interface/utils.py @@ -1,22 +1,6 @@ import sys from typing import TypeAlias, Any -def print_and_exit(msg: str, code: int = 0): - stream = sys.stdout if not code else sys.stderr - print(msg, file=stream) - exit(code) - -def read_stdin(readlines: bool = False): - """Read values from standard input (stdin). """ - if sys.stdin.isatty(): - return - try: - if readlines is False: - return sys.stdin.read().rstrip('\n') - return [_.strip('\n') for _ in sys.stdin if _] - except KeyboardInterrupt: - return - ObjectCollection: TypeAlias = dict[str, Any] | list[Any] ObjectScalar: TypeAlias = bool | float | int | str @@ -33,7 +17,7 @@ def object_parser(obj: ObjectValue, schema: str) -> ObjectValue: if not re.match( pattern=r"^(?:(?:[.](?:[\w]+|\[\d?\]))+)$", string=usable_schema ): - print_and_exit(f'{schema!r} is not a valid schema.', code=1) + raise ValueError(f'{schema!r} is not a valid schema.') def __inner__(_in: ObjectValue, keys: list[str]): for idx, key in enumerate(keys): @@ -48,7 +32,7 @@ def __inner__(_in: ObjectValue, keys: list[str]): _in = _in[int(obj_idx)] continue except (KeyError, IndexError): - print_and_exit(f'{schema!r} schema not compatible with input data.', code=1) + raise ValueError(f'{schema!r} schema not compatible with input data.') _keys = keys[idx + 1:] if not _keys: @@ -58,4 +42,36 @@ def __inner__(_in: ObjectValue, keys: list[str]): return _in return __inner__( obj, [k for k, _ in groupby(usable_schema.lstrip('.').split('.'))] - ) \ No newline at end of file + ) + +def read_stdin(readlines: bool = False): + """Read values from standard input (stdin). """ + if sys.stdin.isatty(): + return + try: + if readlines is False: + return sys.stdin.read().rstrip('\n') + return [_.strip('\n') for _ in sys.stdin if _] + except KeyboardInterrupt: + return + +def print_and_exit(msg: str, code: int = 0): + stream = sys.stdout if not code else sys.stderr + print(msg, file=stream) + exit(code) + +def cli_object_parser(obj: ObjectValue, schema: str): + try: + return object_parser(obj=obj, schema=schema) + except ValueError as e: + print_and_exit(e.args[0], code=1) + +def api_object_parser(obj: ObjectValue, schema: str): + from fastapi import status, HTTPException + try: + return object_parser(obj=obj, schema=schema) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=e.args[0] + ) \ No newline at end of file diff --git a/swiftrank/ranker.py b/swiftrank/ranker.py index c152e3b..cd8ca9b 100644 --- a/swiftrank/ranker.py +++ b/swiftrank/ranker.py @@ -2,7 +2,7 @@ from pathlib import Path from collections import OrderedDict from typing import ( - overload, Any, Optional, Iterable, Callable, TypeVar + overload, cast, Any, Optional, Iterable, Callable, TypeVar ) import numpy as np @@ -66,9 +66,9 @@ def __load(self): tokenizer_config = self.__file_handler("tokenizer_config.json") tokens_map = self.__file_handler("special_tokens_map.json") - tokenizer: TokenizerLoader = TokenizerLoader.from_file(str( + tokenizer = cast(TokenizerLoader, TokenizerLoader.from_file(str( self.__file_handler("tokenizer.json", read_json=False) - )) + ))) tokenizer.enable_truncation(max_length=min(tokenizer_config["model_max_length"], self.max_length)) tokenizer.enable_padding(pad_id=config["pad_token_id"], pad_token=tokenizer_config["pad_token"]) @@ -88,18 +88,28 @@ def __load(self): class ReRankPipeline: """ Pipeline for reranking task. - :param ranker: `Ranker` class instance - :param tokenizer: `Tokenizer` class instance - - >>> from flashrank import ReRankPipeline - >>> pipeline = ReRankPipeline(ranker=ranker, tokenizer=tokenizer) - >>> pipeline.invoke( - ... query="", contexts=["", "", ...] - ... ) + + Example: + ```python + from swiftrank import ReRankPipeline + pipeline = ReRankPipeline(ranker=ranker, tokenizer=tokenizer) + pipeline.invoke( + query="", contexts=["", "", ...] + ) + ``` """ - def __init__(self, ranker: Ranker, tokenizer: Tokenizer) -> None: - self.ranker = ranker.instance - self.tokenizer = tokenizer.instance + def __init__( + self, + ranker: Optional[Ranker] = None, + tokenizer: Optional[Tokenizer] = None + ) -> None: + """ + Initialize a rerank pipeline + @param ranker: `Ranker` class instance + @param tokenizer: `Tokenizer` class instance + """ + self.ranker = (ranker or Ranker()).instance + self.tokenizer = (tokenizer or Tokenizer()).instance @classmethod def from_model_id(cls, __id: str, tk_max_length: int = 512): @@ -123,9 +133,9 @@ def invoke_with_score( ) -> list[tuple[float, str]]: """ Rerank contexts based on query. - :param query: The query to use for reranking evaluation. - :param contexts: The contexts to rerank. - :param threshold: Get contexts that are equal or higher than threshold value. + @param query: The query to use for reranking evaluation. + @param contexts: The contexts to rerank. + @param threshold: Get contexts that are equal or higher than threshold value. """ @overload @@ -134,10 +144,10 @@ def invoke_with_score( ) -> list[tuple[float, _T]]: """ Rerank contexts based on query. - :param query: The query to use for reranking evaluation. - :param contexts: The contexts object. - :param threshold: Get contexts that are equal or higher than threshold value. - :param key: callback to use for getting fields from contexts object. + @param query: The query to use for reranking evaluation. + @param contexts: The contexts object. + @param threshold: Get contexts that are equal or higher than threshold value. + @param key: callback to use for getting fields from contexts object. """ def invoke_with_score( @@ -167,8 +177,8 @@ def invoke_with_score( combined = sorted(zip(scores, contexts), key=lambda x: x[0], reverse=True) if threshold is None: - return [(sc, ctx) for sc, ctx in combined] - return [(sc, ctx) for sc, ctx in combined if sc >= threshold] + return [(float(sc), ctx) for sc, ctx in combined] + return [(float(sc), ctx) for sc, ctx in combined if sc >= threshold] @overload def invoke( @@ -176,9 +186,9 @@ def invoke( ) -> list[str]: """ Rerank contexts based on query. - :param query: The query to use for reranking evaluation. - :param contexts: The contexts to rerank. - :param threshold: Get contexts that are equal or higher than threshold value. + @param query: The query to use for reranking evaluation. + @param contexts: The contexts to rerank. + @param threshold: Get contexts that are equal or higher than threshold value. """ @overload @@ -187,10 +197,10 @@ def invoke( ) -> list[_T]: """ Rerank contexts based on query. - :param query: The query to use for reranking evaluation. - :param contexts: The contexts object. - :param threshold: Get contexts that are equal or higher than threshold value. - :param key: callback to use for getting fields from contexts object. + @param query: The query to use for reranking evaluation. + @param contexts: The contexts object. + @param threshold: Get contexts that are equal or higher than threshold value. + @param key: callback to use for getting fields from contexts object. """ def invoke( diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..c3059e5 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,117 @@ +import json +import requests +from pathlib import Path + +ENDPOINT = "http://127.0.0.1:12345/rerank" +BODY = { + "model": "ms-marco-TinyBERT-L-2-v2", + "contexts": {}, + "query": "string", + "threshold": None, + "schema": { + "pre": None, + "ctx": None, + "post": None + } +} +FINAL_OUTPUT = [ + "Jujutsu Kaisen 2nd Season", + "Jujutsu Kaisen 2nd Season Recaps", + "Jujutsu Kaisen", + "Jujutsu Kaisen Official PV", + "Jujutsu Kaisen 0 Movie", + "Shingeki no Kyojin Season 2", + "Shingeki no Kyojin Season 3 Part 2", + "Shingeki no Kyojin Season 3", + "Kimi ni Todoke 2nd Season", + "Shingeki no Kyojin: The Final Season" +] + +files_path = Path(__file__).parent.parent / 'files' + +def read_file_as_context_field(name: str, rl: bool = False): + content = (files_path / name).read_text() + if rl: + return content.split('\n') + return json.loads(content) + + +def test_http_exception_empty_array_or_object(): + response = requests.post(url=ENDPOINT, json=BODY) + assert response.status_code == 422 + assert response.json()['detail'] == "contexts field cannot be an empty array or object" + +def test_http_exception_model_not_available(): + response = requests.post( + url=ENDPOINT, json=BODY | {'model': 'no-model', 'contexts': ["non", "empty"]} + ) + assert response.status_code == 404 + assert response.json()['detail'] == "'no-model' model is not available" + +def test_http_exception_empty_array_after_pre_processing(): + response = requests.post( + url=ENDPOINT, json=BODY | {'contexts': { + "categories": [ + { + "type": "anime", + "items": [] + } + ] + }, 'schema': {'pre': '.categories[].items'}} + ) + assert response.status_code == 422 + assert response.json()['detail'] == "Empty array after pre-processing" + +def test_http_exception_pre_processing_must_result_into_array(): + response = requests.post( + url=ENDPOINT, json=BODY | {'contexts': { + "categories": [ + { + "type": "anime", + "items": ['non', 'empty'] + } + ] + }, 'schema': {'pre': '.categories[].type'}} + ) + assert response.status_code == 422 + assert response.json()['detail'] == "Pre-processing must result into an array of objects" + +def test_http_exception_expected_arrary_of_string_or_object(): + response = requests.post( + url=ENDPOINT, json=BODY | {'contexts': { + "categories": [ + { + "type": "anime", + "items": ['non', 'empty'] + } + ] + }} + ) + assert response.status_code == 422 + assert response.json()['detail'] == "Expected an array of string or object. 'pre' schema might help" + +def test_arrary_as_input(): + response = requests.post( + url=ENDPOINT, json=BODY | { + 'query': "Jujutsu Season 2", + 'contexts': read_file_as_context_field('contexts', rl=True)} + ) + + assert response.status_code == 200 + assert response.json() == FINAL_OUTPUT + +def test_object_as_input(): + response = requests.post( + url=ENDPOINT, json=BODY | { + 'query': "Jujutsu Season 2", + 'contexts': read_file_as_context_field('contexts.json'), + 'threshold': 0.9, + 'schema': { + 'pre': '.categories[].items', + 'ctx': '.name', + 'post': '.name' + } + } + ) + assert response.status_code == 200 + assert response.json() == FINAL_OUTPUT[0:3] diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..63210c8 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,101 @@ +import os +import sys +from pathlib import Path +from subprocess import Popen, PIPE + +os.environ["SWIFTRANK_MODEL"] = "ms-marco-TinyBERT-L-2-v2" +exec_args = [sys.executable, '-m', 'swiftrank.interface.cli'] +files_path = Path(__file__).parent.parent / 'files' + +def read_file_bytes(name: str): + return (files_path / name).read_bytes() + +def test_print_relevant_context(): + process = Popen( + [*exec_args, '-q', 'Jujutsu Kaisen: Season 2', '-f'], stdin=PIPE, stdout=PIPE, stderr=PIPE + ) + + process.stdin.write(read_file_bytes('contexts')) + process.stdin.close() + + stdout, stderr = process.communicate() + assert stdout.decode().strip() == "Jujutsu Kaisen 2nd Season" + assert stderr.decode().strip() == "" + +def test_filtering_using_threshold(): + process = Popen( + [*exec_args, '-q', 'Jujutsu Kaisen: Season 2', '-t', '0.98'], stdin=PIPE, stdout=PIPE, stderr=PIPE + ) + + process.stdin.write(read_file_bytes('contexts')) + process.stdin.close() + + stdout, stderr = process.communicate() + rlist = [i.strip() for i in stdout.decode().split('\n') if i] + assert rlist == ['Jujutsu Kaisen 2nd Season', 'Jujutsu Kaisen 2nd Season Recaps'] + assert stderr.decode().strip() == "" + +def test_handling_json(): + process = Popen( + [*exec_args, '-q', 'Jujutsu Kaisen: Season 2', 'process', '-r', '.categories[].items', '-c', '.name', '-t', '0.9'], stdin=PIPE, stdout=PIPE, stderr=PIPE + ) + + process.stdin.write(read_file_bytes('contexts.json')) + process.stdin.close() + + stdout, stderr = process.communicate() + rlist = [i.strip() for i in stdout.decode().split('\n') if i] + assert rlist == ['Jujutsu Kaisen 2nd Season', + 'Jujutsu Kaisen 2nd Season Recaps', + 'Jujutsu Kaisen', + 'Jujutsu Kaisen Official PV', + 'Jujutsu Kaisen 0 Movie'] + assert stderr.decode().strip() == "" + +def test_handling_json_with_post_processing(): + process = Popen( + [*exec_args, '-q', 'Jujutsu Kaisen: Season 2', 'process', '-r', '.categories[].items', '-c', '.name', '-p', '.url', '-f'], stdin=PIPE, stdout=PIPE, stderr=PIPE + ) + + process.stdin.write(read_file_bytes('contexts.json')) + process.stdin.close() + + stdout, stderr = process.communicate() + assert stdout.decode().strip() == "https://myanimelist.net/anime/51009/Jujutsu_Kaisen_2nd_Season" + assert stderr.decode().strip() == "" + +def test_handling_yaml_with_post_processing(): + process = Popen( + [*exec_args, '-q', 'Monogatari Series: Season 2', 'process', '-r', '.categories[].items', '-c', '.name', '-p', '.payload.status', '-f'], stdin=PIPE, stdout=PIPE, stderr=PIPE + ) + + process.stdin.write(read_file_bytes('contexts.yaml')) + process.stdin.close() + + stdout, stderr = process.communicate() + assert stdout.decode().strip() == "Finished Airing" + assert stderr.decode().strip() == "" + +def test_handling_jsonlines_with_post_processing(): + process = Popen( + [*exec_args, '-q', 'Monogatari Series: Season 2', 'process', '-c', '.name', '-p', '.payload.aired', '-f'], stdin=PIPE, stdout=PIPE, stderr=PIPE + ) + + process.stdin.write(read_file_bytes('contexts.jsonl')) + process.stdin.close() + + stdout, stderr = process.communicate() + assert stdout.decode().strip() == "Jul 7, 2013 to Dec 29, 2013" + assert stderr.decode().strip() == "" + +def test_handling_yamllines(): + process = Popen( + [*exec_args, '-q', 'Monogatari Series: Season 2', 'process', '-c', '.name', '-f'], stdin=PIPE, stdout=PIPE, stderr=PIPE + ) + + process.stdin.write(read_file_bytes('contextlines.yaml')) + process.stdin.close() + + stdout, stderr = process.communicate() + assert stdout.decode().strip() == "Monogatari Series: Second Season" + assert stderr.decode().strip() == "" \ No newline at end of file diff --git a/tests/test_methods.py b/tests/test_methods.py new file mode 100644 index 0000000..969eddb --- /dev/null +++ b/tests/test_methods.py @@ -0,0 +1,42 @@ +from swiftrank import ReRankPipeline + +PIPELINE = ReRankPipeline.from_model_id("ms-marco-TinyBERT-L-2-v2") + +QUERY = "Tricks to accelerate LLM inference" + +CONTEXTS = [ + "Introduce *lookahead decoding*: - a parallel decoding algo to accelerate LLM inference - w/o the need for a draft model or a data store - linearly decreases # decoding steps relative to log(FLOPs) used per decoding step.", + "LLM inference efficiency will be one of the most crucial topics for both industry and academia, simply because the more efficient you are, the more $$$ you will save. vllm project is a must-read for this direction, and now they have just released the paper", + "There are many ways to increase LLM inference throughput (tokens/second) and decrease memory footprint, sometimes at the same time. Here are a few methods I’ve found effective when working with Llama 2. These methods are all well-integrated with Hugging Face. This list is far from exhaustive; some of these techniques can be used in combination with each other and there are plenty of others to try. - Bettertransformer (Optimum Library): Simply call `model.to_bettertransformer()` on your Hugging Face model for a modest improvement in tokens per second. - Fp4 Mixed-Precision (Bitsandbytes): Requires minimal configuration and dramatically reduces the model's memory footprint. - AutoGPTQ: Time-consuming but leads to a much smaller model and faster inference. The quantization is a one-time cost that pays off in the long run.", + "Ever want to make your LLM inference go brrrrr but got stuck at implementing speculative decoding and finding the suitable draft model? No more pain! Thrilled to unveil Medusa, a simple framework that removes the annoying draft model while getting 2x speedup.", + "vLLM is a fast and easy-to-use library for LLM inference and serving. vLLM is fast with: State-of-the-art serving throughput Efficient management of attention key and value memory with PagedAttention Continuous batching of incoming requests Optimized CUDA kernels" +] + +RERANKED = [ + (0.9977508, 'Introduce *lookahead decoding*: - a parallel decoding algo to accelerate LLM inference - w/o the need for a draft model or a data store - linearly decreases # decoding steps relative to log(FLOPs) used per decoding step.',), + (0.9415497, "There are many ways to increase LLM inference throughput (tokens/second) and decrease memory footprint, sometimes at the same time. Here are a few methods I’ve found effective when working with Llama 2. These methods are all well-integrated with Hugging Face. This list is far from exhaustive; some of these techniques can be used in combination with each other and there are plenty of others to try. - Bettertransformer (Optimum Library): Simply call `model.to_bettertransformer()` on your Hugging Face model for a modest improvement in tokens per second. - Fp4 Mixed-Precision (Bitsandbytes): Requires minimal configuration and dramatically reduces the model's memory footprint. - AutoGPTQ: Time-consuming but leads to a much smaller model and faster inference. The quantization is a one-time cost that pays off in the long run.",), + (0.47455463, 'vLLM is a fast and easy-to-use library for LLM inference and serving. vLLM is fast with: State-of-the-art serving throughput Efficient management of attention key and value memory with PagedAttention Continuous batching of incoming requests Optimized CUDA kernels',), + (0.43783104, 'LLM inference efficiency will be one of the most crucial topics for both industry and academia, simply because the more efficient you are, the more $$$ you will save. vllm project is a must-read for this direction, and now they have just released the paper',), + (0.043041725, 'Ever want to make your LLM inference go brrrrr but got stuck at implementing speculative decoding and finding the suitable draft model? No more pain! Thrilled to unveil Medusa, a simple framework that removes the annoying draft model while getting 2x speedup.',) +] + +def test_invoke_with_score(): + output = PIPELINE.invoke_with_score(query=QUERY, contexts=CONTEXTS) + for idx in range(len(output)): + assert (f"{output[idx][0]:.5f}" == f"{RERANKED[idx][0]:.5f}") + +def test_invoke(): + output = PIPELINE.invoke(query=QUERY, contexts=CONTEXTS) + for idx in range(len(output)): + assert (output[idx] == RERANKED[idx][1]) + +def test_invoke_with_threshold_parameter(): + output = PIPELINE.invoke(query=QUERY, contexts=CONTEXTS, threshold=0.8) + for idx in range(len(output)): + assert (output[idx] == RERANKED[idx][1]) + +def test_invoke_with_key_parameter(): + context_map = [{'id': idx, 'content': content} for idx, content in enumerate(CONTEXTS)] + output = PIPELINE.invoke(query=QUERY, contexts=context_map, key=lambda x: x['content']) + for idx in range(len(output)): + assert (output[idx]['content'] == RERANKED[idx][1])