Skip to content

Commit

Permalink
Binary Quantization Support for pgvector HNSW Algorithm (#389)
Browse files Browse the repository at this point in the history
* Added binary quantization support in pgvector hnsw

* Parameterized search sql queries.
Added distance operator used for reranking, and quantized vector fetch limit in CLI.

* remove debug logs

* update pgvectorhnsw command option name.

* Binary quantization option added in frontend for pgvectorhnsw

* remove redundant code

* Refactored code

* Removed hamming and jaccard distance options for full vectors.
Moved reranking_metric to hnsw config class.

* refactored code, removed duplicate code.

* Reverted code changes for float input type.
  • Loading branch information
Sheharyar570 authored Oct 29, 2024
1 parent 369f3c6 commit d11330d
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 73 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ Options:
--m INTEGER hnsw m
--ef-construction INTEGER hnsw ef-construction
--ef-search INTEGER hnsw ef-search
--quantization-type [none|halfvec]
--quantization-type [none|bit|halfvec]
quantization type for vectors
--custom-case-name TEXT Custom case name i.e. PerformanceCase1536D50K
--custom-case-description TEXT Custom name description
Expand Down
2 changes: 2 additions & 0 deletions vectordb_bench/backend/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class MetricType(str, Enum):
L2 = "L2"
COSINE = "COSINE"
IP = "IP"
HAMMING = "HAMMING"
JACCARD = "JACCARD"


class IndexType(str, Enum):
Expand Down
48 changes: 47 additions & 1 deletion vectordb_bench/backend/clients/pgvector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
from pydantic import SecretStr

from vectordb_bench.backend.clients.api import MetricType

from ....cli.cli import (
CommonTypedDict,
HNSWFlavor1,
Expand All @@ -16,6 +18,13 @@
from vectordb_bench.backend.clients import DB



def set_default_quantized_fetch_limit(ctx, param, value):
if ctx.params.get("reranking") and value is None:
# ef_search is the default value for quantized_fetch_limit as it's bound by ef_search.
return ctx.params["ef_search"]
return value

class PgVectorTypedDict(CommonTypedDict):
user_name: Annotated[
str, click.option("--user-name", type=str, help="Db username", required=True)
Expand Down Expand Up @@ -61,11 +70,45 @@ class PgVectorTypedDict(CommonTypedDict):
Optional[str],
click.option(
"--quantization-type",
type=click.Choice(["none", "halfvec"]),
type=click.Choice(["none", "bit", "halfvec"]),
help="quantization type for vectors",
required=False,
),
]
reranking: Annotated[
Optional[bool],
click.option(
"--reranking/--skip-reranking",
type=bool,
help="Enable reranking for HNSW search for binary quantization",
default=False,
),
]
reranking_metric: Annotated[
Optional[str],
click.option(
"--reranking-metric",
type=click.Choice(
[metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"]]
),
help="Distance metric for reranking",
default="COSINE",
show_default=True,
),
]
quantized_fetch_limit: Annotated[
Optional[int],
click.option(
"--quantized-fetch-limit",
type=int,
help="Limit of fetching quantized vector ranked by distance for reranking \
-- bound by ef_search",
required=False,
callback=set_default_quantized_fetch_limit,
)
]



class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict):
...
Expand Down Expand Up @@ -126,6 +169,9 @@ def PgVectorHNSW(
maintenance_work_mem=parameters["maintenance_work_mem"],
max_parallel_workers=parameters["max_parallel_workers"],
quantization_type=parameters["quantization_type"],
reranking=parameters["reranking"],
reranking_metric=parameters["reranking_metric"],
quantized_fetch_limit=parameters["quantized_fetch_limit"],
),
**parameters,
)
33 changes: 28 additions & 5 deletions vectordb_bench/backend/clients/pgvector/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def parse_metric(self) -> str:
elif self.metric_type == MetricType.IP:
return "halfvec_ip_ops"
return "halfvec_cosine_ops"
elif self.quantization_type == "bit":
if self.metric_type == MetricType.JACCARD:
return "bit_jaccard_ops"
return "bit_hamming_ops"
else:
if self.metric_type == MetricType.L2:
return "vector_l2_ops"
Expand All @@ -73,18 +77,31 @@ def parse_metric(self) -> str:
return "vector_cosine_ops"

def parse_metric_fun_op(self) -> LiteralString:
if self.metric_type == MetricType.L2:
return "<->"
elif self.metric_type == MetricType.IP:
return "<#>"
return "<=>"
if self.quantization_type == "bit":
if self.metric_type == MetricType.JACCARD:
return "<%>"
return "<~>"
else:
if self.metric_type == MetricType.L2:
return "<->"
elif self.metric_type == MetricType.IP:
return "<#>"
return "<=>"

def parse_metric_fun_str(self) -> str:
if self.metric_type == MetricType.L2:
return "l2_distance"
elif self.metric_type == MetricType.IP:
return "max_inner_product"
return "cosine_distance"

def parse_reranking_metric_fun_op(self) -> LiteralString:
if self.reranking_metric == MetricType.L2:
return "<->"
elif self.reranking_metric == MetricType.IP:
return "<#>"
return "<=>"


@abstractmethod
def index_param(self) -> PgVectorIndexParam:
Expand Down Expand Up @@ -195,6 +212,9 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
maintenance_work_mem: Optional[str] = None
max_parallel_workers: Optional[int] = None
quantization_type: Optional[str] = None
reranking: Optional[bool] = None
quantized_fetch_limit: Optional[int] = None
reranking_metric: Optional[str] = None

def index_param(self) -> PgVectorIndexParam:
index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
Expand All @@ -214,6 +234,9 @@ def index_param(self) -> PgVectorIndexParam:
def search_param(self) -> PgVectorSearchParam:
return {
"metric_fun_op": self.parse_metric_fun_op(),
"reranking": self.reranking,
"reranking_metric_fun_op": self.parse_reranking_metric_fun_op(),
"quantized_fetch_limit": self.quantized_fetch_limit,
}

def session_param(self) -> PgVectorSessionCommands:
Expand Down
177 changes: 113 additions & 64 deletions vectordb_bench/backend/clients/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from psycopg import Connection, Cursor, sql

from ..api import VectorDB
from .config import PgVectorConfigDict, PgVectorIndexConfig
from .config import PgVectorConfigDict, PgVectorIndexConfig, PgVectorHNSWConfig

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,6 +87,92 @@ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
assert cursor is not None, "Cursor is not initialized"

return conn, cursor

def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
index_param = self.case_config.index_param()
reranking = self.case_config.search_param()["reranking"]
column_name = (
sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
if index_param["quantization_type"] == "bit"
else sql.SQL("embedding")
)
search_vector = (
sql.SQL("binary_quantize({0})").format(sql.Placeholder())
if index_param["quantization_type"] == "bit"
else sql.Placeholder()
)

# The following sections assume that the quantization_type value matches the quantization function name
if index_param["quantization_type"] != None:
if index_param["quantization_type"] == "bit" and reranking:
# Embeddings needs to be passed to binary_quantize function if quantization_type is bit
search_query = sql.Composed(
[
sql.SQL(
"""
SELECT i.id
FROM (
SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
FROM public.{table_name} {where_clause}
ORDER BY {column_name}::{quantization_type}({dim})
"""
).format(
table_name=sql.Identifier(self.table_name),
column_name=column_name,
reranking_metric_fun_op=sql.SQL(self.case_config.search_param()["reranking_metric_fun_op"]),
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(
"""
{search_vector}
LIMIT {quantized_fetch_limit}
) i
ORDER BY i.distance
LIMIT %s::int
"""
).format(
search_vector=search_vector,
quantized_fetch_limit=sql.Literal(
self.case_config.search_param()["quantized_fetch_limit"]
),
),
]
)
else:
search_query = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) "
).format(
table_name=sql.Identifier(self.table_name),
column_name=column_name,
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" {search_vector} LIMIT %s::int").format(search_vector=search_vector),
]
)
else:
search_query = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding "
).format(
table_name=sql.Identifier(self.table_name),
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)

