Skip to content

Commit

Permalink
cleaned up somewhat
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 committed Oct 31, 2023
1 parent 6624b18 commit e1e048f
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 145 deletions.
26 changes: 0 additions & 26 deletions backend/danswer/configs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,6 @@ class DocumentIndexType(str, Enum):
SPLIT = "split" # Typesense + Qdrant


class DanswerGenAIModel(str, Enum):
"""This represents the internal Danswer GenAI model which determines the class that is used
to generate responses to the user query. Different models/services require different internal
handling, this allows for modularity of implementation within Danswer"""

OPENAI = "openai-completion"
OPENAI_CHAT = "openai-chat-completion"
GPT4ALL = "gpt4all-completion"
GPT4ALL_CHAT = "gpt4all-chat-completion"
HUGGINGFACE = "huggingface-client-completion"
HUGGINGFACE_CHAT = "huggingface-client-chat-completion"
REQUEST = "request-completion"
TRANSFORMERS = "transformers"


class AuthType(str, Enum):
DISABLED = "disabled"
BASIC = "basic"
Expand All @@ -100,17 +85,6 @@ class AuthType(str, Enum):
SAML = "saml"


class ModelHostType(str, Enum):
"""For GenAI models interfaced via requests, different services have different
expectations for what fields are included in the request"""

# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API
# https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183
COLAB_DEMO = "colab-demo"
# TODO support for Azure, AWS, GCP GenAI model hosting


class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
DISLIKE = "dislike" # User dislikes the answer, used for metrics
Expand Down
45 changes: 17 additions & 28 deletions backend/danswer/configs/model_configs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import os

from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType


#####
# Embedding/Reranking Model Configs
#####
Expand Down Expand Up @@ -62,36 +58,29 @@
#####
# Generative AI Model Configs
#####
# Sets the internal Danswer model class to use
INTERNAL_MODEL_VERSION = os.environ.get(
"INTERNAL_MODEL_VERSION", DanswerGenAIModel.OPENAI_CHAT.value
)

# If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))
# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default,
# be sure to use one that is LiteLLM compatible:
# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables
# The provider is the prefix before / in the model argument

# Additionally Danswer supports GPT4All and custom request library based models
# TODO re-enable GPT4ALL and request models
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo"

# TODO fix this
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo")
# If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = (
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None
)

# If the Generative Model is hosted to accept requests (DanswerGenAIModel.REQUEST) then
# set the two below to specify
# - Where to hit the endpoint
# - How should the request be formed
GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "")
GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value)
# API Base, such as (for Azure): https://danswer.openai.azure.com/
GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT") or None
# API Version, such as (for Azure): 2023-09-15-preview
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None

# Set this to be enough for an answer + quotes. Also used for Chat
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
# This next restriction is only used for chat ATM, used to expire old messages as needed
GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)

#####
# OpenAI Azure
#####
# TODO CHECK ALL THESE, MAKE SURE STILL USEFUL
API_BASE_OPENAI = os.environ.get("API_BASE_OPENAI", "")
API_TYPE_OPENAI = os.environ.get("API_TYPE_OPENAI", "").lower()
API_VERSION_OPENAI = os.environ.get("API_VERSION_OPENAI", "")
# Deployment ID used interchangeably with "engine" parameter
AZURE_DEPLOYMENT_ID = os.environ.get("AZURE_DEPLOYMENT_ID", "")
23 changes: 7 additions & 16 deletions backend/danswer/direct_qa/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
from openai.error import AuthenticationError

from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_block import QABlock
from danswer.direct_qa.qa_block import QAHandler
from danswer.direct_qa.qa_block import SimpleChatQAHandler
from danswer.direct_qa.qa_block import SingleMessageQAHandler
from danswer.direct_qa.qa_block import SingleMessageScratchpadHandler
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
Expand Down Expand Up @@ -39,20 +36,16 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
return False


def get_default_qa_handler(model: str, real_time_flow: bool = True) -> QAHandler:
# TODO update this
if model == DanswerGenAIModel.OPENAI_CHAT.value:
return (
SingleMessageQAHandler()
if real_time_flow
else SingleMessageScratchpadHandler()
)
# TODO introduce the prompt choice parameter
def get_default_qa_handler(real_time_flow: bool = True) -> QAHandler:
return (
SingleMessageQAHandler() if real_time_flow else SingleMessageScratchpadHandler()
)

