diff --git a/.github/workflows/docker-build-push-model-server-container-on-tag.yml b/.github/workflows/docker-build-push-model-server-container-on-tag.yml new file mode 100644 index 00000000000..023147dbdca --- /dev/null +++ b/.github/workflows/docker-build-push-model-server-container-on-tag.yml @@ -0,0 +1,36 @@ +name: Build and Push Backend Images on Tagging + +on: + push: + tags: + - '*' + +jobs: + build-and-push: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + + - name: Login to Docker Hub + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_TOKEN }} + + - name: Backend Image Docker Build and Push + uses: docker/build-push-action@v2 + with: + context: ./backend + file: ./backend/Dockerfile.model_server + platforms: linux/amd64,linux/arm64 + push: true + tags: | + danswer/danswer-model-server:${{ github.ref_name }} + danswer/danswer-model-server:latest + build-args: | + DANSWER_VERSION: ${{ github.ref_name }} diff --git a/backend/Dockerfile b/backend/Dockerfile index a85f76e6ef6..4519c053599 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -45,6 +45,7 @@ RUN apt-get remove -y linux-libc-dev && \ # Set up application files WORKDIR /app COPY ./danswer /app/danswer +COPY ./shared_models /app/shared_models COPY ./alembic /app/alembic COPY ./alembic.ini /app/alembic.ini COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf diff --git a/backend/Dockerfile.model_server b/backend/Dockerfile.model_server new file mode 100644 index 00000000000..1ff4b038eec --- /dev/null +++ b/backend/Dockerfile.model_server @@ -0,0 +1,27 @@ +FROM python:3.11.4-slim-bookworm + +# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions. +ARG DANSWER_VERSION=0.2-dev +ENV DANSWER_VERSION=${DANSWER_VERSION} + +COPY ./requirements/model_server.txt /tmp/requirements.txt +RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt + +WORKDIR /app +# Needed for model configs and defaults +COPY ./danswer/configs /app/danswer/configs +# Utils used by model server +COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py +COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py +# Version information +COPY ./danswer/__init__.py /app/danswer/__init__.py +# Shared implementations for running NLP models locally +COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py +# Request/Response models +COPY ./shared_models /app/shared_models +# Model Server main code +COPY ./model_server /app/model_server + +ENV PYTHONPATH /app + +CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"] diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 4bb7a174d8c..bbb6fee3835 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -13,6 +13,7 @@ from danswer.background.indexing.job_client import SimpleJobClient from danswer.background.indexing.run_indexing import run_indexing_entrypoint from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED +from danswer.configs.app_configs import MODEL_SERVER_HOST from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.configs.model_configs import MIN_THREADS_ML_MODELS from danswer.db.connector import fetch_connectors @@ -290,7 +291,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non if __name__ == "__main__": - logger.info("Warming up Embedding Model(s)") - warm_up_models(indexer_only=True) + if not MODEL_SERVER_HOST: + logger.info("Warming up Embedding Model(s)") + warm_up_models(indexer_only=True) logger.info("Starting Indexing Loop") update_loop() diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index dca791da887..7ceeab46cb5 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -3,6 +3,7 @@ from danswer.configs.constants import AuthType from danswer.configs.constants import DocumentIndexType + ##### # App Configs ##### @@ -19,6 +20,7 @@ # Use this if you want to use Danswer as a search engine only without the LLM capabilities DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true" + ##### # Web Configs ##### @@ -56,7 +58,6 @@ if _VALID_EMAIL_DOMAINS_STR else [] ) - # OAuth Login Flow # Used for both Google OAuth2 and OIDC flows OAUTH_CLIENT_ID = ( @@ -200,12 +201,13 @@ ##### -# Encoder Model Endpoint Configs (Currently unused, running the models in memory) +# Model Server Configs ##### -BI_ENCODER_HOST = "localhost" -BI_ENCODER_PORT = 9000 -CROSS_ENCODER_HOST = "localhost" -CROSS_ENCODER_PORT = 9000 +# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via +# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value. +MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None +MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0" +MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000") ##### diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index 8aa3471f6b3..8d224aa10bb 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -1,16 +1,14 @@ -import numpy from sentence_transformers import SentenceTransformer # type: ignore from danswer.configs.app_configs import ENABLE_MINI_CHUNK from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS -from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS from danswer.indexing.chunker import split_chunk_text_into_mini_chunks from danswer.indexing.models import ChunkEmbedding from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import IndexChunk from danswer.search.models import Embedder -from danswer.search.search_nlp_models import get_default_embedding_model +from danswer.search.search_nlp_models import EmbeddingModel from danswer.utils.timing import log_function_time @@ -24,7 +22,7 @@ def encode_chunks( ) -> list[IndexChunk]: embedded_chunks: list[IndexChunk] = [] if embedding_model is None: - embedding_model = get_default_embedding_model() + embedding_model = EmbeddingModel() chunk_texts = [] chunk_mini_chunks_count = {} @@ -43,15 +41,10 @@ def encode_chunks( chunk_texts[i : i + batch_size] for i in range(0, len(chunk_texts), batch_size) ] - embeddings_np: list[numpy.ndarray] = [] + embeddings: list[list[float]] = [] for text_batch in text_batches: # Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss - embeddings_np.extend( - embedding_model.encode( - text_batch, normalize_embeddings=NORMALIZE_EMBEDDINGS - ) - ) - embeddings: list[list[float]] = [embedding.tolist() for embedding in embeddings_np] + embeddings.extend(embedding_model.encode(text_batch)) embedding_ind_start = 0 for chunk_ind, chunk in enumerate(chunks): diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 634e3fb2a63..1c84beba294 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -1,4 +1,5 @@ import nltk # type:ignore +import torch import uvicorn from fastapi import FastAPI from fastapi import Request @@ -7,6 +8,7 @@ from fastapi.responses import JSONResponse from httpx_oauth.clients.google import GoogleOAuth2 +from danswer import __version__ from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserRead from danswer.auth.schemas import UserUpdate @@ -17,6 +19,8 @@ from danswer.configs.app_configs import APP_PORT from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import DISABLE_GENERATIVE_AI +from danswer.configs.app_configs import MODEL_SERVER_HOST +from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.app_configs import OAUTH_CLIENT_ID from danswer.configs.app_configs import OAUTH_CLIENT_SECRET from danswer.configs.app_configs import SECRET @@ -72,7 +76,7 @@ def value_error_handler(_: Request, exc: ValueError) -> JSONResponse: def get_application() -> FastAPI: - application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1") + application = FastAPI(title="Danswer Backend", version=__version__) application.include_router(backend_router) application.include_router(chat_router) application.include_router(event_processing_router) @@ -176,11 +180,23 @@ def startup_event() -> None: logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"') logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"') - logger.info("Warming up local NLP models.") - warm_up_models() - qa_model = get_default_qa_model() + if MODEL_SERVER_HOST: + logger.info( + f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}" + ) + else: + logger.info("Warming up local NLP models.") + if torch.cuda.is_available(): + logger.info("GPU is available") + else: + logger.info("GPU is not available") + logger.info(f"Torch Threads: {torch.get_num_threads()}") + + warm_up_models() + # This is for the LLM, most LLMs will not need warming up - qa_model.warm_up_model() + # It logs for itself + get_default_qa_model().warm_up_model() logger.info("Verifying query preprocessing (NLTK) data is downloaded") nltk.download("stopwords", quiet=True) diff --git a/backend/danswer/search/danswer_helper.py b/backend/danswer/search/danswer_helper.py index 7ac088d6d3b..216375bc5c7 100644 --- a/backend/danswer/search/danswer_helper.py +++ b/backend/danswer/search/danswer_helper.py @@ -1,12 +1,9 @@ -import numpy as np -import tensorflow as tf # type:ignore from transformers import AutoTokenizer # type:ignore from danswer.search.models import QueryFlow from danswer.search.models import SearchType -from danswer.search.search_nlp_models import get_default_intent_model -from danswer.search.search_nlp_models import get_default_intent_model_tokenizer from danswer.search.search_nlp_models import get_default_tokenizer +from danswer.search.search_nlp_models import IntentModel from danswer.search.search_runner import remove_stop_words from danswer.server.models import HelperResponse from danswer.utils.logger import setup_logger @@ -28,15 +25,11 @@ def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int: @log_function_time() def query_intent(query: str) -> tuple[SearchType, QueryFlow]: - tokenizer = get_default_intent_model_tokenizer() - intent_model = get_default_intent_model() - model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True) - - predictions = intent_model(model_input)[0] - probabilities = tf.nn.softmax(predictions, axis=-1) - class_percentages = np.round(probabilities.numpy() * 100, 2) - - keyword, semantic, qa = class_percentages.tolist()[0] + intent_model = IntentModel() + class_probs = intent_model.predict(query) + keyword = class_probs[0] + semantic = class_probs[1] + qa = class_probs[2] # Heavily bias towards QA, from user perspective, answering a statement is not as bad as not answering a question if qa > 20: diff --git a/backend/danswer/search/search_nlp_models.py b/backend/danswer/search/search_nlp_models.py index 023fd9a55f8..e6d8a292bd8 100644 --- a/backend/danswer/search/search_nlp_models.py +++ b/backend/danswer/search/search_nlp_models.py @@ -1,15 +1,30 @@ +import numpy as np +import requests +import tensorflow as tf # type: ignore from sentence_transformers import CrossEncoder # type: ignore from sentence_transformers import SentenceTransformer # type: ignore from transformers import AutoTokenizer # type: ignore from transformers import TFDistilBertForSequenceClassification # type: ignore +from danswer.configs.app_configs import MODEL_SERVER_HOST +from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL from danswer.configs.model_configs import INTENT_MODEL_VERSION +from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE from danswer.configs.model_configs import SKIP_RERANKING +from danswer.utils.logger import setup_logger +from shared_models.model_server_models import EmbedRequest +from shared_models.model_server_models import EmbedResponse +from shared_models.model_server_models import IntentRequest +from shared_models.model_server_models import IntentResponse +from shared_models.model_server_models import RerankRequest +from shared_models.model_server_models import RerankResponse + +logger = setup_logger() _TOKENIZER: None | AutoTokenizer = None @@ -26,39 +41,46 @@ def get_default_tokenizer() -> AutoTokenizer: return _TOKENIZER -def get_default_embedding_model() -> SentenceTransformer: +def get_local_embedding_model( + model_name: str = DOCUMENT_ENCODER_MODEL, + max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE, +) -> SentenceTransformer: global _EMBED_MODEL if _EMBED_MODEL is None: - _EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL) - _EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE + _EMBED_MODEL = SentenceTransformer(model_name) + _EMBED_MODEL.max_seq_length = max_context_length return _EMBED_MODEL -def get_default_reranking_model_ensemble() -> list[CrossEncoder]: +def get_local_reranking_model_ensemble( + model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE, + max_context_length: int = CROSS_EMBED_CONTEXT_SIZE, +) -> list[CrossEncoder]: global _RERANK_MODELS if _RERANK_MODELS is None: - _RERANK_MODELS = [ - CrossEncoder(model_name) for model_name in CROSS_ENCODER_MODEL_ENSEMBLE - ] + _RERANK_MODELS = [CrossEncoder(model_name) for model_name in model_names] for model in _RERANK_MODELS: - model.max_length = CROSS_EMBED_CONTEXT_SIZE + model.max_length = max_context_length return _RERANK_MODELS -def get_default_intent_model_tokenizer() -> AutoTokenizer: +def get_intent_model_tokenizer(model_name: str = INTENT_MODEL_VERSION) -> AutoTokenizer: global _INTENT_TOKENIZER if _INTENT_TOKENIZER is None: - _INTENT_TOKENIZER = AutoTokenizer.from_pretrained(INTENT_MODEL_VERSION) + _INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name) return _INTENT_TOKENIZER -def get_default_intent_model() -> TFDistilBertForSequenceClassification: +def get_local_intent_model( + model_name: str = INTENT_MODEL_VERSION, + max_context_length: int = QUERY_MAX_CONTEXT_SIZE, +) -> TFDistilBertForSequenceClassification: global _INTENT_MODEL if _INTENT_MODEL is None: _INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained( - INTENT_MODEL_VERSION + model_name ) - _INTENT_MODEL.max_seq_length = QUERY_MAX_CONTEXT_SIZE + _INTENT_MODEL.max_seq_length = max_context_length return _INTENT_MODEL @@ -67,20 +89,183 @@ def warm_up_models( ) -> None: warm_up_str = "Danswer is amazing" get_default_tokenizer()(warm_up_str) - get_default_embedding_model().encode(warm_up_str) + get_local_embedding_model().encode(warm_up_str) if indexer_only: return if not skip_cross_encoders: - cross_encoders = get_default_reranking_model_ensemble() + cross_encoders = get_local_reranking_model_ensemble() [ cross_encoder.predict((warm_up_str, warm_up_str)) for cross_encoder in cross_encoders ] - intent_tokenizer = get_default_intent_model_tokenizer() + intent_tokenizer = get_intent_model_tokenizer() inputs = intent_tokenizer( warm_up_str, return_tensors="tf", truncation=True, padding=True ) - get_default_intent_model()(inputs) + get_local_intent_model()(inputs) + + +class EmbeddingModel: + def __init__( + self, + model_name: str = DOCUMENT_ENCODER_MODEL, + max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE, + model_server_host: str | None = MODEL_SERVER_HOST, + model_server_port: int = MODEL_SERVER_PORT, + ) -> None: + self.model_name = model_name + self.max_seq_length = max_seq_length + self.embed_server_endpoint = ( + f"http://{model_server_host}:{model_server_port}/encoder/bi-encoder-embed" + if model_server_host + else None + ) + + def load_model(self) -> SentenceTransformer | None: + if self.embed_server_endpoint: + return None + + return get_local_embedding_model( + model_name=self.model_name, max_context_length=self.max_seq_length + ) + + def encode( + self, texts: list[str], normalize_embeddings: bool = NORMALIZE_EMBEDDINGS + ) -> list[list[float]]: + if self.embed_server_endpoint: + embed_request = EmbedRequest(texts=texts) + + try: + response = requests.post( + self.embed_server_endpoint, json=embed_request.dict() + ) + response.raise_for_status() + + return EmbedResponse(**response.json()).embeddings + except requests.RequestException as e: + logger.exception(f"Failed to get Embedding: {e}") + raise + + local_model = self.load_model() + + if local_model is None: + raise RuntimeError("Failed to load local Embedding Model") + + return local_model.encode( + texts, normalize_embeddings=normalize_embeddings + ).tolist() + + +class CrossEncoderEnsembleModel: + def __init__( + self, + model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE, + max_seq_length: int = CROSS_EMBED_CONTEXT_SIZE, + model_server_host: str | None = MODEL_SERVER_HOST, + model_server_port: int = MODEL_SERVER_PORT, + ) -> None: + self.model_names = model_names + self.max_seq_length = max_seq_length + self.rerank_server_endpoint = ( + f"http://{model_server_host}:{model_server_port}/encoder/cross-encoder-scores" + if model_server_host + else None + ) + + def load_model(self) -> list[CrossEncoder] | None: + if self.rerank_server_endpoint: + return None + + return get_local_reranking_model_ensemble( + model_names=self.model_names, max_context_length=self.max_seq_length + ) + + def predict(self, query: str, passages: list[str]) -> list[list[float]]: + if self.rerank_server_endpoint: + rerank_request = RerankRequest(query=query, documents=passages) + + try: + response = requests.post( + self.rerank_server_endpoint, json=rerank_request.dict() + ) + response.raise_for_status() + + return RerankResponse(**response.json()).scores + except requests.RequestException as e: + logger.exception(f"Failed to get Reranking Scores: {e}") + raise + + local_models = self.load_model() + + if local_models is None: + raise RuntimeError("Failed to load local Reranking Model Ensemble") + + scores = [ + cross_encoder.predict([(query, passage) for passage in passages]).tolist() # type: ignore + for cross_encoder in local_models + ] + + return scores + + +class IntentModel: + def __init__( + self, + model_name: str = INTENT_MODEL_VERSION, + max_seq_length: int = QUERY_MAX_CONTEXT_SIZE, + model_server_host: str | None = MODEL_SERVER_HOST, + model_server_port: int = MODEL_SERVER_PORT, + ) -> None: + self.model_name = model_name + self.max_seq_length = max_seq_length + self.intent_server_endpoint = ( + f"http://{model_server_host}:{model_server_port}/custom/intent-model" + if model_server_host + else None + ) + + def load_model(self) -> SentenceTransformer | None: + if self.intent_server_endpoint: + return None + + return get_local_intent_model( + model_name=self.model_name, max_context_length=self.max_seq_length + ) + + def predict( + self, + query: str, + ) -> list[float]: + if self.intent_server_endpoint: + intent_request = IntentRequest(query=query) + + try: + response = requests.post( + self.intent_server_endpoint, json=intent_request.dict() + ) + response.raise_for_status() + + return IntentResponse(**response.json()).class_probs + except requests.RequestException as e: + logger.exception(f"Failed to get Embedding: {e}") + raise + + tokenizer = get_intent_model_tokenizer() + local_model = self.load_model() + + if local_model is None: + raise RuntimeError("Failed to load local Intent Model") + + intent_model = get_local_intent_model() + model_input = tokenizer( + query, return_tensors="tf", truncation=True, padding=True + ) + + predictions = intent_model(model_input)[0] + probabilities = tf.nn.softmax(predictions, axis=-1) + class_percentages = np.round(probabilities.numpy() * 100, 2) + + return list(class_percentages.tolist()[0]) diff --git a/backend/danswer/search/search_runner.py b/backend/danswer/search/search_runner.py index ba7f1eae366..4efcf1f949e 100644 --- a/backend/danswer/search/search_runner.py +++ b/backend/danswer/search/search_runner.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from typing import cast import numpy from nltk.corpus import stopwords # type:ignore @@ -29,8 +30,8 @@ from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import SearchQuery from danswer.search.models import SearchType -from danswer.search.search_nlp_models import get_default_embedding_model -from danswer.search.search_nlp_models import get_default_reranking_model_ensemble +from danswer.search.search_nlp_models import CrossEncoderEnsembleModel +from danswer.search.search_nlp_models import EmbeddingModel from danswer.server.models import QuestionRequest from danswer.server.models import SearchDoc from danswer.utils.logger import setup_logger @@ -67,14 +68,11 @@ def embed_query( prefix: str = ASYM_QUERY_PREFIX, normalize_embeddings: bool = NORMALIZE_EMBEDDINGS, ) -> list[float]: - model = embedding_model or get_default_embedding_model() + model = embedding_model or EmbeddingModel() prefixed_query = prefix + query query_embedding = model.encode( - prefixed_query, normalize_embeddings=normalize_embeddings - ) - - if not isinstance(query_embedding, list): - query_embedding = query_embedding.tolist() + [prefixed_query], normalize_embeddings=normalize_embeddings + )[0] return query_embedding @@ -104,6 +102,31 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc return search_docs +@log_function_time() +def doc_index_retrieval( + query: SearchQuery, document_index: DocumentIndex +) -> list[InferenceChunk]: + if query.search_type == SearchType.KEYWORD: + top_chunks = document_index.keyword_retrieval( + query.query, query.filters, query.favor_recent, query.num_hits + ) + + elif query.search_type == SearchType.SEMANTIC: + top_chunks = document_index.semantic_retrieval( + query.query, query.filters, query.favor_recent, query.num_hits + ) + + elif query.search_type == SearchType.HYBRID: + top_chunks = document_index.hybrid_retrieval( + query.query, query.filters, query.favor_recent, query.num_hits + ) + + else: + raise RuntimeError("Invalid Search Flow") + + return top_chunks + + @log_function_time() def semantic_reranking( query: str, @@ -112,13 +135,13 @@ def semantic_reranking( model_min: int = CROSS_ENCODER_RANGE_MIN, model_max: int = CROSS_ENCODER_RANGE_MAX, ) -> list[InferenceChunk]: - cross_encoders = get_default_reranking_model_ensemble() - sim_scores = [ - encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore - for encoder in cross_encoders - ] + cross_encoders = CrossEncoderEnsembleModel() + passages = [chunk.content for chunk in chunks] + sim_scores_floats = cross_encoders.predict(query=query, passages=passages) - raw_sim_scores = sum(sim_scores) / len(sim_scores) + sim_scores = [numpy.array(scores) for scores in sim_scores_floats] + + raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores)) cross_models_min = numpy.min(sim_scores) @@ -270,23 +293,7 @@ def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None ] logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") - if query.search_type == SearchType.KEYWORD: - top_chunks = document_index.keyword_retrieval( - query.query, query.filters, query.favor_recent, query.num_hits - ) - - elif query.search_type == SearchType.SEMANTIC: - top_chunks = document_index.semantic_retrieval( - query.query, query.filters, query.favor_recent, query.num_hits - ) - - elif query.search_type == SearchType.HYBRID: - top_chunks = document_index.hybrid_retrieval( - query.query, query.filters, query.favor_recent, query.num_hits - ) - - else: - raise RuntimeError("Invalid Search Flow") + top_chunks = doc_index_retrieval(query=query, document_index=document_index) if not top_chunks: logger.info( diff --git a/backend/model_server/__init__.py b/backend/model_server/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/model_server/custom_models.py b/backend/model_server/custom_models.py new file mode 100644 index 00000000000..fd2f3c964f9 --- /dev/null +++ b/backend/model_server/custom_models.py @@ -0,0 +1,40 @@ +import numpy as np +import tensorflow as tf # type:ignore +from fastapi import APIRouter + +from danswer.search.search_nlp_models import get_intent_model_tokenizer +from danswer.search.search_nlp_models import get_local_intent_model +from danswer.utils.timing import log_function_time +from shared_models.model_server_models import IntentRequest +from shared_models.model_server_models import IntentResponse + +router = APIRouter(prefix="/custom") + + +@log_function_time() +def classify_intent(query: str) -> list[float]: + tokenizer = get_intent_model_tokenizer() + intent_model = get_local_intent_model() + model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True) + + predictions = intent_model(model_input)[0] + probabilities = tf.nn.softmax(predictions, axis=-1) + + class_percentages = np.round(probabilities.numpy() * 100, 2) + return list(class_percentages.tolist()[0]) + + +@router.post("/intent-model") +def process_intent_request( + intent_request: IntentRequest, +) -> IntentResponse: + class_percentages = classify_intent(intent_request.query) + return IntentResponse(class_probs=class_percentages) + + +def warm_up_intent_model() -> None: + intent_tokenizer = get_intent_model_tokenizer() + inputs = intent_tokenizer( + "danswer", return_tensors="tf", truncation=True, padding=True + ) + get_local_intent_model()(inputs) diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py new file mode 100644 index 00000000000..d7d7bdc6d81 --- /dev/null +++ b/backend/model_server/encoders.py @@ -0,0 +1,81 @@ +from fastapi import APIRouter +from fastapi import HTTPException + +from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE +from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL +from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS +from danswer.search.search_nlp_models import get_local_embedding_model +from danswer.search.search_nlp_models import get_local_reranking_model_ensemble +from danswer.utils.logger import setup_logger +from danswer.utils.timing import log_function_time +from shared_models.model_server_models import EmbedRequest +from shared_models.model_server_models import EmbedResponse +from shared_models.model_server_models import RerankRequest +from shared_models.model_server_models import RerankResponse + +logger = setup_logger() + +WARM_UP_STRING = "Danswer is amazing" + +router = APIRouter(prefix="/encoder") + + +@log_function_time() +def embed_text( + texts: list[str], + normalize_embeddings: bool = NORMALIZE_EMBEDDINGS, +) -> list[list[float]]: + model = get_local_embedding_model() + embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings) + + if not isinstance(embeddings, list): + embeddings = embeddings.tolist() + + return embeddings + + +@log_function_time() +def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]: + cross_encoders = get_local_reranking_model_ensemble() + sim_scores = [ + encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore + for encoder in cross_encoders + ] + return sim_scores + + +@router.post("/bi-encoder-embed") +def process_embed_request( + embed_request: EmbedRequest, +) -> EmbedResponse: + try: + embeddings = embed_text(texts=embed_request.texts) + return EmbedResponse(embeddings=embeddings) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/cross-encoder-scores") +def process_rerank_request(embed_request: RerankRequest) -> RerankResponse: + try: + sim_scores = calc_sim_scores( + query=embed_request.query, docs=embed_request.documents + ) + return RerankResponse(scores=sim_scores) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +def warm_up_bi_encoder() -> None: + logger.info(f"Warming up Bi-Encoders: {DOCUMENT_ENCODER_MODEL}") + get_local_embedding_model().encode(WARM_UP_STRING) + + +def warm_up_cross_encoders() -> None: + logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}") + + cross_encoders = get_local_reranking_model_ensemble() + [ + cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING)) + for cross_encoder in cross_encoders + ] diff --git a/backend/model_server/main.py b/backend/model_server/main.py new file mode 100644 index 00000000000..3b7ed5747c6 --- /dev/null +++ b/backend/model_server/main.py @@ -0,0 +1,51 @@ +import torch +import uvicorn +from fastapi import FastAPI + +from danswer import __version__ +from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST +from danswer.configs.app_configs import MODEL_SERVER_PORT +from danswer.configs.model_configs import MIN_THREADS_ML_MODELS +from danswer.utils.logger import setup_logger +from model_server.custom_models import router as custom_models_router +from model_server.custom_models import warm_up_intent_model +from model_server.encoders import router as encoders_router +from model_server.encoders import warm_up_bi_encoder +from model_server.encoders import warm_up_cross_encoders + + +logger = setup_logger() + + +def get_model_app() -> FastAPI: + application = FastAPI(title="Danswer Model Server", version=__version__) + + application.include_router(encoders_router) + application.include_router(custom_models_router) + + @application.on_event("startup") + def startup_event() -> None: + if torch.cuda.is_available(): + logger.info("GPU is available") + else: + logger.info("GPU is not available") + + torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads())) + logger.info(f"Torch Threads: {torch.get_num_threads()}") + + warm_up_bi_encoder() + warm_up_cross_encoders() + warm_up_intent_model() + + return application + + +app = get_model_app() + + +if __name__ == "__main__": + logger.info( + f"Starting Danswer Model Server on http://{MODEL_SERVER_ALLOWED_HOST}:{str(MODEL_SERVER_PORT)}/" + ) + logger.info(f"Model Server Version: {__version__}") + uvicorn.run(app, host=MODEL_SERVER_ALLOWED_HOST, port=MODEL_SERVER_PORT) diff --git a/backend/requirements/model_server.txt b/backend/requirements/model_server.txt new file mode 100644 index 00000000000..1b151a25b72 --- /dev/null +++ b/backend/requirements/model_server.txt @@ -0,0 +1,8 @@ +fastapi==0.103.0 +pydantic==1.10.7 +safetensors==0.3.1 +sentence-transformers==2.2.2 +tensorflow==2.13.0 +torch==2.0.1 +transformers==4.30.1 +uvicorn==0.21.1 diff --git a/backend/shared_models/__init__.py b/backend/shared_models/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/shared_models/model_server_models.py b/backend/shared_models/model_server_models.py new file mode 100644 index 00000000000..263b2b1f5ea --- /dev/null +++ b/backend/shared_models/model_server_models.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel + + +class EmbedRequest(BaseModel): + texts: list[str] + + +class EmbedResponse(BaseModel): + embeddings: list[list[float]] + + +class RerankRequest(BaseModel): + query: str + documents: list[str] + + +class RerankResponse(BaseModel): + scores: list[list[float]] + + +class IntentRequest(BaseModel): + query: str + + +class IntentResponse(BaseModel): + class_probs: list[float] diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index cf455ae2c05..8f9f6b22185 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -42,6 +42,8 @@ services: - SKIP_RERANKING=${SKIP_RERANKING:-} - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} + - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-} + - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} volumes: @@ -94,6 +96,8 @@ services: - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} + - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-} + - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} volumes: @@ -157,6 +161,25 @@ services: /bin/sh -c "sleep 10 && envsubst '$$\{DOMAIN\}' < /etc/nginx/conf.d/app.conf.template.dev > /etc/nginx/conf.d/app.conf && while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\"" + # Run with --profile model-server to bring up the danswer-model-server container + model_server: + image: danswer/danswer-model-server:latest + build: + context: ../../backend + dockerfile: Dockerfile.model_server + profiles: + - "model-server" + command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000 + restart: always + environment: + - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} + - NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-} + - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} + # Set to debug to get more fine-grained logs + - LOG_LEVEL=${LOG_LEVEL:-info} + volumes: + - model_cache_torch:/root/.cache/torch/ + - model_cache_huggingface:/root/.cache/huggingface/ volumes: local_dynamic_storage: file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them diff --git a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml index 6a892f35c76..43f9f7adacb 100644 --- a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml +++ b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml @@ -101,6 +101,25 @@ services: while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\"" env_file: - .env.nginx + # Run with --profile model-server to bring up the danswer-model-server container + model_server: + image: danswer/danswer-model-server:latest + build: + context: ../../backend + dockerfile: Dockerfile.model_server + profiles: + - "model-server" + command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000 + restart: always + environment: + - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} + - NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-} + - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} + # Set to debug to get more fine-grained logs + - LOG_LEVEL=${LOG_LEVEL:-info} + volumes: + - model_cache_torch:/root/.cache/torch/ + - model_cache_huggingface:/root/.cache/huggingface/ volumes: local_dynamic_storage: file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them diff --git a/deployment/docker_compose/docker-compose.prod.yml b/deployment/docker_compose/docker-compose.prod.yml index 973effbc105..1caf7a52328 100644 --- a/deployment/docker_compose/docker-compose.prod.yml +++ b/deployment/docker_compose/docker-compose.prod.yml @@ -110,6 +110,25 @@ services: - ../data/certbot/conf:/etc/letsencrypt - ../data/certbot/www:/var/www/certbot entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'" + # Run with --profile model-server to bring up the danswer-model-server container + model_server: + image: danswer/danswer-model-server:latest + build: + context: ../../backend + dockerfile: Dockerfile.model_server + profiles: + - "model-server" + command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000 + restart: always + environment: + - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} + - NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-} + - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} + # Set to debug to get more fine-grained logs + - LOG_LEVEL=${LOG_LEVEL:-info} + volumes: + - model_cache_torch:/root/.cache/torch/ + - model_cache_huggingface:/root/.cache/huggingface/ volumes: local_dynamic_storage: file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them