return search_query


@contextmanager
def init(self) -> Generator[None, None, None]:
Expand All @@ -112,63 +198,8 @@ def init(self) -> Generator[None, None, None]:
self.cursor.execute(command)
self.conn.commit()

index_param = self.case_config.index_param()
# The following sections assume that the quantization_type value matches the quantization function name
if index_param["quantization_type"] != None:
self._filtered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding::{quantization_type}({dim}) "
).format(
table_name=sql.Identifier(self.table_name),
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
]
)
else:
self._filtered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
).format(table_name=sql.Identifier(self.table_name)),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)

if index_param["quantization_type"] != None:
self._unfiltered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} ORDER BY embedding::{quantization_type}({dim}) "
).format(
table_name=sql.Identifier(self.table_name),
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
]
)
else:
self._unfiltered_search = sql.Composed(
[
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
sql.Identifier(self.table_name)
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)
self._filtered_search = self._generate_search_query(filtered=True)
self._unfiltered_search = self._generate_search_query()

try:
yield
Expand Down Expand Up @@ -306,12 +337,17 @@ def _create_index(self):
if index_param["quantization_type"] != None:
index_create_sql = sql.SQL(
"""
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} ((embedding::{quantization_type}({dim})) {embedding_metric})
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} (({column_name}::{quantization_type}({dim})) {embedding_metric})
"""
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
column_name=(
sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
if index_param["quantization_type"] == "bit"
else sql.Identifier("embedding")
),
index_type=sql.Identifier(index_param["index_type"]),
# This assumes that the quantization_type value matches the quantization function name
quantization_type=sql.SQL(index_param["quantization_type"]),
Expand Down Expand Up @@ -406,15 +442,28 @@ def search_embedding(
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

index_param = self.case_config.index_param()
search_param = self.case_config.search_param()
q = np.asarray(query)
if filters:
gt = filters.get("id")
result = self.cursor.execute(
if index_param["quantization_type"] == "bit" and search_param["reranking"]:
result = self.cursor.execute(
self._filtered_search, (q, gt, q, k), prepare=True, binary=True
)
else:
result = self.cursor.execute(
self._filtered_search, (gt, q, k), prepare=True, binary=True
)
)

else:
result = self.cursor.execute(
if index_param["quantization_type"] == "bit" and search_param["reranking"]:
result = self.cursor.execute(
self._unfiltered_search, (q, q, k), prepare=True, binary=True
)
else:
result = self.cursor.execute(
self._unfiltered_search, (q, k), prepare=True, binary=True
)
)

return [int(i[0]) for i in result.fetchall()]
6 changes: 6 additions & 0 deletions vectordb_bench/frontend/components/run_test/caseSelector.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ def caseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, active
value=config.inputConfig["value"],
help=config.inputHelp,
)
elif config.inputType == InputType.Bool:
caseConfig[config.label] = column.checkbox(
config.displayLabel if config.displayLabel else config.label.value,
value=config.inputConfig["value"],
help=config.inputHelp,
)
k += 1
if k == 0:
columns[1].write("Auto")
Loading

0 comments on commit d11330d

Please sign in to comment.