Skip to content

Commit

Permalink
Add ColPali as reranker (#16829)
Browse files Browse the repository at this point in the history
  • Loading branch information
ravi03071991 authored Nov 5, 2024
1 parent 6b82f1f commit 0eb8fd0
Show file tree
Hide file tree
Showing 11 changed files with 1,355 additions and 0 deletions.
898 changes: 898 additions & 0 deletions docs/docs/examples/node_postprocessor/ColPaliRerank.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
llama_index/_static
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
bin/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
etc/
include/
lib/
lib64/
parts/
sdist/
share/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints
notebooks/

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
pyvenv.cfg

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# Jetbrains
.idea
modules/
*.swp

# VsCode
.vscode

# pipenv
Pipfile
Pipfile.lock

# pyright
pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
poetry_requirements(
name="poetry",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)

help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'

format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files

lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files

test: ## Run tests via pytest.
pytest tests

watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# LlamaIndex Postprocessor Integration: ColPali Rerank

[ColPali](https://huggingface.co/vidore/colpali-v1.2): ColPali it is a model based on a novel model architecture and training strategy based on Vision Language Models (VLMs), to efficiently index documents from their visual features.

Please `pip install llama-index-postprocessor-colpali-rerank` to install ColPali Rerank package.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llama_index.postprocessor.colpali_rerank.base import ColPaliRerank

__all__ = ["ColPaliRerank"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import torch
from colpali_engine.models import ColPali, ColPaliProcessor
from PIL import Image
from typing import Any, List, Optional

from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.events.rerank import (
ReRankEndEvent,
ReRankStartEvent,
)
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle
from llama_index.core.utils import infer_torch_device

dispatcher = get_dispatcher(__name__)


class ColPaliRerank(BaseNodePostprocessor):
model: str = Field(description="Colpali model name.")
top_n: int = Field(description="Number of nodes to return sorted by score.")
device: str = Field(
default="cuda",
description="Device to use for the model.",
)
keep_retrieval_score: bool = Field(
default=False,
description="Whether to keep the retrieval score in metadata.",
)
_model: Any = PrivateAttr()
_processor: Any = PrivateAttr()

def __init__(
self,
top_n: int = 5,
model: str = "vidore/colpali-v1.2",
device: Optional[str] = None,
keep_retrieval_score: Optional[bool] = False,
):
device = infer_torch_device() if device is None else device
super().__init__(
top_n=top_n,
device=device,
keep_retrieval_score=keep_retrieval_score,
model=model,
)

self._model = ColPali.from_pretrained(
model, torch_dtype=torch.bfloat16, device_map=device
).eval()
self._processor = ColPaliProcessor.from_pretrained(model)

@classmethod
def class_name(cls) -> str:
return "ColPaliRerank"

def get_image_paths(self, nodes: List[NodeWithScore]):
image_paths = []
for node_ in nodes:
image_paths.append(node_.node.metadata["file_path"])

return image_paths

def load_image(self, image_path: str) -> Image.Image:
return Image.open(image_path)

def load_images(self, image_paths: List[str]) -> List[Image.Image]:
images = []
for image_path in image_paths:
images.append(self.load_image(image_path))

return images

def _calculate_sim(self, query: str, images_paths: List[str]) -> List[float]:
# Load the images
images = self.load_images(images_paths)

# Process the inputs
batch_images = self._processor.process_images(images).to(self._model.device)
batch_queries = self._processor.process_queries([query]).to(self._model.device)

# Forward pass
with torch.no_grad():
image_embeddings = self._model(**batch_images)
querry_embeddings = self._model(**batch_queries)

return self._processor.score_multi_vector(querry_embeddings, image_embeddings)

def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
dispatcher.event(
ReRankStartEvent(
query=query_bundle, nodes=nodes, top_n=self.top_n, model_name=self.model
)
)

if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []

image_paths = self.get_image_paths(nodes)

with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
scores = self._calculate_sim(query_bundle.query_str, image_paths)[
0
].tolist()

assert len(scores) == len(nodes)

for node, score in zip(nodes, scores):
if self.keep_retrieval_score:
# keep the retrieval score in metadata
node.node.metadata["retrieval_score"] = node.score
node.score = float(score)

reranked_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
: self.top_n
]
event.on_end(payload={EventPayload.NODES: reranked_nodes})

dispatcher.event(ReRankEndEvent(nodes=reranked_nodes))
return reranked_nodes
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]

[tool.codespell]
check-filenames = true
check-hidden = true
# Feel free to un-skip examples, and experimental, you will just need to
# work through many typos (--write-changes and --interactive will help)
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"

[tool.llamahub]
contains_example = true
import_path = "llama_index.postprocessor.colpali_rerank"

[tool.llamahub.class_authors]
ColPaliRerank = "ravitheja"

[tool.mypy]
disallow_untyped_defs = true
# Remove venv skip when integrated with pre-commit
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"

[tool.poetry]
authors = ["Ravi Theja <[email protected]>"]
description = "llama-index postprocessor colpali-rerank integration"
exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-postprocessor-colpali-rerank"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.1.0"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
torch = "*"
transformers = ">=4.37.2"
llama-index-core = "^0.11.0"
colpali-engine = ">=0.3.0,<0.4.0"
setuptools = "*"

[tool.poetry.group.dev.dependencies]
black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"}
codespell = {extras = ["toml"], version = ">=v2.2.6"}
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
types-protobuf = "^4.24.0.4"
types-redis = "4.5.5.0"
types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991
types-setuptools = "67.1.0.0"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_tests()
Loading

0 comments on commit 0eb8fd0

Please sign in to comment.