-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6b82f1f
commit 0eb8fd0
Showing
11 changed files
with
1,355 additions
and
0 deletions.
There are no files selected for viewing
898 changes: 898 additions & 0 deletions
898
docs/docs/examples/node_postprocessor/ColPaliRerank.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
153 changes: 153 additions & 0 deletions
153
llama-index-integrations/postprocessor/llama-index-postprocessor-colpali-rerank/.gitignore
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
llama_index/_static | ||
.DS_Store | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
bin/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
etc/ | ||
include/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
share/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
.ruff_cache | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
notebooks/ | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
pyvenv.cfg | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# Jetbrains | ||
.idea | ||
modules/ | ||
*.swp | ||
|
||
# VsCode | ||
.vscode | ||
|
||
# pipenv | ||
Pipfile | ||
Pipfile.lock | ||
|
||
# pyright | ||
pyrightconfig.json |
3 changes: 3 additions & 0 deletions
3
llama-index-integrations/postprocessor/llama-index-postprocessor-colpali-rerank/BUILD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
poetry_requirements( | ||
name="poetry", | ||
) |
17 changes: 17 additions & 0 deletions
17
llama-index-integrations/postprocessor/llama-index-postprocessor-colpali-rerank/Makefile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
GIT_ROOT ?= $(shell git rev-parse --show-toplevel) | ||
|
||
help: ## Show all Makefile targets. | ||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' | ||
|
||
format: ## Run code autoformatters (black). | ||
pre-commit install | ||
git ls-files | xargs pre-commit run black --files | ||
|
||
lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy | ||
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files | ||
|
||
test: ## Run tests via pytest. | ||
pytest tests | ||
|
||
watch-docs: ## Build and watch documentation. | ||
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ |
5 changes: 5 additions & 0 deletions
5
...x-integrations/postprocessor/llama-index-postprocessor-colpali-rerank/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# LlamaIndex Postprocessor Integration: ColPali Rerank | ||
|
||
[ColPali](https://huggingface.co/vidore/colpali-v1.2): ColPali it is a model based on a novel model architecture and training strategy based on Vision Language Models (VLMs), to efficiently index documents from their visual features. | ||
|
||
Please `pip install llama-index-postprocessor-colpali-rerank` to install ColPali Rerank package. |
1 change: 1 addition & 0 deletions
1
...r/llama-index-postprocessor-colpali-rerank/llama_index/postprocessor/colpali_rerank/BUILD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python_sources() |
3 changes: 3 additions & 0 deletions
3
...a-index-postprocessor-colpali-rerank/llama_index/postprocessor/colpali_rerank/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from llama_index.postprocessor.colpali_rerank.base import ColPaliRerank | ||
|
||
__all__ = ["ColPaliRerank"] |
135 changes: 135 additions & 0 deletions
135
...llama-index-postprocessor-colpali-rerank/llama_index/postprocessor/colpali_rerank/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import torch | ||
from colpali_engine.models import ColPali, ColPaliProcessor | ||
from PIL import Image | ||
from typing import Any, List, Optional | ||
|
||
from llama_index.core.bridge.pydantic import Field, PrivateAttr | ||
from llama_index.core.callbacks import CBEventType, EventPayload | ||
from llama_index.core.instrumentation import get_dispatcher | ||
from llama_index.core.instrumentation.events.rerank import ( | ||
ReRankEndEvent, | ||
ReRankStartEvent, | ||
) | ||
from llama_index.core.postprocessor.types import BaseNodePostprocessor | ||
from llama_index.core.schema import NodeWithScore, QueryBundle | ||
from llama_index.core.utils import infer_torch_device | ||
|
||
dispatcher = get_dispatcher(__name__) | ||
|
||
|
||
class ColPaliRerank(BaseNodePostprocessor): | ||
model: str = Field(description="Colpali model name.") | ||
top_n: int = Field(description="Number of nodes to return sorted by score.") | ||
device: str = Field( | ||
default="cuda", | ||
description="Device to use for the model.", | ||
) | ||
keep_retrieval_score: bool = Field( | ||
default=False, | ||
description="Whether to keep the retrieval score in metadata.", | ||
) | ||
_model: Any = PrivateAttr() | ||
_processor: Any = PrivateAttr() | ||
|
||
def __init__( | ||
self, | ||
top_n: int = 5, | ||
model: str = "vidore/colpali-v1.2", | ||
device: Optional[str] = None, | ||
keep_retrieval_score: Optional[bool] = False, | ||
): | ||
device = infer_torch_device() if device is None else device | ||
super().__init__( | ||
top_n=top_n, | ||
device=device, | ||
keep_retrieval_score=keep_retrieval_score, | ||
model=model, | ||
) | ||
|
||
self._model = ColPali.from_pretrained( | ||
model, torch_dtype=torch.bfloat16, device_map=device | ||
).eval() | ||
self._processor = ColPaliProcessor.from_pretrained(model) | ||
|
||
@classmethod | ||
def class_name(cls) -> str: | ||
return "ColPaliRerank" | ||
|
||
def get_image_paths(self, nodes: List[NodeWithScore]): | ||
image_paths = [] | ||
for node_ in nodes: | ||
image_paths.append(node_.node.metadata["file_path"]) | ||
|
||
return image_paths | ||
|
||
def load_image(self, image_path: str) -> Image.Image: | ||
return Image.open(image_path) | ||
|
||
def load_images(self, image_paths: List[str]) -> List[Image.Image]: | ||
images = [] | ||
for image_path in image_paths: | ||
images.append(self.load_image(image_path)) | ||
|
||
return images | ||
|
||
def _calculate_sim(self, query: str, images_paths: List[str]) -> List[float]: | ||
# Load the images | ||
images = self.load_images(images_paths) | ||
|
||
# Process the inputs | ||
batch_images = self._processor.process_images(images).to(self._model.device) | ||
batch_queries = self._processor.process_queries([query]).to(self._model.device) | ||
|
||
# Forward pass | ||
with torch.no_grad(): | ||
image_embeddings = self._model(**batch_images) | ||
querry_embeddings = self._model(**batch_queries) | ||
|
||
return self._processor.score_multi_vector(querry_embeddings, image_embeddings) | ||
|
||
def _postprocess_nodes( | ||
self, | ||
nodes: List[NodeWithScore], | ||
query_bundle: Optional[QueryBundle] = None, | ||
) -> List[NodeWithScore]: | ||
dispatcher.event( | ||
ReRankStartEvent( | ||
query=query_bundle, nodes=nodes, top_n=self.top_n, model_name=self.model | ||
) | ||
) | ||
|
||
if query_bundle is None: | ||
raise ValueError("Missing query bundle in extra info.") | ||
if len(nodes) == 0: | ||
return [] | ||
|
||
image_paths = self.get_image_paths(nodes) | ||
|
||
with self.callback_manager.event( | ||
CBEventType.RERANKING, | ||
payload={ | ||
EventPayload.NODES: nodes, | ||
EventPayload.MODEL_NAME: self.model, | ||
EventPayload.QUERY_STR: query_bundle.query_str, | ||
EventPayload.TOP_K: self.top_n, | ||
}, | ||
) as event: | ||
scores = self._calculate_sim(query_bundle.query_str, image_paths)[ | ||
0 | ||
].tolist() | ||
|
||
assert len(scores) == len(nodes) | ||
|
||
for node, score in zip(nodes, scores): | ||
if self.keep_retrieval_score: | ||
# keep the retrieval score in metadata | ||
node.node.metadata["retrieval_score"] = node.score | ||
node.score = float(score) | ||
|
||
reranked_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[ | ||
: self.top_n | ||
] | ||
event.on_end(payload={EventPayload.NODES: reranked_nodes}) | ||
|
||
dispatcher.event(ReRankEndEvent(nodes=reranked_nodes)) | ||
return reranked_nodes |
61 changes: 61 additions & 0 deletions
61
...-index-integrations/postprocessor/llama-index-postprocessor-colpali-rerank/pyproject.toml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
[build-system] | ||
build-backend = "poetry.core.masonry.api" | ||
requires = ["poetry-core"] | ||
|
||
[tool.codespell] | ||
check-filenames = true | ||
check-hidden = true | ||
# Feel free to un-skip examples, and experimental, you will just need to | ||
# work through many typos (--write-changes and --interactive will help) | ||
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" | ||
|
||
[tool.llamahub] | ||
contains_example = true | ||
import_path = "llama_index.postprocessor.colpali_rerank" | ||
|
||
[tool.llamahub.class_authors] | ||
ColPaliRerank = "ravitheja" | ||
|
||
[tool.mypy] | ||
disallow_untyped_defs = true | ||
# Remove venv skip when integrated with pre-commit | ||
exclude = ["_static", "build", "examples", "notebooks", "venv"] | ||
ignore_missing_imports = true | ||
python_version = "3.8" | ||
|
||
[tool.poetry] | ||
authors = ["Ravi Theja <[email protected]>"] | ||
description = "llama-index postprocessor colpali-rerank integration" | ||
exclude = ["**/BUILD"] | ||
license = "MIT" | ||
name = "llama-index-postprocessor-colpali-rerank" | ||
packages = [{include = "llama_index/"}] | ||
readme = "README.md" | ||
version = "0.1.0" | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.9,<4.0" | ||
torch = "*" | ||
transformers = ">=4.37.2" | ||
llama-index-core = "^0.11.0" | ||
colpali-engine = ">=0.3.0,<0.4.0" | ||
setuptools = "*" | ||
|
||
[tool.poetry.group.dev.dependencies] | ||
black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} | ||
codespell = {extras = ["toml"], version = ">=v2.2.6"} | ||
ipython = "8.10.0" | ||
jupyter = "^1.0.0" | ||
mypy = "0.991" | ||
pre-commit = "3.2.0" | ||
pylint = "2.15.10" | ||
pytest = "7.2.1" | ||
pytest-mock = "3.11.1" | ||
ruff = "0.0.292" | ||
tree-sitter-languages = "^1.8.0" | ||
types-Deprecated = ">=0.1.0" | ||
types-PyYAML = "^6.0.12.12" | ||
types-protobuf = "^4.24.0.4" | ||
types-redis = "4.5.5.0" | ||
types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991 | ||
types-setuptools = "67.1.0.0" |
1 change: 1 addition & 0 deletions
1
llama-index-integrations/postprocessor/llama-index-postprocessor-colpali-rerank/tests/BUILD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python_tests() |
Oops, something went wrong.