Skip to content

Commit

Permalink
vert
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenliang123 committed Dec 25, 2024
1 parent bb022c1 commit c508489
Showing 1 changed file with 9 additions and 27 deletions.
36 changes: 9 additions & 27 deletions api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

import sqlalchemy
from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, Float, Integer, String, Table, create_engine, insert
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
from sqlalchemy.sql.expression import bindparam

from configs import dify_config
from core.rag.datasource.vdb.field import Field
Expand Down Expand Up @@ -204,32 +203,15 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
tidb_dist_func = "VEC_COSINE_DISTANCE"

with Session(self._engine) as session:
select_statement = sql_text("""
SELECT
meta,
text,
:tidb_dist_func(vector, :query_vector_str) AS distance
FROM :table_name
WHERE distance <= :distance
ORDER BY distance
LIMIT :top_k
""").bindparams(
bindparam("query_vector_str", type_=String),
bindparam("tidb_dist_func", type_=String),
bindparam("table_name", type_=String),
bindparam("top_k", type_=Integer),
bindparam("distance", type_=Float),
)
res = session.execute(
statement=select_statement,
params={
"query_vector_str": query_vector_str,
"tidb_dist_func": tidb_dist_func,
"table_name": self._collection_name,
"top_k": top_k,
"distance": distance,
},
select_statement = sql_text(
f"""SELECT meta, text, distance FROM (
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
FROM {self._collection_name}
ORDER BY distance
LIMIT {top_k}
) t WHERE distance < {distance};"""
)
res = session.execute(select_statement)
results = [(row[0], row[1], row[2]) for row in res]
for meta, text, distance in results:
metadata = json.loads(meta)
Expand Down

0 comments on commit c508489

Please sign in to comment.