Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenliang123 committed Dec 25, 2024
1 parent 7f7b000 commit fb7cc26
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import sqlalchemy
from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import JSON, TEXT, Column, DateTime, Float, Integer, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
from sqlalchemy.sql.expression import bindparam

from configs import dify_config
from core.rag.datasource.vdb.field import Field
Expand Down Expand Up @@ -196,11 +197,11 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc

docs = []
if self._distance_func == "l2":
tidb_dist_func = "Vec_l2_distance"
tidb_dist_func = "VEC_L2_DISTANCE"
elif self._distance_func == "cosine":
tidb_dist_func = "Vec_Cosine_distance"
tidb_dist_func = "VEC_COSINE_DISTANCE"
else:
tidb_dist_func = "Vec_Cosine_distance"
tidb_dist_func = "VEC_COSINE_DISTANCE"

with Session(self._engine) as session:
select_statement = sql_text(f"""
Expand All @@ -212,7 +213,12 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
WHERE distance <= :distance
ORDER BY distance
LIMIT :top_k
""")
""").bindparams(
bindparam("query_vector_str", type_=String),
bindparam("table_name", type_=String),
bindparam("top_k", type_=Integer),
bindparam("distance", type_=Float),
)
res = session.execute(
statement=select_statement,
params={
Expand Down

0 comments on commit fb7cc26

Please sign in to comment.