Skip to content

Commit

Permalink
Reenable option to run Danswer without Gen AI (#906)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 authored Jan 4, 2024
1 parent 20441df commit 6b6b3da
Show file tree
Hide file tree
Showing 20 changed files with 181 additions and 43 deletions.
18 changes: 16 additions & 2 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import CHUNK_SIZE
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.configs.constants import MessageType
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
Expand All @@ -36,6 +37,7 @@
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_token_encode
Expand All @@ -61,10 +63,18 @@ def generate_ai_chat_response(
history: list[ChatMessage],
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
llm: LLM,
llm: LLM | None,
llm_tokenizer: Callable,
all_doc_useful: bool,
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
if llm is None:
try:
llm = get_default_llm()
except GenAIDisabledException:
# Not an error if it's a user configuration
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
return

if query_message.prompt is None:
raise RuntimeError("No prompt received for generating Gen AI answer.")

Expand Down Expand Up @@ -171,7 +181,11 @@ def stream_chat_message(
"Must specify a set of documents for chat or specify search options"
)

llm = get_default_llm()
try:
llm = get_default_llm()
except GenAIDisabledException:
llm = None

llm_tokenizer = get_default_llm_token_encode()
document_index = get_default_document_index()

Expand Down
1 change: 0 additions & 1 deletion backend/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#####
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
# CURRENTLY DOES NOT FULLY WORK, DON'T USE THIS
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"


Expand Down
8 changes: 8 additions & 0 deletions backend/danswer/configs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@
INDEX_SEPARATOR = "==="


# Messages
DISABLED_GEN_AI_MSG = (
"Your System Admin has disabled the Generative AI functionalities of Danswer.\n"
"Please contact them if you wish to have this enabled.\n"
"You can still use Danswer as a search engine."
)


class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
INGESTION_API = "ingestion_api"
Expand Down
9 changes: 8 additions & 1 deletion backend/danswer/danswerbot/slack/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from slack_sdk.models.blocks import SectionBlock

from danswer.chat.models import DanswerQuote
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
Expand Down Expand Up @@ -106,8 +107,11 @@ def build_documents_blocks(
message_id: int | None,
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
) -> list[Block]:
header_text = (
"Retrieved Documents" if DISABLE_GENERATIVE_AI else "Reference Documents"
)
seen_docs_identifiers = set()
section_blocks: list[Block] = [HeaderBlock(text="Reference Documents")]
section_blocks: list[Block] = [HeaderBlock(text=header_text)]
included_docs = 0
for rank, d in enumerate(documents):
if d.document_id in seen_docs_identifiers:
Expand Down Expand Up @@ -208,6 +212,9 @@ def build_qa_response_blocks(
favor_recent: bool,
skip_quotes: bool = False,
) -> list[Block]:
if DISABLE_GENERATIVE_AI:
return []

quotes_blocks: list[Block] = []

ai_answer_header = HeaderBlock(text="AI Answer")
Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/llm/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class GenAIDisabledException(Exception):
def __init__(self, message: str = "Generative AI has been turned off") -> None:
self.message = message
super().__init__(self.message)
5 changes: 5 additions & 0 deletions backend/danswer/llm/factory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.custom_llm import CustomModelServer
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.gpt_4_all import DanswerGPT4All
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_gen_ai_api_key
Expand All @@ -18,6 +20,9 @@ def get_default_llm(
) -> LLM:
"""A single place to fetch the configured LLM for Danswer
Also allows overriding certain LLM defaults"""
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()

if gen_ai_model_version_override:
model_version = gen_ai_model_version_override
else:
Expand Down
6 changes: 3 additions & 3 deletions backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ def startup_event() -> None:
if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")

# Any additional model configs logged here
get_default_llm().log_model_configs()

if MULTILINGUAL_QUERY_EXPANSION:
logger.info(
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
Expand Down Expand Up @@ -258,9 +261,6 @@ def startup_event() -> None:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")

# This is for the LLM, most LLMs will not need warming up
get_default_llm().log_model_configs()

logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
Expand Down
21 changes: 15 additions & 6 deletions backend/danswer/one_shot_answer/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase
from danswer.one_shot_answer.qa_block import no_gen_ai_response
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
Expand Down Expand Up @@ -191,8 +192,12 @@ def stream_answer_objects(
llm_version=llm_override,
)

full_prompt_str = qa_model.build_prompt(
query=query_msg.message, history_str=history_str, context_chunks=llm_chunks
full_prompt_str = (
qa_model.build_prompt(
query=query_msg.message, history_str=history_str, context_chunks=llm_chunks
)
if qa_model is not None
else "Gen AI Disabled"
)

# Create the first User query message
Expand All @@ -207,10 +212,14 @@ def stream_answer_objects(
commit=True,
)

response_packets = qa_model.answer_question_stream(
prompt=full_prompt_str,
llm_context_docs=llm_chunks,
metrics_callback=llm_metrics_callback,
response_packets = (
qa_model.answer_question_stream(
prompt=full_prompt_str,
llm_context_docs=llm_chunks,
metrics_callback=llm_metrics_callback,
)
if qa_model is not None
else no_gen_ai_response()
)

# Capture outputs and errors
Expand Down
16 changes: 10 additions & 6 deletions backend/danswer/one_shot_answer/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.db.models import Prompt
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.one_shot_answer.interfaces import QAModel
from danswer.one_shot_answer.qa_block import QABlock
Expand All @@ -19,18 +20,21 @@ def get_question_answer_model(
chain_of_thought: bool = False,
llm_version: str | None = None,
qa_model_version: str | None = QA_PROMPT_OVERRIDE,
) -> QAModel:
) -> QAModel | None:
if chain_of_thought:
raise NotImplementedError("COT has been disabled")

system_prompt = prompt.system_prompt if prompt is not None else None
task_prompt = prompt.task_prompt if prompt is not None else None

llm = get_default_llm(
api_key=api_key,
timeout=timeout,
gen_ai_model_version_override=llm_version,
)
try:
llm = get_default_llm(
api_key=api_key,
timeout=timeout,
gen_ai_model_version_override=llm_version,
)
except GenAIDisabledException:
return None

if qa_model_version == "weak":
qa_handler: QAHandler = WeakLLMQAHandler(
Expand Down
5 changes: 5 additions & 0 deletions backend/danswer/one_shot_answer/qa_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from danswer.chat.models import LLMMetricsContainer
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.indexing.models import InferenceChunk
from danswer.llm.interfaces import LLM
from danswer.llm.utils import check_number_of_tokens
Expand Down Expand Up @@ -252,6 +253,10 @@ def build_dummy_prompt(
).strip()


def no_gen_ai_response() -> Iterator[DanswerAnswerPiece]:
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)


class QABlock(QAModel):
def __init__(self, llm: LLM, qa_handler: QAHandler) -> None:
self._llm = llm
Expand Down
8 changes: 7 additions & 1 deletion backend/danswer/secondary_llm_flows/answer_validation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.answer_validation import ANSWER_VALIDITY_PROMPT
Expand Down Expand Up @@ -41,12 +42,17 @@ def _extract_validity(model_output: str) -> bool:
return False
return True # If something is wrong, let's not toss away the answer

try:
llm = get_default_llm()
except GenAIDisabledException:
return True

if not answer:
return False

messages = _get_answer_validation_messages(query, answer)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)
model_output = llm.invoke(filled_llm_prompt)
logger.debug(model_output)

validity = _extract_validity(model_output)
Expand Down
8 changes: 7 additions & 1 deletion backend/danswer/secondary_llm_flows/chat_session_naming.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from danswer.chat.chat_utils import combine_message_chain
from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
Expand All @@ -23,7 +24,12 @@ def get_chat_rename_messages(history_str: str) -> list[dict[str, str]]:
return messages

if llm is None:
llm = get_default_llm()
try:
llm = get_default_llm()
except GenAIDisabledException:
# This may be longer than what the LLM tends to produce but is the most
# clear thing we can do
return full_history[0].message

history_str = combine_message_chain(full_history)

Expand Down
12 changes: 9 additions & 3 deletions backend/danswer/secondary_llm_flows/choose_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
Expand Down Expand Up @@ -68,15 +69,20 @@ def _get_search_messages(
if disable_llm_check:
return True

if llm is None:
try:
llm = get_default_llm()
except GenAIDisabledException:
# If Generative AI is turned off the always run Search as Danswer is being used
# as just a search engine
return True

history_str = combine_message_chain(history)

prompt_msgs = _get_search_messages(
question=query_message.message, history_str=history_str
)

if llm is None:
llm = get_default_llm()

filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
require_search_output = llm.invoke(filled_llm_prompt)

Expand Down
13 changes: 10 additions & 3 deletions backend/danswer/secondary_llm_flows/chunk_usefulness.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Callable

from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT
Expand Down Expand Up @@ -30,14 +31,20 @@ def _extract_usefulness(model_output: str) -> bool:
return False
return True

# If Gen AI is disabled, none of the messages are more "useful" than any other
# All are marked not useful (False) so that the icon for Gen AI likes this answer
# is not shown for any result
try:
llm = get_default_llm(use_fast_llm=True, timeout=5)
except GenAIDisabledException:
return False

messages = _get_usefulness_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
# When running in a batch, it takes as long as the longest thread
# And when running a large batch, one may fail and take the whole timeout
# instead cap it to 5 seconds
model_output = get_default_llm(use_fast_llm=True, timeout=5).invoke(
filled_llm_prompt
)
model_output = llm.invoke(filled_llm_prompt)
logger.debug(model_output)

return _extract_usefulness(model_output)
Expand Down
Loading

1 comment on commit 6b6b3da

@vercel
Copy link

@vercel vercel bot commented on 6b6b3da Jan 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.