return SimpleChatQAHandler()
# return SimpleChatQAHandler()


def get_default_qa_model(
internal_model: str = INTERNAL_MODEL_VERSION,
api_key: str | None = GEN_AI_API_KEY,
timeout: int = QA_TIMEOUT,
real_time_flow: bool = True,
Expand All @@ -67,9 +60,7 @@ def get_default_qa_model(
# un-used arguments will be ignored by the underlying `LLM` class
# if any args are missing, a `TypeError` will be thrown
llm = get_default_llm(timeout=timeout)
qa_handler = get_default_qa_handler(
model=internal_model, real_time_flow=real_time_flow
)
qa_handler = get_default_qa_handler(real_time_flow=real_time_flow)

return QABlock(
llm=llm,
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/direct_qa/qa_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.indexing.models import InferenceChunk
from danswer.llm.llm import LLM
from danswer.llm.interfaces import LLM
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import get_default_llm_tokenizer
Expand Down
41 changes: 9 additions & 32 deletions backend/danswer/llm/build.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
from collections.abc import Mapping
from typing import Any

from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.llm.google_colab_demo import GoogleColabDemo
from danswer.llm.interfaces import LLM
from danswer.llm.multi_llm import DefaultMultiLLM

Expand All @@ -20,27 +8,16 @@ def get_default_llm(
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
) -> LLM:
"""NOTE: api_key/timeout must be a special args since we may want to check
if an API key is valid for the default model setup OR we may want to use the
default model with a different timeout specified."""
"""A single place to fetch the configured LLM for Danswer
Also allows overriding certain LLM defaults"""
if api_key is None:
api_key = get_gen_ai_api_key()

model_args: Mapping[str, Any] = {
# provide a dummy key since LangChain will throw an exception if not
# given, which would prevent server startup
"api_key": api_key or "dummy_api_key",
"timeout": timeout,
"model_version": GEN_AI_MODEL_VERSION,
"endpoint": GEN_AI_ENDPOINT,
"max_output_tokens": GEN_AI_MAX_OUTPUT_TOKENS,
"temperature": GEN_AI_TEMPERATURE,
}

if (
INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value
and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO
):
return GoogleColabDemo(**model_args) # type: ignore
# TODO rework
# if (
# INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value
# and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO
# ):
# return GoogleColabDemo(**model_args) # type: ignore

return DefaultMultiLLM(**model_args)
return DefaultMultiLLM(api_key=api_key, timeout=timeout)
42 changes: 30 additions & 12 deletions backend/danswer/llm/multi_llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections.abc import Mapping
from typing import Any
from typing import cast

import litellm # type:ignore
from langchain.chat_models import ChatLiteLLM

from danswer.configs.model_configs import GEN_AI_API_VERSION
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.interfaces import LangChainChatLLM
from danswer.llm.utils import should_be_verbose
Expand All @@ -16,6 +17,22 @@
litellm.telemetry = False


def _get_model_str(
model_provider: str | None,
model_version: str | None,
) -> str:
if model_provider and model_version:
return model_provider + "/" + model_version

if model_version:
# Litellm defaults to openai if no provider specified
# It's implicit so no need to specify here either
return model_version

# User specified something wrong, just use Danswer default
return GEN_AI_MODEL_VERSION


class DefaultMultiLLM(LangChainChatLLM):
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
See https://python.langchain.com/docs/integrations/chat/litellm"""
Expand All @@ -27,22 +44,23 @@ class DefaultMultiLLM(LangChainChatLLM):

def __init__(
self,
api_key: str,
max_output_tokens: int,
api_key: str | None,
timeout: int,
model_version: str,
model_provider: str | None = GEN_AI_MODEL_PROVIDER,
model_version: str | None = GEN_AI_MODEL_VERSION,
api_base: str | None = GEN_AI_ENDPOINT,
api_version: str | None = GEN_AI_API_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
temperature: float = GEN_AI_TEMPERATURE,
*args: list[Any],
**kwargs: Mapping[str, Any]
):
# Litellm Langchain integration currently doesn't take in the api key param
# Can place this in the call below once integration is in
litellm.api_key = api_key
litellm.api_version = api_version

self._llm = ChatLiteLLM( # type: ignore
model=model_version,
# Prefer using None which is the default value, endpoint could be empty string
api_base=cast(str, kwargs.get("endpoint")) or None,
model=_get_model_str(model_provider, model_version),
api_base=api_base,
max_tokens=max_output_tokens,
temperature=temperature,
request_timeout=timeout,
Expand Down
14 changes: 6 additions & 8 deletions backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,9 @@
from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
from danswer.configs.model_configs import API_BASE_OPENAI
from danswer.configs.model_configs import API_TYPE_OPENAI
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.db.credentials import create_initial_public_credential
from danswer.direct_qa.llm_utils import get_default_qa_model
Expand Down Expand Up @@ -155,10 +151,12 @@ def startup_event() -> None:
if DISABLE_GENERATIVE_AI:
logger.info("Generative AI Q&A disabled")
else:
logger.info(f"Using Internal Model: {INTERNAL_MODEL_VERSION}")
logger.info(f"Actual LLM model version: {GEN_AI_MODEL_VERSION}")
if API_TYPE_OPENAI == "azure":
logger.info(f"Using Azure OpenAI with Endpoint: {API_BASE_OPENAI}")
pass
# TODO rework
# logger.info(f"Using Internal Model: {INTERNAL_MODEL_VERSION}")
# logger.info(f"Actual LLM model version: {GEN_AI_MODEL_VERSION}")
# if API_TYPE_OPENAI == "azure":
# logger.info(f"Using Azure OpenAI with Endpoint: {API_BASE_OPENAI}")

verify_auth = fetch_versioned_implementation(
"danswer.auth.users", "verify_auth_setting"
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/server/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from danswer.db.models import User
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.open_ai import get_gen_ai_api_key
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.server.models import ApiKey
Expand Down
14 changes: 2 additions & 12 deletions deployment/docker_compose/docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ services:
ports:
- "8080:8080"
environment:
- INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion}
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
- POSTGRES_HOST=relational_db
- VESPA_HOST=index
Expand All @@ -30,10 +29,6 @@ services:
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
- API_BASE_OPENAI=${API_BASE_OPENAI:-}
- API_TYPE_OPENAI=${API_TYPE_OPENAI:-}
- API_VERSION_OPENAI=${API_VERSION_OPENAI:-}
- AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-}
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
- DISABLE_TIME_FILTER_EXTRACTION=${DISABLE_TIME_FILTER_EXTRACTION:-}
# Don't change the NLP model configs unless you know what you're doing
Expand Down Expand Up @@ -63,17 +58,12 @@ services:
- index
restart: always
environment:
- INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion}
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
- POSTGRES_HOST=relational_db
- VESPA_HOST=index
- API_BASE_OPENAI=${API_BASE_OPENAI:-}
- API_TYPE_OPENAI=${API_TYPE_OPENAI:-}
- API_VERSION_OPENAI=${API_VERSION_OPENAI:-}
- AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-}
- NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-}
# Connector Configs
- CONTINUE_ON_CONNECTOR_FAILURE=${CONTINUE_ON_CONNECTOR_FAILURE:-}
Expand Down
11 changes: 2 additions & 9 deletions deployment/docker_compose/env.prod.template
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,9 @@
# Insert your OpenAI API key here If not provided here, UI will prompt on setup.
# This env variable takes precedence over UI settings.
GEN_AI_API_KEY=
# Choose between "openai-chat-completion" and "openai-completion"
INTERNAL_MODEL_VERSION=openai-chat-completion
# Use a valid model for the choice above, consult https://platform.openai.com/docs/models/model-endpoint-compatibility
GEN_AI_MODEL_VERSION=gpt-4

# Neccessary environment variables for Azure OpenAI:
API_BASE_OPENAI=
API_TYPE_OPENAI=
API_VERSION_OPENAI=
AZURE_DEPLOYMENT_ID=
GEN_AI_MODEL_PROVIDER=openai
GEN_AI_MODEL_VERSION=gpt-4

# Could be something like danswer.companyname.com
WEB_DOMAIN=http://localhost:3000
Expand Down

0 comments on commit e1e048f

Please sign in to comment.