From e1e048f8c7b292f73061fa76ab4a3f4a102d6216 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Tue, 31 Oct 2023 16:02:45 -0700 Subject: [PATCH] cleaned up somewhat --- backend/danswer/configs/constants.py | 26 ----------- backend/danswer/configs/model_configs.py | 45 +++++++------------ backend/danswer/direct_qa/llm_utils.py | 23 +++------- backend/danswer/direct_qa/qa_block.py | 2 +- backend/danswer/llm/build.py | 41 ++++------------- backend/danswer/llm/multi_llm.py | 42 ++++++++++++----- backend/danswer/main.py | 14 +++--- backend/danswer/server/manage.py | 2 +- .../docker_compose/docker-compose.dev.yml | 14 +----- deployment/docker_compose/env.prod.template | 11 +---- 10 files changed, 75 insertions(+), 145 deletions(-) diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index dd22608ed97..24886c1e7ef 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -77,21 +77,6 @@ class DocumentIndexType(str, Enum): SPLIT = "split" # Typesense + Qdrant -class DanswerGenAIModel(str, Enum): - """This represents the internal Danswer GenAI model which determines the class that is used - to generate responses to the user query. Different models/services require different internal - handling, this allows for modularity of implementation within Danswer""" - - OPENAI = "openai-completion" - OPENAI_CHAT = "openai-chat-completion" - GPT4ALL = "gpt4all-completion" - GPT4ALL_CHAT = "gpt4all-chat-completion" - HUGGINGFACE = "huggingface-client-completion" - HUGGINGFACE_CHAT = "huggingface-client-chat-completion" - REQUEST = "request-completion" - TRANSFORMERS = "transformers" - - class AuthType(str, Enum): DISABLED = "disabled" BASIC = "basic" @@ -100,17 +85,6 @@ class AuthType(str, Enum): SAML = "saml" -class ModelHostType(str, Enum): - """For GenAI models interfaced via requests, different services have different - expectations for what fields are included in the request""" - - # https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task - HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API - # https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183 - COLAB_DEMO = "colab-demo" - # TODO support for Azure, AWS, GCP GenAI model hosting - - class QAFeedbackType(str, Enum): LIKE = "like" # User likes the answer, used for metrics DISLIKE = "dislike" # User dislikes the answer, used for metrics diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 7088c12e42f..6e2563bdb9e 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -1,9 +1,5 @@ import os -from danswer.configs.constants import DanswerGenAIModel -from danswer.configs.constants import ModelHostType - - ##### # Embedding/Reranking Model Configs ##### @@ -62,36 +58,29 @@ ##### # Generative AI Model Configs ##### -# Sets the internal Danswer model class to use -INTERNAL_MODEL_VERSION = os.environ.get( - "INTERNAL_MODEL_VERSION", DanswerGenAIModel.OPENAI_CHAT.value -) -# If the Generative AI model requires an API key for access, otherwise can leave blank -GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) +# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default, +# be sure to use one that is LiteLLM compatible: +# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables +# The provider is the prefix before / in the model argument + +# Additionally Danswer supports GPT4All and custom request library based models +# TODO re-enable GPT4ALL and request models +GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai" +GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo" -# TODO fix this -GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo") +# If the Generative AI model requires an API key for access, otherwise can leave blank +GEN_AI_API_KEY = ( + os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None +) -# If the Generative Model is hosted to accept requests (DanswerGenAIModel.REQUEST) then -# set the two below to specify -# - Where to hit the endpoint -# - How should the request be formed -GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "") -GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value) +# API Base, such as (for Azure): https://danswer.openai.azure.com/ +GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT") or None +# API Version, such as (for Azure): 2023-09-15-preview +GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None # Set this to be enough for an answer + quotes. Also used for Chat GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024) # This next restriction is only used for chat ATM, used to expire old messages as needed GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000) GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0) - -##### -# OpenAI Azure -##### -# TODO CHECK ALL THESE, MAKE SURE STILL USEFUL -API_BASE_OPENAI = os.environ.get("API_BASE_OPENAI", "") -API_TYPE_OPENAI = os.environ.get("API_TYPE_OPENAI", "").lower() -API_VERSION_OPENAI = os.environ.get("API_VERSION_OPENAI", "") -# Deployment ID used interchangeably with "engine" parameter -AZURE_DEPLOYMENT_ID = os.environ.get("AZURE_DEPLOYMENT_ID", "") diff --git a/backend/danswer/direct_qa/llm_utils.py b/backend/danswer/direct_qa/llm_utils.py index 091d6575798..8014e06d9a2 100644 --- a/backend/danswer/direct_qa/llm_utils.py +++ b/backend/danswer/direct_qa/llm_utils.py @@ -3,13 +3,10 @@ from openai.error import AuthenticationError from danswer.configs.app_configs import QA_TIMEOUT -from danswer.configs.constants import DanswerGenAIModel from danswer.configs.model_configs import GEN_AI_API_KEY -from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.qa_block import QABlock from danswer.direct_qa.qa_block import QAHandler -from danswer.direct_qa.qa_block import SimpleChatQAHandler from danswer.direct_qa.qa_block import SingleMessageQAHandler from danswer.direct_qa.qa_block import SingleMessageScratchpadHandler from danswer.direct_qa.qa_utils import get_gen_ai_api_key @@ -39,20 +36,16 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool: return False -def get_default_qa_handler(model: str, real_time_flow: bool = True) -> QAHandler: - # TODO update this - if model == DanswerGenAIModel.OPENAI_CHAT.value: - return ( - SingleMessageQAHandler() - if real_time_flow - else SingleMessageScratchpadHandler() - ) +# TODO introduce the prompt choice parameter +def get_default_qa_handler(real_time_flow: bool = True) -> QAHandler: + return ( + SingleMessageQAHandler() if real_time_flow else SingleMessageScratchpadHandler() + ) - return SimpleChatQAHandler() + # return SimpleChatQAHandler() def get_default_qa_model( - internal_model: str = INTERNAL_MODEL_VERSION, api_key: str | None = GEN_AI_API_KEY, timeout: int = QA_TIMEOUT, real_time_flow: bool = True, @@ -67,9 +60,7 @@ def get_default_qa_model( # un-used arguments will be ignored by the underlying `LLM` class # if any args are missing, a `TypeError` will be thrown llm = get_default_llm(timeout=timeout) - qa_handler = get_default_qa_handler( - model=internal_model, real_time_flow=real_time_flow - ) + qa_handler = get_default_qa_handler(real_time_flow=real_time_flow) return QABlock( llm=llm, diff --git a/backend/danswer/direct_qa/qa_block.py b/backend/danswer/direct_qa/qa_block.py index 4d9f5b3b00f..b96736fba5d 100644 --- a/backend/danswer/direct_qa/qa_block.py +++ b/backend/danswer/direct_qa/qa_block.py @@ -28,7 +28,7 @@ from danswer.direct_qa.qa_utils import process_answer from danswer.direct_qa.qa_utils import process_model_tokens from danswer.indexing.models import InferenceChunk -from danswer.llm.llm import LLM +from danswer.llm.interfaces import LLM from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import get_default_llm_tokenizer diff --git a/backend/danswer/llm/build.py b/backend/danswer/llm/build.py index 18259912cac..fa9032e094e 100644 --- a/backend/danswer/llm/build.py +++ b/backend/danswer/llm/build.py @@ -1,17 +1,5 @@ -from collections.abc import Mapping -from typing import Any - from danswer.configs.app_configs import QA_TIMEOUT -from danswer.configs.constants import DanswerGenAIModel -from danswer.configs.constants import ModelHostType -from danswer.configs.model_configs import GEN_AI_ENDPOINT -from danswer.configs.model_configs import GEN_AI_HOST_TYPE -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION -from danswer.configs.model_configs import GEN_AI_TEMPERATURE -from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.direct_qa.qa_utils import get_gen_ai_api_key -from danswer.llm.google_colab_demo import GoogleColabDemo from danswer.llm.interfaces import LLM from danswer.llm.multi_llm import DefaultMultiLLM @@ -20,27 +8,16 @@ def get_default_llm( api_key: str | None = None, timeout: int = QA_TIMEOUT, ) -> LLM: - """NOTE: api_key/timeout must be a special args since we may want to check - if an API key is valid for the default model setup OR we may want to use the - default model with a different timeout specified.""" + """A single place to fetch the configured LLM for Danswer + Also allows overriding certain LLM defaults""" if api_key is None: api_key = get_gen_ai_api_key() - model_args: Mapping[str, Any] = { - # provide a dummy key since LangChain will throw an exception if not - # given, which would prevent server startup - "api_key": api_key or "dummy_api_key", - "timeout": timeout, - "model_version": GEN_AI_MODEL_VERSION, - "endpoint": GEN_AI_ENDPOINT, - "max_output_tokens": GEN_AI_MAX_OUTPUT_TOKENS, - "temperature": GEN_AI_TEMPERATURE, - } - - if ( - INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value - and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO - ): - return GoogleColabDemo(**model_args) # type: ignore + # TODO rework + # if ( + # INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value + # and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO + # ): + # return GoogleColabDemo(**model_args) # type: ignore - return DefaultMultiLLM(**model_args) + return DefaultMultiLLM(api_key=api_key, timeout=timeout) diff --git a/backend/danswer/llm/multi_llm.py b/backend/danswer/llm/multi_llm.py index 9e64d8a8e17..b2cdb02b894 100644 --- a/backend/danswer/llm/multi_llm.py +++ b/backend/danswer/llm/multi_llm.py @@ -1,10 +1,11 @@ -from collections.abc import Mapping -from typing import Any -from typing import cast - import litellm # type:ignore from langchain.chat_models import ChatLiteLLM +from danswer.configs.model_configs import GEN_AI_API_VERSION +from danswer.configs.model_configs import GEN_AI_ENDPOINT +from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.llm.interfaces import LangChainChatLLM from danswer.llm.utils import should_be_verbose @@ -16,6 +17,22 @@ litellm.telemetry = False +def _get_model_str( + model_provider: str | None, + model_version: str | None, +) -> str: + if model_provider and model_version: + return model_provider + "/" + model_version + + if model_version: + # Litellm defaults to openai if no provider specified + # It's implicit so no need to specify here either + return model_version + + # User specified something wrong, just use Danswer default + return GEN_AI_MODEL_VERSION + + class DefaultMultiLLM(LangChainChatLLM): """Uses Litellm library to allow easy configuration to use a multitude of LLMs See https://python.langchain.com/docs/integrations/chat/litellm""" @@ -27,22 +44,23 @@ class DefaultMultiLLM(LangChainChatLLM): def __init__( self, - api_key: str, - max_output_tokens: int, + api_key: str | None, timeout: int, - model_version: str, + model_provider: str | None = GEN_AI_MODEL_PROVIDER, + model_version: str | None = GEN_AI_MODEL_VERSION, + api_base: str | None = GEN_AI_ENDPOINT, + api_version: str | None = GEN_AI_API_VERSION, + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, temperature: float = GEN_AI_TEMPERATURE, - *args: list[Any], - **kwargs: Mapping[str, Any] ): # Litellm Langchain integration currently doesn't take in the api key param # Can place this in the call below once integration is in litellm.api_key = api_key + litellm.api_version = api_version self._llm = ChatLiteLLM( # type: ignore - model=model_version, - # Prefer using None which is the default value, endpoint could be empty string - api_base=cast(str, kwargs.get("endpoint")) or None, + model=_get_model_str(model_provider, model_version), + api_base=api_base, max_tokens=max_output_tokens, temperature=temperature, request_timeout=timeout, diff --git a/backend/danswer/main.py b/backend/danswer/main.py index cdb621709a0..eb67c3f1d77 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -22,13 +22,9 @@ from danswer.configs.app_configs import SECRET from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import AuthType -from danswer.configs.model_configs import API_BASE_OPENAI -from danswer.configs.model_configs import API_TYPE_OPENAI from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX from danswer.configs.model_configs import ASYM_QUERY_PREFIX from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION -from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.configs.model_configs import SKIP_RERANKING from danswer.db.credentials import create_initial_public_credential from danswer.direct_qa.llm_utils import get_default_qa_model @@ -155,10 +151,12 @@ def startup_event() -> None: if DISABLE_GENERATIVE_AI: logger.info("Generative AI Q&A disabled") else: - logger.info(f"Using Internal Model: {INTERNAL_MODEL_VERSION}") - logger.info(f"Actual LLM model version: {GEN_AI_MODEL_VERSION}") - if API_TYPE_OPENAI == "azure": - logger.info(f"Using Azure OpenAI with Endpoint: {API_BASE_OPENAI}") + pass + # TODO rework + # logger.info(f"Using Internal Model: {INTERNAL_MODEL_VERSION}") + # logger.info(f"Actual LLM model version: {GEN_AI_MODEL_VERSION}") + # if API_TYPE_OPENAI == "azure": + # logger.info(f"Using Azure OpenAI with Endpoint: {API_BASE_OPENAI}") verify_auth = fetch_versioned_implementation( "danswer.auth.users", "verify_auth_setting" diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index 4108d5aed94..c496da1c6e9 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -25,7 +25,7 @@ from danswer.db.models import User from danswer.direct_qa.llm_utils import check_model_api_key_is_valid from danswer.direct_qa.llm_utils import get_default_qa_model -from danswer.direct_qa.open_ai import get_gen_ai_api_key +from danswer.direct_qa.qa_utils import get_gen_ai_api_key from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.server.models import ApiKey diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index ac2fd220fd6..e0951eac3fc 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -16,11 +16,10 @@ services: ports: - "8080:8080" environment: - - INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion} + - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai} - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-} - - GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-} - NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-} - POSTGRES_HOST=relational_db - VESPA_HOST=index @@ -30,10 +29,6 @@ services: - GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-} - GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-} - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} - - API_BASE_OPENAI=${API_BASE_OPENAI:-} - - API_TYPE_OPENAI=${API_TYPE_OPENAI:-} - - API_VERSION_OPENAI=${API_VERSION_OPENAI:-} - - AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-} - NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-} - DISABLE_TIME_FILTER_EXTRACTION=${DISABLE_TIME_FILTER_EXTRACTION:-} # Don't change the NLP model configs unless you know what you're doing @@ -63,17 +58,12 @@ services: - index restart: always environment: - - INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion} + - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai} - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-} - - GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-} - POSTGRES_HOST=relational_db - VESPA_HOST=index - - API_BASE_OPENAI=${API_BASE_OPENAI:-} - - API_TYPE_OPENAI=${API_TYPE_OPENAI:-} - - API_VERSION_OPENAI=${API_VERSION_OPENAI:-} - - AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-} - NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-} # Connector Configs - CONTINUE_ON_CONNECTOR_FAILURE=${CONTINUE_ON_CONNECTOR_FAILURE:-} diff --git a/deployment/docker_compose/env.prod.template b/deployment/docker_compose/env.prod.template index a652d7e370c..0c5819da17c 100644 --- a/deployment/docker_compose/env.prod.template +++ b/deployment/docker_compose/env.prod.template @@ -6,16 +6,9 @@ # Insert your OpenAI API key here If not provided here, UI will prompt on setup. # This env variable takes precedence over UI settings. GEN_AI_API_KEY= -# Choose between "openai-chat-completion" and "openai-completion" -INTERNAL_MODEL_VERSION=openai-chat-completion -# Use a valid model for the choice above, consult https://platform.openai.com/docs/models/model-endpoint-compatibility -GEN_AI_MODEL_VERSION=gpt-4 -# Neccessary environment variables for Azure OpenAI: -API_BASE_OPENAI= -API_TYPE_OPENAI= -API_VERSION_OPENAI= -AZURE_DEPLOYMENT_ID= +GEN_AI_MODEL_PROVIDER=openai +GEN_AI_MODEL_VERSION=gpt-4 # Could be something like danswer.companyname.com WEB_DOMAIN=http://localhost:3000