From da89b508c9625ea5a1b84e97bdf8984654227b1e Mon Sep 17 00:00:00 2001 From: gritaro Date: Thu, 15 Feb 2024 02:07:12 +0700 Subject: [PATCH] - Add support for completion models configuration - Rollback from community to official gigachain library --- custom_components/gigachain/__init__.py | 23 ++++++--- custom_components/gigachain/config_flow.py | 55 ++++++++++------------ custom_components/gigachain/const.py | 42 +++++++++++++---- custom_components/gigachain/manifest.json | 4 +- test-model.py | 17 +++++++ 5 files changed, 91 insertions(+), 50 deletions(-) create mode 100644 test-model.py diff --git a/custom_components/gigachain/__init__.py b/custom_components/gigachain/__init__.py index 89a2348..33963b4 100644 --- a/custom_components/gigachain/__init__.py +++ b/custom_components/gigachain/__init__.py @@ -1,5 +1,4 @@ """The GigaChain integration.""" -from __future__ import annotations from homeassistant.components import conversation from homeassistant.config_entries import ConfigEntry from homeassistant.const import MATCH_ALL @@ -9,6 +8,7 @@ template, ) from homeassistant.components.conversation import AgentManager, agent + from typing import Literal from langchain_community.chat_models import GigaChat, ChatYandexGPT, ChatOpenAI from langchain.schema import AIMessage, HumanMessage, SystemMessage @@ -17,7 +17,7 @@ DOMAIN, CONF_ENGINE, CONF_TEMPERATURE, - DEFAULT_CONF_TEMPERATURE, + DEFAULT_TEMPERATURE, CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL, CONF_CHAT_MODEL, @@ -38,21 +38,32 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Initialize GigaChain.""" - temperature = entry.options.get(CONF_TEMPERATURE, DEFAULT_CONF_TEMPERATURE) + temperature = entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) + model = entry.options.get(CONF_CHAT_MODEL) engine = entry.data.get(CONF_ENGINE) or "gigachat" entry.async_on_unload(entry.add_update_listener(update_listener)) if engine == 'gigachat': client = GigaChat(temperature=temperature, - model='GigaChat:latest', + model=model, verbose=True, credentials=entry.data[CONF_API_KEY], verify_ssl_certs=False) elif engine == 'yandexgpt': - client = ChatYandexGPT(temperature=temperature, + if model == "YandexGPT": + model_url = "gpt://" + entry.data[CONF_FOLDER_ID] + "/yandexgpt/latest" + elif model == 'YandexGPT Lite': + model_url = "gpt://" + entry.data[CONF_FOLDER_ID] + "/yandexgpt-lite/latest" + elif model == 'Summary': + model_url = "gpt://" + entry.data[CONF_FOLDER_ID] + "/summarization/latest" + else: + model_url = "" + client = ChatYandexGPT( + model_uri=model_url, + temperature=temperature, api_key=entry.data[CONF_API_KEY], folder_id = entry.data[CONF_FOLDER_ID]) else: - client = ChatOpenAI(model="gpt-3.5-turbo", + client = ChatOpenAI(model=model, temperature=temperature, openai_api_key=entry.data[CONF_API_KEY]) hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client diff --git a/custom_components/gigachain/config_flow.py b/custom_components/gigachain/config_flow.py index cb51e97..71906ed 100644 --- a/custom_components/gigachain/config_flow.py +++ b/custom_components/gigachain/config_flow.py @@ -26,12 +26,13 @@ CONF_TEMPERATURE, CONF_ENGINE_OPTIONS, CONF_PROMPT, - CONF_MAX_TKNS, - DEFAULT_CONF_TEMPERATURE, - DEFAULT_CONF_MAX_TKNS, + CONF_MAX_TOKENS, + DEFAULT_TEMPERATURE, + DEFAULT_MODELS, + DEFAULT_MAX_TOKENS, DEFAULT_CHAT_MODEL, DEFAULT_PROMPT, - UNIQUE_ID, + UNIQUE_ID ) STEP_USER_SCHEMA = vol.Schema( @@ -42,7 +43,7 @@ } ) -STEP_GIGACHAT_SCHEMA = vol.Schema( +STEP_API_KEY_SCHEMA = vol.Schema( { vol.Required(CONF_API_KEY): str } @@ -53,16 +54,11 @@ vol.Required(CONF_FOLDER_ID): str } ) -STEP_OPENAI_SCHEMA = vol.Schema( - { - vol.Required(CONF_API_KEY): str - } -) ENGINE_SCHEMA = { - "gigachat": STEP_GIGACHAT_SCHEMA, + "gigachat": STEP_API_KEY_SCHEMA, "yandexgpt": STEP_YANDEXGPT_SCHEMA, - "openai": STEP_OPENAI_SCHEMA + "openai": STEP_API_KEY_SCHEMA } DEFAULT_OPTIONS = types.MappingProxyType( @@ -136,47 +132,44 @@ async def async_step_init( """Manage the options.""" if user_input is not None: return self.async_create_entry(title=self.config_entry.unique_id, data=user_input) - schema = common_config_option_schema(self.config_entry.options) + schema = common_config_option_schema(self.config_entry.unique_id, self.config_entry.options) return self.async_show_form( step_id="init", data_schema=vol.Schema(schema), ) -def common_config_option_schema(options: MappingProxyType[str, Any]) -> dict: +def common_config_option_schema(unique_id: str, options: MappingProxyType[str, Any]) -> dict: """Return a schema for GigaChain completion options.""" if not options: options = DEFAULT_OPTIONS return { + vol.Optional( + CONF_CHAT_MODEL, + description={ + "suggested_value": options.get(CONF_CHAT_MODEL), + }, default="none", + ): selector.SelectSelector( + selector.SelectSelectorConfig(options=DEFAULT_MODELS[unique_id]), + ), vol.Optional( CONF_PROMPT, description={"suggested_value": options[CONF_PROMPT]}, default=DEFAULT_PROMPT, ): TemplateSelector(), - vol.Optional( - CONF_CHAT_MODEL, - description={ - # New key in HA 2023.4 - "suggested_value": options.get(CONF_CHAT_MODEL, - DEFAULT_CHAT_MODEL) - }, - default=DEFAULT_CHAT_MODEL, - ): str, vol.Optional( CONF_TEMPERATURE, description={ - # New key in HA 2023.4 "suggested_value": options.get(CONF_TEMPERATURE, - DEFAULT_CONF_TEMPERATURE) + DEFAULT_TEMPERATURE) }, - default=DEFAULT_CONF_TEMPERATURE, + default=DEFAULT_TEMPERATURE, ): float, vol.Optional( - CONF_MAX_TKNS, + CONF_MAX_TOKENS, description={ - # New key in HA 2023.4 - "suggested_value": options.get(CONF_MAX_TKNS, - DEFAULT_CONF_MAX_TKNS) + "suggested_value": options.get(CONF_MAX_TOKENS, + DEFAULT_MAX_TOKENS) }, - default=DEFAULT_CONF_MAX_TKNS, + default=DEFAULT_MAX_TOKENS, ): int, } diff --git a/custom_components/gigachain/const.py b/custom_components/gigachain/const.py index 489c299..1601ea6 100644 --- a/custom_components/gigachain/const.py +++ b/custom_components/gigachain/const.py @@ -2,15 +2,38 @@ from homeassistant.helpers import selector DOMAIN = "gigachain" -CONF_ENGINE = "engine" -UNIQUE_ID = {"gigachat": "GigaChat", "yandexgpt": "YandexGPT", "openai": "OpenAI"} + +ID_GIGACHAT = "gigachat" +ID_YANDEX_GPT = "yandexgpt" +ID_OPENAI = "openai" +UNIQUE_ID_GIGACHAT = "GigaChat" +UNIQUE_ID_YANDEX_GPT = "YandexGPT" +UNIQUE_ID_OPENAI = "OpenAI" + +UNIQUE_ID = { + ID_GIGACHAT: UNIQUE_ID_GIGACHAT, + ID_YANDEX_GPT: UNIQUE_ID_YANDEX_GPT, + ID_OPENAI: UNIQUE_ID_OPENAI +} + CONF_ENGINE_OPTIONS = [ - selector.SelectOptionDict(value="gigachat", label="GigaChat"), - selector.SelectOptionDict(value="yandexgpt", label="YandexGPT"), - selector.SelectOptionDict(value="openai", label="OpenAI"), + selector.SelectOptionDict(value=ID_GIGACHAT, label=UNIQUE_ID_GIGACHAT), + selector.SelectOptionDict(value=ID_YANDEX_GPT, label=UNIQUE_ID_YANDEX_GPT), + selector.SelectOptionDict(value=ID_OPENAI, label=UNIQUE_ID_OPENAI), +] +DEFAULT_MODELS_GIGACHAT = [ + "GigaChat", "GigaChat:latest", "GigaChat-Plus", "GigaChat-Pro" ] +DEFAULT_MODELS_YANDEX_GPT = ["YandexGPT", "YandexGPT Lite", "Summary"] +DEFAULT_MODELS_OPENAI = ["gpt-3.5-turbo"] +DEFAULT_MODELS = { + UNIQUE_ID_GIGACHAT: DEFAULT_MODELS_GIGACHAT, + UNIQUE_ID_YANDEX_GPT: DEFAULT_MODELS_YANDEX_GPT, + UNIQUE_ID_OPENAI: DEFAULT_MODELS_OPENAI +} CONF_API_KEY = "api_key" CONF_FOLDER_ID = "folder_id" +CONF_ENGINE = "engine" CONF_PROMPT = "prompt" DEFAULT_PROMPT = """Ты HAL 9000, компьютер из цикла произведений «Космическая одиссея» Артура Кларка, обладающий способностью к самообучению. @@ -33,9 +56,8 @@ """ CONF_CHAT_MODEL = "model" -#GigaChat-Plus,GigaChat-Pro,GigaChat:latest -DEFAULT_CHAT_MODEL = "GigaChat" +DEFAULT_CHAT_MODEL = "" CONF_TEMPERATURE = "temperature" -DEFAULT_CONF_TEMPERATURE = "0.1" -CONF_MAX_TKNS = "max_tokens" -DEFAULT_CONF_MAX_TKNS = "250" +DEFAULT_TEMPERATURE = 0.1 +CONF_MAX_TOKENS = "max_tokens" +DEFAULT_MAX_TOKENS = 250 diff --git a/custom_components/gigachain/manifest.json b/custom_components/gigachain/manifest.json index 088144d..d49b9a8 100644 --- a/custom_components/gigachain/manifest.json +++ b/custom_components/gigachain/manifest.json @@ -10,9 +10,7 @@ "iot_class": "cloud_polling", "issue_tracker": "https://github.com/gritaro/gigachain/issues", "requirements": [ - "gigachat==0.1.16", - "langchain==0.1.7", - "gigachain-community==0.0.16", + "gigachain==0.1.4", "yandexcloud==0.259.0" ], "version": "0.1.2" diff --git a/test-model.py b/test-model.py new file mode 100644 index 0000000..4103abd --- /dev/null +++ b/test-model.py @@ -0,0 +1,17 @@ +from langchain.schema import HumanMessage, SystemMessage +from langchain_community.chat_models import ChatAnyscale + +chat = ChatAnyscale(model="eta-llama/Llama-2-70b-chat-hf", anyscale_api_key="") + +messages = [ + SystemMessage( + content="You are a helpful AI that shares everything you know." + ) +] + +while(True): + user_input = input("User: ") + messages.append(HumanMessage(content=user_input)) + res = chat(messages) + messages.append(res) + print("Bot: ", res.content)